# Experiment 21: Back to Basics

**Date:** 2025-02-05

## Purpose

Previous experiments show KV cache priming often hurts or has no effect. This experiment
isolates exactly WHERE the benefit disappears by testing the theory step-by-step.

## The Theory

1. **Ideal**: Build cache for `[query + document]` → document tokens attend to query during
   encoding → representations are optimized for answering that query
2. **Problem**: Can't precompute for all O(queries × documents) pairs
3. **Hope**: Priming with a surrogate query approximates the ideal

## Experimental Conditions

| Condition | Cache Built From | Query at Scoring | What It Tests |
|-----------|------------------|------------------|---------------|
| **A: Bare** | `[doc]` | Provided fresh | Baseline |
| **B: Full Oracle** | `[query + doc]` | Already in cache | Does query-in-context help? |
| **C: Full + Repeat** | `[query + doc]` | Provided again | Upper bound |
| **D: Truncated Oracle** | `[query + doc]` → keep `[doc]` | Provided fresh | Does truncation preserve benefit? |
| **E: Truncated Random** | `[random + doc]` → keep `[doc]` | Provided fresh | Semantic signal test |

## Expected Outcomes (if theory is correct)

- B ≥ A (query context helps)
- C ≥ B (redundancy OK)
- **D ≥ A (truncation preserves benefit)** ← KEY TEST
- D > E (semantic match matters)

## Failure Diagnosis

- If B ≈ A: Theory is wrong (query context doesn't help this model)
- If D < A but B > A: Truncation destroys the benefit
- If D ≈ E: No semantic signal (priming is just noise)

In [1]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
sys.path.insert(0, '/home/jupyter/research/directed_kvcache')

import json
import random
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from scipy import stats

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from datasets import load_dataset

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTPUT_DIR = '/home/jupyter/research/directed_kvcache/results/exp21'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

PyTorch: 2.10.0+cu128
CUDA: True


In [2]:
# Cell 2: Load Model
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

print(f"Model loaded on {model.device}")
print(f"Model dtype: {model.dtype}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded on cuda:0
Model dtype: torch.bfloat16


In [3]:
# Cell 3: Core Utility Functions

from lib.kv_cache import (
    deepcopy_cache,
    _get_cache_keys,
    _get_cache_values,
    _set_cache_keys,
    _set_cache_values,
    _ensure_dynamic_cache,
    # RoPE correction
    correct_rope_positions_with_bos,
    extract_and_truncate_cache_with_bos,
)


def get_cache_len(cache: DynamicCache) -> int:
    """Get sequence length from cache."""
    cache = _ensure_dynamic_cache(cache)
    return _get_cache_keys(cache, 0).shape[2]


def score_answer_nll(cache: DynamicCache, prompt: str, answer: str) -> float:
    """
    Score P(answer | cache, prompt) as negative log-likelihood.
    
    Lower NLL = higher probability = better.
    """
    cache = _ensure_dynamic_cache(cache)
    cache_len = get_cache_len(cache)
    
    # Tokenize prompt and answer
    prompt_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False).to(model.device)
    answer_ids = tokenizer.encode(answer, return_tensors='pt', add_special_tokens=False).to(model.device)
    
    # Combine prompt + answer
    input_ids = torch.cat([prompt_ids, answer_ids], dim=1)
    
    # Attention mask for full sequence
    total_len = cache_len + input_ids.shape[1]
    attention_mask = torch.ones((1, total_len), device=model.device)
    
    # Forward pass
    cache_copy = deepcopy_cache(cache)
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=cache_copy,
            use_cache=True,
            return_dict=True
        )
    
    # Score only the answer tokens
    logits = outputs.logits  # [1, prompt_len + answer_len, vocab]
    prompt_len = prompt_ids.shape[1]
    answer_len = answer_ids.shape[1]
    
    if answer_len == 0:
        return 0.0
    
    # logits[prompt_len-1] predicts answer[0], logits[prompt_len] predicts answer[1], etc.
    # So logits[prompt_len-1 : prompt_len+answer_len-1] predicts answer[0:answer_len]
    answer_logits = logits[:, prompt_len-1:prompt_len+answer_len-1, :]
    
    # Compute cross-entropy
    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
    loss = loss_fct(
        answer_logits.view(-1, answer_logits.size(-1)),
        answer_ids.view(-1)
    )
    
    return loss.item()


