# Experiment 11: Surrogate Quality Gradient & Ranking Evaluation

## Motivation

Across experiments 01-10, we have established:

1. **Oracle surrogates work dramatically** (NLL ~0.4 vs ~1.5 baseline) — the KV cache mechanism CAN be directed
2. **Generated surrogates show no semantic signal** (r=0.924 correlation with shuffled text, Exp 06)
3. **Any prefix helps equally** regardless of content — the benefit is positional/structural, not semantic
4. **Truncation mechanics are correct** after bug fixes in Exp 05 (RoPE dimension pairing, BOS preservation, BPE boundaries)
5. **Value contamination is the channel** but we can't isolate layers cleanly (Exp 10)

The fundamental unresolved question: **Is the lack of semantic signal because generated surrogates are too low quality (~0.66 similarity), or because the KV cache mechanism fundamentally cannot transmit semantic information through value contamination?**

## Approach: Bridge the Oracle Gap with Real Queries

Previous experiments used model-generated surrogates (similarity ~0.66) or the oracle query itself (similarity 1.0). This leaves a massive unexplored gap. We fill it by using **real MS MARCO queries** as surrogates — queries that were asked about OTHER passages but happen to be semantically similar to the target query at varying degrees.

This is also closer to the production scenario: in an ad-serving system, we'd have a corpus of historical queries we could use as surrogates.

## Three Investigations

### Investigation A: Quality Gradient (N=300)
For each sample, find real queries at 5 similarity levels (0.0-0.3, 0.3-0.5, 0.5-0.7, 0.7-0.85, 0.85-1.0) and use them as surrogates. Plus oracle and bare baseline. This tells us whether there's a quality threshold where semantic priming starts working.

### Investigation B: Ranking Task (N=300)
The NLL metric measures absolute quality. But in ad serving, what matters is **ranking**: does priming help the model rank the relevant ad higher among distractors? For each passage, we score 5 queries (1 correct + 4 distractors) and measure MRR/Hit@1.

### Investigation C: Same-Passage Surrogates (N=200)
MS MARCO has passages that were retrieved for multiple different queries. Use one query as surrogate, test with another. This is the most realistic production scenario: we know what queries a document has been relevant to in the past.

## Key Controls
- All caches use truncation + RoPE correction (production-relevant setting)
- BPE-matched bare/truncated comparison using `build_matched_bare_and_truncated` pattern
- Irrelevant real query baseline (same as Exp 06 random passage control, but real queries)
- All conditions run on same samples for paired comparison
- Bare document baseline uses no framing (confirmed better than "Document:\n" in Exp 06)

## Bug Log (from initial run)

### Bug 11.1: BPE Token Mismatch (CRITICAL — same class as Exp 05 bug)
- `build_bare_cache_no_framing()` tokenized the passage independently
- `build_truncated_cache_from_prefix()` tokenized `prefix + passage` together, producing different BPE tokens at the join boundary
- `build_matched_bare_and_truncated()` was defined correctly but NEVER CALLED in eval loops
- **Fix**: All eval loops now use `build_matched_bare_and_truncated()` which extracts document token IDs from the concatenated encoding and builds bare cache from those exact IDs
- **Impact**: Bare and truncated caches were being compared on different token sequences, invalidating all delta measurements

### Bug 11.2: Cache Mutation in Investigation B (CRITICAL)
- `score_answer_with_cache()` extends the cache in-place via `use_cache=True`
- Investigation B scored 5 queries against the same cache object sequentially
- Queries 2-5 saw a cache contaminated by previous queries' KV entries
- **Fix**: Deep-copy (`copy.deepcopy()`) the cache before each `score_answer_with_cache()` call
- **Impact**: Ranking results for bare/oracle/medium caches were all corrupted

### Bug 11.3: Variable Cache Lengths Across Conditions
- Different prefix lengths cause different BPE splits, leading to different `keep_len` values
- **Fix**: `build_matched_bare_and_truncated()` asserts `bare_len == keep_len` for each condition; per-condition matched pairs ensure fair comparison

In [None]:
import sys
import os
import json
import copy
import time
import datetime
from typing import Dict, List, Any, Optional, Tuple

import torch
import numpy as np
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 120

sys.path.insert(0, '.')

from lib import (
    ExperimentConfig,
    build_kv_cache,
    score_answer_with_cache,
    build_truncated_kv_cache_corrected,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    load_evaluation_samples,
    load_ms_marco,
    _ensure_dynamic_cache,
)
from lib.kv_cache import _get_cache_keys, _get_cache_values

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# Configuration
# ============================================================

config = ExperimentConfig(
    num_samples=5000,  # Load more than we need for query pool
    min_passage_words=50,
    max_passage_words=300,
    seed=42,
)

N_INVESTIGATION_A = 300
N_INVESTIGATION_B = 300
N_INVESTIGATION_C = 200

# Similarity bins for Investigation A
SIMILARITY_BINS = [
    (0.00, 0.30, 'very_low'),
    (0.30, 0.50, 'low'),
    (0.50, 0.70, 'medium'),
    (0.70, 0.85, 'high'),
    (0.85, 1.00, 'very_high'),
]

torch.manual_seed(config.seed)
np.random.seed(config.seed)

print(f"Investigation A: {N_INVESTIGATION_A} samples x {len(SIMILARITY_BINS)+2} conditions")
print(f"Investigation B: {N_INVESTIGATION_B} samples x ranking evaluation")
print(f"Investigation C: {N_INVESTIGATION_C} samples (same-passage surrogates)")

In [None]:
# ============================================================
# Load Model
# ============================================================

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

print(f"Loading {config.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
)
model.eval()
print(f"Model loaded. Layers: {model.config.num_hidden_layers}")

# Load embedding model for similarity computation
print("Loading embedding model...")
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded.")

In [None]:
# ============================================================
# Load Dataset and Build Query Pool
# ============================================================

print("Loading MS MARCO dataset...")
dataset = load_dataset(
    config.dataset_name,
    config.dataset_config,
    split=config.dataset_split,
    trust_remote_code=True
)
print(f"Dataset loaded: {len(dataset)} samples")

# Load evaluation samples (with answers)
all_samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Loaded {len(all_samples)} evaluation samples with answers")

# Build query pool from the FULL dataset for maximum coverage
# in high-similarity bins
print("\nBuilding query pool (full dataset)...")
query_pool = []
seen_queries = set()
for item in tqdm(dataset, desc="Extracting queries"):
    q = item.get('query', '').strip()
    if q and q not in seen_queries and len(q) > 10:
        query_pool.append(q)
        seen_queries.add(q)

print(f"Query pool size: {len(query_pool)}")

# Embed all pool queries
print("Embedding query pool...")
pool_embeddings = embed_model.encode(query_pool, show_progress_bar=True, batch_size=256)
print(f"Pool embeddings shape: {pool_embeddings.shape}")

In [None]:
# ============================================================
# Build Same-Passage Query Pairs for Investigation C
# ============================================================

# Find passages that appear with multiple different queries
print("Finding passages with multiple queries...")
passage_to_queries = {}  # passage_text -> list of (query, answer) tuples

