# Experiment 07: Suffix Priming and Advanced Cache Strategies

## Motivation

Experiments 01-06 showed that **prefix priming** (surrogate before document, then truncate) produces a content-independent artifact: all prefixes help/hurt together (r=0.924 between gen_routed and shuffled deltas). The RoPE correction fixes positions but values carry content-independent contamination from the forward pass.

**The suffix approach sidesteps this entirely.** In a causal model, appending a surrogate AFTER the document means document tokens never attend to the suffix. Document KV entries are byte-identical to bare. Any improvement must come from the query attending to suffix tokens that have "read" the full document — a clean semantic signal.

## 18 Conditions

- **Group A (2)**: Baselines — bare, bare_padded
- **Group B (7)**: Suffix semantic isolation — gen_routed, gen_oracle, perfect, irrelevant, shuffled, rand_passage, rand_tokens
- **Group C (3)**: Suffix format — raw, newline, multi_q
- **Group D (3)**: Prefix comparison — pfx_trunc_routed, pfx_trunc_perfect, pfx_full_ctx
- **Group E (3)**: Suffix routing pool — pool_cosine, pool_oracle, summary

## Experimental Notes

- Motivated by Exp 06's finding that prefix contamination (the forward pass through the surrogate contaminates values, not just keys) is unavoidable with prefix placement.
- The suffix approach guarantees byte-identical document KV entries because document tokens are processed before the suffix tokens, so their keys and values are unaffected.
- **Sanity check confirmed:** document KV entries are identical between suffix-primed and bare-document caches.
- **Result:** Content-independent effects persisted even with suffix placement, ruling out value contamination as the sole explanation.

## Setup

In [None]:
import sys
import os
import copy
import json
import time
import datetime
import random
from typing import Dict, List, Any, Optional, Tuple

import torch
import numpy as np
from tqdm.auto import tqdm
from scipy import stats
from scipy.stats import rankdata
import matplotlib.pyplot as plt

sys.path.insert(0, '.')

from lib import (
    ExperimentConfig,
    build_kv_cache,
    build_suffix_kv_cache,
    score_answer_with_cache,
    build_truncated_kv_cache_corrected,
    generate_all_5_surrogates,
    generate_summary,
    compute_similarity,
    load_evaluation_samples,
    load_ms_marco,
    TOP_5_SURROGATE_TEMPLATES,
    STATIC_SURROGATE_QUERIES,
)
from lib.analysis import cohens_d

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Configuration — same seed=42 and filtering as exp 06
config = ExperimentConfig(
    num_samples=800,
    min_passage_words=50,
    max_passage_words=300,
    surrogate_max_tokens=45,
    surrogate_temperature=0.3,
    seed=42,
)

print(f"Model: {config.model_name}")
print(f"Samples requested: {config.num_samples}")
print(f"Passage words: {config.min_passage_words}-{config.max_passage_words}")
print(f"Device: {config.device}")

## Model Loading

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)

# Load model (4-bit quantized)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {config.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()
print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Embedding model for routing
embed_model = SentenceTransformer(config.embedding_model_name)
print(f"Embedding model loaded: {config.embedding_model_name}")

## Dataset

In [None]:
dataset = load_ms_marco(config)
raw_samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Raw samples after basic filtering: {len(raw_samples)}")

# Same additional filters as exp 06
filtered_samples = []
excluded_ratio = 0
excluded_short_answer = 0

