# Production Simulation Diagnostic: Why Don't Surrogates Help?

## Background

Previous experiments showed that surrogate-primed KV caches **hurt** performance in most cases:
- Win rate vs baseline: only 9% for generated surrogates
- Surrogates help ONLY when baseline is poor (NLL > 3.0)
- When baseline is good, surrogates make predictions 3-4x worse

## Definitions

- **Baseline**: KV cache built from the document in isolation (no surrogate, no priming).
  This is `Document:\n{document}` processed through the model alone.
- **Surrogate condition**: Any approach that incorporates a surrogate query during KV cache construction.
  The surrogate may be generated, static, or even the actual query (perfect surrogate test).

### Terminology Note: "Oracle" vs "Perfect Surrogate"

In **other notebooks** (e.g., `production_simulation_experiment.ipynb`), "oracle" refers to
**best-of-N hindsight selection**: given N surrogate candidates, pick the one that happened to
produce the lowest NLL. This is oracle *selection*.

In **this notebook**, we test using the **actual query** in the surrogate slot. This is a
different concept — it's a *perfect surrogate* (the ideal surrogate you could generate if you
knew the query in advance). We call this `perfect_surrogate` to avoid confusion with oracle
selection used elsewhere.

## Hypotheses to Test

1. **Competing Query Signal**: The surrogate acts as a distractor that competes with the actual query
2. **Template Framing Issue**: The specific phrasing of the surrogate template may be problematic
3. **Attention Dilution**: Extra tokens dilute attention away from key document content

## Experiments in This Notebook

All experiments compare against the **baseline (document in isolation)**:

1. **Baseline Reproduction**: Confirm original results using shared library
2. **Perfect Surrogate Test**: Use the ACTUAL query as surrogate (upper bound)
3. **Template Ablation**: Test different surrogate framing approaches
4. **Truncated Cache Comparison**: Does removing surrogate at inference help?
5. **Poor Baseline Analysis**: Deep dive into cases where surrogates helped

## Setup

In [None]:
# Install dependencies
!pip install transformers torch datasets tqdm scipy bitsandbytes accelerate matplotlib sentence-transformers -q

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import json
from typing import Dict, List, Tuple, Any
from dataclasses import dataclass
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Import from our shared library
from lib import (
    ExperimentConfig,
    build_kv_cache,
    score_answer_with_cache,
    build_truncated_kv_cache,
    TOP_5_SURROGATE_TEMPLATES,
    STATIC_SURROGATE_QUERIES,
    generate_surrogate_with_template,
    generate_all_5_surrogates,
    compute_similarity,
    count_words,
)

print("Imports complete.")

In [None]:
# Configuration
config = ExperimentConfig(
    num_samples=500,  # Enough for statistical power, fast enough for diagnostics
    min_passage_words=50,
    max_passage_words=300,
)

print(f"Device: {config.device}")
print(f"Model: {config.model_name}")
print(f"Samples: {config.num_samples}")

In [None]:
# Load models
torch.manual_seed(config.seed)
np.random.seed(config.seed)

