# Surrogate-Primed KV Caching Experiment

## Hypothesis
Prepending a domain-specific "surrogate query" to a document before caching will produce KV states that are more receptive to the actual user query, compared to a generic system prompt.

## Metric
We measure **Conditional Perplexity Delta**:
- `Delta = NLL(Query | Baseline Cache) - NLL(Query | Surrogate Cache)`
- **Positive Delta** = Surrogate reduced surprise (success)
- **Negative Delta** = Surrogate confused the model (failure)

## Step 1: Configuration & Setup

In [None]:
# Install dependencies if needed
# !pip install transformers torch datasets tqdm scipy bitsandbytes accelerate matplotlib

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt
import numpy as np
import json
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

In [None]:
@dataclass
class ExperimentConfig:
    """Configuration for the Surrogate-Primed KV Caching experiment."""
    
    # Model settings
    model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct"
    # Alternative: "mistralai/Mistral-7B-Instruct-v0.2"
    
    # Quantization for memory efficiency (fits on T4/L4 GPU)
    use_4bit: bool = True
    
    # Dataset settings
    dataset_name: str = "ms_marco"
    dataset_config: str = "v1.1"
    dataset_split: str = "validation"
    num_samples: int = 100
    min_passage_words: int = 50
    max_passage_words: int = 200
    
    # Surrogate generation settings
    surrogate_temperature: float = 0.0  # Deterministic for reproducibility
    surrogate_max_tokens: int = 15
    
    # Prompts
    baseline_prompt: str = "System: Read the following document carefully.\n\n"
    surrogate_generation_prompt: str = (
        "You are a search engine optimization expert. Read the following text. "
        "Write a single, short search query that a user would type to find this text. "
        "Output ONLY the query, nothing else."
    )
    surrogate_prefix_template: str = "System: Answer the following query: {surrogate}\n\nDocument:\n"
    
    # Random seed for reproducibility
    seed: int = 42
    
    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


config = ExperimentConfig()
print(f"Running on device: {config.device}")
print(f"Model: {config.model_name}")
print(f"4-bit quantization: {config.use_4bit}")

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(config.seed)
np.random.seed(config.seed)

# Load model and tokenizer
print("Loading model and tokenizer...")

if config.use_4bit:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=quantization_config,
        device_map="auto",
        trust_remote_code=True,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )

tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)

# Ensure padding is on the right (critical for causal LM)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model.eval()
print(f"Model loaded successfully. Parameters: {model.num_parameters():,}")

## Step 2: Dataset Loading & Filtering

In [None]:
def count_words(text: str) -> int:
    """Count words in a text string."""
    return len(text.split())


def load_and_filter_dataset(config: ExperimentConfig) -> List[Dict]:
    """
    Load MS MARCO dataset and filter passages by word count.
    
    Returns list of dicts with 'passage' and 'query' keys.
    """
    print(f"Loading {config.dataset_name} dataset...")
    
    # Load MS MARCO
    dataset = load_dataset(
        config.dataset_name,
        config.dataset_config,
        split=config.dataset_split,
        trust_remote_code=True
    )
    
    print(f"Total samples in {config.dataset_split}: {len(dataset)}")
    
    # Filter and extract valid samples
    filtered_samples = []
    
    for item in tqdm(dataset, desc="Filtering passages"):
        # MS MARCO has passages as a list; we take the first relevant one
        passages = item.get('passages', {})
        passage_texts = passages.get('passage_text', [])
        is_selected = passages.get('is_selected', [])
        
        query = item.get('query', '')
        
        if not passage_texts or not query:
            continue
        
        # Find a passage that meets word count criteria
        for i, passage in enumerate(passage_texts):
            word_count = count_words(passage)
            if config.min_passage_words <= word_count <= config.max_passage_words:
                # Prefer selected passages if available
                if is_selected and i < len(is_selected) and is_selected[i] == 1:
                    filtered_samples.append({
                        'passage': passage,
                        'query': query
                    })
                    break
                elif not any(is_selected):  # No selection info, take first valid
                    filtered_samples.append({
                        'passage': passage,
                        'query': query
                    })
                    break
        
        # Early stop if we have enough
        if len(filtered_samples) >= config.num_samples * 2:  # Buffer for random selection
            break
    
    # Randomly sample the required number
    np.random.shuffle(filtered_samples)
    filtered_samples = filtered_samples[:config.num_samples]
    
    print(f"Selected {len(filtered_samples)} samples meeting criteria")
    return filtered_samples