for s in raw_samples:
    if len(s['answer']) / max(len(s['passage']), 1) > 0.5:
        excluded_ratio += 1
        continue
    answer_ids = tokenizer(s['answer'], return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        excluded_short_answer += 1
        continue
    filtered_samples.append(s)

samples = filtered_samples
print(f"\nFiltering stats:")
print(f"  Excluded (answer/passage ratio > 0.5): {excluded_ratio}")
print(f"  Excluded (answer < 2 tokens):          {excluded_short_answer}")
print(f"  Remaining samples:                     {len(samples)}")

s = samples[0]
print(f"\nExample sample:")
print(f"  Query:   {s['query'][:100]}...")
print(f"  Passage: {s['passage'][:100]}...")
print(f"  Answer:  {s['answer'][:100]}")

## Sanity Check: Suffix KV Identity

Verify that document KV entries in a suffix cache are byte-identical to the bare cache.
In a causal model, passage tokens cannot attend to suffix tokens, so their KV representations must be identical.

In [None]:
# Sanity check: passage KV entries must be identical between bare and suffix caches
test_passage = samples[0]['passage']
test_suffix = "What is the main topic discussed?"

bare_len, bare_cache = build_kv_cache(test_passage, model, tokenizer, config)
sfx_len, sfx_cache = build_suffix_kv_cache(test_passage, test_suffix, model, tokenizer, config)

print(f"Bare cache length: {bare_len}")
print(f"Suffix cache length: {sfx_len}")
print(f"Suffix added {sfx_len - bare_len} tokens")

def _get_kv(cache, layer_idx):
    """Get (keys, values) tensors from a cache layer, compatible with all DynamicCache versions."""
    if hasattr(cache, 'key_cache'):
        return cache.key_cache[layer_idx], cache.value_cache[layer_idx]
    return cache.layers[layer_idx].keys, cache.layers[layer_idx].values

# Compare first bare_len positions across all layers
all_match = True
n_layers = len(bare_cache)
for layer_idx in range(n_layers):
    bare_k, bare_v = _get_kv(bare_cache, layer_idx)
    sfx_k, sfx_v = _get_kv(sfx_cache, layer_idx)

    k_match = torch.equal(bare_k[:, :, :bare_len, :], sfx_k[:, :, :bare_len, :])
    v_match = torch.equal(bare_v[:, :, :bare_len, :], sfx_v[:, :, :bare_len, :])

    if not k_match or not v_match:
        all_match = False
        print(f"  Layer {layer_idx}: keys_match={k_match}, values_match={v_match}")
        break

if all_match:
    print(f"\nPASSED: All {n_layers} layers verified — passage KV entries are byte-identical.")
    print("  Document tokens never attend to suffix. Causal isolation confirmed.")
else:
    print(f"\nFAILED: Passage KV entries differ! Something is wrong.")

# Clean up
del bare_cache, sfx_cache
torch.cuda.empty_cache()

## Helper Functions

In [None]:
def shuffle_text(text: str, rng: random.Random) -> str:
    words = text.split()
    rng.shuffle(words)
    return ' '.join(words)


def generate_random_tokens(tokenizer, n_tokens: int, rng: random.Random) -> str:
    vocab_size = tokenizer.vocab_size
    random_ids = [rng.randint(100, vocab_size - 1) for _ in range(n_tokens)]
    return tokenizer.decode(random_ids, skip_special_tokens=True)


def get_irrelevant_query(samples: list, current_idx: int, rng: random.Random) -> str:
    other_idx = current_idx
    while other_idx == current_idx:
        other_idx = rng.randint(0, len(samples) - 1)
    return samples[other_idx]['query']


def get_random_passage(samples: list, current_idx: int, rng: random.Random) -> str:
    other_idx = current_idx
    while other_idx == current_idx:
        other_idx = rng.randint(0, len(samples) - 1)
    return samples[other_idx]['passage']


def deep_copy_cache(cache):
    return copy.deepcopy(cache)


def score_query_nll(cache, cache_len, query_text, model, tokenizer, config):
    cache_copy = deep_copy_cache(cache)
    query_ids = tokenizer(
        query_text, return_tensors="pt", add_special_tokens=False
    )['input_ids'].to(config.device)
    query_len = query_ids.shape[1]
    if query_len < 2:
        return float('inf')
    attention_mask = torch.ones((1, cache_len + query_len), device=config.device)
    with torch.no_grad():
        outputs = model(
            input_ids=query_ids,
            attention_mask=attention_mask,
            past_key_values=cache_copy,
            use_cache=True,
            return_dict=True,
        )
    logits = outputs.logits
    shift_logits = logits[:, :-1, :].contiguous().view(-1, logits.size(-1))
    shift_labels = query_ids[:, 1:].contiguous().view(-1)
    loss = torch.nn.CrossEntropyLoss(reduction='sum')
    nll = loss(shift_logits, shift_labels).item()
    return nll / (query_len - 1)


print("Helper functions defined.")

## Resume from Checkpoint

In [None]:
checkpoint_path = '07_checkpoint.json'
results = []
start_from = 0

if os.path.exists(checkpoint_path):
    with open(checkpoint_path) as f:
        ckpt = json.load(f)
    results = ckpt['results']
    start_from = ckpt['n_done']
    print(f"Resuming from checkpoint: {start_from} samples already done.")
else:
    print("No checkpoint found. Starting from scratch.")

## Pipeline Verification

Test all 18 conditions on one sample to verify everything works.

In [None]:
test_sample = samples[0]
test_idx = 0
test_rng = random.Random(config.seed)
passage = test_sample['passage']
query = test_sample['query']
answer = test_sample['answer']
query_prompt = config.query_template.format(query=query)

print(f"Passage: {passage[:100]}...")
print(f"Query:   {query}")
print(f"Answer:  {answer[:80]}")
print()

# --- Group A: Baselines ---
# 1. Bare
bare_len, bare_cache = build_kv_cache(passage, model, tokenizer, config)
bare_nll = score_answer_with_cache(bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
print(f" 1. bare              NLL: {bare_nll:.4f}")

# 2. Bare padded (match suffix length with newlines)
test_surrogates = generate_all_5_surrogates(passage, model, tokenizer, config)
test_sims = {k: compute_similarity(v, query, embed_model) for k, v in test_surrogates.items()}
routed_key = max(test_sims, key=test_sims.get)
routed_surr = test_surrogates[routed_key]
# Estimate suffix token count for padding
sfx_text = "\n\nRelated question: " + routed_surr
sfx_tok_len = len(tokenizer(sfx_text, add_special_tokens=False)['input_ids'])
padding = "\n" * sfx_tok_len
padded_len, padded_cache = build_kv_cache(passage + padding, model, tokenizer, config)
padded_nll = score_answer_with_cache(padded_cache, padded_len, query_prompt, answer, model, tokenizer, config)
print(f" 2. bare_padded       NLL: {padded_nll:.4f} (added {sfx_tok_len} newline tokens)")

# --- Group B: Suffix Semantic Isolation ---
# 3. sfx_gen_routed
sfx_len, sfx_cache = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config)
sfx_gen_routed_nll = score_answer_with_cache(sfx_cache, sfx_len, query_prompt, answer, model, tokenizer, config)
print(f" 3. sfx_gen_routed    NLL: {sfx_gen_routed_nll:.4f}")

# 4. sfx_gen_oracle — score all 5
sfx_gen_nlls = {}
for k, surr in test_surrogates.items():
    sl, sc = build_suffix_kv_cache(passage, surr, model, tokenizer, config)
    sfx_gen_nlls[k] = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
sfx_oracle_key = min(sfx_gen_nlls, key=sfx_gen_nlls.get)
print(f" 4. sfx_gen_oracle    NLL: {sfx_gen_nlls[sfx_oracle_key]:.4f} (key={sfx_oracle_key})")

# 5. sfx_perfect
sl, sc = build_suffix_kv_cache(passage, query, model, tokenizer, config)
sfx_perfect_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f" 5. sfx_perfect       NLL: {sfx_perfect_nll:.4f}")

# 6. sfx_irrel
irrel_q = get_irrelevant_query(samples, test_idx, test_rng)
sl, sc = build_suffix_kv_cache(passage, irrel_q, model, tokenizer, config)
sfx_irrel_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f" 6. sfx_irrel         NLL: {sfx_irrel_nll:.4f}")

# 7. sfx_shuffled
shuffled = shuffle_text(routed_surr, test_rng)
sl, sc = build_suffix_kv_cache(passage, shuffled, model, tokenizer, config)
sfx_shuffled_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f" 7. sfx_shuffled      NLL: {sfx_shuffled_nll:.4f}")

# 8. sfx_rand_passage
rp = get_random_passage(samples, test_idx, test_rng)[:200]
sl, sc = build_suffix_kv_cache(passage, rp, model, tokenizer, config)
sfx_rand_passage_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f" 8. sfx_rand_passage  NLL: {sfx_rand_passage_nll:.4f}")

# 9. sfx_rand_tokens
rt = generate_random_tokens(tokenizer, 20, test_rng)
sl, sc = build_suffix_kv_cache(passage, rt, model, tokenizer, config)
sfx_rand_tokens_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f" 9. sfx_rand_tokens   NLL: {sfx_rand_tokens_nll:.4f}")

# --- Group C: Suffix Format ---
# 10. sfx_raw
sl, sc = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config, separator="")
sfx_raw_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f"10. sfx_raw           NLL: {sfx_raw_nll:.4f}")

# 11. sfx_newline
sl, sc = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config, separator="\n\n")
sfx_newline_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f"11. sfx_newline       NLL: {sfx_newline_nll:.4f}")

# 12. sfx_multi_q (all 5 generated)
multi_q_text = "\n".join(f"Related question: {v}" for v in test_surrogates.values())
sl, sc = build_suffix_kv_cache(passage, multi_q_text, model, tokenizer, config, separator="\n\n")
sfx_multi_q_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f"12. sfx_multi_q       NLL: {sfx_multi_q_nll:.4f}")

# --- Group D: Prefix Comparison ---
# 13. pfx_trunc_routed
dl, c = build_truncated_kv_cache_corrected(routed_surr, passage, model, tokenizer, config)
pfx_trunc_routed_nll = score_answer_with_cache(c, dl, query_prompt, answer, model, tokenizer, config)
print(f"13. pfx_trunc_routed  NLL: {pfx_trunc_routed_nll:.4f}")

# 14. pfx_trunc_perfect
dl, c = build_truncated_kv_cache_corrected(query, passage, model, tokenizer, config)
pfx_trunc_perfect_nll = score_answer_with_cache(c, dl, query_prompt, answer, model, tokenizer, config)
print(f"14. pfx_trunc_perfect NLL: {pfx_trunc_perfect_nll:.4f}")

# 15. pfx_full_ctx
full_ctx = config.surrogate_cache_template.format(surrogate=routed_surr, document=passage)
fl, fc = build_kv_cache(full_ctx, model, tokenizer, config)
pfx_full_ctx_nll = score_answer_with_cache(fc, fl, query_prompt, answer, model, tokenizer, config)
print(f"15. pfx_full_ctx      NLL: {pfx_full_ctx_nll:.4f}")