for item in tqdm(dataset, desc="Building passage-query map"):
    passages = item.get('passages', {})
    passage_texts = passages.get('passage_text', [])
    is_selected = passages.get('is_selected', [])
    query = item.get('query', '').strip()
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])

    if not query:
        continue

    # Get best answer
    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]

    if answer is None:
        continue

    for i, passage in enumerate(passage_texts):
        if is_selected and i < len(is_selected) and is_selected[i] == 1:
            word_count = len(passage.split())
            if config.min_passage_words <= word_count <= config.max_passage_words:
                key = passage.strip()[:200]  # Use prefix as key to handle near-dupes
                if key not in passage_to_queries:
                    passage_to_queries[key] = []
                passage_to_queries[key].append({
                    'passage': passage,
                    'query': query,
                    'answer': answer
                })
            break

# Filter to passages with 2+ distinct queries
multi_query_passages = {
    k: v for k, v in passage_to_queries.items()
    if len(v) >= 2 and len(set(item['query'] for item in v)) >= 2
}

print(f"Passages with 2+ queries: {len(multi_query_passages)}")
print(f"Total query pairs available: {sum(len(v) * (len(v)-1) for v in multi_query_passages.values())}")

# Build Investigation C samples: (passage, surrogate_query, test_query, test_answer)
inv_c_samples = []
for key, items in multi_query_passages.items():
    # Use first query as surrogate, second as test
    for i in range(len(items)):
        for j in range(len(items)):
            if i != j and items[i]['query'] != items[j]['query']:
                inv_c_samples.append({
                    'passage': items[i]['passage'],
                    'surrogate_query': items[i]['query'],
                    'test_query': items[j]['query'],
                    'test_answer': items[j]['answer'],
                })
    if len(inv_c_samples) >= N_INVESTIGATION_C * 3:
        break

np.random.shuffle(inv_c_samples)
inv_c_samples = inv_c_samples[:N_INVESTIGATION_C]
print(f"Investigation C samples: {len(inv_c_samples)}")

# Compute similarity between surrogate and test queries
if inv_c_samples:
    surr_qs = [s['surrogate_query'] for s in inv_c_samples]
    test_qs = [s['test_query'] for s in inv_c_samples]
    surr_embs = embed_model.encode(surr_qs)
    test_embs = embed_model.encode(test_qs)
    pair_sims = [float(cosine_similarity([surr_embs[i]], [test_embs[i]])[0][0])
                 for i in range(len(inv_c_samples))]
    for i, s in enumerate(inv_c_samples):
        s['surrogate_similarity'] = pair_sims[i]
    print(f"Same-passage pair similarity: mean={np.mean(pair_sims):.3f}, "
          f"std={np.std(pair_sims):.3f}, range=[{np.min(pair_sims):.3f}, {np.max(pair_sims):.3f}]")

In [None]:
# ============================================================
# Helper Functions
# ============================================================