In [None]:
# Load and filter dataset
samples = load_and_filter_dataset(config)

# Preview a sample
print("\n" + "="*80)
print("SAMPLE PREVIEW")
print("="*80)
print(f"\nPassage ({count_words(samples[0]['passage'])} words):")
print(samples[0]['passage'][:500] + "..." if len(samples[0]['passage']) > 500 else samples[0]['passage'])
print(f"\nGround Truth Query: {samples[0]['query']}")

## Step 3: Surrogate Query Generator

This function generates a "surrogate query" from a document. The surrogate mimics what a user might search for to find this document. 

**Critical**: The generator NEVER sees the ground truth query - it only sees the document (mimicking "indexing time").

In [None]:
def generate_surrogate(
    doc_text: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    config: ExperimentConfig
) -> str:
    """
    Generate a surrogate query for a document.
    
    The surrogate is what we hypothesize a user might search to find this document.
    This function uses the model's chat template for proper instruction following.
    
    Args:
        doc_text: The document/passage text
        model: The language model
        tokenizer: The tokenizer
        config: Experiment configuration
    
    Returns:
        Generated surrogate query string
    """
    # Build the prompt using chat template for instruct models
    messages = [
        {
            "role": "user",
            "content": f"{config.surrogate_generation_prompt}\n\nText:\n{doc_text}"
        }
    ]
    
    # Apply chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(config.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=config.surrogate_max_tokens,
            temperature=config.surrogate_temperature if config.surrogate_temperature > 0 else None,
            do_sample=config.surrogate_temperature > 0,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode only the generated tokens (exclude prompt)
    generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
    surrogate = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
    
    return surrogate

In [None]:
# Test surrogate generation
print("Testing surrogate generation...")
test_surrogate = generate_surrogate(samples[0]['passage'], model, tokenizer, config)
print(f"\nDocument preview: {samples[0]['passage'][:200]}...")
print(f"\nGenerated Surrogate: '{test_surrogate}'")
print(f"Ground Truth Query: '{samples[0]['query']}'")

## Step 4: Scoring Engine - KV Cache Surgery

This is the critical component. We perform "KV Cache Surgery":

1. **Prefill Phase**: Forward pass the context (prompt + document) to generate the KV cache
2. **Decode Phase**: Forward pass the target query using the frozen KV cache
3. **Scoring**: Compute NLL only on the target query tokens

### Mathematical Foundation

For a sequence `[context, target]`, the model computes:
- Keys: `K = W_k @ [context, target]`  
- Values: `V = W_v @ [context, target]`

The KV cache stores `K_context` and `V_context` from the prefill. During decoding, attention is computed as:

```
Attention(Q_target, [K_context; K_target], [V_context; V_target])
```

By priming with different contexts, we change `K_context` and `V_context`, potentially making them more "receptive" to `Q_target`.