# --- Group E: Suffix Routing Pool + Summary ---
# 16. sfx_pool_cosine — build pool of suffix caches
# (deferred to evaluate_sample)
print(f"16. sfx_pool_cosine   (tested in evaluate_sample)")

# 17. sfx_pool_oracle
print(f"17. sfx_pool_oracle   (tested in evaluate_sample)")

# 18. sfx_summary
summary = generate_summary(passage, model, tokenizer, config)
sl, sc = build_suffix_kv_cache(passage, summary, model, tokenizer, config, separator="\n\nSummary: ")
sfx_summary_nll = score_answer_with_cache(sc, sl, query_prompt, answer, model, tokenizer, config)
print(f"18. sfx_summary       NLL: {sfx_summary_nll:.4f}")
print(f"    Summary: {summary[:120]}...")

print("\nAll conditions produce finite NLLs. Pipeline verified.")

## Evaluation Function

In [None]:
def evaluate_sample(
    sample: Dict,
    idx: int,
    all_samples: List[Dict],
    model,
    tokenizer,
    embed_model,
    config: ExperimentConfig,
) -> Optional[Dict]:
    """Evaluate a single sample across all 18 experimental conditions."""
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    query_prompt = config.query_template.format(query=query)
    rng = random.Random(config.seed + idx)

    # ==================== GROUP A: BASELINES ====================

    # 1. Bare passage
    bare_len, bare_cache = build_kv_cache(passage, model, tokenizer, config)
    bare_nll = score_answer_with_cache(
        bare_cache, bare_len, query_prompt, answer, model, tokenizer, config
    )

    # ==================== GENERATE SURROGATES ====================
    generated_surrogates = generate_all_5_surrogates(passage, model, tokenizer, config)
    gen_similarities = {
        k: compute_similarity(v, query, embed_model)
        for k, v in generated_surrogates.items()
    }
    gen_routed_key = max(gen_similarities, key=gen_similarities.get)
    routed_surr = generated_surrogates[gen_routed_key]

    # 2. Bare padded (length control: match suffix token count with newlines)
    sfx_text_for_len = "\n\nRelated question: " + routed_surr
    sfx_tok_len = len(tokenizer(sfx_text_for_len, add_special_tokens=False)['input_ids'])
    padding = "\n" * sfx_tok_len
    padded_len, padded_cache = build_kv_cache(passage + padding, model, tokenizer, config)
    padded_nll = score_answer_with_cache(
        padded_cache, padded_len, query_prompt, answer, model, tokenizer, config
    )

    # ==================== GROUP B: SUFFIX SEMANTIC ISOLATION ====================

    # Score all 5 generated surrogates as suffixes
    sfx_gen_nlls = {}
    sfx_gen_cache_data = {}
    for key, surrogate in generated_surrogates.items():
        sl, sc = build_suffix_kv_cache(passage, surrogate, model, tokenizer, config)
        sfx_gen_nlls[key] = score_answer_with_cache(
            sc, sl, query_prompt, answer, model, tokenizer, config
        )
        sfx_gen_cache_data[key] = (sl, sc)

    # 3. sfx_gen_routed
    sfx_gen_routed_nll = sfx_gen_nlls[gen_routed_key]

    # 4. sfx_gen_oracle
    sfx_gen_oracle_key = min(sfx_gen_nlls, key=sfx_gen_nlls.get)
    sfx_gen_oracle_nll = sfx_gen_nlls[sfx_gen_oracle_key]

    # 5. sfx_perfect
    sl, sc = build_suffix_kv_cache(passage, query, model, tokenizer, config)
    sfx_perfect_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # 6. sfx_irrel
    irrel_query = get_irrelevant_query(all_samples, idx, rng)
    sl, sc = build_suffix_kv_cache(passage, irrel_query, model, tokenizer, config)
    sfx_irrel_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # 7. sfx_shuffled
    shuffled_text = shuffle_text(routed_surr, rng)
    sl, sc = build_suffix_kv_cache(passage, shuffled_text, model, tokenizer, config)
    sfx_shuffled_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # 8. sfx_rand_passage
    rand_passage_text = get_random_passage(all_samples, idx, rng)[:200]
    sl, sc = build_suffix_kv_cache(passage, rand_passage_text, model, tokenizer, config)
    sfx_rand_passage_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # 9. sfx_rand_tokens
    rand_tokens_text = generate_random_tokens(tokenizer, 20, rng)
    sl, sc = build_suffix_kv_cache(passage, rand_tokens_text, model, tokenizer, config)
    sfx_rand_tokens_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # ==================== GROUP C: SUFFIX FORMAT ====================

    # 10. sfx_raw (no separator)
    sl, sc = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config, separator="")
    sfx_raw_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # 11. sfx_newline (minimal separator)
    sl, sc = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config, separator="\n\n")
    sfx_newline_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # 12. sfx_multi_q (all 5 generated)
    multi_q_text = "\n".join(f"Related question: {v}" for v in generated_surrogates.values())
    sl, sc = build_suffix_kv_cache(passage, multi_q_text, model, tokenizer, config, separator="\n\n")
    sfx_multi_q_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    # ==================== GROUP D: PREFIX COMPARISON ====================

    # 13. pfx_trunc_routed (same surrogate, prefix placement, truncated+RoPE corrected)
    dl, c = build_truncated_kv_cache_corrected(
        routed_surr, passage, model, tokenizer, config
    )
    pfx_trunc_routed_nll = score_answer_with_cache(
        c, dl, query_prompt, answer, model, tokenizer, config
    )

    # 14. pfx_trunc_perfect (actual query as prefix)
    dl, c = build_truncated_kv_cache_corrected(
        query, passage, model, tokenizer, config
    )
    pfx_trunc_perfect_nll = score_answer_with_cache(
        c, dl, query_prompt, answer, model, tokenizer, config
    )

    # 15. pfx_full_ctx (prefix kept visible, no truncation)
    full_ctx = config.surrogate_cache_template.format(
        surrogate=routed_surr, document=passage
    )
    fl, fc = build_kv_cache(full_ctx, model, tokenizer, config)
    pfx_full_ctx_nll = score_answer_with_cache(
        fc, fl, query_prompt, answer, model, tokenizer, config
    )

    # ==================== GROUP E: SUFFIX ROUTING POOL ====================

    # Build suffix pool: 5 generated + 5 static + bare = 11 caches
    pool_caches = {}
    pool_caches['bare'] = (bare_len, bare_cache)
    for key in generated_surrogates:
        pool_caches[f'gen_{key}'] = sfx_gen_cache_data[key]
    for key, info in STATIC_SURROGATE_QUERIES.items():
        sl, sc = build_suffix_kv_cache(passage, info['query'], model, tokenizer, config)
        pool_caches[f'static_{key}'] = (sl, sc)

    # Cosine similarities for pool routing
    pool_sims = {}
    pool_sims['bare'] = compute_similarity(passage, query, embed_model)
    for k, surr in generated_surrogates.items():
        pool_sims[f'gen_{k}'] = gen_similarities[k]
    for k, info in STATIC_SURROGATE_QUERIES.items():
        pool_sims[f'static_{k}'] = compute_similarity(info['query'], query, embed_model)

    # Answer NLL for all pool caches
    pool_answer_nlls = {}
    for cache_key, (clen, cache_obj) in pool_caches.items():
        if cache_key == 'bare':
            pool_answer_nlls[cache_key] = bare_nll
        elif cache_key.startswith('gen_'):
            template_key = cache_key[4:]
            pool_answer_nlls[cache_key] = sfx_gen_nlls[template_key]
        else:
            cache_copy = deep_copy_cache(cache_obj)
            pool_answer_nlls[cache_key] = score_answer_with_cache(
                cache_copy, clen, query_prompt, answer, model, tokenizer, config
            )

    # 16. sfx_pool_cosine
    pool_cosine_key = max(pool_sims, key=pool_sims.get)
    sfx_pool_cosine_nll = pool_answer_nlls[pool_cosine_key]

    # 17. sfx_pool_oracle
    pool_oracle_key = min(pool_answer_nlls, key=pool_answer_nlls.get)
    sfx_pool_oracle_nll = pool_answer_nlls[pool_oracle_key]

    # 18. sfx_summary
    summary = generate_summary(passage, model, tokenizer, config)
    sl, sc = build_suffix_kv_cache(passage, summary, model, tokenizer, config, separator="\n\nSummary: ")
    sfx_summary_nll = score_answer_with_cache(
        sc, sl, query_prompt, answer, model, tokenizer, config
    )

    return {
        'idx': idx,
        'query': query,
        'answer_len': len(answer),
        'passage_len': len(passage),

        # Group A
        'bare_nll': bare_nll,
        'padded_nll': padded_nll,

        # Group B
        'sfx_gen_routed_key': gen_routed_key,
        'sfx_gen_routed_nll': sfx_gen_routed_nll,
        'sfx_gen_routed_sim': gen_similarities[gen_routed_key],
        'sfx_gen_oracle_key': sfx_gen_oracle_key,
        'sfx_gen_oracle_nll': sfx_gen_oracle_nll,
        'sfx_gen_nlls': sfx_gen_nlls,
        'sfx_gen_similarities': gen_similarities,
        'generated_surrogates': generated_surrogates,
        'sfx_perfect_nll': sfx_perfect_nll,
        'sfx_irrel_nll': sfx_irrel_nll,
        'sfx_irrel_query': irrel_query,
        'sfx_shuffled_nll': sfx_shuffled_nll,
        'sfx_rand_passage_nll': sfx_rand_passage_nll,
        'sfx_rand_tokens_nll': sfx_rand_tokens_nll,

        # Group C
        'sfx_raw_nll': sfx_raw_nll,
        'sfx_newline_nll': sfx_newline_nll,
        'sfx_multi_q_nll': sfx_multi_q_nll,

        # Group D
        'pfx_trunc_routed_nll': pfx_trunc_routed_nll,
        'pfx_trunc_perfect_nll': pfx_trunc_perfect_nll,
        'pfx_full_ctx_nll': pfx_full_ctx_nll,

        # Group E
        'pool_sims': pool_sims,
        'pool_answer_nlls': pool_answer_nlls,
        'sfx_pool_cosine_key': pool_cosine_key,
        'sfx_pool_cosine_nll': sfx_pool_cosine_nll,
        'sfx_pool_oracle_key': pool_oracle_key,
        'sfx_pool_oracle_nll': sfx_pool_oracle_nll,
        'sfx_summary_nll': sfx_summary_nll,
        'summary_text': summary,
    }