def find_surrogate_at_similarity(
    target_query: str,
    target_embedding: np.ndarray,
    sim_low: float,
    sim_high: float,
    pool_queries: list,
    pool_embs: np.ndarray,
    rng: np.random.RandomState,
) -> Optional[Tuple[str, float]]:
    """Find a real query from the pool within the specified similarity range.
    
    Returns (query, similarity) or None if no match found.
    """
    sims = cosine_similarity([target_embedding], pool_embs)[0]
    mask = (sims >= sim_low) & (sims < sim_high)
    # Exclude exact match
    for idx in np.where(mask)[0]:
        if pool_queries[idx].strip().lower() == target_query.strip().lower():
            mask[idx] = False
    
    candidates = np.where(mask)[0]
    if len(candidates) == 0:
        return None
    
    # Pick one near the middle of the bin for stability
    mid = (sim_low + sim_high) / 2
    distances_to_mid = np.abs(sims[candidates] - mid)
    # Pick from the closest 10% to the bin midpoint (or at least 1)
    n_pick = max(1, len(candidates) // 10)
    best_idxs = candidates[np.argsort(distances_to_mid)[:n_pick]]
    chosen = rng.choice(best_idxs)
    return pool_queries[chosen], float(sims[chosen])


def build_bare_cache_no_framing(passage, model, tokenizer, config):
    """Build bare document cache WITHOUT 'Document:\n' framing.
    
    Exp 06 confirmed bare passage (no framing) is the correct baseline.
    """
    length, cache = build_kv_cache(passage, model, tokenizer, config)
    return length, _ensure_dynamic_cache(cache)


def build_truncated_cache_from_prefix(
    prefix_text: str,
    passage: str,
    model,
    tokenizer,
    config,
) -> Tuple[int, DynamicCache]:
    """Build a truncated+corrected cache from arbitrary prefix text.
    
    Uses the passage directly (no 'Document:\n' framing) to match
    the bare baseline. Handles BPE boundary matching.
    
    Returns: (keep_len, corrected_cache)
    """
    prefix_with_sep = prefix_text.strip() + " "
    
    # Tokenize prefix with BOS to get exact prefix length
    prefix_encoding = tokenizer(
        prefix_with_sep, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    prefix_len = prefix_encoding['input_ids'].shape[1]
    
    # Tokenize full context
    full_context = prefix_with_sep + passage
    full_encoding = tokenizer(
        full_context, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    full_ids = full_encoding['input_ids'].to(config.device)
    doc_len = full_ids.shape[1] - prefix_len
    
    # Generate full KV cache
    with torch.no_grad():
        outputs = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )
    
    # Truncate: BOS + document portion
    truncated = extract_and_truncate_cache_with_bos(
        outputs.past_key_values, doc_len
    )
    keep_len = 1 + doc_len
    
    # RoPE correction
    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated, surrogate_offset, model)
    
    return keep_len, truncated


def build_matched_bare_and_truncated(
    prefix_text: str,
    passage: str,
    model,
    tokenizer,
    config,
) -> Tuple[int, DynamicCache, int, DynamicCache]:
    """Build BPE-matched bare and truncated caches.
    
    Ensures both caches see identical document tokens by extracting
    the document token IDs from the full [prefix + passage] encoding
    and building the bare cache from those exact IDs.
    
    No 'Document:\n' framing (bare passage only).
    
    Returns: (bare_len, bare_cache, trunc_len, trunc_cache)
    """
    prefix_with_sep = prefix_text.strip() + " "
    
    # Tokenize prefix with BOS
    prefix_encoding = tokenizer(
        prefix_with_sep, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    prefix_len = prefix_encoding['input_ids'].shape[1]
    
    # Tokenize full context
    full_context = prefix_with_sep + passage
    full_encoding = tokenizer(
        full_context, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    full_ids = full_encoding['input_ids'].to(config.device)
    doc_len = full_ids.shape[1] - prefix_len
    
    # Extract exact document token IDs
    doc_token_ids = full_ids[:, prefix_len:]  # (1, doc_len)
    bos_id = full_ids[:, :1]
    bare_ids = torch.cat([bos_id, doc_token_ids], dim=1)  # (1, 1+doc_len)
    bare_len = bare_ids.shape[1]
    
    # Build bare cache
    with torch.no_grad():
        bare_out = model(
            input_ids=bare_ids,
            attention_mask=torch.ones_like(bare_ids),
            use_cache=True,
            return_dict=True
        )
    bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
    
    # Build truncated cache from full context
    with torch.no_grad():
        full_out = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )
    
    truncated = extract_and_truncate_cache_with_bos(
        full_out.past_key_values, doc_len
    )
    keep_len = 1 + doc_len
    
    assert bare_len == keep_len, f"Length mismatch: {bare_len} vs {keep_len}"
    
    # RoPE correction
    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated, surrogate_offset, model)
    
    return bare_len, bare_cache, keep_len, truncated

In [None]:
# ============================================================
# Investigation A: Surrogate Quality Gradient
# ============================================================
#
# For each sample, we find real queries at 5 similarity levels,
# build truncated+corrected caches, and compare against bare.
#
# Conditions:
#   1. bare       — no prefix, no framing
#   2. oracle     — actual target query as prefix
#   3-7. sim bins — real queries at varying similarity levels
#   8. irrelevant — a real query with sim < 0.1 (random baseline)
# ============================================================

print("="*80)
print("INVESTIGATION A: SURROGATE QUALITY GRADIENT")
print("="*80)

samples_a = all_samples[:N_INVESTIGATION_A]

# Pre-embed all target queries
print("Embedding target queries...")
target_queries_a = [s['query'] for s in samples_a]
target_embeddings_a = embed_model.encode(target_queries_a, batch_size=128)
print(f"Embedded {len(target_queries_a)} target queries")

# Pre-select surrogates for each sample at each similarity level
print("\nSelecting surrogates at each similarity level...")
rng = np.random.RandomState(config.seed)

sample_surrogates = []  # list of dicts: {bin_name: (query, sim), ...}
skipped_bins = {b[2]: 0 for b in SIMILARITY_BINS}

for i in tqdm(range(len(samples_a)), desc="Selecting surrogates"):
    surr = {}
    for sim_low, sim_high, bin_name in SIMILARITY_BINS:
        result = find_surrogate_at_similarity(
            target_queries_a[i], target_embeddings_a[i],
            sim_low, sim_high,
            query_pool, pool_embeddings, rng
        )
        if result is not None:
            surr[bin_name] = result
        else:
            skipped_bins[bin_name] += 1
    sample_surrogates.append(surr)

print("\nSurrogate coverage per bin:")
for sim_low, sim_high, bin_name in SIMILARITY_BINS:
    n_found = N_INVESTIGATION_A - skipped_bins[bin_name]
    print(f"  {bin_name} ({sim_low:.2f}-{sim_high:.2f}): {n_found}/{N_INVESTIGATION_A} samples")

In [None]:
# ============================================================
# Investigation A: Run Evaluation
# ============================================================
# FIX 11.1: Use build_matched_bare_and_truncated for BPE-matched
# comparison. Oracle prefix defines the reference bare cache.
# Per-bin conditions also use matched pairs for correctness.
# ============================================================

results_a = []
skipped_a = 0
errors_a = 0
start_a = time.time()

CHECKPOINT_PATH_A = 'results/exp11/11_checkpoint_a.json'
start_idx_a = 0
if os.path.exists(CHECKPOINT_PATH_A):
    with open(CHECKPOINT_PATH_A) as f:
        ckpt = json.load(f)
    results_a = ckpt['results']
    skipped_a = ckpt['skipped']
    errors_a = ckpt['errors']
    start_idx_a = ckpt['next_idx']
    print(f"Resumed from checkpoint: {len(results_a)} results, starting at idx {start_idx_a}")

for idx in tqdm(range(start_idx_a, len(samples_a)), desc="Inv A",
                initial=start_idx_a, total=len(samples_a)):
    sample = samples_a[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    
    # Skip short answers
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_a += 1
        continue
    
    query_prompt = config.query_template.format(query=query)
    surrogates = sample_surrogates[idx]
    
    try:
        result = {'idx': idx, 'query': query}
        
        # --- Oracle (also provides BPE-matched bare baseline) ---
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache = build_matched_bare_and_truncated(
            oracle_prefix, passage, model, tokenizer, config
        )
        
        nll_bare = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer,
            model, tokenizer, config
        )
        result['nll_bare'] = nll_bare
        
        nll_oracle = score_answer_with_cache(
            oracle_cache, oracle_len, query_prompt, answer,
            model, tokenizer, config
        )
        result['nll_oracle'] = nll_oracle
        
        # --- Similarity bins: matched pairs for each ---
        for sim_low, sim_high, bin_name in SIMILARITY_BINS:
            if bin_name in surrogates:
                surr_query, surr_sim = surrogates[bin_name]
                surr_prefix = f"This document answers: {surr_query}"
                _, _, surr_len, surr_cache = build_matched_bare_and_truncated(
                    surr_prefix, passage, model, tokenizer, config
                )
                nll_surr = score_answer_with_cache(
                    surr_cache, surr_len, query_prompt, answer,
                    model, tokenizer, config
                )
                result[f'nll_{bin_name}'] = nll_surr
                result[f'sim_{bin_name}'] = surr_sim
                result[f'surr_{bin_name}'] = surr_query
            else:
                result[f'nll_{bin_name}'] = None
                result[f'sim_{bin_name}'] = None
        
        results_a.append(result)
        
    except Exception as e:
        errors_a += 1
        if errors_a <= 5:
            print(f"\n  Error on sample {idx}: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    # Checkpoint
    if len(results_a) % 25 == 0:
        with open(CHECKPOINT_PATH_A, 'w') as f:
            json.dump({
                'results': results_a, 'skipped': skipped_a,
                'errors': errors_a, 'next_idx': idx + 1
            }, f)
        elapsed = time.time() - start_a
        rate = len(results_a) / (elapsed / 60)
        print(f"\n  [{len(results_a)} done | {elapsed/60:.0f}m | {rate:.1f} samples/min]")

elapsed_a = time.time() - start_a
print(f"\nDone. {len(results_a)} evaluated, {skipped_a} skipped, {errors_a} errors.")
print(f"Time: {elapsed_a/60:.1f} min")

In [None]:
# ============================================================
# Investigation A: Analysis
# ============================================================

print("="*80)
print("INVESTIGATION A RESULTS: SURROGATE QUALITY GRADIENT")
print("="*80)

bare_nlls_a = np.array([r['nll_bare'] for r in results_a])
oracle_nlls_a = np.array([r['nll_oracle'] for r in results_a])

# Oracle stats
oracle_deltas = bare_nlls_a - oracle_nlls_a
oracle_wr = np.mean(oracle_deltas > 0) * 100
t_oracle, p_oracle = stats.ttest_rel(bare_nlls_a, oracle_nlls_a)
d_oracle = np.mean(oracle_deltas) / np.std(oracle_deltas, ddof=1)

print(f"\n{'Condition':<25} {'N':>5} {'Mean NLL':>10} {'Win%':>8} {'Delta':>10} {'Cohen d':>10} {'p-value':>10}")
print("-" * 80)
print(f"{'Bare (baseline)':<25} {len(results_a):>5} {np.mean(bare_nlls_a):>10.4f} {'--':>8} {'--':>10} {'--':>10} {'--':>10}")
print(f"{'Oracle':<25} {len(results_a):>5} {np.mean(oracle_nlls_a):>10.4f} {oracle_wr:>7.1f}% {np.mean(oracle_deltas):>+10.4f} {d_oracle:>10.3f} {p_oracle:>10.6f}")

# Per-bin stats
bin_stats = {}
for sim_low, sim_high, bin_name in SIMILARITY_BINS:
    valid = [r for r in results_a if r.get(f'nll_{bin_name}') is not None]
    if len(valid) < 10:
        print(f"{bin_name:<25} {len(valid):>5} -- insufficient data")
        continue
    
    nlls = np.array([r[f'nll_{bin_name}'] for r in valid])
    bares = np.array([r['nll_bare'] for r in valid])
    sims = np.array([r[f'sim_{bin_name}'] for r in valid])
    deltas = bares - nlls
    wr = np.mean(deltas > 0) * 100
    t, p = stats.ttest_rel(bares, nlls)
    d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
    
    bin_stats[bin_name] = {
        'n': len(valid),
        'mean_nll': float(np.mean(nlls)),
        'mean_sim': float(np.mean(sims)),
        'win_rate': float(wr),
        'mean_delta': float(np.mean(deltas)),
        'cohens_d': float(d),
        'p_value': float(p),
    }
    
    label = f"{bin_name} ({np.mean(sims):.2f})"
    print(f"{label:<25} {len(valid):>5} {np.mean(nlls):>10.4f} {wr:>7.1f}% {np.mean(deltas):>+10.4f} {d:>10.3f} {p:>10.6f}")

# Key test: does similarity predict delta?
print("\n" + "="*80)
print("CRITICAL TEST: Does surrogate quality predict improvement?")
print("="*80)

# Aggregate all (sim, delta) pairs across bins
all_sims = []
all_deltas_a = []
for r in results_a:
    for _, _, bin_name in SIMILARITY_BINS:
        if r.get(f'nll_{bin_name}') is not None and r.get(f'sim_{bin_name}') is not None:
            all_sims.append(r[f'sim_{bin_name}'])
            all_deltas_a.append(r['nll_bare'] - r[f'nll_{bin_name}'])

all_sims = np.array(all_sims)
all_deltas_a = np.array(all_deltas_a)

r_pearson, p_pearson = stats.pearsonr(all_sims, all_deltas_a)
r_spearman, p_spearman = stats.spearmanr(all_sims, all_deltas_a)

print(f"\nTotal (sim, delta) pairs: {len(all_sims)}")
print(f"Pearson r = {r_pearson:.4f}, p = {p_pearson:.6f}")
print(f"Spearman rho = {r_spearman:.4f}, p = {p_spearman:.6f}")

if r_pearson > 0.1 and p_pearson < 0.01:
    print("\n>>> SEMANTIC SIGNAL DETECTED: Higher similarity surrogates produce better caches.")
    print("    This would be the first evidence of a genuine semantic priming effect.")
elif r_pearson > 0.05 and p_pearson < 0.05:
    print("\n>>> WEAK SIGNAL: Marginal positive correlation. May warrant larger sample.")
else:
    print("\n>>> NO SEMANTIC SIGNAL: Surrogate quality does not predict improvement.")
    print("    Consistent with prior experiments (Exp 06: r=0.924 with shuffled).")

# Monotonicity test: do the bin means increase with similarity?
if len(bin_stats) >= 3:
    bin_sims_ordered = [bin_stats[b]['mean_sim'] for b in [bn for _, _, bn in SIMILARITY_BINS] if b in bin_stats]
    bin_deltas_ordered = [bin_stats[b]['mean_delta'] for b in [bn for _, _, bn in SIMILARITY_BINS] if b in bin_stats]
    rho_bins, p_bins = stats.spearmanr(bin_sims_ordered, bin_deltas_ordered)
    print(f"\nBin-level monotonicity (Spearman on bin means): rho={rho_bins:.3f}, p={p_bins:.4f}")
    if rho_bins > 0.8 and p_bins < 0.1:
        print("  -> Clear monotonic trend: better surrogates = better caches")
    else:
        print("  -> No clear monotonic trend across bins")

In [None]:
# ============================================================
# Investigation A: Visualization
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Bar chart of win rates by condition
ax = axes[0]
conditions = ['bare']
win_rates_plot = [50.0]
for _, _, bn in SIMILARITY_BINS:
    if bn in bin_stats:
        conditions.append(f"{bn}\n({bin_stats[bn]['mean_sim']:.2f})")
        win_rates_plot.append(bin_stats[bn]['win_rate'])
conditions.append('oracle')
win_rates_plot.append(oracle_wr)

colors = ['#888888'] + ['#4c72b0'] * len(bin_stats) + ['#c44e52']
bars = ax.bar(range(len(conditions)), win_rates_plot, color=colors)
ax.axhline(50, color='gray', linestyle='--', linewidth=0.8)
ax.set_xticks(range(len(conditions)))
ax.set_xticklabels(conditions, fontsize=7, rotation=30, ha='right')
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('Win Rate by Surrogate Quality')
for i, wr in enumerate(win_rates_plot):
    ax.text(i, wr + 1, f'{wr:.0f}%', ha='center', fontsize=7)

# Plot 2: Scatter of similarity vs delta
ax = axes[1]
ax.scatter(all_sims, all_deltas_a, alpha=0.1, s=5, color='#4c72b0')
# Bin means
for _, _, bn in SIMILARITY_BINS:
    if bn in bin_stats:
        ax.scatter(bin_stats[bn]['mean_sim'], bin_stats[bn]['mean_delta'],
                  s=100, color='#c44e52', zorder=5, edgecolor='black')
ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)
# Regression line
z = np.polyfit(all_sims, all_deltas_a, 1)
p_fit = np.poly1d(z)
x_line = np.linspace(all_sims.min(), all_sims.max(), 100)
ax.plot(x_line, p_fit(x_line), 'r-', linewidth=2, label=f'r={r_pearson:.3f}')
ax.set_xlabel('Surrogate-Query Similarity')
ax.set_ylabel('Delta NLL (positive = priming helped)')
ax.set_title(f'Similarity vs Improvement\n(Pearson r={r_pearson:.3f}, p={p_pearson:.4f})')
ax.legend()

# Plot 3: Mean NLL by condition (gradient from bare to oracle)
ax = axes[2]
cond_labels = ['bare']
cond_nlls = [np.mean(bare_nlls_a)]
cond_errs = [np.std(bare_nlls_a) / np.sqrt(len(bare_nlls_a))]
for _, _, bn in SIMILARITY_BINS:
    if bn in bin_stats:
        cond_labels.append(f"{bn}\n({bin_stats[bn]['mean_sim']:.2f})")
        cond_nlls.append(bin_stats[bn]['mean_nll'])
        valid = [r for r in results_a if r.get(f'nll_{bn}') is not None]
        cond_errs.append(np.std([r[f'nll_{bn}'] for r in valid]) / np.sqrt(len(valid)))
cond_labels.append('oracle')
cond_nlls.append(np.mean(oracle_nlls_a))
cond_errs.append(np.std(oracle_nlls_a) / np.sqrt(len(oracle_nlls_a)))

ax.errorbar(range(len(cond_labels)), cond_nlls, yerr=cond_errs,
           fmt='o-', color='#4c72b0', capsize=3)
ax.set_xticks(range(len(cond_labels)))
ax.set_xticklabels(cond_labels, fontsize=7, rotation=30, ha='right')
ax.set_ylabel('Mean NLL (lower = better)')
ax.set_title('NLL Gradient from Bare to Oracle')

plt.tight_layout()
plt.savefig('results/exp11/11_investigation_a.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 11_investigation_a.png')

In [None]:
# ============================================================
# Investigation B: Ranking Task
# ============================================================
#
# For ad serving, what matters is RANKING: given a page and
# multiple candidate ads (queries), does priming help rank the
# relevant ad higher?
#
# Setup:
#   - Each sample has 1 correct query + 4 distractor queries
#   - Score each query using bare cache and primed cache
#   - Measure MRR and Hit@1 for both conditions
#   - Priming uses the ORACLE query (best case) and a medium-sim
#     real query (realistic case)
# ============================================================

# Ensure query pool is available (built in cell-4)
assert 'query_pool' in dir() and len(query_pool) > 0, (
    "query_pool not found. Please run cell-4 (Load Dataset and Build Query Pool) first."
)
assert 'pool_embeddings' in dir(), (
    "pool_embeddings not found. Please run cell-4 first."
)

print("="*80)
print("INVESTIGATION B: RANKING TASK")
print("="*80)

samples_b = all_samples[:N_INVESTIGATION_B]

# For each sample, select 4 distractor queries
# Distractors should be from different topics (low similarity to target)
print("Selecting distractor queries...")
target_queries_b = [s['query'] for s in samples_b]
target_embeddings_b = embed_model.encode(target_queries_b, batch_size=128)

rng_b = np.random.RandomState(config.seed + 100)

sample_distractors = []  # list of [query, ...] x 4
for i in tqdm(range(len(samples_b)), desc="Selecting distractors"):
    sims = cosine_similarity([target_embeddings_b[i]], pool_embeddings)[0]
    
    # Pick 4 distractors with varying similarity:
    # 2 low similarity (0.1-0.3) and 2 medium similarity (0.3-0.6)
    # This makes the ranking task non-trivial
    distractors = []
    for sim_lo, sim_hi, n_pick in [(0.1, 0.3, 2), (0.3, 0.6, 2)]:
        mask = (sims >= sim_lo) & (sims < sim_hi)
        candidates = np.where(mask)[0]
        if len(candidates) >= n_pick:
            chosen = rng_b.choice(candidates, size=n_pick, replace=False)
            for c in chosen:
                distractors.append(query_pool[c])
        else:
            # Fall back to whatever is available
            for c in candidates[:n_pick]:
                distractors.append(query_pool[c])
    
    # Pad if we don't have enough
    while len(distractors) < 4:
        rand_idx = rng_b.randint(0, len(query_pool))
        distractors.append(query_pool[rand_idx])
    
    sample_distractors.append(distractors[:4])

print(f"Selected distractors for {len(sample_distractors)} samples")

In [None]:
# ============================================================
# Investigation B: Run Ranking Evaluation
# ============================================================
# FIX 11.1: Use build_matched_bare_and_truncated for BPE matching
# FIX 11.2: Deep-copy cache before each score_answer_with_cache call
#           to prevent cache mutation across queries
# ============================================================

results_b = []
skipped_b = 0
errors_b = 0
start_b = time.time()

CHECKPOINT_PATH_B = 'results/exp11/11_checkpoint_b.json'
start_idx_b = 0
if os.path.exists(CHECKPOINT_PATH_B):
    with open(CHECKPOINT_PATH_B) as f:
        ckpt = json.load(f)
    results_b = ckpt['results']
    skipped_b = ckpt['skipped']
    errors_b = ckpt['errors']
    start_idx_b = ckpt['next_idx']
    print(f"Resumed from checkpoint: {len(results_b)} results")

for idx in tqdm(range(start_idx_b, len(samples_b)), desc="Inv B",
                initial=start_idx_b, total=len(samples_b)):
    sample = samples_b[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    distractors = sample_distractors[idx]
    
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_b += 1
        continue
    
    # All candidate queries (correct first, then distractors)
    all_queries = [query] + distractors
    
    try:
        result = {'idx': idx}
        
        # Build BPE-matched caches using oracle prefix
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache = build_matched_bare_and_truncated(
            oracle_prefix, passage, model, tokenizer, config
        )
        
        # Find a medium-similarity real query for this sample
        surrogates_for_b = sample_surrogates[idx] if idx < len(sample_surrogates) else {}
        medium_surr = None
        for bn in ['medium', 'high', 'low']:  # prefer medium, fall back
            if bn in surrogates_for_b:
                medium_surr = surrogates_for_b[bn][0]
                break
        
        medium_cache = None
        medium_len = None
        if medium_surr:
            medium_prefix = f"This document answers: {medium_surr}"
            _, _, medium_len, medium_cache = build_matched_bare_and_truncated(
                medium_prefix, passage, model, tokenizer, config
            )
        
        # Score each candidate query under each cache condition
        # CRITICAL: deep-copy cache before each call to prevent mutation
        bare_scores = []
        oracle_scores = []
        medium_scores = []
        
        for q in all_queries:
            q_prompt = config.query_template.format(query=q)
            
            nll_b = score_answer_with_cache(
                copy.deepcopy(bare_cache), bare_len, q_prompt, answer,
                model, tokenizer, config
            )
            bare_scores.append(nll_b)
            
            nll_o = score_answer_with_cache(
                copy.deepcopy(oracle_cache), oracle_len, q_prompt, answer,
                model, tokenizer, config
            )
            oracle_scores.append(nll_o)
            
            if medium_cache is not None:
                nll_m = score_answer_with_cache(
                    copy.deepcopy(medium_cache), medium_len, q_prompt, answer,
                    model, tokenizer, config
                )
                medium_scores.append(nll_m)
        
        # Compute rankings (rank by ascending NLL — lower = better match)
        bare_rank = int(np.argsort(bare_scores).tolist().index(0)) + 1
        oracle_rank = int(np.argsort(oracle_scores).tolist().index(0)) + 1
        medium_rank = int(np.argsort(medium_scores).tolist().index(0)) + 1 if medium_scores else None
        
        result['bare_scores'] = bare_scores
        result['oracle_scores'] = oracle_scores
        result['medium_scores'] = medium_scores
        result['bare_rank'] = bare_rank
        result['oracle_rank'] = oracle_rank
        result['medium_rank'] = medium_rank
        result['has_medium'] = medium_cache is not None
        
        results_b.append(result)
        
    except Exception as e:
        errors_b += 1
        if errors_b <= 5:
            print(f"\n  Error on sample {idx}: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    if len(results_b) % 25 == 0:
        with open(CHECKPOINT_PATH_B, 'w') as f:
            json.dump({
                'results': results_b, 'skipped': skipped_b,
                'errors': errors_b, 'next_idx': idx + 1
            }, f)
        elapsed = time.time() - start_b
        print(f"\n  [{len(results_b)} done | {elapsed/60:.0f}m]")

elapsed_b = time.time() - start_b
print(f"\nDone. {len(results_b)} evaluated, {skipped_b} skipped, {errors_b} errors.")
print(f"Time: {elapsed_b/60:.1f} min")

In [None]:
# ============================================================
# Investigation B: Analysis
# ============================================================

print("="*80)
print("INVESTIGATION B RESULTS: RANKING TASK")
print("="*80)

n_b = len(results_b)

# MRR and Hit@1
bare_ranks = np.array([r['bare_rank'] for r in results_b])
oracle_ranks = np.array([r['oracle_rank'] for r in results_b])
medium_results = [r for r in results_b if r['has_medium']]
medium_ranks = np.array([r['medium_rank'] for r in medium_results]) if medium_results else np.array([])

bare_mrr = np.mean(1.0 / bare_ranks)
bare_hit1 = np.mean(bare_ranks == 1)

oracle_mrr = np.mean(1.0 / oracle_ranks)
oracle_hit1 = np.mean(oracle_ranks == 1)

print(f"\n{'Condition':<20} {'N':>5} {'MRR':>8} {'Hit@1':>8} {'Hit@3':>8} {'Mean Rank':>10}")
print("-" * 65)
print(f"{'Bare':<20} {n_b:>5} {bare_mrr:>8.3f} {bare_hit1:>7.1%} {np.mean(bare_ranks <= 3):>7.1%} {np.mean(bare_ranks):>10.2f}")
print(f"{'Oracle primed':<20} {n_b:>5} {oracle_mrr:>8.3f} {oracle_hit1:>7.1%} {np.mean(oracle_ranks <= 3):>7.1%} {np.mean(oracle_ranks):>10.2f}")

if len(medium_ranks) > 0:
    medium_mrr = np.mean(1.0 / medium_ranks)
    medium_hit1 = np.mean(medium_ranks == 1)
    print(f"{'Medium primed':<20} {len(medium_results):>5} {medium_mrr:>8.3f} {medium_hit1:>7.1%} {np.mean(medium_ranks <= 3):>7.1%} {np.mean(medium_ranks):>10.2f}")

# Statistical tests
print("\nStatistical tests (Wilcoxon signed-rank on ranks):")
# Bare vs Oracle
if not np.all(bare_ranks == oracle_ranks):
    stat_bo, p_bo = stats.wilcoxon(bare_ranks, oracle_ranks)
    print(f"  Bare vs Oracle: W={stat_bo:.0f}, p={p_bo:.6f}")
else:
    print(f"  Bare vs Oracle: identical rankings")

# Bare vs Medium
if len(medium_ranks) > 0 and not np.all(bare_ranks[:len(medium_ranks)] == medium_ranks):
    bare_sub = np.array([r['bare_rank'] for r in medium_results])
    stat_bm, p_bm = stats.wilcoxon(bare_sub, medium_ranks)
    print(f"  Bare vs Medium: W={stat_bm:.0f}, p={p_bm:.6f}")

# Interpretation
print("\nInterpretation:")
if oracle_mrr > bare_mrr + 0.05:
    print(f"  Oracle priming improves ranking (MRR: {bare_mrr:.3f} -> {oracle_mrr:.3f})")
    if len(medium_ranks) > 0 and medium_mrr > bare_mrr + 0.02:
        print(f"  Medium-sim surrogates also help ranking (MRR: {medium_mrr:.3f})")
        print("  -> Practical benefit for ad serving with historical query surrogates")
    else:
        print("  But medium-sim surrogates do NOT improve ranking.")
        print("  -> Oracle effect does not transfer to realistic surrogates.")
else:
    print(f"  No ranking improvement from priming (MRR bare={bare_mrr:.3f}, oracle={oracle_mrr:.3f})")

In [None]:
# ============================================================
# Investigation B: Visualization
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Rank distribution
ax = axes[0]
rank_labels = [1, 2, 3, 4, 5]
bare_dist = [np.sum(bare_ranks == r) / n_b * 100 for r in rank_labels]
oracle_dist = [np.sum(oracle_ranks == r) / n_b * 100 for r in rank_labels]
x = np.arange(len(rank_labels))
w = 0.35
ax.bar(x - w/2, bare_dist, w, label='Bare', color='#4c72b0')
ax.bar(x + w/2, oracle_dist, w, label='Oracle', color='#c44e52')
ax.set_xticks(x)
ax.set_xticklabels(rank_labels)
ax.set_xlabel('Rank of Correct Query')
ax.set_ylabel('% of Samples')
ax.set_title('Rank Distribution')
ax.legend()

# Plot 2: MRR comparison
ax = axes[1]
conditions_mrr = ['Bare', 'Oracle']
mrrs = [bare_mrr, oracle_mrr]
if len(medium_ranks) > 0:
    conditions_mrr.append('Medium')
    mrrs.append(medium_mrr)
colors_mrr = ['#4c72b0', '#c44e52', '#55a868'][:len(conditions_mrr)]
ax.bar(range(len(conditions_mrr)), mrrs, color=colors_mrr)
ax.set_xticks(range(len(conditions_mrr)))
ax.set_xticklabels(conditions_mrr)
ax.set_ylabel('Mean Reciprocal Rank')
ax.set_title('MRR by Cache Condition')
ax.set_ylim(0, 1)
for i, m in enumerate(mrrs):
    ax.text(i, m + 0.02, f'{m:.3f}', ha='center', fontsize=10)

# Plot 3: Per-sample rank change (bare -> oracle)
ax = axes[2]
rank_change = bare_ranks - oracle_ranks  # positive = oracle improved ranking
change_vals, change_counts = np.unique(rank_change, return_counts=True)
colors_change = ['#c44e52' if v < 0 else '#55a868' if v > 0 else '#888888' for v in change_vals]
ax.bar(change_vals, change_counts / n_b * 100, color=colors_change)
ax.set_xlabel('Rank Change (positive = oracle better)')
ax.set_ylabel('% of Samples')
ax.set_title(f'Rank Change: Bare -> Oracle\n(improved: {np.mean(rank_change > 0)*100:.0f}%, '
             f'worse: {np.mean(rank_change < 0)*100:.0f}%)')

plt.tight_layout()
plt.savefig('results/exp11/11_investigation_b.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 11_investigation_b.png')

In [None]:
# ============================================================
# Investigation C: Same-Passage Surrogates
# ============================================================
# FIX 11.1: Use build_matched_bare_and_truncated for BPE matching
# Each condition builds a matched pair; bare from oracle match
# is the reference baseline.
# ============================================================

# Ensure prerequisites
assert 'query_pool' in dir() and len(query_pool) > 0, (
    "query_pool not found. Please run cell-4 first."
)
assert 'inv_c_samples' in dir() and len(inv_c_samples) > 0, (
    "inv_c_samples not found. Please run cell-5 first."
)

print("="*80)
print("INVESTIGATION C: SAME-PASSAGE SURROGATES")
print("="*80)

results_c = []
skipped_c = 0
errors_c = 0
start_c = time.time()

CHECKPOINT_PATH_C = 'results/exp11/11_checkpoint_c.json'
start_idx_c = 0
if os.path.exists(CHECKPOINT_PATH_C):
    with open(CHECKPOINT_PATH_C) as f:
        ckpt = json.load(f)
    results_c = ckpt['results']
    skipped_c = ckpt['skipped']
    errors_c = ckpt['errors']
    start_idx_c = ckpt['next_idx']
    print(f"Resumed from checkpoint: {len(results_c)} results")

# Pre-select irrelevant queries for control
rng_c = np.random.RandomState(config.seed + 200)
irrelevant_queries_c = [query_pool[rng_c.randint(0, len(query_pool))]
                        for _ in range(len(inv_c_samples))]

for idx in tqdm(range(start_idx_c, len(inv_c_samples)), desc="Inv C",
                initial=start_idx_c, total=len(inv_c_samples)):
    sample = inv_c_samples[idx]
    passage = sample['passage']
    surrogate_query = sample['surrogate_query']
    test_query = sample['test_query']
    test_answer = sample['test_answer']
    surr_sim = sample['surrogate_similarity']
    
    answer_ids = tokenizer(test_answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_c += 1
        continue
    
    query_prompt = config.query_template.format(query=test_query)
    
    try:
        result = {
            'idx': idx,
            'surrogate_query': surrogate_query,
            'test_query': test_query,
            'surrogate_similarity': surr_sim,
        }
        
        # 1. Oracle (also provides BPE-matched bare baseline)
        oracle_prefix = f"This document answers: {test_query}"
        bare_len, bare_cache, oracle_len, oracle_cache = build_matched_bare_and_truncated(
            oracle_prefix, passage, model, tokenizer, config
        )
        
        nll_bare = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, test_answer,
            model, tokenizer, config
        )
        result['nll_bare'] = nll_bare
        
        nll_oracle = score_answer_with_cache(
            oracle_cache, oracle_len, query_prompt, test_answer,
            model, tokenizer, config
        )
        result['nll_oracle'] = nll_oracle
        
        # 2. Same-passage primed
        surr_prefix = f"This document answers: {surrogate_query}"
        _, _, surr_len, surr_cache = build_matched_bare_and_truncated(
            surr_prefix, passage, model, tokenizer, config
        )
        nll_surr = score_answer_with_cache(
            surr_cache, surr_len, query_prompt, test_answer,
            model, tokenizer, config
        )
        result['nll_same_passage'] = nll_surr
        
        # 3. Irrelevant control
        irrel_prefix = f"This document answers: {irrelevant_queries_c[idx]}"
        _, _, irrel_len, irrel_cache = build_matched_bare_and_truncated(
            irrel_prefix, passage, model, tokenizer, config
        )
        nll_irrel = score_answer_with_cache(
            irrel_cache, irrel_len, query_prompt, test_answer,
            model, tokenizer, config
        )
        result['nll_irrelevant'] = nll_irrel
        
        results_c.append(result)
        
    except Exception as e:
        errors_c += 1
        if errors_c <= 5:
            print(f"\n  Error on sample {idx}: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    if len(results_c) % 25 == 0:
        with open(CHECKPOINT_PATH_C, 'w') as f:
            json.dump({
                'results': results_c, 'skipped': skipped_c,
                'errors': errors_c, 'next_idx': idx + 1
            }, f)
        elapsed = time.time() - start_c
        print(f"\n  [{len(results_c)} done | {elapsed/60:.0f}m]")

elapsed_c = time.time() - start_c
print(f"\nDone. {len(results_c)} evaluated, {skipped_c} skipped, {errors_c} errors.")
print(f"Time: {elapsed_c/60:.1f} min")

In [None]:
# ============================================================
# Investigation C: Analysis
# ============================================================

print("="*80)
print("INVESTIGATION C RESULTS: SAME-PASSAGE SURROGATES")
print("="*80)

n_c = len(results_c)
bare_nlls_c = np.array([r['nll_bare'] for r in results_c])
surr_nlls_c = np.array([r['nll_same_passage'] for r in results_c])
oracle_nlls_c = np.array([r['nll_oracle'] for r in results_c])
irrel_nlls_c = np.array([r['nll_irrelevant'] for r in results_c])
surr_sims_c = np.array([r['surrogate_similarity'] for r in results_c])

conditions_c = [
    ('Bare', bare_nlls_c),
    ('Same-passage', surr_nlls_c),
    ('Oracle', oracle_nlls_c),
    ('Irrelevant', irrel_nlls_c),
]

print(f"\nSame-passage surrogate similarity: mean={np.mean(surr_sims_c):.3f}, "
      f"std={np.std(surr_sims_c):.3f}")

print(f"\n{'Condition':<20} {'N':>5} {'Mean NLL':>10} {'Win% vs Bare':>14} {'Delta':>10} {'Cohen d':>10} {'p-value':>10}")
print("-" * 85)

for label, nlls in conditions_c:
    if label == 'Bare':
        print(f"{label:<20} {n_c:>5} {np.mean(nlls):>10.4f} {'--':>14} {'--':>10} {'--':>10} {'--':>10}")
    else:
        deltas = bare_nlls_c - nlls
        wr = np.mean(deltas > 0) * 100
        t, p = stats.ttest_rel(bare_nlls_c, nlls)
        d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
        print(f"{label:<20} {n_c:>5} {np.mean(nlls):>10.4f} {wr:>13.1f}% {np.mean(deltas):>+10.4f} {d:>10.3f} {p:>10.6f}")

# Key comparison: same-passage vs irrelevant
print("\n--- Key Comparisons ---")
t_si, p_si = stats.ttest_rel(surr_nlls_c, irrel_nlls_c)
surr_beats_irrel = np.mean(surr_nlls_c < irrel_nlls_c) * 100
print(f"Same-passage vs Irrelevant: t={t_si:.3f}, p={p_si:.4f}, "
      f"same-passage wins {surr_beats_irrel:.1f}%")

if p_si < 0.05 and surr_beats_irrel > 55:
    print("  >>> SEMANTIC SIGNAL: Same-passage surrogates significantly beat irrelevant ones.")
    print("      This means the content of the surrogate matters when it's a real query")
    print("      that was previously relevant to the same document.")
else:
    print("  >>> NO SEMANTIC SIGNAL even with real same-passage queries.")

# Correlation: does surrogate-test similarity predict delta?
delta_surr_c = bare_nlls_c - surr_nlls_c
r_c, p_rc = stats.pearsonr(surr_sims_c, delta_surr_c)
print(f"\nSimilarity-Delta correlation: r={r_c:.4f}, p={p_rc:.6f}")
if r_c > 0.1 and p_rc < 0.05:
    print("  -> More similar surrogates produce better caches")
else:
    print("  -> No correlation between surrogate quality and cache quality")

In [None]:
# ============================================================
# Investigation C: Visualization
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Bar chart of conditions
ax = axes[0]
labels = ['Bare', 'Irrelevant', 'Same-Passage', 'Oracle']
nlls_means = [np.mean(bare_nlls_c), np.mean(irrel_nlls_c),
              np.mean(surr_nlls_c), np.mean(oracle_nlls_c)]
nlls_errs = [np.std(bare_nlls_c)/np.sqrt(n_c), np.std(irrel_nlls_c)/np.sqrt(n_c),
             np.std(surr_nlls_c)/np.sqrt(n_c), np.std(oracle_nlls_c)/np.sqrt(n_c)]
colors_c = ['#888888', '#8c564b', '#4c72b0', '#c44e52']
ax.bar(range(len(labels)), nlls_means, yerr=nlls_errs,
       color=colors_c, capsize=3)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=9)
ax.set_ylabel('Mean NLL (lower = better)')
ax.set_title('Same-Passage Surrogates')

# Plot 2: Similarity vs Delta scatter
ax = axes[1]
ax.scatter(surr_sims_c, delta_surr_c, alpha=0.3, s=20, color='#4c72b0')
ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)
z_c = np.polyfit(surr_sims_c, delta_surr_c, 1)
p_c = np.poly1d(z_c)
x_c = np.linspace(surr_sims_c.min(), surr_sims_c.max(), 100)
ax.plot(x_c, p_c(x_c), 'r-', linewidth=2)
ax.set_xlabel('Surrogate-Test Query Similarity')
ax.set_ylabel('Delta NLL (positive = priming helped)')
ax.set_title(f'Similarity vs Improvement\n(r={r_c:.3f}, p={p_rc:.4f})')

# Plot 3: Distribution of deltas
ax = axes[2]
ax.hist(delta_surr_c, bins=30, alpha=0.6, label='Same-Passage', color='#4c72b0')
ax.hist(bare_nlls_c - irrel_nlls_c, bins=30, alpha=0.6, label='Irrelevant', color='#8c564b')
ax.axvline(0, color='gray', linestyle='--')
ax.set_xlabel('Delta NLL vs Bare')
ax.set_ylabel('Count')
ax.set_title('Delta Distribution')
ax.legend()

plt.tight_layout()
plt.savefig('results/exp11/11_investigation_c.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 11_investigation_c.png')

In [None]:
# ============================================================
# Comprehensive Summary
# ============================================================

print("="*80)
print("EXPERIMENT 11: COMPREHENSIVE SUMMARY")
print("="*80)

print("\n--- Investigation A: Surrogate Quality Gradient ---")
print(f"  Oracle win rate: {oracle_wr:.1f}% (ceiling)")
for _, _, bn in SIMILARITY_BINS:
    if bn in bin_stats:
        s = bin_stats[bn]
        print(f"  {bn} (sim ~{s['mean_sim']:.2f}): win%={s['win_rate']:.1f}%, d={s['cohens_d']:.3f}")
print(f"  Similarity-Delta correlation: r={r_pearson:.4f}, p={p_pearson:.6f}")

print("\n--- Investigation B: Ranking Task ---")
print(f"  Bare MRR: {bare_mrr:.3f}, Hit@1: {bare_hit1:.1%}")
print(f"  Oracle MRR: {oracle_mrr:.3f}, Hit@1: {oracle_hit1:.1%}")
if len(medium_ranks) > 0:
    print(f"  Medium MRR: {medium_mrr:.3f}, Hit@1: {medium_hit1:.1%}")

print("\n--- Investigation C: Same-Passage Surrogates ---")
print(f"  Same-passage vs Bare: win%={np.mean(bare_nlls_c > surr_nlls_c)*100:.1f}%")
print(f"  Same-passage vs Irrelevant: win%={surr_beats_irrel:.1f}% (p={p_si:.4f})")
print(f"  Similarity-Delta correlation: r={r_c:.4f}")

print("\n" + "="*80)
print("OVERALL VERDICT")
print("="*80)

# Determine overall conclusion
semantic_signal_a = r_pearson > 0.1 and p_pearson < 0.01
semantic_signal_c = p_si < 0.05 and surr_beats_irrel > 55
ranking_benefit = oracle_mrr > bare_mrr + 0.05

if semantic_signal_a and semantic_signal_c:
    print("\n  SEMANTIC PRIMING EFFECT CONFIRMED")
    print("  Higher-quality surrogates produce better KV caches.")
    print("  The effect was previously masked by low-quality generated surrogates.")
    if len(medium_ranks) > 0 and medium_mrr > bare_mrr + 0.02:
        print("  The effect transfers to ranking, making it practically useful.")
    else:
        print("  However, the effect may not be large enough for practical ranking benefit.")
elif semantic_signal_a or semantic_signal_c:
    print("\n  PARTIAL SEMANTIC SIGNAL")
    print("  Evidence is mixed across investigations.")
    if semantic_signal_a:
        print("  Inv A (quality gradient) shows a signal.")
    if semantic_signal_c:
        print("  Inv C (same-passage) shows a signal.")
else:
    print("\n  NO SEMANTIC PRIMING EFFECT")
    print("  Even with real queries at high similarity, no semantic signal emerges.")
    print("  The KV cache mechanism fundamentally cannot transmit semantic")
    print("  information through value contamination during truncated caching.")
    print("")
    print("  The oracle effect works because the query-as-prefix directly")
    print("  conditions the model's representations, but this information is")
    print("  lost when truncating back to just the document cache.")

if ranking_benefit:
    print("\n  RANKING NOTE: Oracle priming does improve ranking,")
    print("  confirming the mechanism works in principle for ad serving.")
    print("  The challenge remains: practical surrogates are not close enough to oracle.")

In [None]:
# ============================================================
# Save All Results
# ============================================================

output = {
    'metadata': {
        'experiment': '11_surrogate_quality_gradient',
        'timestamp': datetime.datetime.now().isoformat(),
        'model_name': config.model_name,
        'seed': config.seed,
        'n_investigation_a': N_INVESTIGATION_A,
        'n_investigation_b': N_INVESTIGATION_B,
        'n_investigation_c': N_INVESTIGATION_C,
    },
    'investigation_a': {
        'n_evaluated': len(results_a),
        'n_skipped': skipped_a,
        'n_errors': errors_a,
        'bin_stats': bin_stats,
        'oracle_win_rate': float(oracle_wr),
        'pearson_r': float(r_pearson),
        'pearson_p': float(p_pearson),
        'spearman_rho': float(r_spearman),
        'spearman_p': float(p_spearman),
        'results': results_a,
    },
    'investigation_b': {
        'n_evaluated': len(results_b),
        'n_skipped': skipped_b,
        'n_errors': errors_b,
        'bare_mrr': float(bare_mrr),
        'oracle_mrr': float(oracle_mrr),
        'bare_hit1': float(bare_hit1),
        'oracle_hit1': float(oracle_hit1),
        'medium_mrr': float(medium_mrr) if len(medium_ranks) > 0 else None,
        'medium_hit1': float(medium_hit1) if len(medium_ranks) > 0 else None,
        'results': results_b,
    },
    'investigation_c': {
        'n_evaluated': len(results_c),
        'n_skipped': skipped_c,
        'n_errors': errors_c,
        'same_passage_vs_irrelevant_p': float(p_si),
        'same_passage_beats_irrelevant': float(surr_beats_irrel),
        'similarity_delta_r': float(r_c),
        'similarity_delta_p': float(p_rc),
        'results': results_c,
    },
    'semantic_signal_a': bool(semantic_signal_a),
    'semantic_signal_c': bool(semantic_signal_c),
    'ranking_benefit': bool(ranking_benefit),
}

output_path = 'results/exp11/11_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Results saved to {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1e6:.1f} MB")

# Summary visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# A: Quality gradient
ax = axes[0]
sims_plot = [0.0]  # bare
deltas_plot = [0.0]
for _, _, bn in SIMILARITY_BINS:
    if bn in bin_stats:
        sims_plot.append(bin_stats[bn]['mean_sim'])
        deltas_plot.append(bin_stats[bn]['mean_delta'])
sims_plot.append(1.0)
deltas_plot.append(float(np.mean(oracle_deltas)))
ax.plot(sims_plot, deltas_plot, 'o-', color='#4c72b0', linewidth=2, markersize=8)
ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)
ax.set_xlabel('Surrogate Similarity to Target Query')
ax.set_ylabel('Mean Delta NLL (positive = better)')
ax.set_title('A: Quality Gradient\n(0=bare, 1=oracle)')

# B: Ranking
ax = axes[1]
conds_plot = ['Bare', 'Oracle']
mrrs_plot = [bare_mrr, oracle_mrr]
if len(medium_ranks) > 0:
    conds_plot.insert(1, 'Medium')
    mrrs_plot.insert(1, medium_mrr)
ax.bar(range(len(conds_plot)), mrrs_plot, color=['#888888', '#55a868', '#c44e52'][:len(conds_plot)])
ax.set_xticks(range(len(conds_plot)))
ax.set_xticklabels(conds_plot)
ax.set_ylabel('MRR')
ax.set_title('B: Ranking Performance')
ax.set_ylim(0, 1)

# C: Same-passage
ax = axes[2]
c_labels = ['Bare', 'Irrelevant', 'Same-\nPassage', 'Oracle']
c_wins = [
    50.0,
    np.mean(bare_nlls_c > irrel_nlls_c) * 100,
    np.mean(bare_nlls_c > surr_nlls_c) * 100,
    np.mean(bare_nlls_c > oracle_nlls_c) * 100,
]
ax.bar(range(len(c_labels)), c_wins, color=['#888888', '#8c564b', '#4c72b0', '#c44e52'])
ax.axhline(50, color='gray', linestyle='--')
ax.set_xticks(range(len(c_labels)))
ax.set_xticklabels(c_labels)
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('C: Same-Passage Surrogates')
for i, w in enumerate(c_wins):
    ax.text(i, w + 1, f'{w:.0f}%', ha='center', fontsize=9)

plt.tight_layout()
plt.savefig('results/exp11/11_summary.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 11_summary.png')