print("Utility functions defined.")

Utility functions defined.


In [4]:
# Cell 4: Cache Building Functions for Each Condition

def build_cache_A_bare(doc: str) -> Tuple[DynamicCache, str]:
    """
    Condition A: Bare document cache.
    
    Cache: [BOS, doc_tokens]
    Query: Provided fresh at scoring time
    """
    ids = tokenizer.encode(doc, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return _ensure_dynamic_cache(out.past_key_values), "bare"


def build_cache_B_full_oracle(query: str, doc: str) -> Tuple[DynamicCache, str]:
    """
    Condition B: Full context with query (not truncated).
    
    Cache: [BOS, query_tokens, doc_tokens]
    Query: Already in cache, not repeated at scoring
    
    This tests: Does having query in context during doc encoding help?
    """
    full_text = query + " " + doc
    ids = tokenizer.encode(full_text, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return _ensure_dynamic_cache(out.past_key_values), "full_oracle"


def build_cache_C_full_repeat(query: str, doc: str) -> Tuple[DynamicCache, str]:
    """
    Condition C: Full context, query will be repeated at scoring.
    
    Cache: [BOS, query_tokens, doc_tokens]
    Query: Will be provided again at scoring (redundant but potentially helpful)
    
    This is the same cache as B, but scoring differs.
    """
    full_text = query + " " + doc
    ids = tokenizer.encode(full_text, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return _ensure_dynamic_cache(out.past_key_values), "full_repeat"


def build_cache_D_truncated_oracle(query: str, doc: str) -> Tuple[DynamicCache, str]:
    """
    Condition D: Build with query, then truncate to doc only.
    
    Build: [BOS, query_tokens, doc_tokens]
    Truncate to: [BOS, doc_tokens] (with RoPE correction)
    Query: Provided fresh at scoring
    
    This tests: Does truncation preserve the benefit of query-priming?
    THIS IS THE KEY TEST.
    """
    # Tokenize to find boundaries
    query_with_space = query + " "
    query_ids = tokenizer.encode(query_with_space, return_tensors='pt', add_special_tokens=True)
    query_len = query_ids.shape[1]  # includes BOS
    
    full_text = query_with_space + doc
    full_ids = tokenizer.encode(full_text, return_tensors='pt', add_special_tokens=True).to(model.device)
    full_len = full_ids.shape[1]
    doc_len = full_len - query_len  # tokens after query (not including BOS)
    
    # Build full cache
    with torch.no_grad():
        out = model(full_ids, use_cache=True)
    full_cache = _ensure_dynamic_cache(out.past_key_values)
    
    # Use lib functions for truncation and RoPE correction
    # extract_and_truncate_cache_with_bos keeps BOS + last doc_len tokens
    truncated_cache = extract_and_truncate_cache_with_bos(full_cache, doc_len)
    
    # RoPE correction: doc tokens were at positions [query_len, ...], now at [1, ...]
    # Offset = query_len - 1 (subtract this from original positions)
    surrogate_offset = query_len - 1
    correct_rope_positions_with_bos(truncated_cache, surrogate_offset, model)
    
    return truncated_cache, "truncated_oracle"


def build_cache_E_truncated_random(random_query: str, doc: str) -> Tuple[DynamicCache, str]:
    """
    Condition E: Build with RANDOM query, then truncate.
    
    Same as D but with a mismatched query.
    
    This tests: Is there semantic signal, or is any priming equivalent?
    """
    # Reuse the truncated oracle logic with random query
    cache, _ = build_cache_D_truncated_oracle(random_query, doc)
    return cache, "truncated_random"


print("Cache building functions defined.")

Cache building functions defined.


---
## Part 1: Sanity Check with Synthetic Example

Before running on real data, let's verify the setup with a simple example where we KNOW what should happen.

In [5]:
# Cell 6: Synthetic Sanity Check

print("="*70)
print("SANITY CHECK: Synthetic Example")
print("="*70)

# Simple factoid that the model should know
doc = "The capital of France is Paris. It is known for the Eiffel Tower."
query = "What is the capital of France?"
answer = " Paris"
wrong_query = "What is the population of Japan?"

print(f"\nDocument: {doc}")
print(f"Query: {query}")
print(f"Expected answer: {answer}")
print(f"Wrong query (for E): {wrong_query}")

# Build all caches
cache_A, _ = build_cache_A_bare(doc)
cache_B, _ = build_cache_B_full_oracle(query, doc)
cache_C, _ = build_cache_C_full_repeat(query, doc)  # Same as B
cache_D, _ = build_cache_D_truncated_oracle(query, doc)
cache_E, _ = build_cache_E_truncated_random(wrong_query, doc)

# Scoring prompts differ by condition
prompt_with_query = f"\n\nQuestion: {query}\nAnswer:"
prompt_no_query = "\n\nAnswer:"  # Query already in cache for B

# Score each condition
print("\n" + "-"*50)
print("Scoring P(answer | cache, prompt)")
print("-"*50)

nll_A = score_answer_nll(cache_A, prompt_with_query, answer)
print(f"A (bare, query at scoring):        NLL = {nll_A:.4f}")

nll_B = score_answer_nll(cache_B, prompt_no_query, answer)
print(f"B (full oracle, no query repeat):  NLL = {nll_B:.4f}")

nll_C = score_answer_nll(cache_C, prompt_with_query, answer)
print(f"C (full oracle, query repeated):   NLL = {nll_C:.4f}")

nll_D = score_answer_nll(cache_D, prompt_with_query, answer)
print(f"D (truncated oracle):              NLL = {nll_D:.4f}")

nll_E = score_answer_nll(cache_E, prompt_with_query, answer)
print(f"E (truncated random):              NLL = {nll_E:.4f}")

print("\n" + "-"*50)
print("Analysis")
print("-"*50)
print(f"B vs A (does query context help?):     {nll_A - nll_B:+.4f} (positive = B better)")
print(f"C vs A (upper bound with redundancy):  {nll_A - nll_C:+.4f} (positive = C better)")
print(f"D vs A (does truncation preserve?):    {nll_A - nll_D:+.4f} (positive = D better) <- KEY")
print(f"D vs E (semantic signal?):             {nll_E - nll_D:+.4f} (positive = D better)")

SANITY CHECK: Synthetic Example

Document: The capital of France is Paris. It is known for the Eiffel Tower.
Query: What is the capital of France?
Expected answer:  Paris
Wrong query (for E): What is the population of Japan?

--------------------------------------------------
Scoring P(answer | cache, prompt)
--------------------------------------------------
A (bare, query at scoring):        NLL = 0.0486
B (full oracle, no query repeat):  NLL = 1.2891
C (full oracle, query repeated):   NLL = 0.2012
D (truncated oracle):              NLL = 0.1797
E (truncated random):              NLL = 0.0698

--------------------------------------------------
Analysis
--------------------------------------------------
B vs A (does query context help?):     -1.2405 (positive = B better)
C vs A (upper bound with redundancy):  -0.1526 (positive = C better)
D vs A (does truncation preserve?):    -0.1311 (positive = D better) <- KEY
D vs E (semantic signal?):             -0.1099 (positive = D better)


---
## Part 2: MS MARCO Evaluation

Now test on real data to see if the patterns hold.

In [6]:
# Cell 8: Load MS MARCO Data

print("Loading MS MARCO...")
msmarco = load_dataset("ms_marco", "v1.1", split="train")

# Build evaluation samples
samples = []
all_queries = []  # For random query sampling

for item in msmarco:
    if len(samples) >= 300:  # Buffer for 200
        break
    
    query = item.get('query', '')
    passages = item.get('passages', {}).get('passage_text', [])
    answers = item.get('answers', [])
    
    if query:
        all_queries.append(query)
    
    if not query or not passages or not passages[0] or not answers or not answers[0]:
        continue
    
    passage = passages[0]
    answer = answers[0]
    
    # Filter reasonable lengths
    if len(passage.split()) < 20 or len(passage.split()) > 150:
        continue
    if len(answer.split()) < 2 or len(answer.split()) > 30:
        continue
    
    samples.append({
        'query': query,
        'passage': passage,
        'answer': answer,
    })

print(f"Built {len(samples)} evaluation samples")
print(f"Query pool size: {len(all_queries)}")

print(f"\nExample:")
print(f"  Query: {samples[0]['query']}")
print(f"  Passage: {samples[0]['passage'][:100]}...")
print(f"  Answer: {samples[0]['answer']}")

Loading MS MARCO...




Built 300 evaluation samples
Query pool size: 418

Example:
  Query: what is rba
  Passage: Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. Th...
  Answer: Results-Based Accountability is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole.


In [7]:
# Cell 9: Run Full Evaluation

print("="*70)
print("MS MARCO EVALUATION")
print("="*70)

N_SAMPLES = 200

results = {
    'A_bare': [],
    'B_full_oracle': [],
    'C_full_repeat': [],
    'D_truncated_oracle': [],
    'E_truncated_random': [],
}

for sample in tqdm(samples[:N_SAMPLES], desc="Evaluating"):
    query = sample['query']
    passage = sample['passage']
    answer = " " + sample['answer']  # Leading space for tokenization
    
    # Random query for condition E (different from actual query)
    random_query = random.choice([q for q in all_queries if q != query])
    
    # Prompts
    prompt_with_query = f"\n\nQuestion: {query}\nAnswer:"
    prompt_no_query = "\n\nAnswer:"
    
    # Condition A: Bare
    cache_A, _ = build_cache_A_bare(passage)
    nll_A = score_answer_nll(cache_A, prompt_with_query, answer)
    results['A_bare'].append(nll_A)
    
    # Condition B: Full Oracle (query in cache, not repeated)
    cache_B, _ = build_cache_B_full_oracle(query, passage)
    nll_B = score_answer_nll(cache_B, prompt_no_query, answer)
    results['B_full_oracle'].append(nll_B)
    
    # Condition C: Full Oracle (query in cache, repeated at scoring)
    cache_C, _ = build_cache_C_full_repeat(query, passage)
    nll_C = score_answer_nll(cache_C, prompt_with_query, answer)
    results['C_full_repeat'].append(nll_C)
    
    # Condition D: Truncated Oracle
    cache_D, _ = build_cache_D_truncated_oracle(query, passage)
    nll_D = score_answer_nll(cache_D, prompt_with_query, answer)
    results['D_truncated_oracle'].append(nll_D)
    
    # Condition E: Truncated Random
    cache_E, _ = build_cache_E_truncated_random(random_query, passage)
    nll_E = score_answer_nll(cache_E, prompt_with_query, answer)
    results['E_truncated_random'].append(nll_E)

print("\nEvaluation complete.")

MS MARCO EVALUATION


Evaluating:   0%|          | 0/200 [00:00<?, ?it/s]


Evaluation complete.


In [8]:
# Cell 10: Results Summary

print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)

print("\nMean NLL by Condition (lower = better):")
print("-"*50)
for cond, nlls in results.items():
    print(f"  {cond:25s}: {np.mean(nlls):.4f} (+/- {np.std(nlls):.4f})")

# Key comparisons
A = np.array(results['A_bare'])
B = np.array(results['B_full_oracle'])
C = np.array(results['C_full_repeat'])
D = np.array(results['D_truncated_oracle'])
E = np.array(results['E_truncated_random'])

print("\n" + "="*70)
print("KEY COMPARISONS")
print("="*70)

def compare(name, baseline, test, expect_positive=True):
    delta = baseline - test  # positive means test is better (lower NLL)
    win_rate = np.mean(delta > 0)
    t_stat, p_val = stats.ttest_rel(baseline, test)
    d = np.mean(delta) / np.std(delta) if np.std(delta) > 0 else 0
    
    status = ""
    if expect_positive:
        if delta.mean() > 0 and p_val < 0.05:
            status = "CONFIRMED"
        elif delta.mean() < 0 and p_val < 0.05:
            status = "REVERSED!"
        else:
            status = "no effect"
    
    print(f"\n{name}")
    print(f"  Delta: {np.mean(delta):+.4f} (positive = test better)")
    print(f"  Win rate: {win_rate*100:.1f}%")
    print(f"  Cohen's d: {d:+.3f}")
    print(f"  p-value: {p_val:.4f}")
    print(f"  Status: {status}")
    return delta.mean(), win_rate, d, p_val

print("\n### Test 1: Does query-in-context help at all? (B vs A) ###")
compare("B (full oracle) vs A (bare)", A, B)

print("\n### Test 2: Upper bound with query redundancy? (C vs A) ###")
compare("C (full + repeat) vs A (bare)", A, C)

print("\n### Test 3: KEY - Does truncation preserve benefit? (D vs A) ###")
d_vs_a = compare("D (truncated oracle) vs A (bare)", A, D)

print("\n### Test 4: Is there semantic signal? (D vs E) ###")
d_vs_e = compare("D (truncated oracle) vs E (truncated random)", E, D)

print("\n### Test 5: Does truncation hurt vs full context? (D vs B) ###")
compare("D (truncated) vs B (full)", B, D, expect_positive=False)


RESULTS SUMMARY

Mean NLL by Condition (lower = better):
--------------------------------------------------
  A_bare                   : 3.1649 (+/- 1.8655)
  B_full_oracle            : 3.1718 (+/- 1.7439)
  C_full_repeat            : 3.3721 (+/- 1.9134)
  D_truncated_oracle       : 3.1574 (+/- 1.8237)
  E_truncated_random       : 3.1839 (+/- 1.8390)

KEY COMPARISONS

### Test 1: Does query-in-context help at all? (B vs A) ###

B (full oracle) vs A (bare)
  Delta: -0.0069 (positive = test better)
  Win rate: 49.0%
  Cohen's d: -0.009
  p-value: 0.9007
  Status: no effect

### Test 2: Upper bound with query redundancy? (C vs A) ###

C (full + repeat) vs A (bare)
  Delta: -0.2072 (positive = test better)
  Win rate: 31.0%
  Cohen's d: -0.370
  p-value: 0.0000
  Status: REVERSED!

### Test 3: KEY - Does truncation preserve benefit? (D vs A) ###

D (truncated oracle) vs A (bare)
  Delta: +0.0075 (positive = test better)
  Win rate: 48.5%
  Cohen's d: +0.017
  p-value: 0.8153
  Status: no 

(0.0143817138671875, 0.495, 0.021894136136281748, 0.7577550859661021)

In [9]:
# Cell 11: Diagnosis

print("\n" + "="*70)
print("DIAGNOSIS")
print("="*70)

A_mean = np.mean(A)
B_mean = np.mean(B)
D_mean = np.mean(D)
E_mean = np.mean(E)

print("\nChecking failure points...\n")

# Check 1: Does query context help at all?
if B_mean < A_mean - 0.05:
    print("[OK] Query-in-context HELPS (B < A)")
    print(f"     Full oracle improves by {A_mean - B_mean:.4f} NLL")
    theory_works = True
elif B_mean > A_mean + 0.05:
    print("[PROBLEM] Query-in-context HURTS (B > A)")
    print(f"     Full oracle is WORSE by {B_mean - A_mean:.4f} NLL")
    print("     -> The basic theory may be wrong for this model/task")
    theory_works = False
else:
    print("[UNCLEAR] Query-in-context has NO EFFECT (B ≈ A)")
    print(f"     Difference: {A_mean - B_mean:.4f} NLL")
    print("     -> Document may already contain enough signal")
    theory_works = None

# Check 2: Does truncation preserve benefit?
print()
if D_mean < A_mean - 0.02:
    print("[OK] Truncation PRESERVES benefit (D < A)")
    print(f"     Truncated oracle improves by {A_mean - D_mean:.4f} NLL")
    truncation_ok = True
elif D_mean > A_mean + 0.02:
    print("[PROBLEM] Truncation HURTS (D > A)")
    print(f"     Truncated oracle is WORSE by {D_mean - A_mean:.4f} NLL")
    truncation_ok = False
    if theory_works:
        print("     -> Truncation/RoPE correction is destroying the benefit!")
else:
    print("[UNCLEAR] Truncation has NO EFFECT (D ≈ A)")
    truncation_ok = None

# Check 3: Semantic signal?
print()
if D_mean < E_mean - 0.02:
    print("[OK] Semantic signal EXISTS (D < E)")
    print(f"     Oracle priming beats random by {E_mean - D_mean:.4f} NLL")
elif D_mean > E_mean + 0.02:
    print("[PROBLEM] INVERTED signal (D > E)")
    print(f"     Random priming beats oracle by {D_mean - E_mean:.4f} NLL")
    print("     -> This suggests interference, not semantic benefit")
else:
    print("[UNCLEAR] No semantic signal (D ≈ E)")
    print(f"     Difference: {E_mean - D_mean:.4f} NLL")
    print("     -> Priming effect is non-semantic (just noise)")

print("\n" + "="*70)
print("CONCLUSION")
print("="*70)


DIAGNOSIS

Checking failure points...

[UNCLEAR] Query-in-context has NO EFFECT (B ≈ A)
     Difference: -0.0069 NLL
     -> Document may already contain enough signal

[UNCLEAR] Truncation has NO EFFECT (D ≈ A)

[OK] Semantic signal EXISTS (D < E)
     Oracle priming beats random by 0.0265 NLL

CONCLUSION


In [10]:
# Cell 12: Save Results

output = {
    'n_samples': N_SAMPLES,
    'conditions': {
        'A_bare': {'mean': float(np.mean(A)), 'std': float(np.std(A))},
        'B_full_oracle': {'mean': float(np.mean(B)), 'std': float(np.std(B))},
        'C_full_repeat': {'mean': float(np.mean(C)), 'std': float(np.std(C))},
        'D_truncated_oracle': {'mean': float(np.mean(D)), 'std': float(np.std(D))},
        'E_truncated_random': {'mean': float(np.mean(E)), 'std': float(np.std(E))},
    },
    'comparisons': {
        'B_vs_A': {'delta': float(np.mean(A - B)), 'win_rate': float(np.mean(A > B))},
        'D_vs_A': {'delta': float(np.mean(A - D)), 'win_rate': float(np.mean(A > D))},
        'D_vs_E': {'delta': float(np.mean(E - D)), 'win_rate': float(np.mean(E > D))},
    },
    'raw_results': {k: [float(x) for x in v] for k, v in results.items()},
}

with open(f'{OUTPUT_DIR}/results.json', 'w') as f:
    json.dump(output, f, indent=2)

print(f"Results saved to {OUTPUT_DIR}/results.json")

Results saved to /home/jupyter/research/directed_kvcache/results/exp21/results.json


---
## Part 3: Deep Dive (if needed)

If the above reveals where the breakdown occurs, we can add more targeted tests here.

In [11]:
# Cell 14: Cache Shape Verification

print("Cache shape verification:")
print("-"*50)

sample = samples[0]
query = sample['query']
passage = sample['passage']

cache_A, _ = build_cache_A_bare(passage)
cache_D, _ = build_cache_D_truncated_oracle(query, passage)

print(f"Query: '{query[:50]}...'")
print(f"Passage: '{passage[:50]}...'")
print()

# Tokenize to check lengths
query_ids = tokenizer.encode(query + " ", add_special_tokens=True)
passage_ids = tokenizer.encode(passage, add_special_tokens=True)
full_ids = tokenizer.encode(query + " " + passage, add_special_tokens=True)

print(f"Query tokens (with BOS): {len(query_ids)}")
print(f"Passage tokens (with BOS): {len(passage_ids)}")
print(f"Full tokens (with BOS): {len(full_ids)}")
print()

cache_A_len = get_cache_len(cache_A)
cache_D_len = get_cache_len(cache_D)

print(f"Cache A (bare) length: {cache_A_len}")
print(f"Cache D (truncated) length: {cache_D_len}")
print()

# They should match if truncation is correct
if cache_A_len == cache_D_len:
    print("[OK] Cache lengths match")
else:
    print(f"[WARNING] Cache lengths differ by {abs(cache_A_len - cache_D_len)}")
    print("  This could indicate a tokenization boundary issue")

Cache shape verification:
--------------------------------------------------
Query: 'what is rba...'
Passage: 'Since 2007, the RBA's outstanding reputation has b...'

Query tokens (with BOS): 6
Passage tokens (with BOS): 121
Full tokens (with BOS): 125

Cache A (bare) length: 121
Cache D (truncated) length: 120

  This could indicate a tokenization boundary issue


In [12]:
# Cell 15: Value Vector Analysis

print("Value vector analysis:")
print("-"*50)

# Compare value vectors between bare and truncated caches
# If priming works, the value vectors should be DIFFERENT
# (that's the whole point - document representations change)

sample = samples[0]
cache_A, _ = build_cache_A_bare(sample['passage'])
cache_D, _ = build_cache_D_truncated_oracle(sample['query'], sample['passage'])

# Get value vectors from first and last layers
n_layers = len(cache_A)
layers_to_check = [0, n_layers // 2, n_layers - 1]

for layer in layers_to_check:
    v_A = _get_cache_values(cache_A, layer)  # [batch, heads, seq, head_dim]
    v_D = _get_cache_values(cache_D, layer)
    
    # Compare overlapping positions (skip BOS, compare doc tokens)
    min_len = min(v_A.shape[2], v_D.shape[2])
    v_A_doc = v_A[:, :, 1:min_len, :]  # Skip BOS
    v_D_doc = v_D[:, :, 1:min_len, :]
    
    # Cosine similarity
    v_A_flat = v_A_doc.reshape(-1).float()
    v_D_flat = v_D_doc.reshape(-1).float()
    cos_sim = torch.nn.functional.cosine_similarity(v_A_flat.unsqueeze(0), v_D_flat.unsqueeze(0)).item()
    
    # L2 distance
    l2_dist = torch.norm(v_A_flat - v_D_flat).item()
    
    print(f"Layer {layer}:")
    print(f"  Cosine similarity: {cos_sim:.6f}")
    print(f"  L2 distance: {l2_dist:.4f}")
    print(f"  -> {'DIFFERENT' if cos_sim < 0.999 else 'NEARLY IDENTICAL'}")

Value vector analysis:
--------------------------------------------------
Layer 0:
  Cosine similarity: 0.073697
  L2 distance: 11.2471
  -> DIFFERENT
Layer 16:
  Cosine similarity: 0.484364
  L2 distance: 148.6682
  -> DIFFERENT
Layer 31:
  Cosine similarity: 0.278531
  L2 distance: 396.0867
  -> DIFFERENT