print("evaluate_sample() defined — 18 conditions.")

## Run Experiment

In [None]:
errors = 0
start_time = time.time()

print("=" * 80)
print("RUNNING EXPERIMENT 07: SUFFIX PRIMING AND ADVANCED CACHE STRATEGIES")
print(f"Samples: {len(samples)}, Conditions per sample: 18")
if start_from > 0:
    print(f"Resuming from sample {start_from}")
print("=" * 80)

for idx in tqdm(range(start_from, len(samples)), desc="Evaluating"):
    sample = samples[idx]
    try:
        result = evaluate_sample(
            sample, idx, samples, model, tokenizer, embed_model, config
        )
        if result is not None:
            results.append(result)
    except Exception as e:
        errors += 1
        if errors <= 5:
            print(f"\n  Error on sample {idx}: {type(e).__name__}: {e}")
        continue

    # Progress + checkpoint every 10 samples
    if len(results) > 0 and len(results) % 10 == 0:
        elapsed = time.time() - start_time
        rate = elapsed / (len(results) - start_from) if len(results) > start_from else 1
        remaining = rate * (len(samples) - idx - 1)

        recent = results[-min(10, len(results)):]
        bare_mean = np.mean([r['bare_nll'] for r in recent])
        sfx_gen_mean = np.mean([r['sfx_gen_routed_nll'] for r in recent])
        sfx_irrel_mean = np.mean([r['sfx_irrel_nll'] for r in recent])
        wr_gen = np.mean([r['bare_nll'] - r['sfx_gen_routed_nll'] > 0 for r in recent]) * 100

        if len(results) % 50 == 0:
            print(
                f"\n  [{len(results):>4d} done | {elapsed/60:.0f}m elapsed | ~{remaining/60:.0f}m left]"
                f"\n  Last batch: bare={bare_mean:.3f}  sfx_gen={sfx_gen_mean:.3f} ({wr_gen:.0f}% win)"
                f"  sfx_irrel={sfx_irrel_mean:.3f}"
            )

        # Checkpoint
        with open(checkpoint_path, 'w') as f:
            json.dump({
                'n_done': len(results),
                'n_errors': errors,
                'elapsed': time.time() - start_time,
                'results': results,
            }, f, default=str)

elapsed_total = time.time() - start_time
print(f"\nDone. {len(results)} evaluated, {errors} errors.")
print(f"Total time: {elapsed_total/60:.1f} minutes ({elapsed_total/len(results) if results else 0:.1f}s per sample)")

## Primary Results

Summary table: all 18 conditions vs bare baseline.

In [None]:
bare_arr = np.array([r['bare_nll'] for r in results])
n = len(results)

conditions = [
    (' 1. bare (BASELINE)',         'bare_nll'),
    (' 2. bare_padded',             'padded_nll'),
    (' 3. sfx_gen_routed',          'sfx_gen_routed_nll'),
    (' 4. sfx_gen_oracle',          'sfx_gen_oracle_nll'),
    (' 5. sfx_perfect',             'sfx_perfect_nll'),
    (' 6. sfx_irrel',               'sfx_irrel_nll'),
    (' 7. sfx_shuffled',            'sfx_shuffled_nll'),
    (' 8. sfx_rand_passage',        'sfx_rand_passage_nll'),
    (' 9. sfx_rand_tokens',         'sfx_rand_tokens_nll'),
    ('10. sfx_raw',                 'sfx_raw_nll'),
    ('11. sfx_newline',             'sfx_newline_nll'),
    ('12. sfx_multi_q',             'sfx_multi_q_nll'),
    ('13. pfx_trunc_routed',        'pfx_trunc_routed_nll'),
    ('14. pfx_trunc_perfect',       'pfx_trunc_perfect_nll'),
    ('15. pfx_full_ctx',            'pfx_full_ctx_nll'),
    ('16. sfx_pool_cosine',         'sfx_pool_cosine_nll'),
    ('17. sfx_pool_oracle',         'sfx_pool_oracle_nll'),
    ('18. sfx_summary',             'sfx_summary_nll'),
]