print("Loading language model...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.eval()
print(f"Language model loaded: {model.num_parameters():,} parameters")

print(f"\nLoading embedding model: {config.embedding_model_name}")
embed_model = SentenceTransformer(config.embedding_model_name)
print("Embedding model loaded.")

In [None]:
# Load dataset
print("Loading MS MARCO dataset...")
full_dataset = load_dataset(
    config.dataset_name,
    config.dataset_config,
    split=config.dataset_split,
)
print(f"Dataset loaded: {len(full_dataset)} samples")

# Filter samples
def load_samples_with_answers(dataset, config):
    """Load samples with passage, query, and answer."""
    filtered = []
    for item in tqdm(dataset, desc="Filtering"):
        passages = item.get('passages', {})
        passage_texts = passages.get('passage_text', [])
        is_selected = passages.get('is_selected', [])
        query = item.get('query', '')
        answers = item.get('answers', [])
        well_formed = item.get('wellFormedAnswers', [])
        
        if not passage_texts or not query:
            continue
        
        # Get answer
        if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
            answer = well_formed[0]
        elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
            answer = answers[0]
        else:
            continue
        
        # Find valid passage
        for i, passage in enumerate(passage_texts):
            wc = count_words(passage)
            if config.min_passage_words <= wc <= config.max_passage_words:
                if is_selected and i < len(is_selected) and is_selected[i] == 1:
                    filtered.append({'passage': passage, 'query': query, 'answer': answer})
                    break
        
        if len(filtered) >= config.num_samples * 2:
            break
    
    np.random.shuffle(filtered)
    return filtered[:config.num_samples]

samples = load_samples_with_answers(full_dataset, config)
print(f"\nLoaded {len(samples)} samples")
print(f"Sample query: {samples[0]['query']}")

## Experiment 1: Template Ablation Study

Compare the **baseline (document in isolation)** against multiple surrogate conditions.

Each surrogate condition uses a different template for how the surrogate is incorporated into the context.
All are compared back to the same baseline: the document's KV cache computed in isolation.

In [None]:
def evaluate_single_sample(
    sample: Dict,
    model,
    tokenizer,
    embed_model,
    config: ExperimentConfig,
    surrogate_templates: Dict[str, str],
) -> Dict:
    """
    Evaluate a sample: build baseline cache (document in isolation) and
    compare against multiple surrogate conditions.
    
    The baseline is ALWAYS the document's KV cache computed in isolation:
        "Document:\n{document}"
    
    Each surrogate_template is a different way of incorporating a surrogate
    into the context window. Templates can use:
        {document} - the document text
        {surrogate} - a generated surrogate query
        {query} - the actual query (for oracle tests)
    """
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    query_prompt = config.query_template.format(query=query)
    
    results = {
        'query': query,
        'answer': answer[:50],
        'passage_len': count_words(passage),
    }
    
    # ===== BASELINE: Document KV cache in isolation =====
    baseline_context = config.baseline_cache_template.format(document=passage)
    baseline_len, baseline_cache = build_kv_cache(baseline_context, model, tokenizer, config)
    baseline_nll = score_answer_with_cache(
        baseline_cache, baseline_len, query_prompt, answer, model, tokenizer, config
    )
    results['baseline_nll'] = baseline_nll
    results['baseline_len'] = baseline_len
    
    # ===== SURROGATE CONDITIONS =====
    # Generate a surrogate once and reuse across templates that need it
    surrogate = generate_surrogate_with_template(
        passage,
        TOP_5_SURROGATE_TEMPLATES['target_question']['prompt'],
        model, tokenizer, config
    )
    results['generated_surrogate'] = surrogate
    results['surrogate_similarity'] = compute_similarity(surrogate, query, embed_model)
    
    for name, template in surrogate_templates.items():
        # Format context based on what placeholders the template uses
        if '{surrogate}' in template and '{document}' in template:
            context = template.format(document=passage, surrogate=surrogate)
        elif '{query}' in template and '{document}' in template:
            context = template.format(document=passage, query=query)
        elif '{document}' in template:
            context = template.format(document=passage)
        else:
            context = template
        
        # Build cache and score
        cache_len, cache = build_kv_cache(context, model, tokenizer, config)
        nll = score_answer_with_cache(
            cache, cache_len, query_prompt, answer, model, tokenizer, config
        )
        results[f'{name}_nll'] = nll
        results[f'{name}_len'] = cache_len
    
    return results

In [None]:
# =====================================================================
# SURROGATE CONDITIONS TO TEST
# =====================================================================
# Each is compared against the BASELINE (document KV cache in isolation).
# The baseline is always: "Document:\n{document}"
#
# NOTE: "perfect_surrogate" uses the ACTUAL query in the surrogate slot.
# This is different from "oracle selection" in other notebooks, which means
# picking the best-of-N surrogates in hindsight. See terminology note above.

SURROGATE_CONDITIONS = {
    # ----- Original surrogate template (from production simulation) -----
    'surr_original': (
        'This document may be relevant to queries like: {surrogate}\n\n'
        'Document:\n{document}'
    ),
    
    # ----- Perfect surrogate: use the ACTUAL query (upper bound on surrogate quality) -----
    'perfect_surrogate': (
        'This document may be relevant to queries like: {query}\n\n'
        'Document:\n{document}'
    ),
    
    # ----- Simpler framing (test if the verbose template is the problem) -----
    'surr_simple_prefix': 'Query hint: {surrogate}\n\n{document}',
    'surr_minimal': '{surrogate}\n\n{document}',
    
    # ----- Position test: surrogate AFTER document -----
    'surr_suffix': 'Document:\n{document}\n\nRelevant queries: {surrogate}',
    
    # ----- Control: bare document (no "Document:" label at all) -----
    'bare_doc_no_label': '{document}',
}

print("Baseline: Document KV cache in isolation")
print("  Template: \"Document:\\n{document}\"")
print()
print("Surrogate conditions to compare against baseline:")
for name, template in SURROGATE_CONDITIONS.items():
    print(f"  {name}: \"{template[:60]}...\"")

In [None]:
# Run diagnostic experiment
print(f"Running diagnostic experiment on {len(samples)} samples...")
print(f"Baseline: document KV cache in isolation")
print(f"Testing {len(SURROGATE_CONDITIONS)} surrogate conditions against baseline")

diagnostic_results = []

for i, sample in enumerate(tqdm(samples, desc="Evaluating")):
    try:
        result = evaluate_single_sample(
            sample, model, tokenizer, embed_model, config,
            SURROGATE_CONDITIONS
        )
        diagnostic_results.append(result)
        
        if (i + 1) % 50 == 0:
            recent = diagnostic_results[-50:]
            baseline_avg = np.mean([r['baseline_nll'] for r in recent])
            perfect_avg = np.mean([r['perfect_surrogate_nll'] for r in recent])
            print(f"  [{i+1}] Baseline: {baseline_avg:.3f}, Perfect surrogate: {perfect_avg:.3f}")
            
    except Exception as e:
        print(f"Error on sample {i}: {e}")
        continue

print(f"\nCompleted {len(diagnostic_results)} samples")

In [None]:
# Analyze results
print("=" * 80)
print("ALL CONDITIONS vs BASELINE (document in isolation)")
print("=" * 80)

condition_names = list(SURROGATE_CONDITIONS.keys())

print(f"\n{'Condition':<25} {'Mean NLL':>10} {'Std':>10} {'vs Baseline':>12} {'Win Rate':>10}")
print("-" * 70)

baseline_nlls = np.array([r['baseline_nll'] for r in diagnostic_results])

# Print baseline first
print(f"{'BASELINE (doc only)':<25} {np.mean(baseline_nlls):>10.3f} {np.std(baseline_nlls):>10.3f} {'---':>12} {'---':>10}")
print("-" * 70)

for name in condition_names:
    nlls = np.array([r[f'{name}_nll'] for r in diagnostic_results])
    delta = baseline_nlls - nlls  # positive = this condition beats baseline
    win_rate = np.mean(delta > 0)
    
    mean_delta = np.mean(delta)
    sign = '+' if mean_delta > 0 else ''
    
    print(f"{name:<25} {np.mean(nlls):>10.3f} {np.std(nlls):>10.3f} {sign}{mean_delta:>11.3f} {win_rate*100:>9.1f}%")

print("\n(Positive 'vs Baseline' = better than baseline, higher win rate = better)")

In [None]:
# Statistical significance tests
print("\n" + "=" * 80)
print("STATISTICAL TESTS: Each condition vs BASELINE (doc in isolation)")
print("=" * 80)

for name in condition_names:
    nlls = np.array([r[f'{name}_nll'] for r in diagnostic_results])
    t_stat, p_value = stats.ttest_rel(baseline_nlls, nlls)
    
    direction = "BETTER" if t_stat > 0 else "WORSE"
    sig = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else ""
    
    print(f"{name:<25} t={t_stat:>7.3f}, p={p_value:.4f} {sig} ({direction} than baseline)")

## Experiment 2: Analysis by Baseline Quality

Do surrogates help more when the baseline is poor?

In [None]:
# Bin analysis by baseline quality
print("=" * 80)
print("SURROGATE PERFORMANCE BY BASELINE QUALITY")
print("(Baseline = document KV cache in isolation)")
print("=" * 80)

bins = [
    (0, 0.1, "Perfect (0-0.1)"),
    (0.1, 0.5, "Excellent (0.1-0.5)"),
    (0.5, 1.5, "Good (0.5-1.5)"),
    (1.5, 3.0, "Medium (1.5-3.0)"),
    (3.0, float('inf'), "Poor (3.0+)"),
]

for name in ['perfect_surrogate', 'surr_original', 'surr_simple_prefix', 'bare_doc_no_label']:
    print(f"\n--- {name} vs Baseline ---")
    print(f"{'Baseline Quality':<20} {'N':>5} {'Avg Delta':>12} {'Win Rate':>10}")
    print("-" * 50)
    
    for low, high, label in bins:
        subset = [r for r in diagnostic_results if low <= r['baseline_nll'] < high]
        if not subset:
            continue
        
        deltas = [r['baseline_nll'] - r[f'{name}_nll'] for r in subset]
        win_rate = np.mean([d > 0 for d in deltas])
        
        print(f"{label:<20} {len(subset):>5} {np.mean(deltas):>+12.3f} {win_rate*100:>9.1f}%")

In [None]:
# Correlation analysis
print("\n" + "=" * 80)
print("CORRELATION: Baseline NLL vs Delta")
print("(Does the surrogate condition help more when the baseline is worse?)")
print("=" * 80)

for name in condition_names:
    deltas = np.array([r['baseline_nll'] - r[f'{name}_nll'] for r in diagnostic_results])
    
    # Filter out near-zero baselines for cleaner correlation
    mask = baseline_nlls > 0.01
    corr = np.corrcoef(baseline_nlls[mask], deltas[mask])[0, 1]
    
    interpretation = "helps more when baseline worse" if corr > 0.1 else "helps more when baseline better" if corr < -0.1 else "no clear pattern"
    print(f"{name:<25} r={corr:>6.3f} ({interpretation})")

## Experiment 3: Perfect Surrogate Deep Dive

The "perfect surrogate" test places the **actual query** into the surrogate slot.
If even this doesn't beat the baseline (document in isolation),
the problem is fundamental to prepending any query text, not surrogate quality.

Note: This is different from "oracle selection" used in `production_simulation_experiment.ipynb`,
which refers to picking the best-of-N surrogate candidates in hindsight.

In [None]:
# Detailed perfect surrogate analysis
print("=" * 80)
print("PERFECT SURROGATE ANALYSIS")
print("=" * 80)
print("\nQuestion: Does placing the ACTUAL query before the document help")
print("compared to the baseline (document in isolation)?")

perfect_deltas = [r['baseline_nll'] - r['perfect_surrogate_nll'] for r in diagnostic_results]

print(f"\nUsing ACTUAL QUERY in surrogate position:")
print(f"  Mean delta vs baseline: {np.mean(perfect_deltas):+.4f}")
print(f"  Win rate vs baseline: {np.mean([d > 0 for d in perfect_deltas])*100:.1f}%")
print(f"  Median delta: {np.median(perfect_deltas):+.4f}")

# Cases where even perfect surrogate hurts
perfect_hurt = [r for r in diagnostic_results if r['baseline_nll'] - r['perfect_surrogate_nll'] < -0.5]
print(f"\nCases where perfect surrogate HURT by >0.5 NLL vs baseline: {len(perfect_hurt)} ({len(perfect_hurt)/len(diagnostic_results)*100:.1f}%)")

if perfect_hurt:
    print("\nExamples where using exact query HURT vs document-in-isolation baseline:")
    for r in sorted(perfect_hurt, key=lambda x: x['baseline_nll'] - x['perfect_surrogate_nll'])[:5]:
        delta = r['baseline_nll'] - r['perfect_surrogate_nll']
        print(f"  Query: {r['query'][:50]}...")
        print(f"    Baseline: {r['baseline_nll']:.3f}, Perfect surrogate: {r['perfect_surrogate_nll']:.3f}, Delta: {delta:+.3f}")

In [None]:
# Compare context lengths
print("\n" + "=" * 80)
print("CONTEXT LENGTH ANALYSIS")
print("(How many extra tokens does each surrogate condition add vs baseline?)")
print("=" * 80)

print(f"\n{'Condition':<25} {'Mean Tokens':>12} {'Extra vs Baseline':>18}")
print("-" * 60)

baseline_lens = np.array([r['baseline_len'] for r in diagnostic_results])
print(f"{'BASELINE (doc only)':<25} {np.mean(baseline_lens):>12.1f} {'---':>18}")

for name in condition_names:
    lens = np.array([r[f'{name}_len'] for r in diagnostic_results])
    extra = np.mean(lens - baseline_lens)
    print(f"{name:<25} {np.mean(lens):>12.1f} {extra:>+17.1f}")

# Correlation between extra tokens and performance hit
print("\nDoes more tokens = worse performance?")
for name in ['perfect_surrogate', 'surr_original']:
    lens = np.array([r[f'{name}_len'] for r in diagnostic_results])
    deltas = np.array([r['baseline_nll'] - r[f'{name}_nll'] for r in diagnostic_results])
    extra_tokens = lens - baseline_lens
    corr = np.corrcoef(extra_tokens, deltas)[0, 1]
    print(f"  {name}: r={corr:.3f} (negative = more tokens hurts more)")

## Experiment 4: Truncated Cache Test

Test if removing the surrogate at inference time helps.
This isolates whether the problem is:
- (A) The surrogate competing with the query at inference time, or
- (B) The surrogate fundamentally not helping document representations

All conditions are compared against the same **baseline: document KV cache in isolation**.

In [None]:
# Run truncated cache experiment on a subset
print("Running truncated cache experiment...")
print("(Surrogate used during cache generation, then REMOVED before inference)")

truncated_results = []
test_samples = samples[:100]  # Smaller subset for this diagnostic

for i, sample in enumerate(tqdm(test_samples, desc="Truncated cache test")):
    try:
        passage = sample['passage']
        query = sample['query']
        answer = sample['answer']
        query_prompt = config.query_template.format(query=query)
        
        # Baseline
        baseline_context = f"Document:\n{passage}"
        baseline_len, baseline_cache = build_kv_cache(baseline_context, model, tokenizer, config)
        baseline_nll = score_answer_with_cache(
            baseline_cache, baseline_len, query_prompt, answer, model, tokenizer, config
        )
        
        # Generate surrogate
        surrogate = generate_surrogate_with_template(
            passage,
            TOP_5_SURROGATE_TEMPLATES['target_question']['prompt'],
            model, tokenizer, config
        )
        
        # Full context (surrogate visible at inference)
        full_context = f"This document may be relevant to queries like: {surrogate}\n\nDocument:\n{passage}"
        full_len, full_cache = build_kv_cache(full_context, model, tokenizer, config)
        full_nll = score_answer_with_cache(
            full_cache, full_len, query_prompt, answer, model, tokenizer, config
        )
        
        # Truncated (surrogate used for generation, then removed)
        trunc_len, trunc_cache = build_truncated_kv_cache(
            surrogate, passage, model, tokenizer, config
        )
        trunc_nll = score_answer_with_cache(
            trunc_cache, trunc_len, query_prompt, answer, model, tokenizer, config
        )
        
        # Perfect surrogate truncated (use actual query, then remove)
        perfect_trunc_len, perfect_trunc_cache = build_truncated_kv_cache(
            query, passage, model, tokenizer, config
        )
        perfect_trunc_nll = score_answer_with_cache(
            perfect_trunc_cache, perfect_trunc_len, query_prompt, answer, model, tokenizer, config
        )
        
        truncated_results.append({
            'query': query,
            'surrogate': surrogate,
            'baseline_nll': baseline_nll,
            'full_nll': full_nll,
            'truncated_nll': trunc_nll,
            'perfect_surrogate_truncated_nll': perfect_trunc_nll,
        })
        
    except Exception as e:
        print(f"Error on sample {i}: {e}")
        continue

print(f"\nCompleted {len(truncated_results)} samples")

In [None]:
# Analyze truncated cache results
print("=" * 80)
print("TRUNCATED CACHE RESULTS")
print("(All conditions compared against baseline: document in isolation)")
print("=" * 80)

conditions = [
    ('baseline', 'baseline_nll', 'BASELINE: Document in isolation'),
    ('full', 'full_nll', 'Full context (surrogate visible)'),
    ('truncated', 'truncated_nll', 'Truncated (surrogate used then removed)'),
    ('perfect_trunc', 'perfect_surrogate_truncated_nll', 'Perfect surr. truncated (query used then removed)'),
]

print(f"\n{'Condition':<45} {'Mean NLL':>10} {'vs Baseline':>12} {'Win Rate':>10}")
print("-" * 80)

baseline_nlls_t = np.array([r['baseline_nll'] for r in truncated_results])

for name, key, desc in conditions:
    nlls = np.array([r[key] for r in truncated_results])
    delta = baseline_nlls_t - nlls
    win_rate = np.mean(delta > 0)
    
    if name == 'baseline':
        print(f"{desc:<45} {np.mean(nlls):>10.3f} {'---':>12} {'---':>10}")
    else:
        print(f"{desc:<45} {np.mean(nlls):>10.3f} {np.mean(delta):>+11.3f} {win_rate*100:>9.1f}%")

# Key comparison: full vs truncated
print("\n--- Key Comparison ---")
full_nlls = np.array([r['full_nll'] for r in truncated_results])
trunc_nlls = np.array([r['truncated_nll'] for r in truncated_results])

print(f"Does removing surrogate at inference help?")
print(f"  Truncated better than full context: {np.mean(trunc_nlls < full_nlls)*100:.1f}%")
print(f"  Mean improvement over full: {np.mean(full_nlls - trunc_nlls):+.3f}")

t_stat, p_value = stats.ttest_rel(full_nlls, trunc_nlls)
print(f"  t={t_stat:.3f}, p={p_value:.4f}")

## Experiment 5: Deep Dive on Poor Baseline Cases

Analyze the cases where baseline performs poorly - these are where surrogates might actually help.

In [None]:
# Identify poor baseline cases
poor_baseline = [r for r in diagnostic_results if r['baseline_nll'] > 3.0]

print("=" * 80)
print(f"POOR BASELINE CASES (NLL > 3.0): {len(poor_baseline)} samples")
print("(These are cases where the document alone does not predict the answer well)")
print("=" * 80)

if poor_baseline:
    print(f"\n{'Condition':<25} {'Mean NLL':>10} {'vs Baseline':>12} {'Win Rate':>10}")
    print("-" * 60)
    
    baseline_poor = np.array([r['baseline_nll'] for r in poor_baseline])
    print(f"{'BASELINE (doc only)':<25} {np.mean(baseline_poor):>10.3f} {'---':>12} {'---':>10}")
    
    for name in condition_names:
        nlls = np.array([r[f'{name}_nll'] for r in poor_baseline])
        delta = baseline_poor - nlls
        win_rate = np.mean(delta > 0)
        
        print(f"{name:<25} {np.mean(nlls):>10.3f} {np.mean(delta):>+11.3f} {win_rate*100:>9.1f}%")

    # Show examples
    print("\n--- Examples of Poor Baseline Cases ---")
    for r in sorted(poor_baseline, key=lambda x: x['baseline_nll'], reverse=True)[:5]:
        print(f"\nQuery: {r['query']}")
        print(f"  Baseline NLL: {r['baseline_nll']:.3f}")
        print(f"  Perfect surrogate NLL: {r['perfect_surrogate_nll']:.3f} (delta: {r['baseline_nll']-r['perfect_surrogate_nll']:+.3f})")
        print(f"  Surrogate NLL: {r['surr_original_nll']:.3f} (delta: {r['baseline_nll']-r['surr_original_nll']:+.3f})")

## Summary and Conclusions

In [None]:
print("=" * 80)
print("SUMMARY OF FINDINGS")
print("(Baseline = document KV cache computed in isolation)")
print("=" * 80)

# Key metrics
baseline_mean = np.mean([r['baseline_nll'] for r in diagnostic_results])
perfect_mean = np.mean([r['perfect_surrogate_nll'] for r in diagnostic_results])
surrogate_mean = np.mean([r['surr_original_nll'] for r in diagnostic_results])
bare_mean = np.mean([r['bare_doc_no_label_nll'] for r in diagnostic_results])

perfect_win = np.mean([r['baseline_nll'] > r['perfect_surrogate_nll'] for r in diagnostic_results])
surrogate_win = np.mean([r['baseline_nll'] > r['surr_original_nll'] for r in diagnostic_results])

print(f"\n1. BASELINE (document in isolation)")
print(f"   Template: \"Document:\\n{{document}}\"")
print(f"   Mean NLL: {baseline_mean:.3f}")
print(f"   Control (bare doc, no label): {bare_mean:.3f}")

print(f"\n2. PERFECT SURROGATE (actual query placed in surrogate slot)")
print(f"   (NOT oracle selection — this tests the best possible surrogate content)")
print(f"   Mean NLL: {perfect_mean:.3f}")
print(f"   Win rate vs baseline: {perfect_win*100:.1f}%")
if perfect_win > 0.5:
    print(f"   -> Perfect surrogate beats baseline. Surrogate quality is the bottleneck.")
else:
    print(f"   -> Even the actual query doesn't beat baseline.")
    print(f"   -> Problem is fundamental to the approach, not surrogate quality.")

print(f"\n3. GENERATED SURROGATE (original template)")
print(f"   Mean NLL: {surrogate_mean:.3f}")
print(f"   Win rate vs baseline: {surrogate_win*100:.1f}%")

print(f"\n4. BY BASELINE QUALITY")
poor_wins = [r for r in diagnostic_results if r['baseline_nll'] > 3.0 and r['baseline_nll'] > r['perfect_surrogate_nll']]
good_wins = [r for r in diagnostic_results if r['baseline_nll'] < 1.0 and r['baseline_nll'] > r['perfect_surrogate_nll']]
poor_total = len([r for r in diagnostic_results if r['baseline_nll'] > 3.0])
good_total = len([r for r in diagnostic_results if r['baseline_nll'] < 1.0])
print(f"   Poor baseline (>3.0): Perfect surr. wins {len(poor_wins)}/{poor_total} ({len(poor_wins)/max(poor_total,1)*100:.1f}%)")
print(f"   Good baseline (<1.0): Perfect surr. wins {len(good_wins)}/{good_total} ({len(good_wins)/max(good_total,1)*100:.1f}%)")

if truncated_results:
    print(f"\n5. TRUNCATED CACHE (surrogate removed after cache generation)")
    trunc_win = np.mean([r['baseline_nll'] > r['truncated_nll'] for r in truncated_results])
    full_win = np.mean([r['baseline_nll'] > r['full_nll'] for r in truncated_results])
    print(f"   Full context win rate vs baseline: {full_win*100:.1f}%")
    print(f"   Truncated cache win rate vs baseline: {trunc_win*100:.1f}%")
    if trunc_win > full_win:
        print(f"   -> Removing surrogate at inference helps!")
    else:
        print(f"   -> Truncation does not help vs baseline")

In [None]:
# Save results
output = {
    'config': {
        'model_name': config.model_name,
        'num_samples': len(diagnostic_results),
        'conditions_tested': list(SURROGATE_CONDITIONS.keys()),
        'baseline': 'Document:\n{document} (document KV cache in isolation)',
        'terminology_note': (
            'perfect_surrogate = actual query used in surrogate slot (upper bound on surrogate quality). '
            'This is different from "oracle" in production_simulation_experiment.ipynb, which means '
            'best-of-N surrogate selection in hindsight.'
        ),
    },
    'diagnostic_results': diagnostic_results,
    'truncated_results': truncated_results,
    'summary': {
        'baseline_mean_nll': baseline_mean,
        'perfect_surrogate_mean_nll': perfect_mean,
        'perfect_surrogate_win_rate': perfect_win,
        'surrogate_mean_nll': surrogate_mean,
        'surrogate_win_rate': surrogate_win,
    }
}

with open('diagnostic_experiment_results.json', 'w') as f:
    json.dump(output, f, indent=2, default=str)

print("Results saved to: diagnostic_experiment_results.json")

In [None]:
# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("All conditions vs Baseline (document in isolation)", fontsize=13, fontweight='bold')

# Plot 1: NLL by condition
ax1 = axes[0, 0]
all_names = ['baseline'] + condition_names
all_means = [np.mean(baseline_nlls)] + [np.mean([r[f'{name}_nll'] for r in diagnostic_results]) for name in condition_names]
all_stds = [np.std(baseline_nlls)] + [np.std([r[f'{name}_nll'] for r in diagnostic_results]) for name in condition_names]
colors = ['green'] + ['steelblue'] * len(condition_names)
x = range(len(all_names))
ax1.bar(x, all_means, yerr=all_stds, capsize=3, alpha=0.7, color=colors)
ax1.set_xticks(x)
ax1.set_xticklabels(all_names, rotation=45, ha='right', fontsize=8)
ax1.set_ylabel('Mean NLL')
ax1.set_title('NLL by Condition (lower is better)')
ax1.axhline(y=all_means[0], color='r', linestyle='--', label='Baseline')
ax1.legend()

# Plot 2: Win rate vs baseline
ax2 = axes[0, 1]
win_rates = []
for name in condition_names:
    nlls = np.array([r[f'{name}_nll'] for r in diagnostic_results])
    win_rates.append(np.mean(baseline_nlls > nlls) * 100)
ax2.bar(range(len(condition_names)), win_rates, alpha=0.7, color='steelblue')
ax2.set_xticks(range(len(condition_names)))
ax2.set_xticklabels(condition_names, rotation=45, ha='right', fontsize=8)
ax2.set_ylabel('Win Rate vs Baseline (%)')
ax2.set_title('Win Rate vs Baseline (doc in isolation)')
ax2.axhline(y=50, color='r', linestyle='--', label='50% (no effect)')
ax2.legend()

# Plot 3: Baseline NLL vs Perfect Surrogate Delta
ax3 = axes[1, 0]
perfect_deltas = [r['baseline_nll'] - r['perfect_surrogate_nll'] for r in diagnostic_results]
ax3.scatter(baseline_nlls, perfect_deltas, alpha=0.5, s=10)
ax3.axhline(y=0, color='r', linestyle='--')
ax3.set_xlabel('Baseline NLL (doc in isolation)')
ax3.set_ylabel('Delta (Baseline - Perfect Surrogate)')
ax3.set_title('Does perfect surrogate help more when baseline is worse?\n(positive = perfect surrogate better than baseline)')

# Plot 4: Distribution of deltas vs baseline
ax4 = axes[1, 1]
surrogate_deltas = [r['baseline_nll'] - r['surr_original_nll'] for r in diagnostic_results]
ax4.hist(perfect_deltas, bins=30, alpha=0.5, label='Perfect Surr. vs Baseline', color='blue')
ax4.hist(surrogate_deltas, bins=30, alpha=0.5, label='Gen. Surrogate vs Baseline', color='orange')
ax4.axvline(x=0, color='r', linestyle='--')
ax4.set_xlabel('Delta vs Baseline (positive = beats baseline)')
ax4.set_ylabel('Count')
ax4.set_title('Distribution of Improvement over Baseline')
ax4.legend()

plt.tight_layout()
plt.savefig('diagnostic_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: diagnostic_results.png")