# Experimental Walkthrough: Understanding Surrogate KV Cache Priming

This notebook walks through a single example from start to finish to build intuition for the experimental setup. We'll examine:

1. **What is a data point?** - A passage, query, and answer from MS MARCO
2. **What are static vs. generated surrogates?** - The two approaches we're comparing
3. **How do we build KV caches?** - The caching mechanism at the heart of the experiment
4. **How do we run generation/scoring?** - Using the cache to score answers
5. **What is the evaluation metric?** - Negative Log-Likelihood (NLL) explained

---

## The Big Picture

In production RAG systems, we retrieve a document and use it to answer user queries. The **KV cache** stores the model's internal representations of the document, allowing us to reuse computation across queries.

**The core idea**: What if we "prime" the cache by prepending a surrogate query before the document? This might help the model understand what kind of questions it should expect.

We compare two approaches:
- **Static surrogates**: 5 fixed queries (same for every document), covering common intent categories
- **Generated surrogates**: 5 document-specific queries, generated by an LLM

```
BASELINE CACHE:                    SURROGATE-PRIMED CACHE:
┌─────────────────────┐            ┌─────────────────────────────────────┐
│ Document: ...       │            │ This doc may answer: {surrogate}    │
│                     │            │                                     │
│ [KV Cache]          │            │ Document: ...                       │
└─────────────────────┘            │                                     │
                                   │ [KV Cache]                          │
                                   └─────────────────────────────────────┘
```

## Step 1: Setup and Dependencies

In [1]:
# Install dependencies (run once)
!pip install transformers torch datasets bitsandbytes accelerate sentence-transformers -q

In [2]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import warnings
warnings.filterwarnings('ignore')

print("Imports complete.")

Imports complete.


In [3]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Device: {DEVICE}")
print(f"Model: {MODEL_NAME}")