print("=" * 130)
print(f"ALL 18 CONDITIONS vs BARE BASELINE  (N = {n})")
print("Positive delta = better than bare baseline.  Sorted by mean NLL.")
print("=" * 130)
header = f"{'Condition':<28} {'Mean NLL':>10} {'Std':>8} {'Delta':>10} {'Win%':>8} {'t-stat':>8} {'p-value':>12} {'Cohen d':>10}"
print(header)
print("-" * 130)

rows = []
for label, key in conditions:
    arr = np.array([r[key] for r in results])
    delta = bare_arr - arr
    mean_nll = np.mean(arr)
    std_nll = np.std(arr)
    if key == 'bare_nll':
        rows.append((mean_nll, label, std_nll, '--', '--', '--', '--', '--'))
    else:
        t, p = stats.ttest_rel(bare_arr, arr)
        d = cohens_d(delta)
        win_rate = np.mean(delta > 0) * 100
        rows.append((mean_nll, label, std_nll, f"{np.mean(delta):+.4f}", f"{win_rate:.1f}%",
                      f"{t:.3f}", f"{p:.6f}", f"{d:.4f}"))

rows.sort(key=lambda x: x[0])
for mean_nll, label, std_nll, *rest in rows:
    vals = [f"{mean_nll:>10.4f}", f"{std_nll:>8.4f}"] + [f"{v:>12}" if i >= 4 else f"{v:>10}" for i, v in enumerate(rest)]
    print(f"{label:<28} {' '.join(vals)}")

## Semantic Isolation Tests

**The decisive test**: Correlation of per-sample deltas between `sfx_gen_routed` and controls.
If r << 0.9 (unlike prefix r=0.924 in exp 06), the suffix effect is content-dependent.

In [None]:
alpha = 0.01
n_key_tests = 5
bonferroni_alpha = alpha / n_key_tests

print("=" * 110)
print("SEMANTIC ISOLATION: KEY PAIRWISE COMPARISONS (SUFFIX)")
print(f"Bonferroni-corrected alpha = {alpha}/{n_key_tests} = {bonferroni_alpha:.4f}")
print("=" * 110)

# Pairwise tests
pairwise = [
    ("sfx_gen_routed (3) vs bare (1)",          'sfx_gen_routed_nll', 'bare_nll',
     "Does suffix priming help at all?"),
    ("sfx_gen_routed (3) vs sfx_irrel (6)",     'sfx_gen_routed_nll', 'sfx_irrel_nll',
     "DECISIVE: Semantic or structural?"),
    ("sfx_gen_routed (3) vs sfx_shuffled (7)",  'sfx_gen_routed_nll', 'sfx_shuffled_nll',
     "Does word order matter in suffix?"),
    ("sfx_gen_routed (3) vs sfx_rand_pass (8)", 'sfx_gen_routed_nll', 'sfx_rand_passage_nll',
     "Does topic relevance matter?"),
    ("sfx_gen_routed (3) vs sfx_rand_tok (9)",  'sfx_gen_routed_nll', 'sfx_rand_tokens_nll',
     "Does coherence matter?"),
]

print(f"\n{'Comparison':<46} {'Mean A':>8} {'Mean B':>8} {'Delta':>8} {'t':>8} {'p':>12} {'d':>8} {'Sig?':>6}")
print("-" * 110)

for label, key_a, key_b, question in pairwise:
    a = np.array([r[key_a] for r in results])
    b = np.array([r[key_b] for r in results])
    diff = b - a  # positive = A is better (lower NLL)
    t, p = stats.ttest_rel(a, b)
    d = cohens_d(diff)
    sig = "YES" if p < bonferroni_alpha else "no"
    print(f"{label:<46} {np.mean(a):>8.4f} {np.mean(b):>8.4f} {np.mean(diff):>+8.4f} {t:>8.3f} {p:>12.6f} {d:>8.4f} {sig:>6}")
    print(f"  Q: {question}")

# THE KEY TEST: Correlation of per-sample deltas
print(f"\n{'=' * 110}")
print("CORRELATION OF PER-SAMPLE DELTAS (suffix vs prefix comparison)")
print("If suffix r << 0.9, the suffix effect is content-dependent (unlike prefix r=0.924)")
print("=" * 110)

sfx_gen_deltas = np.array([r['bare_nll'] - r['sfx_gen_routed_nll'] for r in results])
sfx_shuf_deltas = np.array([r['bare_nll'] - r['sfx_shuffled_nll'] for r in results])
sfx_irrel_deltas = np.array([r['bare_nll'] - r['sfx_irrel_nll'] for r in results])
sfx_rand_p_deltas = np.array([r['bare_nll'] - r['sfx_rand_passage_nll'] for r in results])
sfx_rand_t_deltas = np.array([r['bare_nll'] - r['sfx_rand_tokens_nll'] for r in results])

correlations = [
    ("sfx_gen_routed vs sfx_shuffled",     sfx_gen_deltas, sfx_shuf_deltas),
    ("sfx_gen_routed vs sfx_irrel",        sfx_gen_deltas, sfx_irrel_deltas),
    ("sfx_gen_routed vs sfx_rand_passage", sfx_gen_deltas, sfx_rand_p_deltas),
    ("sfx_gen_routed vs sfx_rand_tokens",  sfx_gen_deltas, sfx_rand_t_deltas),
]

print(f"\n{'Delta correlation':<42} {'r':>8} {'p':>12}")
print("-" * 65)
for label, x, y in correlations:
    r, p = stats.pearsonr(x, y)
    print(f"{label:<42} {r:>8.4f} {p:>12.6f}")

print(f"\nPrefix reference (exp 06): r=0.924 between gen_routed and shuffled deltas")
print(f"Suffix result:             r={stats.pearsonr(sfx_gen_deltas, sfx_shuf_deltas)[0]:.4f}")

## Suffix vs Prefix Comparison

Paired comparison: same surrogate, different placement.

In [None]:
print("=" * 110)
print("SUFFIX vs PREFIX: SAME SURROGATE, DIFFERENT PLACEMENT")
print("=" * 110)

comparisons = [
    ("sfx_gen_routed (3) vs pfx_trunc_routed (13)", 'sfx_gen_routed_nll', 'pfx_trunc_routed_nll',
     "Suffix vs truncated prefix (same routed surrogate)"),
    ("sfx_perfect (5) vs pfx_trunc_perfect (14)",    'sfx_perfect_nll',    'pfx_trunc_perfect_nll',
     "Suffix vs prefix oracle (actual query)"),
    ("sfx_gen_routed (3) vs pfx_full_ctx (15)",      'sfx_gen_routed_nll', 'pfx_full_ctx_nll',
     "Suffix vs full-context prefix (strongest prefix)"),
]

print(f"\n{'Comparison':<52} {'Mean A':>8} {'Mean B':>8} {'Delta':>8} {'t':>8} {'p':>12} {'d':>8}")
print("-" * 110)

for label, key_a, key_b, question in comparisons:
    a = np.array([r[key_a] for r in results])
    b = np.array([r[key_b] for r in results])
    diff = b - a  # positive = A (suffix) is better
    t, p = stats.ttest_rel(a, b)
    d = cohens_d(diff)
    print(f"{label:<52} {np.mean(a):>8.4f} {np.mean(b):>8.4f} {np.mean(diff):>+8.4f} {t:>8.3f} {p:>12.6f} {d:>8.4f}")
    print(f"  Q: {question}")