In [None]:
def compute_conditional_nll(
    context_prefix: str,
    document: str,
    target_query: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    config: ExperimentConfig
) -> Tuple[float, int]:
    """
    Compute the Negative Log-Likelihood of target_query conditioned on
    the KV cache generated from (context_prefix + document).
    
    This implements "KV Cache Surgery":
    1. Prefill: Generate KV cache from context
    2. Decode: Score target using the frozen cache
    3. Loss: Compute NLL only on target tokens
    
    Args:
        context_prefix: The prefix prompt (baseline or surrogate-based)
        document: The document/passage text
        target_query: The ground truth query to score
        model: The language model
        tokenizer: The tokenizer
        config: Experiment configuration
    
    Returns:
        Tuple of (nll_value, num_target_tokens)
    """
    # =========================================================================
    # STEP 1: Tokenize context and target separately
    # =========================================================================
    # Build full context string
    context_str = context_prefix + document
    
    # Tokenize context (no padding - we process one sample at a time)
    context_encoding = tokenizer(
        context_str,
        return_tensors="pt",
        add_special_tokens=True,  # Add BOS if model uses it
        padding=False,
        truncation=False
    )
    context_ids = context_encoding['input_ids'].to(config.device)
    context_len = context_ids.shape[1]
    
    # Tokenize target (no special tokens - it continues from context)
    target_encoding = tokenizer(
        target_query,
        return_tensors="pt",
        add_special_tokens=False,  # No BOS - continues from context
        padding=False,
        truncation=False
    )
    target_ids = target_encoding['input_ids'].to(config.device)
    target_len = target_ids.shape[1]
    
    # =========================================================================
    # STEP 2: Prefill - Generate KV cache from context
    # =========================================================================
    # Forward pass through context to build KV cache
    # We don't need the logits from this pass, only the cache
    with torch.no_grad():
        prefill_outputs = model(
            input_ids=context_ids,
            attention_mask=torch.ones_like(context_ids),
            use_cache=True,
            return_dict=True
        )
        
        # Extract the KV cache
        # Shape per layer: (batch, num_heads, seq_len, head_dim)
        past_key_values = prefill_outputs.past_key_values
    
    # =========================================================================
    # STEP 3: Decode - Forward pass target with frozen KV cache
    # =========================================================================
    # CRITICAL: The attention mask must cover BOTH context AND target
    # The model needs to know the full sequence length for position embeddings
    # and attention masking
    combined_len = context_len + target_len
    attention_mask = torch.ones((1, combined_len), device=config.device)
    
    # Forward pass target tokens using the cached KV states
    with torch.no_grad():
        decode_outputs = model(
            input_ids=target_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            use_cache=False,  # Don't need to extend cache further
            return_dict=True
        )
    
    # =========================================================================
    # STEP 4: Compute NLL on target tokens only
    # =========================================================================
    # decode_outputs.logits shape: (batch, target_len, vocab_size)
    # For autoregressive loss, we predict token[i+1] from logits[i]
    # So we use logits[:-1] to predict target_ids[1:]
    
    logits = decode_outputs.logits  # (1, target_len, vocab_size)
    
    # Shift: predict next token from current position
    # logits[:, :-1, :] predicts target_ids[:, 1:]
    shift_logits = logits[:, :-1, :].contiguous()  # (1, target_len-1, vocab)
    shift_labels = target_ids[:, 1:].contiguous()  # (1, target_len-1)
    
    # Flatten for cross entropy
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = shift_labels.view(-1)
    
    # Compute per-token cross entropy (NLL)
    loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
    nll = loss_fct(shift_logits, shift_labels).item()
    
    # Number of tokens scored (target_len - 1 due to shift)
    num_scored_tokens = target_len - 1
    
    # Return mean NLL per token for comparability
    mean_nll = nll / num_scored_tokens if num_scored_tokens > 0 else 0.0
    
    return mean_nll, num_scored_tokens

In [None]:
# Test the scoring engine
print("Testing scoring engine...")

test_sample = samples[0]

# Baseline condition
baseline_nll, baseline_tokens = compute_conditional_nll(
    context_prefix=config.baseline_prompt,
    document=test_sample['passage'],
    target_query=test_sample['query'],
    model=model,
    tokenizer=tokenizer,
    config=config
)

# Surrogate condition
surrogate = generate_surrogate(test_sample['passage'], model, tokenizer, config)
surrogate_prefix = config.surrogate_prefix_template.format(surrogate=surrogate)

surrogate_nll, surrogate_tokens = compute_conditional_nll(
    context_prefix=surrogate_prefix,
    document=test_sample['passage'],
    target_query=test_sample['query'],
    model=model,
    tokenizer=tokenizer,
    config=config
)

delta = baseline_nll - surrogate_nll

print(f"\nBaseline NLL: {baseline_nll:.4f} ({baseline_tokens} tokens)")
print(f"Surrogate NLL: {surrogate_nll:.4f} ({surrogate_tokens} tokens)")
print(f"Delta (Baseline - Surrogate): {delta:.4f}")
print(f"Result: {'Surrogate WINS' if delta > 0 else 'Baseline WINS'}")