# Load the language model (4-bit quantized for memory efficiency)
print("\nLoading language model...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.eval()
print(f"Language model loaded: {model.num_parameters():,} parameters")

# Load embedding model for similarity computation
print("\nLoading embedding model...")
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded.")

Device: cuda
Model: mistralai/Mistral-7B-Instruct-v0.2

Loading language model...


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Language model loaded: 7,241,732,096 parameters

Loading embedding model...




Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Embedding model loaded.


---

## Step 2: Understanding a Data Point from MS MARCO

MS MARCO (Microsoft Machine Reading Comprehension) is a large-scale dataset for question answering. Each data point contains:

- **Passage**: A document (typically a paragraph from the web)
- **Query**: A natural language question from a real user
- **Answer**: The correct answer to the question

Let's load the dataset and examine one example in detail.

In [4]:
# Load MS MARCO validation set
print("Loading MS MARCO dataset...")
dataset = load_dataset("ms_marco", "v1.1", split="validation")
print(f"Dataset loaded: {len(dataset)} samples")

Loading MS MARCO dataset...


Using the latest cached version of the dataset since ms_marco couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'v1.1' at /home/petergrabowski_google_com/.cache/huggingface/datasets/ms_marco/v1.1/0.0.0/a47ee7aae8d7d466ba15f9f0bfac3b3681087b3a (last modified on Mon Jan 26 15:46:33 2026).


Dataset loaded: 10047 samples


In [5]:
def find_good_example(dataset, min_words=50, max_words=200):
    """
    Find an example with:
    - A passage of reasonable length
    - A clear answer
    - An interesting query
    """
    for item in dataset:
        passages = item.get('passages', {})
        passage_texts = passages.get('passage_text', [])
        is_selected = passages.get('is_selected', [])
        
        query = item.get('query', '')
        answers = item.get('answers', [])
        well_formed = item.get('wellFormedAnswers', [])
        
        if not passage_texts or not query:
            continue
        
        # Get best answer (prefer well-formed answers)
        if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
            answer = well_formed[0]
        elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
            answer = answers[0]
        else:
            continue
        
        # Find the selected passage with reasonable length
        for i, passage in enumerate(passage_texts):
            word_count = len(passage.split())
            if min_words <= word_count <= max_words:
                if is_selected and i < len(is_selected) and is_selected[i] == 1:
                    return {
                        'passage': passage,
                        'query': query,
                        'answer': answer
                    }
    return None

# Find our example
example = find_good_example(dataset)
print("Found example!")

Found example!


In [6]:
# Display the example in a clear format
print("="*80)
print("EXAMPLE DATA POINT FROM MS MARCO")
print("="*80)

print("\n" + "-"*40)
print("PASSAGE (the document)")
print("-"*40)
print(example['passage'])

print("\n" + "-"*40)
print("QUERY (what the user is asking)")
print("-"*40)
print(example['query'])

print("\n" + "-"*40)
print("ANSWER (the correct response)")
print("-"*40)
print(example['answer'])

print("\n" + "="*80)
print(f"Passage length: {len(example['passage'].split())} words")
print(f"Query length: {len(example['query'].split())} words")
print(f"Answer length: {len(example['answer'].split())} words")

EXAMPLE DATA POINT FROM MS MARCO

----------------------------------------
PASSAGE (the document)
----------------------------------------
The average Walgreens salary ranges from approximately $15,000 per year for Customer Service Associate / Cashier to $179,900 per year for District Manager. Average Walgreens hourly pay ranges from approximately $7.35 per hour for Laboratory Technician to $68.90 per hour for Pharmacy Manager. Salary information comes from 7,810 data points collected directly from employees, users, and jobs on Indeed.

----------------------------------------
QUERY (what the user is asking)
----------------------------------------
walgreens store sales average

----------------------------------------
ANSWER (the correct response)
----------------------------------------
Approximately $15,000 per year.

Passage length: 59 words
Query length: 4 words
Answer length: 4 words


---

## Step 3: Static vs. Generated Surrogates

Now we define the two types of surrogate queries:

### Static Surrogates
These are **5 fixed queries** that cover the main intent categories. They're the same for every document:

| # | Query | Intent Covered |
|---|-------|----------------|
| 1 | "What is this and what does it mean?" | Definitional |
| 2 | "How do I do this step by step?" | Procedural |
| 3 | "How much does this cost or how long does it take?" | Quantitative |
| 4 | "What are the key facts I need to know?" | Factual |
| 5 | "What problem does this solve?" | Problem/Solution |

### Generated Surrogates
These are **5 document-specific queries** generated by the LLM:

| # | Type | Description |
|---|------|-------------|
| 1 | Target Question | The ideal question this doc answers |
| 2 | Keyword Query | How users actually search (no grammar) |
| 3 | Symptom Query | The problem/symptom leading to this doc |
| 4 | Misconception Query | Concerns or "what NOT to do" questions |
| 5 | Messy Query | Real-world typing: abbreviations, urgency |

In [7]:
# Define the 5 static surrogate queries
STATIC_SURROGATES = {
    'definitional': {
        'query': 'What is this and what does it mean?',
        'covers': 'what is, define, meaning, explanation',
    },
    'procedural': {
        'query': 'How do I do this step by step?',
        'covers': 'how to, instructions, guide',
    },
    'quantitative': {
        'query': 'How much does this cost or how long does it take?',
        'covers': 'how much, how many, cost, duration',
    },
    'factual': {
        'query': 'What are the key facts I need to know?',
        'covers': 'who, when, where, facts',
    },
    'problem': {
        'query': 'What problem does this solve?',
        'covers': 'why, troubleshooting, help',
    },
}

print("STATIC SURROGATES (same for every document)")
print("="*60)
for name, info in STATIC_SURROGATES.items():
    print(f"\n{name.upper()}")
    print(f"  Query: \"{info['query']}\"")
    print(f"  Covers: {info['covers']}")

STATIC SURROGATES (same for every document)

DEFINITIONAL
  Query: "What is this and what does it mean?"
  Covers: what is, define, meaning, explanation

PROCEDURAL
  Query: "How do I do this step by step?"
  Covers: how to, instructions, guide

QUANTITATIVE
  Query: "How much does this cost or how long does it take?"
  Covers: how much, how many, cost, duration

FACTUAL
  Query: "What are the key facts I need to know?"
  Covers: who, when, where, facts

PROBLEM
  Query: "What problem does this solve?"
  Covers: why, troubleshooting, help


In [8]:
# Define templates for generating document-specific surrogates
GENERATED_TEMPLATES = {
    'target_question': {
        'name': 'Target Question',
        'prompt': (
            "You are helping index a document for search. Write the single most likely "
            "natural language question that a user would ask that this document perfectly answers. "
            "The question should be grammatically correct, clear, and specific. "
            "Output only the question (5-12 words), nothing else.\n\n"
            "Document:"
        ),
    },
    'keyword_query': {
        'name': 'Keyword Query',
        'prompt': (
            "You are helping index a document for search. Write a search query the way "
            "real users type into Google: just keywords, no complete sentences, no question marks. "
            "Think of someone quickly typing a few relevant words. "
            "Output only the keyword query (3-6 words), nothing else.\n\n"
            "Document:"
        ),
    },
    'symptom_scenario': {
        'name': 'Symptom Query',
        'prompt': (
            "You are helping index a document for search. This document contains a solution or answer. "
            "Write a query that describes the PROBLEM or SYMPTOM that would lead someone to need this document. "
            "Focus on what the user is experiencing, not what they want to learn. "
            "Output only the problem-focused query (4-10 words), nothing else.\n\n"
            "Document:"
        ),
    },
    'misconception_negative': {
        'name': 'Misconception Query',
        'prompt': (
            "You are helping index a document for search. Write a query that reflects "
            "a common misconception, concern, or 'what NOT to do' question related to this topic. "
            "Think of someone who is worried, skeptical, or wants to avoid mistakes. "
            "Output only the concern/negative query (4-10 words), nothing else.\n\n"
            "Document:"
        ),
    },
    'messy_realworld': {
        'name': 'Messy Query',
        'prompt': (
            "You are helping index a document for search. Write a messy, realistic search query "
            "like someone would actually type in a hurry: use common abbreviations, "
            "internet slang, or urgent language (help, asap, need, plz). "
            "Output only the messy query (3-8 words), nothing else.\n\n"
            "Document:"
        ),
    },
}

print("GENERATED SURROGATE TEMPLATES (document-specific)")
print("="*60)
for name, info in GENERATED_TEMPLATES.items():
    print(f"\n{name.upper()}: {info['name']}")

GENERATED SURROGATE TEMPLATES (document-specific)

TARGET_QUESTION: Target Question

KEYWORD_QUERY: Keyword Query

SYMPTOM_SCENARIO: Symptom Query

MISCONCEPTION_NEGATIVE: Misconception Query

MESSY_REALWORLD: Messy Query


In [9]:
def generate_surrogate(doc_text, template_prompt, max_tokens=45):
    """
    Generate a single surrogate query for a document using a template.
    """
    messages = [
        {"role": "user", "content": f"{template_prompt}\n\nText:\n{doc_text}"}
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.3,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
    surrogate = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
    surrogate = surrogate.strip('"\'')
    surrogate = surrogate.split('\n')[0].strip()  # Take first line only
    
    return surrogate

# Generate all 5 surrogates for our example document
print("Generating document-specific surrogates...")
print("(This uses the LLM to create queries tailored to this specific passage)")
print()

generated_surrogates = {}
for key, template in GENERATED_TEMPLATES.items():
    print(f"Generating {template['name']}...", end=" ")
    generated_surrogates[key] = generate_surrogate(
        example['passage'], 
        template['prompt']
    )
    print("Done!")

print("\nAll surrogates generated!")

Generating document-specific surrogates...
(This uses the LLM to create queries tailored to this specific passage)

Generating Target Question... Done!
Generating Keyword Query... Done!
Generating Symptom Query... Done!
Generating Misconception Query... Done!
Generating Messy Query... Done!

All surrogates generated!


In [10]:
# Display both static and generated surrogates side by side
print("="*80)
print("COMPARISON: STATIC vs GENERATED SURROGATES")
print("="*80)

print(f"\nActual user query: \"{example['query']}\"")

print("\n" + "-"*40)
print("STATIC SURROGATES (identical for all docs)")
print("-"*40)
for name, info in STATIC_SURROGATES.items():
    print(f"  {name:<15}: \"{info['query']}\"")

print("\n" + "-"*40)
print("GENERATED SURROGATES (specific to this document)")
print("-"*40)
for key, surrogate in generated_surrogates.items():
    name = GENERATED_TEMPLATES[key]['name']
    print(f"  {name:<20}: \"{surrogate}\"")

COMPARISON: STATIC vs GENERATED SURROGATES

Actual user query: "walgreens store sales average"

----------------------------------------
STATIC SURROGATES (identical for all docs)
----------------------------------------
  definitional   : "What is this and what does it mean?"
  procedural     : "How do I do this step by step?"
  quantitative   : "How much does this cost or how long does it take?"
  factual        : "What are the key facts I need to know?"
  problem        : "What problem does this solve?"

----------------------------------------
GENERATED SURROGATES (specific to this document)
----------------------------------------
  Target Question     : "What is the salary range for different positions at Walgreens, based on 7,810 data points?"
  Keyword Query       : "Walgreens salary ranges"
  Symptom Query       : "What is the average Walgreens salary or hourly pay for specific positions?"
  Misconception Query : "What not to earn working at Walgreens: min. $15,000 or max. $17

In [11]:
# Compute semantic similarity between each surrogate and the actual query
def compute_similarity(text1, text2):
    """Compute cosine similarity between two texts using embeddings."""
    embeddings = embed_model.encode([text1, text2])
    return float(cosine_similarity([embeddings[0]], [embeddings[1]])[0][0])

print("="*80)
print("SEMANTIC SIMILARITY TO ACTUAL QUERY")
print("="*80)
print(f"\nActual query: \"{example['query']}\"")
print("\nHigher similarity = surrogate is more semantically related to the actual query")

print("\n" + "-"*40)
print("STATIC SURROGATES")
print("-"*40)
static_similarities = {}
for name, info in STATIC_SURROGATES.items():
    sim = compute_similarity(info['query'], example['query'])
    static_similarities[name] = sim
    print(f"  {name:<15}: {sim:.4f}")

print("\n" + "-"*40)
print("GENERATED SURROGATES")
print("-"*40)
generated_similarities = {}
for key, surrogate in generated_surrogates.items():
    sim = compute_similarity(surrogate, example['query'])
    generated_similarities[key] = sim
    name = GENERATED_TEMPLATES[key]['name']
    print(f"  {name:<20}: {sim:.4f}")

# Find best matches
best_static = max(static_similarities.items(), key=lambda x: x[1])
best_generated = max(generated_similarities.items(), key=lambda x: x[1])

print("\n" + "="*80)
print(f"Best static match: {best_static[0]} (sim={best_static[1]:.4f})")
print(f"Best generated match: {GENERATED_TEMPLATES[best_generated[0]]['name']} (sim={best_generated[1]:.4f})")
print("="*80)

SEMANTIC SIMILARITY TO ACTUAL QUERY

Actual query: "walgreens store sales average"

Higher similarity = surrogate is more semantically related to the actual query

----------------------------------------
STATIC SURROGATES
----------------------------------------
  definitional   : -0.0774
  procedural     : 0.0368
  quantitative   : 0.0101
  factual        : -0.0916
  problem        : 0.0027

----------------------------------------
GENERATED SURROGATES
----------------------------------------
  Target Question     : 0.5695
  Keyword Query       : 0.7552
  Symptom Query       : 0.6214
  Misconception Query : 0.5299
  Messy Query         : 0.4897

Best static match: procedural (sim=0.0368)
Best generated match: Keyword Query (sim=0.7552)


---

## Step 4: Building KV Caches

Now we'll see how the KV caches are constructed. A **KV (Key-Value) cache** stores the intermediate computations from the transformer's attention mechanism. When we process a context once, we can reuse these computations for subsequent queries.

### Cache Templates

```
BASELINE CACHE:
┌─────────────────────────────────┐
│ Document:                       │
│ {passage text}                  │
└─────────────────────────────────┘

SURROGATE-PRIMED CACHE:
┌─────────────────────────────────────────────────┐
│ This document may be relevant to queries like:  │
│ {surrogate query}                               │
│                                                 │
│ Document:                                       │
│ {passage text}                                  │
└─────────────────────────────────────────────────┘
```

The hypothesis is that the surrogate-primed cache will be better prepared to answer queries similar to the surrogate.

In [12]:
# Define the cache templates
BASELINE_TEMPLATE = "Document:\n{document}"

SURROGATE_TEMPLATE = (
    "This document may be relevant to queries like: {surrogate}\n\n"
    "Document:\n{document}"
)

QUERY_TEMPLATE = "\n\nQuery: {query}\n\nAnswer:"

print("CACHE TEMPLATES")
print("="*60)

print("\nBASELINE TEMPLATE:")
print("-"*40)
print(BASELINE_TEMPLATE.format(document="[passage text]"))

print("\nSURROGATE-PRIMED TEMPLATE:")
print("-"*40)
print(SURROGATE_TEMPLATE.format(surrogate="[surrogate query]", document="[passage text]"))

print("\nQUERY TEMPLATE (appended at query time):")
print("-"*40)
print(QUERY_TEMPLATE.format(query="[user query]"))

CACHE TEMPLATES

BASELINE TEMPLATE:
----------------------------------------
Document:
[passage text]

SURROGATE-PRIMED TEMPLATE:
----------------------------------------
This document may be relevant to queries like: [surrogate query]

Document:
[passage text]

QUERY TEMPLATE (appended at query time):
----------------------------------------


Query: [user query]

Answer:


In [13]:
# Show the actual contexts for our example
baseline_context = BASELINE_TEMPLATE.format(document=example['passage'])

# Pick the best-matching surrogate for demonstration
best_static_query = STATIC_SURROGATES[best_static[0]]['query']
best_generated_query = generated_surrogates[best_generated[0]]

static_context = SURROGATE_TEMPLATE.format(
    surrogate=best_static_query, 
    document=example['passage']
)

generated_context = SURROGATE_TEMPLATE.format(
    surrogate=best_generated_query,
    document=example['passage']
)

print("="*80)
print("ACTUAL CACHE CONTEXTS FOR OUR EXAMPLE")
print("="*80)

print("\n" + "-"*40)
print("BASELINE CONTEXT")
print("-"*40)
print(baseline_context[:500] + "..." if len(baseline_context) > 500 else baseline_context)

print("\n" + "-"*40)
print(f"STATIC SURROGATE CONTEXT (using '{best_static[0]}')")
print("-"*40)
print(static_context[:600] + "..." if len(static_context) > 600 else static_context)

print("\n" + "-"*40)
print(f"GENERATED SURROGATE CONTEXT (using '{GENERATED_TEMPLATES[best_generated[0]]['name']}')")
print("-"*40)
print(generated_context[:600] + "..." if len(generated_context) > 600 else generated_context)

ACTUAL CACHE CONTEXTS FOR OUR EXAMPLE

----------------------------------------
BASELINE CONTEXT
----------------------------------------
Document:
The average Walgreens salary ranges from approximately $15,000 per year for Customer Service Associate / Cashier to $179,900 per year for District Manager. Average Walgreens hourly pay ranges from approximately $7.35 per hour for Laboratory Technician to $68.90 per hour for Pharmacy Manager. Salary information comes from 7,810 data points collected directly from employees, users, and jobs on Indeed.

----------------------------------------
STATIC SURROGATE CONTEXT (using 'procedural')
----------------------------------------
This document may be relevant to queries like: How do I do this step by step?

Document:
The average Walgreens salary ranges from approximately $15,000 per year for Customer Service Associate / Cashier to $179,900 per year for District Manager. Average Walgreens hourly pay ranges from approximately $7.35 per hour for L

In [14]:
def build_kv_cache(context):
    """
    Build a KV cache from the given context.
    
    This runs a forward pass through the model, storing the key and value
    tensors from each attention layer. These can be reused for subsequent
    queries, avoiding redundant computation.
    
    Returns:
        context_length: Number of tokens in the context
        past_key_values: The cached key-value pairs for all layers
    """
    # Tokenize the context
    context_encoding = tokenizer(
        context, 
        return_tensors="pt", 
        add_special_tokens=True,
        padding=False, 
        truncation=False
    )
    context_ids = context_encoding['input_ids'].to(DEVICE)
    
    # Forward pass to build cache
    with torch.no_grad():
        outputs = model(
            input_ids=context_ids,
            attention_mask=torch.ones_like(context_ids),
            use_cache=True,  # This tells the model to return the KV cache
            return_dict=True
        )
    
    return context_ids.shape[1], outputs.past_key_values

print("Function defined: build_kv_cache()")
print("\nThis function:")
print("  1. Tokenizes the context")
print("  2. Runs a forward pass through the model")
print("  3. Returns the KV cache (stored attention states)")

Function defined: build_kv_cache()

This function:
  1. Tokenizes the context
  2. Runs a forward pass through the model
  3. Returns the KV cache (stored attention states)


In [15]:
# Build the caches for our example
print("Building KV caches...")
print()

print("Building baseline cache...", end=" ")
baseline_len, baseline_cache = build_kv_cache(baseline_context)
print(f"Done! ({baseline_len} tokens)")

print("Building static surrogate cache...", end=" ")
static_len, static_cache = build_kv_cache(static_context)
print(f"Done! ({static_len} tokens)")

print("Building generated surrogate cache...", end=" ")
generated_len, generated_cache = build_kv_cache(generated_context)
print(f"Done! ({generated_len} tokens)")

print("\n" + "="*60)
print("CACHE SUMMARY")
print("="*60)
print(f"Baseline cache:          {baseline_len:>4} tokens")
print(f"Static surrogate cache:  {static_len:>4} tokens")
print(f"Generated surrogate cache: {generated_len:>4} tokens")
print(f"\nExtra tokens from surrogate priming: ~{static_len - baseline_len} tokens")

Building KV caches...

Building baseline cache... Done! (107 tokens)
Building static surrogate cache... Done! (127 tokens)
Building generated surrogate cache... Done! (123 tokens)

CACHE SUMMARY
Baseline cache:           107 tokens
Static surrogate cache:   127 tokens
Generated surrogate cache:  123 tokens

Extra tokens from surrogate priming: ~20 tokens


In [19]:
# Explain the structure of the KV cache
print("="*60)
print("UNDERSTANDING THE KV CACHE STRUCTURE")
print("="*60)

# Check what type of cache we have and how to access it
cache_type = type(baseline_cache).__name__
print(f"\nCache type: {cache_type}")

# Try to access keys/values based on the cache type
try:
    # Try DynamicCache style access first
    if hasattr(baseline_cache, 'key_cache') and baseline_cache.key_cache:
        num_layers = len(baseline_cache.key_cache)
        layer_0_key = baseline_cache.key_cache[0]
        layer_0_value = baseline_cache.value_cache[0]
        print(f"The cache contains {num_layers} layers (one per transformer layer).")
    elif hasattr(baseline_cache, 'get_seq_length'):
        # DynamicCache with different structure - iterate to get first layer
        # Access via to_legacy_cache() if available
        if hasattr(baseline_cache, 'to_legacy_cache'):
            legacy_cache = baseline_cache.to_legacy_cache()
            num_layers = len(legacy_cache)
            layer_0_key = legacy_cache[0][0]
            layer_0_value = legacy_cache[0][1]
        else:
            # Get the internal _data attribute or similar
            print(f"Cache attributes: {[a for a in dir(baseline_cache) if not a.startswith('_')]}")
            raise AttributeError("Cannot access cache internals")
        print(f"The cache contains {num_layers} layers (one per transformer layer).")
    else:
        # Tuple-style cache
        num_layers = len(baseline_cache)
        layer_0_key = baseline_cache[0][0]
        layer_0_value = baseline_cache[0][1]
        print(f"The cache is a tuple of {num_layers} layers.")
    
    print("Each layer stores key and value tensors for attention.")
    print(f"\nLayer 0 key shape:   {list(layer_0_key.shape)}")
    print(f"Layer 0 value shape: {list(layer_0_value.shape)}")

    print(f"\nDimensions:")
    print(f"  - Batch size: {layer_0_key.shape[0]}")
    print(f"  - Num attention heads: {layer_0_key.shape[1]}")
    print(f"  - Sequence length: {layer_0_key.shape[2]} (= context tokens)")
    print(f"  - Head dimension: {layer_0_key.shape[3]}")

except Exception as e:
    print(f"\nNote: Could not inspect cache internals ({e})")
    print(f"The cache has {len(baseline_cache)} layers.")
    print("Each layer contains key and value tensors for the attention mechanism.")

print("\n" + "-"*60)
print("KEY INSIGHT:")
print("-"*60)
print("The KV cache stores the attention states for every token in the context.")
print("When we add a query, we only need to compute attention for the new tokens,")
print("while reusing the cached states for the context. This is the efficiency gain.")

UNDERSTANDING THE KV CACHE STRUCTURE

Cache type: DynamicCache
Cache attributes: ['batch_repeat_interleave', 'batch_select_indices', 'crop', 'early_initialization', 'get_mask_sizes', 'get_max_cache_shape', 'get_seq_length', 'is_compileable', 'is_initialized', 'is_sliding', 'layer_class_to_replicate', 'layers', 'max_batch_size', 'max_cache_len', 'offload', 'offloading', 'prefetch', 'reorder_cache', 'reset', 'update']

Note: Could not inspect cache internals (Cannot access cache internals)
The cache has 32 layers.
Each layer contains key and value tensors for the attention mechanism.

------------------------------------------------------------
KEY INSIGHT:
------------------------------------------------------------
The KV cache stores the attention states for every token in the context.
When we add a query, we only need to compute attention for the new tokens,
while reusing the cached states for the context. This is the efficiency gain.


---

## Step 5: Running Generation/Scoring with the Cache

Now we'll see how to use the KV cache to score an answer. The key metric is **Negative Log-Likelihood (NLL)**:

$$\text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log P(t_i | t_1, ..., t_{i-1}, \text{context})$$

Where:
- $N$ is the number of tokens in the answer
- $P(t_i | ...)$ is the probability the model assigns to token $t_i$ given all previous tokens

**Lower NLL = Model is more confident in the answer = Better cache quality**

### The Scoring Process

```
1. Start with pre-built cache (context already processed)
2. Append the query to the cache
3. For each token in the answer:
   - Get model's probability for that token
   - Compute log probability
4. Average the negative log probabilities
```

In [20]:
def score_answer_with_cache(past_key_values, context_len, query_prompt, answer):
    """
    Score an answer using a pre-built KV cache.
    
    The process:
    1. Extend the cache by processing the query
    2. Compute the probability of each token in the answer
    3. Return the mean negative log-likelihood
    
    Lower NLL = model is more confident = better cache priming
    
    Args:
        past_key_values: Pre-built KV cache from context
        context_len: Number of tokens in the cached context
        query_prompt: The query to append (formatted)
        answer: The answer to score
    
    Returns:
        Mean NLL for the answer tokens
    """
    # Step 1: Tokenize the query
    query_encoding = tokenizer(
        query_prompt, 
        return_tensors="pt", 
        add_special_tokens=False,
        padding=False, 
        truncation=False
    )
    query_ids = query_encoding['input_ids'].to(DEVICE)
    query_len = query_ids.shape[1]
    
    # Step 2: Tokenize the answer
    answer_encoding = tokenizer(
        answer, 
        return_tensors="pt", 
        add_special_tokens=False,
        padding=False, 
        truncation=False
    )
    answer_ids = answer_encoding['input_ids'].to(DEVICE)
    answer_len = answer_ids.shape[1]
    
    # Step 3: Extend cache with query
    combined_len = context_len + query_len
    attention_mask = torch.ones((1, combined_len), device=DEVICE)
    
    with torch.no_grad():
        query_outputs = model(
            input_ids=query_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,  # Reuse the cache!
            use_cache=True,
            return_dict=True
        )
        extended_cache = query_outputs.past_key_values
    
    # Step 4: Score the answer
    combined_len_final = context_len + query_len + answer_len
    attention_mask_final = torch.ones((1, combined_len_final), device=DEVICE)
    
    with torch.no_grad():
        answer_outputs = model(
            input_ids=answer_ids,
            attention_mask=attention_mask_final,
            past_key_values=extended_cache,
            use_cache=False,
            return_dict=True
        )
    
    # Step 5: Compute NLL
    # The logits predict the NEXT token, so we shift
    logits = answer_outputs.logits  # Shape: [1, answer_len, vocab_size]
    shift_logits = logits[:, :-1, :].contiguous()  # All but last
    shift_labels = answer_ids[:, 1:].contiguous()  # All but first
    
    # Flatten for cross-entropy
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = shift_labels.view(-1)
    
    # Compute cross-entropy loss (= NLL)
    loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
    nll = loss_fct(shift_logits, shift_labels).item()
    
    # Average over number of tokens scored
    num_scored = answer_len - 1
    return nll / num_scored if num_scored > 0 else 0.0

print("Function defined: score_answer_with_cache()")

Function defined: score_answer_with_cache()


In [21]:
# Score the answer with each cache
query_prompt = QUERY_TEMPLATE.format(query=example['query'])

print("="*60)
print("SCORING THE ANSWER WITH DIFFERENT CACHES")
print("="*60)

print(f"\nQuery: \"{example['query']}\"")
print(f"Answer: \"{example['answer']}\"")

print("\nScoring with each cache...")

print("\nScoring with baseline cache...", end=" ")
baseline_nll = score_answer_with_cache(
    baseline_cache, baseline_len, query_prompt, example['answer']
)
print(f"NLL = {baseline_nll:.4f}")

print("Scoring with static surrogate cache...", end=" ")
static_nll = score_answer_with_cache(
    static_cache, static_len, query_prompt, example['answer']
)
print(f"NLL = {static_nll:.4f}")

print("Scoring with generated surrogate cache...", end=" ")
generated_nll = score_answer_with_cache(
    generated_cache, generated_len, query_prompt, example['answer']
)
print(f"NLL = {generated_nll:.4f}")

SCORING THE ANSWER WITH DIFFERENT CACHES

Query: "walgreens store sales average"
Answer: "Approximately $15,000 per year."

Scoring with each cache...

Scoring with baseline cache... NLL = 2.9583
Scoring with static surrogate cache... NLL = 2.4583
Scoring with generated surrogate cache... NLL = 2.7292


In [22]:
# Display results with interpretation
print("="*80)
print("RESULTS SUMMARY")
print("="*80)

print("\n" + "-"*40)
print("ANSWER NLL BY CACHE TYPE (lower = better)")
print("-"*40)
print(f"{'Cache Type':<30} {'NLL':>10} {'vs Baseline':>15}")
print("-"*55)
print(f"{'Baseline (document only)':<30} {baseline_nll:>10.4f} {'-':>15}")
print(f"{'Static Surrogate':<30} {static_nll:>10.4f} {baseline_nll - static_nll:>+15.4f}")
print(f"{'Generated Surrogate':<30} {generated_nll:>10.4f} {baseline_nll - generated_nll:>+15.4f}")

print("\n" + "-"*40)
print("INTERPRETATION")
print("-"*40)

# Determine winner
nlls = {
    'Baseline': baseline_nll,
    'Static Surrogate': static_nll,
    'Generated Surrogate': generated_nll
}
winner = min(nlls.items(), key=lambda x: x[1])

print(f"\nBest cache for this example: {winner[0]} (NLL={winner[1]:.4f})")

static_delta = baseline_nll - static_nll
generated_delta = baseline_nll - generated_nll

if static_delta > 0:
    print(f"\nStatic surrogate IMPROVES over baseline by {static_delta:.4f} NLL")
else:
    print(f"\nStatic surrogate WORSE than baseline by {-static_delta:.4f} NLL")

if generated_delta > 0:
    print(f"Generated surrogate IMPROVES over baseline by {generated_delta:.4f} NLL")
else:
    print(f"Generated surrogate WORSE than baseline by {-generated_delta:.4f} NLL")

if generated_nll < static_nll:
    print(f"\nGenerated beats Static by {static_nll - generated_nll:.4f} NLL")
else:
    print(f"\nStatic beats Generated by {generated_nll - static_nll:.4f} NLL")

RESULTS SUMMARY

----------------------------------------
ANSWER NLL BY CACHE TYPE (lower = better)
----------------------------------------
Cache Type                            NLL     vs Baseline
-------------------------------------------------------
Baseline (document only)           2.9583               -
Static Surrogate                   2.4583         +0.5000
Generated Surrogate                2.7292         +0.2292

----------------------------------------
INTERPRETATION
----------------------------------------

Best cache for this example: Static Surrogate (NLL=2.4583)

Static surrogate IMPROVES over baseline by 0.5000 NLL
Generated surrogate IMPROVES over baseline by 0.2292 NLL

Static beats Generated by 0.2708 NLL


---

## Step 6: Understanding the Evaluation Metric (NLL)

Let's dig deeper into what NLL means and why we use it.

### What is Negative Log-Likelihood?

For each token in the answer, the model outputs a probability distribution over all possible next tokens. NLL measures how "surprised" the model is by the correct answer:

- **Low NLL**: Model assigns high probability to the correct tokens → Confident, good prediction
- **High NLL**: Model assigns low probability to the correct tokens → Uncertain, poor prediction

### Why NLL Instead of Generated Text?

1. **Deterministic**: Same input always gives same score (no sampling randomness)
2. **Fine-grained**: Captures subtle differences in model confidence
3. **Fast**: No need to generate full responses
4. **Comparable**: Easy to compare across conditions statistically

In [23]:
# Let's look at token-by-token probabilities for deeper understanding
def get_token_probabilities(past_key_values, context_len, query_prompt, answer):
    """
    Get per-token probabilities for the answer.
    Returns list of (token, probability, log_prob) tuples.
    """
    # Tokenize
    query_encoding = tokenizer(query_prompt, return_tensors="pt", add_special_tokens=False)
    query_ids = query_encoding['input_ids'].to(DEVICE)
    query_len = query_ids.shape[1]
    
    answer_encoding = tokenizer(answer, return_tensors="pt", add_special_tokens=False)
    answer_ids = answer_encoding['input_ids'].to(DEVICE)
    answer_len = answer_ids.shape[1]
    
    # Extend cache with query
    combined_len = context_len + query_len
    attention_mask = torch.ones((1, combined_len), device=DEVICE)
    
    with torch.no_grad():
        query_outputs = model(
            input_ids=query_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
            return_dict=True
        )
        extended_cache = query_outputs.past_key_values
    
    # Score answer
    combined_len_final = context_len + query_len + answer_len
    attention_mask_final = torch.ones((1, combined_len_final), device=DEVICE)
    
    with torch.no_grad():
        answer_outputs = model(
            input_ids=answer_ids,
            attention_mask=attention_mask_final,
            past_key_values=extended_cache,
            use_cache=False,
            return_dict=True
        )
    
    # Get per-token probabilities
    logits = answer_outputs.logits
    probs = torch.softmax(logits, dim=-1)
    
    token_info = []
    for i in range(answer_len - 1):
        next_token_id = answer_ids[0, i + 1].item()
        next_token = tokenizer.decode([next_token_id])
        prob = probs[0, i, next_token_id].item()
        log_prob = np.log(prob + 1e-10)
        token_info.append((next_token, prob, -log_prob))
    
    return token_info

print("Getting token-level probabilities for deeper analysis...")

Getting token-level probabilities for deeper analysis...


In [24]:
# Compare token probabilities between caches
print("="*80)
print("TOKEN-BY-TOKEN PROBABILITY ANALYSIS")
print("="*80)

print(f"\nAnswer being scored: \"{example['answer']}\"")

baseline_tokens = get_token_probabilities(
    baseline_cache, baseline_len, query_prompt, example['answer']
)
generated_tokens = get_token_probabilities(
    generated_cache, generated_len, query_prompt, example['answer']
)

print("\n" + "-"*70)
print(f"{'Token':<15} {'Baseline Prob':>15} {'Generated Prob':>15} {'Winner':>12}")
print("-"*70)

for (tok1, prob1, nll1), (tok2, prob2, nll2) in zip(baseline_tokens, generated_tokens):
    winner = "Generated" if prob2 > prob1 else "Baseline" if prob1 > prob2 else "Tie"
    tok_display = repr(tok1) if len(tok1.strip()) > 0 else "' '"
    print(f"{tok_display:<15} {prob1:>15.4f} {prob2:>15.4f} {winner:>12}")

print("-"*70)
avg_baseline = np.mean([x[1] for x in baseline_tokens])
avg_generated = np.mean([x[1] for x in generated_tokens])
print(f"{'AVERAGE':<15} {avg_baseline:>15.4f} {avg_generated:>15.4f}")

TOKEN-BY-TOKEN PROBABILITY ANALYSIS

Answer being scored: "Approximately $15,000 per year."

----------------------------------------------------------------------
Token             Baseline Prob  Generated Prob       Winner
----------------------------------------------------------------------
'xim'                    0.0140          0.1338    Generated
'ately'                  0.0008          0.0002     Baseline
'$'                      0.0030          0.0065    Generated
'1'                      0.9414          0.3926     Baseline
'5'                      0.0933          0.0369     Baseline
','                      0.3242          0.2129     Baseline
'0'                      0.0064          0.0157    Generated
'0'                      0.0306          0.0374    Generated
'0'                      0.1865          0.1348     Baseline
'per'                    0.0012          0.0004     Baseline
'year'                   0.9922          0.9883     Baseline
'.'                      0.0000  

In [25]:
# Visual explanation of NLL
print("="*80)
print("HOW NLL IS COMPUTED")
print("="*80)

print("""
For each token in the answer, we:
1. Get the model's predicted probability for that token
2. Take the negative log: -log(probability)
3. Average across all tokens

Example calculation (baseline cache):
""")

print(f"{'Token':<15} {'Probability':>12} {'-log(prob)':>12}")
print("-"*42)
total_nll = 0
for i, (tok, prob, nll) in enumerate(baseline_tokens[:5]):  # Show first 5
    tok_display = repr(tok) if len(tok.strip()) > 0 else "' '"
    print(f"{tok_display:<15} {prob:>12.4f} {nll:>12.4f}")
    total_nll += nll

if len(baseline_tokens) > 5:
    print(f"{'...':<15} {'...':>12} {'...':>12}")

mean_nll = np.mean([x[2] for x in baseline_tokens])
print("-"*42)
print(f"Mean NLL = {mean_nll:.4f}")
print(f"(This is the score we use to compare caches)")

HOW NLL IS COMPUTED

For each token in the answer, we:
1. Get the model's predicted probability for that token
2. Take the negative log: -log(probability)
3. Average across all tokens

Example calculation (baseline cache):

Token            Probability   -log(prob)
------------------------------------------
'xim'                 0.0140       4.2703
'ately'               0.0008       7.1391
'$'                   0.0030       5.8072
'1'                   0.9414       0.0604
'5'                   0.0933       2.3723
...                      ...          ...
------------------------------------------
Mean NLL = 4.4645
(This is the score we use to compare caches)


---

## Step 7: The Routing Decision

In practice, we have **5 surrogate caches** for each document (not just 1). At query time, we need to **route** the query to the best-matching cache.

### The Routing Strategy

1. Embed the user's query using a sentence embedding model
2. Embed each of the 5 surrogates
3. Compute cosine similarity between query and each surrogate
4. Route to the cache with the highest similarity

```
User Query: "what temperature should it be to plant grass seeds"
                         │
                         ▼
    ┌────────────────────────────────────────────────┐
    │           Compute Similarity to:               │
    │  Surrogate 1: sim=0.45  ─────────────────────┐ │
    │  Surrogate 2: sim=0.62  ──────────────────┐  │ │
    │  Surrogate 3: sim=0.83  ←── BEST MATCH! ──┼──┼ │
    │  Surrogate 4: sim=0.31  ──────────────────┘  │ │
    │  Surrogate 5: sim=0.55  ─────────────────────┘ │
    └────────────────────────────────────────────────┘
                         │
                         ▼
                 Use Cache #3
```

In [26]:
# Build all 5 static and 5 generated caches
print("="*60)
print("BUILDING ALL SURROGATE CACHES")
print("="*60)

print("\nBuilding 5 static surrogate caches...")
all_static_caches = {}
all_static_lens = {}
for name, info in STATIC_SURROGATES.items():
    context = SURROGATE_TEMPLATE.format(surrogate=info['query'], document=example['passage'])
    cache_len, cache = build_kv_cache(context)
    all_static_caches[name] = cache
    all_static_lens[name] = cache_len
    print(f"  {name}: {cache_len} tokens")

print("\nBuilding 5 generated surrogate caches...")
all_generated_caches = {}
all_generated_lens = {}
for key, surrogate in generated_surrogates.items():
    context = SURROGATE_TEMPLATE.format(surrogate=surrogate, document=example['passage'])
    cache_len, cache = build_kv_cache(context)
    all_generated_caches[key] = cache
    all_generated_lens[key] = cache_len
    print(f"  {GENERATED_TEMPLATES[key]['name']}: {cache_len} tokens")

print("\nTotal caches built: 11 (1 baseline + 5 static + 5 generated)")

BUILDING ALL SURROGATE CACHES

Building 5 static surrogate caches...
  definitional: 127 tokens
  procedural: 127 tokens
  quantitative: 130 tokens
  factual: 128 tokens
  problem: 124 tokens

Building 5 generated surrogate caches...
  Target Question: 142 tokens
  Keyword Query: 123 tokens
  Symptom Query: 134 tokens
  Misconception Query: 162 tokens
  Messy Query: 152 tokens

Total caches built: 11 (1 baseline + 5 static + 5 generated)


In [27]:
# Demonstrate the routing decision
print("="*80)
print("ROUTING DEMONSTRATION")
print("="*80)

print(f"\nUser query: \"{example['query']}\"")

# Static routing
print("\n" + "-"*40)
print("STATIC SURROGATE ROUTING")
print("-"*40)
print(f"{'Surrogate':<15} {'Query':<45} {'Similarity':>10}")
print("-"*70)

static_sims = {}
for name, info in STATIC_SURROGATES.items():
    sim = compute_similarity(info['query'], example['query'])
    static_sims[name] = sim
    query_preview = info['query'][:42] + "..." if len(info['query']) > 45 else info['query']
    marker = " <-- ROUTE HERE" if sim == max(static_sims.values()) else ""
    print(f"{name:<15} \"{query_preview:<43}\" {sim:>10.4f}{marker}")

best_static_key = max(static_sims.keys(), key=lambda k: static_sims[k])
print(f"\nRouted to: {best_static_key} (similarity={static_sims[best_static_key]:.4f})")

# Generated routing
print("\n" + "-"*40)
print("GENERATED SURROGATE ROUTING")
print("-"*40)
print(f"{'Type':<20} {'Surrogate':<40} {'Similarity':>10}")
print("-"*72)

generated_sims = {}
for key, surrogate in generated_surrogates.items():
    sim = compute_similarity(surrogate, example['query'])
    generated_sims[key] = sim
    name = GENERATED_TEMPLATES[key]['name']
    surrogate_preview = surrogate[:37] + "..." if len(surrogate) > 40 else surrogate
    marker = " <--" if sim == max(generated_sims.values()) else ""
    print(f"{name:<20} \"{surrogate_preview:<38}\" {sim:>10.4f}{marker}")

best_generated_key = max(generated_sims.keys(), key=lambda k: generated_sims[k])
print(f"\nRouted to: {GENERATED_TEMPLATES[best_generated_key]['name']} (similarity={generated_sims[best_generated_key]:.4f})")

ROUTING DEMONSTRATION

User query: "walgreens store sales average"

----------------------------------------
STATIC SURROGATE ROUTING
----------------------------------------
Surrogate       Query                                         Similarity
----------------------------------------------------------------------
definitional    "What is this and what does it mean?        "    -0.0774 <-- ROUTE HERE
procedural      "How do I do this step by step?             "     0.0368 <-- ROUTE HERE
quantitative    "How much does this cost or how long does i..."     0.0101
factual         "What are the key facts I need to know?     "    -0.0916
problem         "What problem does this solve?              "     0.0027

Routed to: procedural (similarity=0.0368)

----------------------------------------
GENERATED SURROGATE ROUTING
----------------------------------------
Type                 Surrogate                                Similarity
---------------------------------------------------------

In [28]:
# Score with the routed caches and compare
print("="*80)
print("FINAL COMPARISON: ROUTED CACHES")
print("="*80)

# Score with routed static cache
routed_static_nll = score_answer_with_cache(
    all_static_caches[best_static_key],
    all_static_lens[best_static_key],
    query_prompt,
    example['answer']
)

# Score with routed generated cache
routed_generated_nll = score_answer_with_cache(
    all_generated_caches[best_generated_key],
    all_generated_lens[best_generated_key],
    query_prompt,
    example['answer']
)

print(f"\nQuery: \"{example['query']}\"")
print(f"Answer: \"{example['answer']}\"")

print("\n" + "-"*60)
print("RESULTS")
print("-"*60)
print(f"{'Method':<25} {'Routed To':<25} {'NLL':>10}")
print("-"*60)
print(f"{'Baseline':<25} {'(no routing)':<25} {baseline_nll:>10.4f}")
print(f"{'Static (routed)':<25} {best_static_key:<25} {routed_static_nll:>10.4f}")
print(f"{'Generated (routed)':<25} {GENERATED_TEMPLATES[best_generated_key]['name']:<25} {routed_generated_nll:>10.4f}")

print("\n" + "-"*60)
print("IMPROVEMENT OVER BASELINE")
print("-"*60)
static_improvement = baseline_nll - routed_static_nll
generated_improvement = baseline_nll - routed_generated_nll

print(f"Static:    {static_improvement:+.4f} NLL {'(better)' if static_improvement > 0 else '(worse)'}")
print(f"Generated: {generated_improvement:+.4f} NLL {'(better)' if generated_improvement > 0 else '(worse)'}")

print("\n" + "="*60)
print("WINNER FOR THIS EXAMPLE")
print("="*60)
all_nlls = {
    'Baseline': baseline_nll,
    f'Static ({best_static_key})': routed_static_nll,
    f'Generated ({GENERATED_TEMPLATES[best_generated_key]["name"]})': routed_generated_nll
}
winner = min(all_nlls.items(), key=lambda x: x[1])
print(f"\n{winner[0]} wins with NLL = {winner[1]:.4f}")

FINAL COMPARISON: ROUTED CACHES

Query: "walgreens store sales average"
Answer: "Approximately $15,000 per year."

------------------------------------------------------------
RESULTS
------------------------------------------------------------
Method                    Routed To                        NLL
------------------------------------------------------------
Baseline                  (no routing)                  2.9583
Static (routed)           procedural                    2.4583
Generated (routed)        Keyword Query                 2.7292

------------------------------------------------------------
IMPROVEMENT OVER BASELINE
------------------------------------------------------------
Static:    +0.5000 NLL (better)
Generated: +0.2292 NLL (better)

WINNER FOR THIS EXAMPLE

Static (procedural) wins with NLL = 2.4583


---

## Summary: What We Learned

### The Experimental Setup

1. **Data**: MS MARCO passages with queries and answers
2. **Caches**: 11 per document (1 baseline + 5 static + 5 generated)
3. **Routing**: Use embedding similarity to pick best cache
4. **Metric**: NLL (negative log-likelihood) - lower is better

### Key Components

| Component | Purpose |
|-----------|----------|
| Static Surrogates | 5 fixed intent queries (same for all docs) |
| Generated Surrogates | 5 doc-specific queries (LLM generates) |
| KV Cache | Stores attention states for reuse |
| Routing | Matches query to best surrogate |
| NLL Scoring | Measures answer probability |

### The Research Question

**Is document-specific surrogate generation worth the compute cost?**

- Generated surrogates require an LLM call per document
- Static surrogates are free (just fixed strings)
- If they perform similarly, static is more cost-effective

The full experiment runs this comparison across thousands of samples to get statistically significant results.

In [29]:
print("="*80)
print("EXPERIMENT WALKTHROUGH COMPLETE")
print("="*80)

print("""
You've now seen each step of the experiment:

1. DATA POINT: A passage, query, and answer from MS MARCO
   - The passage is the document we're caching
   - The query is what users ask
   - The answer is the ground truth

2. SURROGATES: Two approaches to "prime" the cache
   - Static: 5 fixed intent queries (cheap, generic)
   - Generated: 5 doc-specific queries (expensive, tailored)

3. KV CACHE: Pre-computed attention states
   - Built once per document (indexing time)
   - Reused for all queries (query time)
   - Surrogate priming adds context before the document

4. SCORING: NLL measures answer quality
   - Lower NLL = model more confident in answer
   - Compare across cache types to measure improvement

5. ROUTING: Pick the best surrogate cache
   - Embed query and surrogates
   - Route to highest-similarity match

The full experiment runs this for 2500 samples to determine:
- Do surrogates help at all?
- Is generated better than static?
- Is the extra compute worth it?
""")

EXPERIMENT WALKTHROUGH COMPLETE

You've now seen each step of the experiment:

1. DATA POINT: A passage, query, and answer from MS MARCO
   - The passage is the document we're caching
   - The query is what users ask
   - The answer is the ground truth

2. SURROGATES: Two approaches to "prime" the cache
   - Static: 5 fixed intent queries (cheap, generic)
   - Generated: 5 doc-specific queries (expensive, tailored)

3. KV CACHE: Pre-computed attention states
   - Built once per document (indexing time)
   - Reused for all queries (query time)
   - Surrogate priming adds context before the document

4. SCORING: NLL measures answer quality
   - Lower NLL = model more confident in answer
   - Compare across cache types to measure improvement

5. ROUTING: Pick the best surrogate cache
   - Embed query and surrogates
   - Route to highest-similarity match

The full experiment runs this for 2500 samples to determine:
- Do surrogates help at all?
- Is generated better than static?
- Is the ex