## Format Sensitivity

Conditions 3 vs 10 vs 11 — does the model need structural cues?

In [None]:
print("=" * 110)
print("FORMAT SENSITIVITY: Does the suffix separator matter?")
print("=" * 110)

format_conditions = [
    ("sfx_gen_routed (3: 'Related question:')", 'sfx_gen_routed_nll'),
    ("sfx_raw (10: no separator)",               'sfx_raw_nll'),
    ("sfx_newline (11: just newlines)",          'sfx_newline_nll'),
    ("sfx_multi_q (12: all 5 queries)",          'sfx_multi_q_nll'),
]

print(f"\n{'Condition':<46} {'Mean NLL':>10} {'Delta vs bare':>14} {'Win%':>8} {'d':>8}")
print("-" * 90)

for label, key in format_conditions:
    arr = np.array([r[key] for r in results])
    delta = bare_arr - arr
    t, p = stats.ttest_rel(bare_arr, arr)
    d = cohens_d(delta)
    wr = np.mean(delta > 0) * 100
    print(f"{label:<46} {np.mean(arr):>10.4f} {np.mean(delta):>+14.4f} {wr:>7.1f}% {d:>8.4f}")

# Pairwise format comparisons
print(f"\nPairwise format comparisons:")
fmt_pairs = [
    ("Related question (3) vs Raw (10)",     'sfx_gen_routed_nll', 'sfx_raw_nll'),
    ("Related question (3) vs Newline (11)", 'sfx_gen_routed_nll', 'sfx_newline_nll'),
    ("Related question (3) vs Multi-Q (12)", 'sfx_gen_routed_nll', 'sfx_multi_q_nll'),
]
for label, ka, kb in fmt_pairs:
    a = np.array([r[ka] for r in results])
    b = np.array([r[kb] for r in results])
    t, p = stats.ttest_rel(a, b)
    diff = np.mean(b - a)
    print(f"  {label}: delta={diff:+.4f}, t={t:.3f}, p={p:.6f}")

## Suffix Pool Routing and Summary Condition

In [None]:
print("=" * 110)
print("SUFFIX POOL ROUTING")
print("=" * 110)

pool_conditions = [
    ("bare (baseline)",      'bare_nll'),
    ("sfx_pool_cosine (16)", 'sfx_pool_cosine_nll'),
    ("sfx_pool_oracle (17)", 'sfx_pool_oracle_nll'),
    ("sfx_summary (18)",     'sfx_summary_nll'),
]

print(f"\n{'Condition':<30} {'Mean NLL':>10} {'Delta vs bare':>14} {'Win%':>8} {'d':>8} {'p':>12}")
print("-" * 90)

for label, key in pool_conditions:
    arr = np.array([r[key] for r in results])
    delta = bare_arr - arr
    if key == 'bare_nll':
        print(f"{label:<30} {np.mean(arr):>10.4f} {'--':>14} {'--':>8} {'--':>8} {'--':>12}")
    else:
        t, p = stats.ttest_rel(bare_arr, arr)
        d = cohens_d(delta)
        wr = np.mean(delta > 0) * 100
        print(f"{label:<30} {np.mean(arr):>10.4f} {np.mean(delta):>+14.4f} {wr:>7.1f}% {d:>8.4f} {p:>12.6f}")

# Oracle ceiling comparison
oracle_delta = np.mean(bare_arr - np.array([r['sfx_pool_oracle_nll'] for r in results]))
print(f"\nSuffix pool oracle delta: {oracle_delta:+.4f}")
print(f"(Compare to exp 06 prefix pool oracle delta: ~0.60)")

## Stratified Analysis by Difficulty

In [None]:
quartiles = np.percentile(bare_arr, [25, 50, 75])
difficulty_bins = [
    ('Q1 (easiest)', lambda x: x <= quartiles[0]),
    ('Q2',           lambda x: quartiles[0] < x <= quartiles[1]),
    ('Q3',           lambda x: quartiles[1] < x <= quartiles[2]),
    ('Q4 (hardest)', lambda x: x > quartiles[2]),
]

print("=" * 110)
print("STRATIFIED ANALYSIS BY DIFFICULTY QUARTILE")
print("=" * 110)

key_conditions = [
    ('sfx_gen_routed', 'sfx_gen_routed_nll'),
    ('sfx_irrel',      'sfx_irrel_nll'),
    ('sfx_shuffled',   'sfx_shuffled_nll'),
    ('sfx_perfect',    'sfx_perfect_nll'),
    ('pfx_trunc_rou',  'pfx_trunc_routed_nll'),
]

header = f"{'Quartile':<16}"
for name, _ in key_conditions:
    header += f" {name:>16}"
print(header)
print("-" * 110)

for bin_label, cond in difficulty_bins:
    subset = [r for r in results if cond(r['bare_nll'])]
    if not subset:
        continue
    row = f"{bin_label:<16}"
    for _, key in key_conditions:
        deltas = [r['bare_nll'] - r[key] for r in subset]
        wr = np.mean([d > 0 for d in deltas]) * 100
        row += f" {wr:>15.1f}%"
    print(row)

## Visualization

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(22, 18))
fig.suptitle('Experiment 07: Suffix Priming and Advanced Cache Strategies', fontsize=16, fontweight='bold')

# --- (0,0) All suffix conditions bar chart ---
ax = axes[0, 0]
cond_order = [
    ('Bare', 'bare_nll'),
    ('Padded', 'padded_nll'),
    ('Sfx\nRouted', 'sfx_gen_routed_nll'),
    ('Sfx\nOracle', 'sfx_gen_oracle_nll'),
    ('Sfx\nPerfect', 'sfx_perfect_nll'),
    ('Sfx\nIrrel', 'sfx_irrel_nll'),
    ('Sfx\nShuf', 'sfx_shuffled_nll'),
    ('Sfx\nRndP', 'sfx_rand_passage_nll'),
    ('Sfx\nRndT', 'sfx_rand_tokens_nll'),
]
means = [np.mean([r[k] for r in results]) for _, k in cond_order]
sems = [stats.sem([r[k] for r in results]) for _, k in cond_order]
colors = ['#888888', '#aaaaaa', '#4c72b0', '#2a5298', '#55a868',
          '#c44e52', '#dd8452', '#8c564b', '#7f7f7f']
ax.bar(range(len(cond_order)), means, yerr=sems, color=colors, capsize=3)
ax.set_xticks(range(len(cond_order)))
ax.set_xticklabels([l for l, _ in cond_order], fontsize=7)
ax.set_ylabel('Mean NLL')
ax.set_title('Group A+B: Suffix Conditions')

# --- (0,1) Semantic isolation: delta correlations scatter ---
ax = axes[0, 1]
sfx_gen_deltas = np.array([r['bare_nll'] - r['sfx_gen_routed_nll'] for r in results])
sfx_shuf_deltas = np.array([r['bare_nll'] - r['sfx_shuffled_nll'] for r in results])
r_val, _ = stats.pearsonr(sfx_gen_deltas, sfx_shuf_deltas)
ax.scatter(sfx_gen_deltas, sfx_shuf_deltas, alpha=0.2, s=8, c='#4c72b0')
lims = [min(sfx_gen_deltas.min(), sfx_shuf_deltas.min()), max(sfx_gen_deltas.max(), sfx_shuf_deltas.max())]
ax.plot(lims, lims, 'r--', linewidth=1, alpha=0.5)
ax.set_xlabel('Gen routed delta')
ax.set_ylabel('Shuffled delta')
ax.set_title(f'Semantic Isolation: r={r_val:.3f}\n(prefix ref: r=0.924)')