## Step 5: Main Experiment Loop

In [None]:
def run_experiment(
    samples: List[Dict],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    config: ExperimentConfig
) -> List[Dict]:
    """
    Run the full experiment comparing baseline vs surrogate-primed KV caching.
    
    Args:
        samples: List of {passage, query} dicts
        model: The language model
        tokenizer: The tokenizer
        config: Experiment configuration
    
    Returns:
        List of result dicts with NLLs and deltas
    """
    results = []
    
    for i, sample in enumerate(tqdm(samples, desc="Running experiment")):
        passage = sample['passage']
        query = sample['query']
        
        try:
            # Generate surrogate query (only sees document, NOT ground truth query)
            surrogate = generate_surrogate(passage, model, tokenizer, config)
            
            # Condition A: Baseline
            baseline_nll, baseline_tokens = compute_conditional_nll(
                context_prefix=config.baseline_prompt,
                document=passage,
                target_query=query,
                model=model,
                tokenizer=tokenizer,
                config=config
            )
            
            # Condition B: Surrogate-primed
            surrogate_prefix = config.surrogate_prefix_template.format(surrogate=surrogate)
            surrogate_nll, surrogate_tokens = compute_conditional_nll(
                context_prefix=surrogate_prefix,
                document=passage,
                target_query=query,
                model=model,
                tokenizer=tokenizer,
                config=config
            )
            
            # Compute delta: positive means surrogate is better
            delta = baseline_nll - surrogate_nll
            
            results.append({
                'sample_idx': i,
                'query': query,
                'surrogate': surrogate,
                'baseline_nll': baseline_nll,
                'surrogate_nll': surrogate_nll,
                'delta': delta,
                'num_tokens': baseline_tokens,
                'passage_preview': passage[:100] + '...' if len(passage) > 100 else passage
            })
            
        except Exception as e:
            print(f"\nError processing sample {i}: {e}")
            continue
        
        # Clear GPU cache periodically
        if i % 20 == 0:
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return results

In [None]:
# Run the experiment
print(f"Running experiment on {len(samples)} samples...")
print("="*80)

results = run_experiment(samples, model, tokenizer, config)

print(f"\nCompleted {len(results)} samples successfully")

## Step 6: Analysis & Visualization

In [None]:
def analyze_results(results: List[Dict]) -> Dict:
    """
    Compute summary statistics and perform statistical tests.
    
    Args:
        results: List of result dicts from experiment
    
    Returns:
        Dictionary of analysis results
    """
    deltas = np.array([r['delta'] for r in results])
    baseline_nlls = np.array([r['baseline_nll'] for r in results])
    surrogate_nlls = np.array([r['surrogate_nll'] for r in results])
    
    # Win rate: proportion where surrogate is better (delta > 0)
    wins = np.sum(deltas > 0)
    losses = np.sum(deltas < 0)
    ties = np.sum(deltas == 0)
    win_rate = wins / len(deltas)
    
    # Paired t-test: Are the NLLs significantly different?
    t_stat, p_value = stats.ttest_rel(baseline_nlls, surrogate_nlls)
    
    # Effect size (Cohen's d for paired samples)
    diff = baseline_nlls - surrogate_nlls
    cohens_d = np.mean(diff) / np.std(diff, ddof=1) if np.std(diff) > 0 else 0
    
    # Wilcoxon signed-rank test (non-parametric alternative)
    wilcoxon_stat, wilcoxon_p = stats.wilcoxon(baseline_nlls, surrogate_nlls, alternative='two-sided')
    
    analysis = {
        'n_samples': len(results),
        'win_rate': win_rate,
        'wins': wins,
        'losses': losses,
        'ties': ties,
        'mean_delta': np.mean(deltas),
        'std_delta': np.std(deltas),
        'median_delta': np.median(deltas),
        'mean_baseline_nll': np.mean(baseline_nlls),
        'mean_surrogate_nll': np.mean(surrogate_nlls),
        't_statistic': t_stat,
        'p_value': p_value,
        'cohens_d': cohens_d,
        'wilcoxon_stat': wilcoxon_stat,
        'wilcoxon_p': wilcoxon_p,
    }
    
    return analysis