# --- (0,2) Suffix vs prefix ---
ax = axes[0, 2]
sfx_pfx_labels = ['Sfx\nRouted', 'Pfx\nTrunc', 'Pfx\nFull']
sfx_pfx_keys = ['sfx_gen_routed_nll', 'pfx_trunc_routed_nll', 'pfx_full_ctx_nll']
sfx_pfx_means = [np.mean([r[k] for r in results]) for k in sfx_pfx_keys]
sfx_pfx_sems = [stats.sem([r[k] for r in results]) for k in sfx_pfx_keys]
sfx_pfx_colors = ['#4c72b0', '#dd8452', '#55a868']
ax.bar(range(3), sfx_pfx_means, yerr=sfx_pfx_sems, color=sfx_pfx_colors, capsize=3)
ax.axhline(np.mean(bare_arr), color='#888888', linestyle='--', linewidth=1, label='Bare')
ax.set_xticks(range(3))
ax.set_xticklabels(sfx_pfx_labels, fontsize=8)
ax.set_ylabel('Mean NLL')
ax.set_title('Suffix vs Prefix (same surrogate)')
ax.legend(fontsize=7)

# --- (1,0) Format sensitivity ---
ax = axes[1, 0]
fmt_labels = ['Related Q\n(default)', 'Raw\n(no sep)', 'Newline\nonly', 'Multi-Q\n(all 5)']
fmt_keys = ['sfx_gen_routed_nll', 'sfx_raw_nll', 'sfx_newline_nll', 'sfx_multi_q_nll']
fmt_means = [np.mean([r[k] for r in results]) for k in fmt_keys]
fmt_sems = [stats.sem([r[k] for r in results]) for k in fmt_keys]
fmt_colors = ['#4c72b0', '#c44e52', '#dd8452', '#55a868']
ax.bar(range(4), fmt_means, yerr=fmt_sems, color=fmt_colors, capsize=3)
ax.axhline(np.mean(bare_arr), color='#888888', linestyle='--', linewidth=1, label='Bare')
ax.set_xticks(range(4))
ax.set_xticklabels(fmt_labels, fontsize=7)
ax.set_ylabel('Mean NLL')
ax.set_title('Group C: Format Sensitivity')
ax.legend(fontsize=7)

# --- (1,1) Pool routing ---
ax = axes[1, 1]
pool_labels_viz = ['Bare', 'Pool\nCosine', 'Pool\nOracle', 'Summary']
pool_keys_viz = ['bare_nll', 'sfx_pool_cosine_nll', 'sfx_pool_oracle_nll', 'sfx_summary_nll']
pool_means_viz = [np.mean([r[k] for r in results]) for k in pool_keys_viz]
pool_sems_viz = [stats.sem([r[k] for r in results]) for k in pool_keys_viz]
pool_colors_viz = ['#888888', '#4c72b0', '#2a5298', '#55a868']
ax.bar(range(4), pool_means_viz, yerr=pool_sems_viz, color=pool_colors_viz, capsize=3)
ax.set_xticks(range(4))
ax.set_xticklabels(pool_labels_viz, fontsize=8)
ax.set_ylabel('Mean NLL')
ax.set_title('Group E: Pool Routing + Summary')

# --- (1,2) Win rate by difficulty quartile ---
ax = axes[1, 2]
q_labels_list = []
q_sfx_gen_wr = []
q_sfx_irrel_wr = []
q_pfx_wr = []
for label, cond in difficulty_bins:
    subset = [r for r in results if cond(r['bare_nll'])]
    if not subset:
        continue
    q_labels_list.append(label.replace(' ', '\n'))
    q_sfx_gen_wr.append(np.mean([r['bare_nll'] - r['sfx_gen_routed_nll'] > 0 for r in subset]) * 100)
    q_sfx_irrel_wr.append(np.mean([r['bare_nll'] - r['sfx_irrel_nll'] > 0 for r in subset]) * 100)
    q_pfx_wr.append(np.mean([r['bare_nll'] - r['pfx_trunc_routed_nll'] > 0 for r in subset]) * 100)

x_pos = np.arange(len(q_labels_list))
w = 0.25
ax.bar(x_pos - w, q_sfx_gen_wr, w, label='Sfx Gen', color='#4c72b0')
ax.bar(x_pos, q_sfx_irrel_wr, w, label='Sfx Irrel', color='#c44e52')
ax.bar(x_pos + w, q_pfx_wr, w, label='Pfx Trunc', color='#dd8452')
ax.axhline(50, color='black', linestyle='--', linewidth=0.8, alpha=0.3)
ax.set_xticks(x_pos)
ax.set_xticklabels(q_labels_list, fontsize=7)
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('Win Rate by Difficulty')
ax.legend(fontsize=7)
ax.set_ylim(0, 100)

# --- (2,0) Similarity vs suffix delta scatter ---
ax = axes[2, 0]
gen_sims = np.array([r['sfx_gen_routed_sim'] for r in results])
gen_deltas_viz = sfx_gen_deltas
r_sim, _ = stats.pearsonr(gen_sims, gen_deltas_viz)
ax.scatter(gen_sims, gen_deltas_viz, alpha=0.2, s=8, c='#4c72b0')
z = np.polyfit(gen_sims, gen_deltas_viz, 1)
p_line = np.poly1d(z)
x_range = np.linspace(gen_sims.min(), gen_sims.max(), 100)
ax.plot(x_range, p_line(x_range), 'r-', linewidth=1.5)
ax.axhline(0, color='black', linestyle='--', linewidth=0.8)
ax.set_xlabel('Surrogate-Query Cosine Similarity')
ax.set_ylabel('Suffix Delta NLL')
ax.set_title(f'Similarity vs Suffix Delta (r={r_sim:.3f})')

# --- (2,1) Per-template NLL (suffix) ---
ax = axes[2, 1]
gen_keys = list(TOP_5_SURROGATE_TEMPLATES.keys())
gen_template_means = [np.mean([r['sfx_gen_nlls'][k] for r in results]) for k in gen_keys]
gen_template_sems = [stats.sem([r['sfx_gen_nlls'][k] for r in results]) for k in gen_keys]
ax.bar(range(len(gen_keys)), gen_template_means, yerr=gen_template_sems,
       color='#4c72b0', capsize=3)
ax.axhline(np.mean(bare_arr), color='#888888', linestyle='--', linewidth=1, label='Bare')
ax.set_xticks(range(len(gen_keys)))
ax.set_xticklabels([k.replace('_', '\n') for k in gen_keys], fontsize=6)
ax.set_ylabel('Mean NLL')
ax.set_title('Per-Template: Suffix Generated')
ax.legend(fontsize=7)

# --- (2,2) Padded control ---
ax = axes[2, 2]
padded_delta = bare_arr - np.array([r['padded_nll'] for r in results])
sfx_gen_delta = bare_arr - np.array([r['sfx_gen_routed_nll'] for r in results])
bp = ax.boxplot([padded_delta, sfx_gen_delta], labels=['Padded\n(length ctrl)', 'Sfx Gen\nRouted'],
                patch_artist=True)