In [None]:
# Compute analysis
analysis = analyze_results(results)

# Print summary
print("="*80)
print("EXPERIMENT RESULTS SUMMARY")
print("="*80)
print(f"\nSamples analyzed: {analysis['n_samples']}")
print(f"\n--- Win/Loss Record ---")
print(f"Surrogate Wins:  {analysis['wins']} ({analysis['win_rate']*100:.1f}%)")
print(f"Baseline Wins:   {analysis['losses']} ({(analysis['losses']/analysis['n_samples'])*100:.1f}%)")
print(f"Ties:            {analysis['ties']}")

print(f"\n--- NLL Statistics ---")
print(f"Mean Baseline NLL:  {analysis['mean_baseline_nll']:.4f}")
print(f"Mean Surrogate NLL: {analysis['mean_surrogate_nll']:.4f}")
print(f"Mean Delta:         {analysis['mean_delta']:.4f} (positive = surrogate better)")
print(f"Std Delta:          {analysis['std_delta']:.4f}")
print(f"Median Delta:       {analysis['median_delta']:.4f}")

print(f"\n--- Statistical Tests ---")
print(f"Paired t-test: t={analysis['t_statistic']:.4f}, p={analysis['p_value']:.6f}")
print(f"Wilcoxon test: W={analysis['wilcoxon_stat']:.4f}, p={analysis['wilcoxon_p']:.6f}")
print(f"Cohen's d (effect size): {analysis['cohens_d']:.4f}")

significance_level = 0.05
if analysis['p_value'] < significance_level:
    print(f"\n*** RESULT: Statistically significant difference (p < {significance_level}) ***")
    if analysis['mean_delta'] > 0:
        print("*** CONCLUSION: Surrogate-primed KV caching IMPROVES query prediction ***")
    else:
        print("*** CONCLUSION: Surrogate-primed KV caching HURTS query prediction ***")
else:
    print(f"\n*** RESULT: No statistically significant difference (p >= {significance_level}) ***")

In [None]:
# Visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Histogram of Deltas
deltas = [r['delta'] for r in results]
ax1 = axes[0]
ax1.hist(deltas, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero (no difference)')
ax1.axvline(x=np.mean(deltas), color='green', linestyle='-', linewidth=2, label=f'Mean ({np.mean(deltas):.3f})')
ax1.set_xlabel('Delta (Baseline NLL - Surrogate NLL)')
ax1.set_ylabel('Frequency')
ax1.set_title('Distribution of NLL Deltas\n(Positive = Surrogate Better)')
ax1.legend()

# Plot 2: Baseline vs Surrogate NLL scatter
ax2 = axes[1]
baseline_nlls = [r['baseline_nll'] for r in results]
surrogate_nlls = [r['surrogate_nll'] for r in results]
ax2.scatter(baseline_nlls, surrogate_nlls, alpha=0.6, edgecolor='black', linewidth=0.5)
min_val = min(min(baseline_nlls), min(surrogate_nlls))
max_val = max(max(baseline_nlls), max(surrogate_nlls))
ax2.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='y=x (no difference)')
ax2.set_xlabel('Baseline NLL')
ax2.set_ylabel('Surrogate NLL')
ax2.set_title('Baseline vs Surrogate NLL\n(Points below line = Surrogate better)')
ax2.legend()

# Plot 3: Box plot comparison
ax3 = axes[2]
box_data = [baseline_nlls, surrogate_nlls]
bp = ax3.boxplot(box_data, labels=['Baseline', 'Surrogate'], patch_artist=True)
bp['boxes'][0].set_facecolor('lightcoral')
bp['boxes'][1].set_facecolor('lightgreen')
ax3.set_ylabel('NLL')
ax3.set_title('NLL Distribution by Condition')