bp['boxes'][0].set_facecolor('#aaaaaa')
bp['boxes'][1].set_facecolor('#4c72b0')
for box in bp['boxes']:
    box.set_alpha(0.7)
ax.axhline(0, color='black', linestyle='--', linewidth=0.8)
ax.set_ylabel('Delta NLL vs Bare')
ax.set_title('Length Control: Padded vs Suffix')

plt.tight_layout()
plt.savefig('07_suffix_priming_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: 07_suffix_priming_results.png")

## Save Results

In [None]:
os.makedirs('results', exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

output = {
    'metadata': {
        'experiment': '07_suffix_priming_experiment',
        'description': (
            'Suffix priming and advanced cache strategies: 18 conditions testing '
            'suffix placement (passage + suffix), format sensitivity, prefix comparison, '
            'suffix routing pool, and summary-as-suffix. Key test: semantic isolation '
            'via delta correlations (comparing to prefix r=0.924).'
        ),
        'timestamp': datetime.datetime.now().isoformat(),
        'model_name': config.model_name,
        'num_samples_requested': config.num_samples,
        'num_samples_evaluated': len(results),
        'num_errors': errors,
        'elapsed_seconds': elapsed_total,
        'seed': config.seed,
    },
    'results': results,
}

# Save timestamped copy
output_path_ts = f'results/07_suffix_priming_results_{timestamp}.json'
with open(output_path_ts, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Timestamped results: {output_path_ts}")

# Save canonical copy
output_path = '07_suffix_priming_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Canonical results:   {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1e6:.1f} MB")

## Conclusions

In [None]:
print("=" * 80)
print("AUTOMATED VERDICTS")
print("=" * 80)

alpha = 0.01
bonferroni_alpha = alpha / 5
verdicts = {}

# H1: Does suffix priming help?
sfx_gen_arr = np.array([r['sfx_gen_routed_nll'] for r in results])
t, p = stats.ttest_rel(bare_arr, sfx_gen_arr)
verdicts['H1: Suffix priming helps'] = {
    'test': 'sfx_gen_routed (3) vs bare (1)',
    'delta': float(np.mean(bare_arr - sfx_gen_arr)),
    't': float(t), 'p': float(p),
    'verdict': 'SUPPORTED' if p < bonferroni_alpha and np.mean(sfx_gen_arr) < np.mean(bare_arr) else 'NOT SUPPORTED'
}

# H2: Suffix effect is semantic (gen vs irrel)
sfx_irrel_arr = np.array([r['sfx_irrel_nll'] for r in results])
t, p = stats.ttest_rel(sfx_gen_arr, sfx_irrel_arr)
verdicts['H2: Suffix semantic content matters'] = {
    'test': 'sfx_gen_routed (3) vs sfx_irrel (6)',
    'delta': float(np.mean(sfx_irrel_arr - sfx_gen_arr)),
    't': float(t), 'p': float(p),
    'verdict': 'SUPPORTED' if p < bonferroni_alpha and np.mean(sfx_gen_arr) < np.mean(sfx_irrel_arr) else 'NOT SUPPORTED'
}

# H3: Delta correlation is low (content-dependent, unlike prefix)
sfx_gen_deltas_v = np.array([r['bare_nll'] - r['sfx_gen_routed_nll'] for r in results])
sfx_shuf_deltas_v = np.array([r['bare_nll'] - r['sfx_shuffled_nll'] for r in results])
r_corr, p_corr = stats.pearsonr(sfx_gen_deltas_v, sfx_shuf_deltas_v)
verdicts['H3: Suffix effect is content-dependent (r < 0.5)'] = {
    'test': 'Correlation of per-sample deltas: sfx_gen_routed vs sfx_shuffled',
    'r': float(r_corr), 'p': float(p_corr),
    'prefix_reference': 0.924,
    'verdict': 'SUPPORTED' if r_corr < 0.5 else ('PARTIAL' if r_corr < 0.8 else 'NOT SUPPORTED')
}

# H4: sfx_perfect beats bare significantly
sfx_perf_arr = np.array([r['sfx_perfect_nll'] for r in results])
t, p = stats.ttest_rel(bare_arr, sfx_perf_arr)
d = cohens_d(bare_arr - sfx_perf_arr)
verdicts['H4: sfx_perfect beats bare (d > 0.3)'] = {
    'test': 'sfx_perfect (5) vs bare (1)',
    'delta': float(np.mean(bare_arr - sfx_perf_arr)),
    'd': float(d), 't': float(t), 'p': float(p),
    'verdict': 'SUPPORTED' if p < 0.01 and d > 0.3 else 'NOT SUPPORTED'
}

# H5: Suffix vs prefix — suffix is better or comparable
pfx_arr = np.array([r['pfx_trunc_routed_nll'] for r in results])
t, p = stats.ttest_rel(sfx_gen_arr, pfx_arr)
verdicts['H5: Suffix >= prefix (same surrogate)'] = {
    'test': 'sfx_gen_routed (3) vs pfx_trunc_routed (13)',
    'sfx_mean': float(np.mean(sfx_gen_arr)),
    'pfx_mean': float(np.mean(pfx_arr)),
    'delta': float(np.mean(pfx_arr - sfx_gen_arr)),
    't': float(t), 'p': float(p),
    'verdict': 'SUPPORTED' if np.mean(sfx_gen_arr) <= np.mean(pfx_arr) else 'NOT SUPPORTED'
}

# Print
for hyp, v in verdicts.items():
    sig_marker = "***" if v.get('p', 1) < 0.01 else "   "
    print(f"\n{sig_marker} {hyp}: {v['verdict']}")
    print(f"    {v['test']}")
    if 'r' in v:
        print(f"    r={v['r']:.4f} (prefix ref: {v['prefix_reference']})")
    elif 'd' in v:
        print(f"    d={v['d']:.4f}, t={v['t']:.3f}, p={v['p']:.6f}")
    else:
        print(f"    delta={v.get('delta', 'N/A')}, t={v['t']:.3f}, p={v['p']:.6f}")

# Overall assessment
print(f"\n{'=' * 80}")
print("OVERALL ASSESSMENT")
print("=" * 80)
n_supported = sum(1 for v in verdicts.values() if v['verdict'] == 'SUPPORTED')
print(f"Hypotheses supported: {n_supported}/{len(verdicts)}")

# Success criteria check
print(f"\nSuccess criteria:")
h4_ok = verdicts['H4: sfx_perfect beats bare (d > 0.3)']['verdict'] == 'SUPPORTED'
h2_ok = verdicts['H2: Suffix semantic content matters']['verdict'] == 'SUPPORTED'
h3_ok = verdicts['H3: Suffix effect is content-dependent (r < 0.5)']['verdict'] == 'SUPPORTED'

if h4_ok and h2_ok and h3_ok:
    print("  STRONG SUCCESS: sfx_perfect helps AND gen_routed beats controls AND low correlation")
elif h4_ok and (h2_ok or h3_ok):
    print("  PARTIAL SUCCESS: sfx_perfect helps, some semantic signal")
elif h4_ok:
    print("  PARTIAL: sfx_perfect helps but no semantic separation")
else:
    print("  FAILURE: suffix approach does not produce meaningful improvements")