plt.tight_layout()
plt.savefig('experiment_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nFigure saved to: experiment_results.png")

In [None]:
# Show examples: Best and worst cases
sorted_results = sorted(results, key=lambda x: x['delta'], reverse=True)

print("="*80)
print("TOP 5 SURROGATE WINS (Largest Positive Delta)")
print("="*80)
for r in sorted_results[:5]:
    print(f"\nDelta: {r['delta']:.4f}")
    print(f"  Ground Truth Query: {r['query']}")
    print(f"  Generated Surrogate: {r['surrogate']}")
    print(f"  Baseline NLL: {r['baseline_nll']:.4f} | Surrogate NLL: {r['surrogate_nll']:.4f}")

print("\n" + "="*80)
print("TOP 5 BASELINE WINS (Largest Negative Delta)")
print("="*80)
for r in sorted_results[-5:]:
    print(f"\nDelta: {r['delta']:.4f}")
    print(f"  Ground Truth Query: {r['query']}")
    print(f"  Generated Surrogate: {r['surrogate']}")
    print(f"  Baseline NLL: {r['baseline_nll']:.4f} | Surrogate NLL: {r['surrogate_nll']:.4f}")

## Step 7: Save Results

In [None]:
# Save detailed results
output_data = {
    'config': {
        'model_name': config.model_name,
        'num_samples': config.num_samples,
        'baseline_prompt': config.baseline_prompt,
        'surrogate_prefix_template': config.surrogate_prefix_template,
        'seed': config.seed,
    },
    'analysis': analysis,
    'results': results
}

with open('experiment_results.json', 'w') as f:
    json.dump(output_data, f, indent=2, default=str)

print("Results saved to: experiment_results.json")

## Appendix: Ablation Studies

Run additional experiments with different configurations.

In [None]:
def run_ablation_different_prompts(
    samples: List[Dict],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    config: ExperimentConfig,
    prompt_variants: Dict[str, str]
) -> Dict[str, Dict]:
    """
    Run ablation study with different surrogate prompt templates.
    
    Args:
        samples: Dataset samples
        model: Language model
        tokenizer: Tokenizer
        config: Base configuration
        prompt_variants: Dict mapping variant name to surrogate prefix template
    
    Returns:
        Dict mapping variant name to analysis results
    """
    ablation_results = {}
    
    # Use a subset for ablation
    ablation_samples = samples[:20]
    
    for variant_name, template in prompt_variants.items():
        print(f"\nRunning ablation: {variant_name}")
        
        # Temporarily modify config
        original_template = config.surrogate_prefix_template
        config.surrogate_prefix_template = template
        
        variant_results = run_experiment(ablation_samples, model, tokenizer, config)
        variant_analysis = analyze_results(variant_results)
        
        ablation_results[variant_name] = {
            'template': template,
            'analysis': variant_analysis
        }
        
        # Restore original
        config.surrogate_prefix_template = original_template
    
    return ablation_results

In [None]:
# Example ablation: different prompt phrasings
# Uncomment to run

# prompt_variants = {
#     'query_focused': "System: Answer the following query: {surrogate}\n\nDocument:\n",
#     'question_focused': "System: Answer this question: {surrogate}\n\nText:\n",
#     'search_focused': "System: User searched for: {surrogate}\n\nResult:\n",
#     'minimal': "{surrogate}\n\n",
# }

# ablation_results = run_ablation_different_prompts(
#     samples, model, tokenizer, config, prompt_variants
# )

# for name, data in ablation_results.items():
#     print(f"\n{name}: Win Rate = {data['analysis']['win_rate']*100:.1f}%, Mean Delta = {data['analysis']['mean_delta']:.4f}")

---

## Notes on Interpretation

**If Win Rate > 50% and p < 0.05:**
- Surrogate priming successfully "steers" the KV cache toward the query space
- The document embeddings become more aligned with typical query patterns

**If Win Rate < 50% or not significant:**
- The generic prompt may already provide sufficient context
- Surrogate generation quality may be a bottleneck
- The effect may be domain-dependent

**Potential confounds:**
- Surrogate length differs from baseline prompt length (position effects)
- Surrogate quality varies by passage complexity
- Some queries may be too different from any reasonable surrogate