# Experiment 10: Diagnosing the Truncation Mechanism

## The Mystery

Across experiments 05, 06, and 09, truncated+corrected KV caches consistently beat bare caches — even when the truncated prefix is **irrelevant** text. RoPE correction restores key positions, so the corrected cache should be equivalent to a bare cache. But it isn't.

## Three Competing Hypotheses

**H1 — Value Vector Contamination (leading):** During forward pass through `[prefix][document]`, document tokens attend to prefix tokens. Their **value vectors** encode this attention. After truncation, prefix KV entries are removed but document values retain the "fingerprint." RoPE correction only fixes keys — values are never corrected.

**H2 — Float16 RoPE Residual:** Applying RoPE(+S) then RoPE(-S) in float16 is not identity — introduces ~2e-3 max error per element. This key perturbation might act as beneficial noise.

**H3 — BOS Contamination:** The BOS token (position 0) is preserved during truncation. In the truncated cache, BOS attended to prefix tokens during the forward pass, so its KV entry differs from bare BOS. This contaminated BOS could be the benefit channel.

## Discriminating Predictions

| If H1 (values) | Cond 3 wins ~70% | Cond 4 ~50% | Cond 5 ~50% |
| If H2 (RoPE noise) | Cond 3 ~50% | Cond 4 ~70% | Cond 5 ~70% |
| If H3 (BOS) | Cond 6 drops | Cond 7 ~70% | |

## Important methodological notes

**Investigation A limitation:** The hybrid cache surgery (mixing keys from one forward pass with values from another) is inherently destructive. Keys and values are co-adapted within a single forward pass. Mixing them destroys this co-adaptation, producing catastrophic NLL (~4.0 vs ~1.5 baseline). This means Investigation A **cannot discriminate between hypotheses** — the surgery artifact dominates any signal.

**Investigation C (layer ablation)** is a gentler test: replacing values at a *single layer* at a time, which preserves most of the co-adapted structure. This is the primary tool for understanding the mechanism.

**BPE matching:** All investigations use `build_matched_caches` to ensure bare and truncated caches have identical token sequences, avoiding BPE boundary artifacts.

In [None]:
import sys
import os
import json
import time
import copy
import datetime
from typing import Dict, List, Any, Optional, Tuple

import torch
import numpy as np
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 120

sys.path.insert(0, '.')

from lib import (
    ExperimentConfig,
    build_kv_cache,
    score_answer_with_cache,
    score_answer_with_cache_and_attention,
    build_truncated_kv_cache_corrected,
    build_hybrid_cache,
    swap_bos_entry,
    apply_rope_roundtrip_noise,
    replace_values_at_layers,
    build_truncated_cache_variable_prefix,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    load_evaluation_samples,
    load_ms_marco,
    _ensure_dynamic_cache,
)

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Configuration
config = ExperimentConfig(
    num_samples=2500,
    min_passage_words=50,
    max_passage_words=300,
    seed=42,
)

# Set seeds
torch.manual_seed(config.seed)
np.random.seed(config.seed)

# Load model (4-bit quantized)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

print(f"Loading {config.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
)
model.eval()
print(f"Model loaded. Layers: {model.config.num_hidden_layers}")
NUM_LAYERS = model.config.num_hidden_layers

In [None]:
# Load dataset
dataset = load_ms_marco(config)
all_samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Loaded {len(all_samples)} evaluation samples")

# We'll use subsets for different investigations
samples_a = all_samples[:200]  # Investigation A: 200 samples
samples_b = all_samples[:100]  # Investigation B: 100 samples
samples_c_div = all_samples[:50]   # Investigation C divergence: 50 samples
samples_c_abl = all_samples[:100]  # Investigation C ablation: 100 samples
samples_d = all_samples[:30]   # Investigation D: 30 samples

print(f"Investigation A: {len(samples_a)} samples")
print(f"Investigation B: {len(samples_b)} samples")
print(f"Investigation C (divergence): {len(samples_c_div)} samples")
print(f"Investigation C (ablation): {len(samples_c_abl)} samples")
print(f"Investigation D: {len(samples_d)} samples")

In [None]:
# Irrelevant prefix text for truncation experiments
IRRELEVANT_PREFIX = (
    "The quick brown fox jumps over the lazy dog. "
    "Pack my box with five dozen liquor jugs. "
    "How vexingly quick daft zebras jump."
)

# Import cache accessors
from lib.kv_cache import _get_cache_keys, _get_cache_values


def build_bare_cache(passage, model, tokenizer, config):
    """Build bare document cache with standard framing. Returns (len, DynamicCache)."""
    ctx = config.baseline_cache_template.format(document=passage)
    length, cache = build_kv_cache(ctx, model, tokenizer, config)
    return length, _ensure_dynamic_cache(cache)


def build_trunc_corrected_cache(passage, model, tokenizer, config,
                                prefix_text=IRRELEVANT_PREFIX):
    """Build a truncated+corrected cache (standalone, no matched bare cache).

    Used by Investigations C and D where we don't need hybrid caches
    and thus don't need BPE-matched bare/trunc pairs.

    Returns: (keep_len, trunc_corrected, trunc_uncorrected, offset)
    """
    document_text = f"Document:\n{passage}"
    prefix_with_sep = prefix_text + " "

    prefix_encoding = tokenizer(
        prefix_with_sep, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    prefix_len = prefix_encoding['input_ids'].shape[1]

    full_context = prefix_with_sep + document_text
    full_encoding = tokenizer(
        full_context, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    full_ids = full_encoding['input_ids'].to(config.device)
    full_len = full_ids.shape[1]
    doc_len = full_len - prefix_len

    with torch.no_grad():
        full_out = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )

    truncated = extract_and_truncate_cache_with_bos(full_out.past_key_values, doc_len)
    keep_len = 1 + doc_len

    trunc_uncorrected = DynamicCache()
    for li in range(len(truncated)):
        trunc_uncorrected.update(
            _get_cache_keys(truncated, li).clone(),
            _get_cache_values(truncated, li).clone(),
            li
        )

    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated, surrogate_offset, model)

    return keep_len, truncated, trunc_uncorrected, surrogate_offset


def build_matched_caches(passage, model, tokenizer, config,
                         prefix_text=IRRELEVANT_PREFIX):
    """Build bare and truncated+corrected caches with identical token sequences.

    To ensure both caches have exactly the same document tokens (avoiding BPE
    boundary mismatches), we:
    1. Tokenize the full [prefix + document] context
    2. Extract the document token IDs from that encoding
    3. Build the bare cache from [BOS] + those exact document token IDs
    4. Build the truncated cache from the full encoding, then truncate+correct

    Returns: (cache_len, bare_cache, cache_len, trunc_corrected, trunc_uncorrected, offset)
    """
    document_text = f"Document:\n{passage}"
    prefix_with_sep = prefix_text + " "

    # Tokenize prefix alone (with BOS) to get prefix length
    prefix_encoding = tokenizer(
        prefix_with_sep, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    prefix_len = prefix_encoding['input_ids'].shape[1]  # includes BOS

    # Tokenize full context
    full_context = prefix_with_sep + document_text
    full_encoding = tokenizer(
        full_context, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    full_ids = full_encoding['input_ids'].to(config.device)
    full_len = full_ids.shape[1]
    doc_len = full_len - prefix_len

    # Extract the exact document token IDs as they appear in the full encoding
    doc_token_ids = full_ids[:, prefix_len:]  # (1, doc_len)

    # Build bare cache from [BOS] + exact document tokens
    bos_id = full_ids[:, :1]  # BOS token
    bare_ids = torch.cat([bos_id, doc_token_ids], dim=1)  # (1, 1+doc_len)
    bare_len = bare_ids.shape[1]

    with torch.no_grad():
        bare_out = model(
            input_ids=bare_ids,
            attention_mask=torch.ones_like(bare_ids),
            use_cache=True,
            return_dict=True
        )
    bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)

    # Build truncated cache from full context
    with torch.no_grad():
        full_out = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )

    # Truncate: BOS + last doc_len positions
    truncated = extract_and_truncate_cache_with_bos(full_out.past_key_values, doc_len)
    keep_len = 1 + doc_len

    # Verify match
    assert bare_len == keep_len, f"Length mismatch: bare_len={bare_len}, keep_len={keep_len}"

    # Clone before correction
    trunc_uncorrected = DynamicCache()
    for li in range(len(truncated)):
        trunc_uncorrected.update(
            _get_cache_keys(truncated, li).clone(),
            _get_cache_values(truncated, li).clone(),
            li
        )

    # RoPE correction on keys
    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated, surrogate_offset, model)

    return bare_len, bare_cache, keep_len, truncated, trunc_uncorrected, surrogate_offset


def evaluate_investigation_a(sample, model, tokenizer, config):
    """Evaluate one sample across all 7 Investigation A conditions.

    Returns dict of NLLs or None if sample should be skipped.
    """
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        return None

    query_prompt = config.query_template.format(query=query)

    # Build matched caches (same token IDs for key/value swaps)
    bare_len, bare_cache, keep_len, trunc_corrected, trunc_uncorrected, offset = \
        build_matched_caches(passage, model, tokenizer, config)

    # --- Condition 1: Bare ---
    nll_bare = score_answer_with_cache(
        bare_cache, bare_len, query_prompt, answer, model, tokenizer, config
    )

    # --- Condition 2: Full truncation (corrected keys + truncated values) ---
    nll_full_trunc = score_answer_with_cache(
        trunc_corrected, keep_len, query_prompt, answer, model, tokenizer, config
    )

    # --- Condition 3: Value-only (bare keys + truncated values) -> tests H1 ---
    hybrid_val_only = build_hybrid_cache(bare_cache, trunc_corrected)
    nll_value_only = score_answer_with_cache(
        hybrid_val_only, bare_len, query_prompt, answer, model, tokenizer, config
    )

    # --- Condition 4: Key-only (corrected trunc keys + bare values) -> tests H2 ---
    hybrid_key_only = build_hybrid_cache(trunc_corrected, bare_cache)
    nll_key_only = score_answer_with_cache(
        hybrid_key_only, keep_len, query_prompt, answer, model, tokenizer, config
    )

    # --- Condition 5: RoPE noise only (bare + roundtrip noise, bare values) -> tests H2 ---
    noisy_cache = DynamicCache()
    for li in range(len(bare_cache)):
        noisy_cache.update(
            _get_cache_keys(bare_cache, li).clone(),
            _get_cache_values(bare_cache, li).clone(),
            li
        )
    apply_rope_roundtrip_noise(noisy_cache, offset, model)
    nll_rope_noise = score_answer_with_cache(
        noisy_cache, bare_len, query_prompt, answer, model, tokenizer, config
    )

    # --- Condition 6: Full trunc but with bare BOS -> tests H3 ---
    trunc_bare_bos = swap_bos_entry(trunc_corrected, bare_cache)
    nll_trunc_bare_bos = score_answer_with_cache(
        trunc_bare_bos, keep_len, query_prompt, answer, model, tokenizer, config
    )

    # --- Condition 7: Bare but with truncated BOS -> tests H3 ---
    bare_trunc_bos = swap_bos_entry(bare_cache, trunc_corrected)
    nll_bare_trunc_bos = score_answer_with_cache(
        bare_trunc_bos, bare_len, query_prompt, answer, model, tokenizer, config
    )

    return {
        'nll_bare': nll_bare,
        'nll_full_trunc': nll_full_trunc,
        'nll_value_only': nll_value_only,
        'nll_key_only': nll_key_only,
        'nll_rope_noise': nll_rope_noise,
        'nll_trunc_bare_bos': nll_trunc_bare_bos,
        'nll_bare_trunc_bos': nll_bare_trunc_bos,
    }

In [None]:
# ============================================================
# Investigation A: Key vs Value Separation (200 samples, 7 conditions)
# ============================================================

results_a = []
skipped_a = 0
errors_a = 0
start_a = time.time()

CHECKPOINT_PATH_A = 'results/exp10/10_checkpoint_a.json'

start_idx_a = 0
if os.path.exists(CHECKPOINT_PATH_A):
    with open(CHECKPOINT_PATH_A) as f:
        ckpt = json.load(f)
    results_a = ckpt['results']
    skipped_a = ckpt['skipped']
    errors_a = ckpt['errors']
    start_idx_a = ckpt['next_idx']
    print(f"Resumed from checkpoint: {len(results_a)} results")

print(f"Investigation A: {len(samples_a)} samples, 7 conditions each")

for idx in tqdm(range(start_idx_a, len(samples_a)), desc="Inv A", initial=start_idx_a, total=len(samples_a)):
    sample = samples_a[idx]
    try:
        result = evaluate_investigation_a(sample, model, tokenizer, config)
        if result is None:
            skipped_a += 1
            continue
        results_a.append(result)
    except Exception as e:
        errors_a += 1
        if errors_a <= 3:
            print(f"\n  Error on sample {idx}: {e}")
        continue

    if len(results_a) % 25 == 0:
        with open(CHECKPOINT_PATH_A, 'w') as f:
            json.dump({'results': results_a, 'skipped': skipped_a, 'errors': errors_a, 'next_idx': idx + 1}, f)
        elapsed = time.time() - start_a
        print(f"\n  [{len(results_a)} done | {elapsed/60:.0f}m]")

elapsed_a = time.time() - start_a
print(f"\nDone. {len(results_a)} evaluated, {skipped_a} skipped, {errors_a} errors. Time: {elapsed_a/60:.1f} min")

In [None]:
# ============================================================
# Investigation A: Results Table + Bar Chart
# ============================================================

conditions_a = [
    ('1. Bare (baseline)', 'nll_bare'),
    ('2. Full truncation', 'nll_full_trunc'),
    ('3. Value-only (H1)', 'nll_value_only'),
    ('4. Key-only (H2)', 'nll_key_only'),
    ('5. RoPE noise (H2)', 'nll_rope_noise'),
    ('6. Trunc, bare BOS (H3)', 'nll_trunc_bare_bos'),
    ('7. Bare, trunc BOS (H3)', 'nll_bare_trunc_bos'),
]

bare_nlls = np.array([r['nll_bare'] for r in results_a])

print('=' * 100)
print('INVESTIGATION A: KEY vs VALUE SEPARATION')
print('=' * 100)
print(f"{'#':<3} {'Condition':<30} {'Mean NLL':>10} {'Std':>8} {'Win% vs Bare':>14} {'Delta':>10} {'t-stat':>8} {'p-value':>10}")
print('-' * 100)

inv_a_summary = {}
for label, key in conditions_a:
    nlls = np.array([r[key] for r in results_a])
    deltas = bare_nlls - nlls
    win_rate = np.mean(deltas > 0) * 100
    if key == 'nll_bare':
        print(f"{label:<33} {np.mean(nlls):>10.4f} {np.std(nlls):>8.4f} {'--':>14} {'--':>10} {'--':>8} {'--':>10}")
    else:
        t, p = stats.ttest_rel(bare_nlls, nlls)
        print(f"{label:<33} {np.mean(nlls):>10.4f} {np.std(nlls):>8.4f} {win_rate:>13.1f}% {np.mean(deltas):>+10.4f} {t:>8.2f} {p:>10.4f}")
    inv_a_summary[key] = {
        'mean_nll': float(np.mean(nlls)),
        'std_nll': float(np.std(nlls)),
        'win_rate': float(win_rate / 100),
        'mean_delta': float(np.mean(deltas)) if key != 'nll_bare' else 0.0,
    }

# Bar chart
fig, ax = plt.subplots(figsize=(10, 5))
labels = [l for l, _ in conditions_a]
means = [inv_a_summary[k]['mean_nll'] for _, k in conditions_a]
colors = ['#888888', '#4c72b0', '#c44e52', '#55a868', '#8c564b', '#e377c2', '#ff7f0e']
bars = ax.bar(range(len(labels)), means, color=colors)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=30, ha='right', fontsize=8)
ax.set_ylabel('Mean NLL (lower = better)')
ax.set_title('Investigation A: 7 Conditions')

# Annotate win rates
for i, (_, k) in enumerate(conditions_a):
    wr = inv_a_summary[k]['win_rate'] * 100
    if k != 'nll_bare':
        ax.text(i, means[i], f'{wr:.0f}%', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.savefig('results/exp10/10_investigation_a.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 10_investigation_a.png')

In [None]:
# ============================================================
# Investigation A: Statistical Tests + Verdict
# ============================================================

print('=' * 80)
print('HYPOTHESIS DISCRIMINATION')
print('=' * 80)

# H1 test: value contamination
val_only_wr = inv_a_summary['nll_value_only']['win_rate'] * 100
key_only_wr = inv_a_summary['nll_key_only']['win_rate'] * 100
rope_noise_wr = inv_a_summary['nll_rope_noise']['win_rate'] * 100
full_trunc_wr = inv_a_summary['nll_full_trunc']['win_rate'] * 100
trunc_bare_bos_wr = inv_a_summary['nll_trunc_bare_bos']['win_rate'] * 100
bare_trunc_bos_wr = inv_a_summary['nll_bare_trunc_bos']['win_rate'] * 100

print(f'\nFull truncation win%: {full_trunc_wr:.1f}% (replication target: ~80%)')
print(f'\n--- H1: Value Contamination ---')
print(f'  Cond 3 (value-only) win%: {val_only_wr:.1f}% (predict ~70% if H1)')
print(f'  Cond 4 (key-only) win%:   {key_only_wr:.1f}% (predict ~50% if H1)')
print(f'  Cond 5 (RoPE noise) win%: {rope_noise_wr:.1f}% (predict ~50% if H1)')
h1_score = (val_only_wr > 60) + (key_only_wr < 55) + (rope_noise_wr < 55)

print(f'\n--- H2: RoPE Float16 Noise ---')
print(f'  Cond 3 (value-only) win%: {val_only_wr:.1f}% (predict ~50% if H2)')
print(f'  Cond 4 (key-only) win%:   {key_only_wr:.1f}% (predict ~70% if H2)')
print(f'  Cond 5 (RoPE noise) win%: {rope_noise_wr:.1f}% (predict ~70% if H2)')
h2_score = (val_only_wr < 55) + (key_only_wr > 60) + (rope_noise_wr > 60)

print(f'\n--- H3: BOS Contamination ---')
print(f'  Cond 6 (trunc, bare BOS) win%: {trunc_bare_bos_wr:.1f}% (predict drops if H3)')
print(f'  Cond 7 (bare, trunc BOS) win%: {bare_trunc_bos_wr:.1f}% (predict ~70% if H3)')
h3_score = (trunc_bare_bos_wr < full_trunc_wr - 10) + (bare_trunc_bos_wr > 60)

print(f'\n--- Verdict Scores (higher = more consistent) ---')
print(f'  H1 (value contamination): {h1_score}/3')
print(f'  H2 (RoPE noise):          {h2_score}/3')
print(f'  H3 (BOS contamination):   {h3_score}/2')

# Pairwise comparisons between key conditions
print('\n--- Direct pairwise tests ---')
for label_a, key_a, label_b, key_b in [
    ('value-only', 'nll_value_only', 'key-only', 'nll_key_only'),
    ('value-only', 'nll_value_only', 'rope-noise', 'nll_rope_noise'),
    ('full-trunc', 'nll_full_trunc', 'trunc-bare-bos', 'nll_trunc_bare_bos'),
]:
    a = np.array([r[key_a] for r in results_a])
    b = np.array([r[key_b] for r in results_a])
    t, p = stats.ttest_rel(a, b)
    print(f'  {label_a} vs {label_b}: t={t:.3f}, p={p:.4f}')

In [None]:
# ============================================================
# Investigation B: Prefix Length Sensitivity (100 samples x 6 lengths)
# ============================================================

PREFIX_LENGTHS = [5, 10, 20, 50, 100, 200]  # target token counts

# Generate random text of varying lengths
def generate_random_text(n_tokens, tokenizer, seed=42):
    """Generate random text of approximately n_tokens length."""
    rng = np.random.RandomState(seed)
    # Use common English words to generate somewhat natural text
    words = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'lazy', 'dog',
             'a', 'simple', 'test', 'of', 'random', 'text', 'generation',
             'with', 'various', 'common', 'english', 'words', 'that', 'form',
             'sentences', 'and', 'paragraphs', 'for', 'our', 'experiment',
             'is', 'was', 'were', 'been', 'being', 'have', 'has', 'had',
             'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may',
             'might', 'can', 'shall', 'must', 'need', 'dare', 'ought', 'used']
    # Generate enough words, then trim to target token count
    text = ' '.join(rng.choice(words, size=n_tokens * 2))
    # Tokenize and trim
    tokens = tokenizer.encode(text, add_special_tokens=False)[:n_tokens]
    return tokenizer.decode(tokens)


results_b = []
skipped_b = 0
errors_b = 0
start_b = time.time()

CHECKPOINT_PATH_B = 'results/exp10/10_checkpoint_b.json'

start_idx_b = 0
if os.path.exists(CHECKPOINT_PATH_B):
    with open(CHECKPOINT_PATH_B) as f:
        ckpt = json.load(f)
    results_b = ckpt['results']
    skipped_b = ckpt['skipped']
    errors_b = ckpt['errors']
    start_idx_b = ckpt['next_idx']
    print(f"Resumed from checkpoint: {len(results_b)} results")

print(f"Investigation B: {len(samples_b)} samples x {len(PREFIX_LENGTHS)} prefix lengths")

for idx in tqdm(range(start_idx_b, len(samples_b)), desc="Inv B", initial=start_idx_b, total=len(samples_b)):
    sample = samples_b[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_b += 1
        continue

    query_prompt = config.query_template.format(query=query)

    try:
        # Bare baseline
        bare_len, bare_cache = build_bare_cache(passage, model, tokenizer, config)
        nll_bare = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config
        )

        length_nlls = {}
        for n_tok in PREFIX_LENGTHS:
            prefix = generate_random_text(n_tok, tokenizer, seed=config.seed + idx + n_tok)
            keep_len, trunc_cache, _ = build_truncated_cache_variable_prefix(
                prefix, passage, model, tokenizer, config
            )
            nll = score_answer_with_cache(
                trunc_cache, keep_len, query_prompt, answer, model, tokenizer, config
            )
            length_nlls[n_tok] = nll

        results_b.append({'nll_bare': nll_bare, 'length_nlls': length_nlls})

    except Exception as e:
        errors_b += 1
        if errors_b <= 3:
            print(f"\n  Error on sample {idx}: {e}")
        continue

    if len(results_b) % 25 == 0:
        with open(CHECKPOINT_PATH_B, 'w') as f:
            json.dump({'results': results_b, 'skipped': skipped_b, 'errors': errors_b, 'next_idx': idx + 1}, f)
        elapsed = time.time() - start_b
        print(f"\n  [{len(results_b)} done | {elapsed/60:.0f}m]")

elapsed_b = time.time() - start_b
print(f"\nDone. {len(results_b)} evaluated, {skipped_b} skipped, {errors_b} errors. Time: {elapsed_b/60:.1f} min")

In [None]:
# ============================================================
# Investigation B: Length Curve Plot + Trend Test
# ============================================================

print('=' * 80)
print('INVESTIGATION B: PREFIX LENGTH SENSITIVITY')
print('=' * 80)

# Compute win rates and mean deltas per length
length_stats = {}
for n_tok in PREFIX_LENGTHS:
    deltas = [r['nll_bare'] - r['length_nlls'][n_tok] for r in results_b
              if n_tok in r['length_nlls']]
    nlls = [r['length_nlls'][n_tok] for r in results_b if n_tok in r['length_nlls']]
    wr = np.mean([d > 0 for d in deltas]) * 100
    length_stats[n_tok] = {
        'mean_nll': np.mean(nlls),
        'mean_delta': np.mean(deltas),
        'win_rate': wr,
        'n': len(deltas),
    }
    print(f"  Prefix {n_tok:>3d} tokens: mean NLL={np.mean(nlls):.4f}, "
          f"delta={np.mean(deltas):+.4f}, win%={wr:.1f}% (n={len(deltas)})")

# Trend test: Spearman correlation between prefix length and delta
all_lengths = []
all_deltas = []
for r in results_b:
    for n_tok in PREFIX_LENGTHS:
        if n_tok in r['length_nlls']:
            all_lengths.append(n_tok)
            all_deltas.append(r['nll_bare'] - r['length_nlls'][n_tok])

rho, p_spearman = stats.spearmanr(all_lengths, all_deltas)
print(f"\nSpearman correlation (length vs delta): rho={rho:.3f}, p={p_spearman:.4f}")

# Per-sample Spearman (within each sample, does longer prefix = bigger delta?)
per_sample_rhos = []
for r in results_b:
    lens = []
    dels = []
    for n_tok in PREFIX_LENGTHS:
        if n_tok in r['length_nlls']:
            lens.append(n_tok)
            dels.append(r['nll_bare'] - r['length_nlls'][n_tok])
    if len(lens) >= 4:
        rho_i, _ = stats.spearmanr(lens, dels)
        if not np.isnan(rho_i):
            per_sample_rhos.append(rho_i)

print(f"Per-sample Spearman: mean={np.mean(per_sample_rhos):.3f}, "
      f"median={np.median(per_sample_rhos):.3f}, "
      f"% positive={np.mean(np.array(per_sample_rhos) > 0)*100:.1f}%")

# Interpretation
print('\nInterpretation:')
if rho > 0.1 and p_spearman < 0.05:
    print('  Benefit increases with prefix length -> supports H1 (more tokens = more value contamination)')
    print('  Also consistent with H2 if monotonic (larger offset = more noise)')
elif abs(rho) < 0.1:
    print('  No length effect -> supports H3 (BOS is always 1 token regardless of prefix length)')
else:
    print(f'  Ambiguous: rho={rho:.3f}, p={p_spearman:.4f}')

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

lengths = sorted(length_stats.keys())
win_rates = [length_stats[l]['win_rate'] for l in lengths]
mean_deltas = [length_stats[l]['mean_delta'] for l in lengths]

ax1.plot(lengths, win_rates, 'o-', color='#4c72b0', linewidth=2)
ax1.axhline(50, color='gray', linestyle='--', linewidth=0.8)
ax1.set_xlabel('Prefix Length (tokens)')
ax1.set_ylabel('Win Rate vs Bare (%)')
ax1.set_title(f'Win Rate by Prefix Length\n(Spearman rho={rho:.3f}, p={p_spearman:.4f})')
ax1.set_ylim(40, 90)

ax2.plot(lengths, mean_deltas, 's-', color='#c44e52', linewidth=2)
ax2.axhline(0, color='gray', linestyle='--', linewidth=0.8)
ax2.set_xlabel('Prefix Length (tokens)')
ax2.set_ylabel('Mean Delta NLL (positive = trunc better)')
ax2.set_title('Mean Delta by Prefix Length')

plt.tight_layout()
plt.savefig('results/exp10/10_investigation_b.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 10_investigation_b.png')

In [None]:
# ============================================================
# Investigation C: Value Divergence (50 samples)
# Per-layer L2 distance between bare and truncated value vectors
# FIX: Use build_matched_caches to ensure identical token sequences
# ============================================================

print('Investigation C: Value Divergence (per-layer L2 distance)')
print('  Using build_matched_caches for proper BPE-matched comparison')

divergence_results = []  # list of (num_layers,) arrays
skipped_c = 0

for idx in tqdm(range(len(samples_c_div)), desc="Inv C divergence"):
    sample = samples_c_div[idx]
    passage = sample['passage']

    answer_ids = tokenizer(sample['answer'], return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_c += 1
        continue

    try:
        # Use build_matched_caches to ensure identical token sequences
        bare_len, bare_cache, keep_len, trunc_corrected, _, offset = \
            build_matched_caches(passage, model, tokenizer, config)

        assert bare_len == keep_len, f"Length mismatch: {bare_len} vs {keep_len}"

        # Per-layer L2 distance of value vectors (skip BOS, compare doc portion)
        layer_l2 = []
        for li in range(NUM_LAYERS):
            bv = _get_cache_values(bare_cache, li)
            tv = _get_cache_values(trunc_corrected, li)
            # Both should have identical seq_len now
            assert bv.shape[2] == tv.shape[2], \
                f"Layer {li}: seq_len mismatch {bv.shape[2]} vs {tv.shape[2]}"
            # Skip BOS at position 0
            bv_doc = bv[:, :, 1:, :]
            tv_doc = tv[:, :, 1:, :]
            l2 = torch.norm(bv_doc.float() - tv_doc.float()).item()
            # Normalize by number of elements
            l2_norm = l2 / bv_doc.numel()**0.5
            layer_l2.append(l2_norm)

        divergence_results.append(np.array(layer_l2))

    except Exception as e:
        if len(divergence_results) < 3:
            print(f"  Error: {e}")
        continue

print(f"Computed divergence for {len(divergence_results)} samples")

# Aggregate
div_matrix = np.stack(divergence_results)  # (n_samples, n_layers)
mean_div = div_matrix.mean(axis=0)
std_div = div_matrix.std(axis=0)

print(f"\nPer-layer mean value divergence (L2 norm, normalized):")
for li in range(len(mean_div)):
    print(f"  Layer {li:>2d}: {mean_div[li]:.6f} +/- {std_div[li]:.6f}")

In [None]:
# ============================================================
# Investigation C: Layer Ablation (100 samples)
# Replace values one layer at a time: truncated -> bare
# FIX: Use build_matched_caches to ensure identical seq_len
#      so replace_values_at_layers doesn't hit dimension mismatch
# ============================================================

print(f'Investigation C: Layer Ablation ({len(samples_c_abl)} samples x {NUM_LAYERS} layers)')
print('  Using build_matched_caches for BPE-matched bare/trunc caches')

ablation_results = []  # list of dicts: {nll_bare, nll_trunc, layer_nlls: {layer_idx: nll}}
skipped_c_abl = 0
errors_c_abl = 0
start_c = time.time()

CHECKPOINT_PATH_C = 'results/exp10/10_checkpoint_c.json'
start_idx_c = 0
if os.path.exists(CHECKPOINT_PATH_C):
    with open(CHECKPOINT_PATH_C) as f:
        ckpt = json.load(f)
    results_loaded = ckpt['results']
    # Validate checkpoint: layer_nlls keys should be string ints from old format
    # Convert back to int keys for consistency
    ablation_results = []
    for r in results_loaded:
        layer_nlls = {}
        for k, v in r['layer_nlls'].items():
            layer_nlls[int(k)] = v
        ablation_results.append({
            'nll_bare': r['nll_bare'],
            'nll_trunc': r['nll_trunc'],
            'layer_nlls': layer_nlls,
        })
    skipped_c_abl = ckpt['skipped']
    errors_c_abl = ckpt['errors']
    start_idx_c = ckpt['next_idx']
    print(f"Resumed from checkpoint: {len(ablation_results)} results")

for idx in tqdm(range(start_idx_c, len(samples_c_abl)), desc="Inv C ablation",
                initial=start_idx_c, total=len(samples_c_abl)):
    sample = samples_c_abl[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_c_abl += 1
        continue

    query_prompt = config.query_template.format(query=query)

    try:
        # FIX: Use build_matched_caches for identical token sequences
        bare_len, bare_cache, keep_len, trunc_corrected, _, offset = \
            build_matched_caches(passage, model, tokenizer, config)

        # Verify lengths match (critical for replace_values_at_layers)
        assert bare_len == keep_len, f"Length mismatch: bare_len={bare_len}, keep_len={keep_len}"
        for li in range(NUM_LAYERS):
            bv = _get_cache_values(bare_cache, li)
            tv = _get_cache_values(trunc_corrected, li)
            assert bv.shape == tv.shape, \
                f"Layer {li} shape mismatch: bare={bv.shape}, trunc={tv.shape}"

        nll_bare = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config
        )
        nll_trunc = score_answer_with_cache(
            trunc_corrected, keep_len, query_prompt, answer, model, tokenizer, config
        )

        # For each layer, replace truncated values with bare values at that layer
        layer_nlls = {}
        for li in range(NUM_LAYERS):
            ablated = replace_values_at_layers(trunc_corrected, bare_cache, [li])
            nll_abl = score_answer_with_cache(
                ablated, keep_len, query_prompt, answer, model, tokenizer, config
            )
            layer_nlls[li] = nll_abl

        ablation_results.append({
            'nll_bare': nll_bare,
            'nll_trunc': nll_trunc,
            'layer_nlls': layer_nlls,
        })

    except Exception as e:
        errors_c_abl += 1
        if errors_c_abl <= 5:
            import traceback
            print(f"\n  Error on sample {idx}: {e}")
            traceback.print_exc()
        continue

    if len(ablation_results) % 10 == 0:
        with open(CHECKPOINT_PATH_C, 'w') as f:
            json.dump({
                'results': [{k: v if not isinstance(v, dict) else {str(kk): vv for kk, vv in v.items()}
                             for k, v in r.items()} for r in ablation_results],
                'skipped': skipped_c_abl, 'errors': errors_c_abl, 'next_idx': idx + 1
            }, f)
        elapsed = time.time() - start_c
        print(f"\n  [{len(ablation_results)} done | {elapsed/60:.0f}m]")

elapsed_c = time.time() - start_c
print(f"\nDone. {len(ablation_results)} evaluated, {skipped_c_abl} skipped, {errors_c_abl} errors.")
print(f"Time: {elapsed_c/60:.1f} min")

In [None]:
# ============================================================
# Investigation C: Layer Heatmap + Correlation
# ============================================================

print('=' * 80)
print('INVESTIGATION C: LAYER-BY-LAYER ANALYSIS')
print('=' * 80)

# Per-layer: how much does replacing trunc values -> bare values at that layer hurt?
# "hurt" = NLL goes up (closer to bare). A layer that contributes more to the benefit
# will show a bigger NLL increase when its values are replaced.

per_layer_impact = np.zeros(NUM_LAYERS)  # mean NLL increase when layer ablated
per_layer_pvals = np.zeros(NUM_LAYERS)

for li in range(NUM_LAYERS):
    trunc_nlls = np.array([r['nll_trunc'] for r in ablation_results])
    ablated_nlls = np.array([r['layer_nlls'][li] for r in ablation_results])
    impact = ablated_nlls - trunc_nlls  # positive = ablation hurt (layer contributed)
    per_layer_impact[li] = np.mean(impact)
    _, per_layer_pvals[li] = stats.ttest_rel(trunc_nlls, ablated_nlls)

# Bonferroni correction
bonferroni_alpha = 0.05 / NUM_LAYERS

print(f"\nPer-layer ablation impact (replacing trunc values -> bare values):")
print(f"{'Layer':>6} {'Impact':>10} {'p-value':>10} {'Sig':>5}")
print('-' * 35)
for li in range(NUM_LAYERS):
    sig = '*' if per_layer_pvals[li] < bonferroni_alpha else ''
    print(f"{li:>6d} {per_layer_impact[li]:>+10.5f} {per_layer_pvals[li]:>10.4f} {sig:>5}")

# Correlation with value divergence
if len(divergence_results) > 0:
    mean_div_per_layer = div_matrix.mean(axis=0)
    # Ensure same number of layers
    n_compare = min(len(mean_div_per_layer), len(per_layer_impact))
    rho_div_impact, p_div_impact = stats.spearmanr(
        mean_div_per_layer[:n_compare], per_layer_impact[:n_compare]
    )
    print(f"\nCorrelation (value divergence vs ablation impact):")
    print(f"  Spearman rho={rho_div_impact:.3f}, p={p_div_impact:.4f}")
    if rho_div_impact > 0.3 and p_div_impact < 0.05:
        print("  -> Layers with more divergent values contribute more to the benefit.")
        print("     This supports H1 (value contamination).")

# Plots
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Layer ablation impact
ax = axes[0]
colors = ['#c44e52' if p < bonferroni_alpha else '#4c72b0' for p in per_layer_pvals]
ax.bar(range(NUM_LAYERS), per_layer_impact, color=colors)
ax.set_xlabel('Layer')
ax.set_ylabel('Mean NLL increase when ablated')
ax.set_title('Per-Layer Ablation Impact\n(red = significant after Bonferroni)')
ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)

# Value divergence heatmap
if len(divergence_results) > 0:
    ax = axes[1]
    ax.bar(range(len(mean_div_per_layer)), mean_div_per_layer, color='#55a868')
    ax.set_xlabel('Layer')
    ax.set_ylabel('Mean L2 Divergence (normalized)')
    ax.set_title('Per-Layer Value Divergence\n(bare vs truncated)')

    # Correlation scatter
    ax = axes[2]
    ax.scatter(mean_div_per_layer[:n_compare], per_layer_impact[:n_compare],
              c=range(n_compare), cmap='viridis', s=30)
    ax.set_xlabel('Value Divergence')
    ax.set_ylabel('Ablation Impact')
    ax.set_title(f'Divergence vs Impact\n(rho={rho_div_impact:.3f}, p={p_div_impact:.4f})')
    # Label a few extreme points
    for li in np.argsort(per_layer_impact)[-3:]:
        if li < n_compare:
            ax.annotate(f'L{li}', (mean_div_per_layer[li], per_layer_impact[li]),
                       fontsize=7, ha='left')

plt.tight_layout()
plt.savefig('results/exp10/10_investigation_c.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 10_investigation_c.png')

In [None]:
# ============================================================
# Investigation D: Attention Pattern Analysis (30 samples)
# FIX: Process bare and trunc attention sequentially to avoid OOM.
#      Compute stats and free attention tensors before the second pass.
#      Use build_matched_caches for proper comparison.
# ============================================================

print(f'Investigation D: Attention Analysis ({len(samples_d)} samples)')
print('  Processing bare/trunc sequentially to avoid OOM')

attn_results = []
skipped_d = 0
errors_d = 0

for idx in tqdm(range(len(samples_d)), desc="Inv D"):
    sample = samples_d[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_d += 1
        continue

    query_prompt = config.query_template.format(query=query)

    try:
        # Build matched caches
        bare_len, bare_cache, keep_len, trunc_corrected, _, offset = \
            build_matched_caches(passage, model, tokenizer, config)

        # --- Process BARE first, compute stats, free memory ---
        nll_bare, attn_bare = score_answer_with_cache_and_attention(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config
        )

        per_layer_entropy_bare = []
        per_layer_bos_mass_bare = []
        for li in range(NUM_LAYERS):
            ab = attn_bare[li][0].float()  # (n_heads, answer_len, seq_len)
            eps = 1e-10
            ent_b = -(ab * (ab + eps).log()).sum(dim=-1).mean().item()
            bos_b = ab[:, :, 0].mean().item()
            per_layer_entropy_bare.append(ent_b)
            per_layer_bos_mass_bare.append(bos_b)

        # Free bare attention immediately
        del attn_bare, bare_cache
        torch.cuda.empty_cache()

        # --- Now process TRUNCATED ---
        nll_trunc, attn_trunc = score_answer_with_cache_and_attention(
            trunc_corrected, keep_len, query_prompt, answer, model, tokenizer, config
        )

        per_layer_entropy_trunc = []
        per_layer_bos_mass_trunc = []
        for li in range(NUM_LAYERS):
            at = attn_trunc[li][0].float()
            eps = 1e-10
            ent_t = -(at * (at + eps).log()).sum(dim=-1).mean().item()
            bos_t = at[:, :, 0].mean().item()
            per_layer_entropy_trunc.append(ent_t)
            per_layer_bos_mass_trunc.append(bos_t)

        # Free trunc attention
        del attn_trunc, trunc_corrected
        torch.cuda.empty_cache()

        attn_results.append({
            'nll_bare': nll_bare,
            'nll_trunc': nll_trunc,
            'entropy_bare': per_layer_entropy_bare,
            'entropy_trunc': per_layer_entropy_trunc,
            'bos_mass_bare': per_layer_bos_mass_bare,
            'bos_mass_trunc': per_layer_bos_mass_trunc,
        })

    except Exception as e:
        errors_d += 1
        if errors_d <= 5:
            print(f"  Error on sample {idx}: {e}")
        # Clean up on error
        torch.cuda.empty_cache()
        continue

print(f"\nDone. {len(attn_results)} samples analyzed, {skipped_d} skipped, {errors_d} errors.")

In [None]:
# ============================================================
# Investigation D: Attention Entropy + BOS Attention Plots
# ============================================================

print('=' * 80)
print('INVESTIGATION D: ATTENTION PATTERN ANALYSIS')
print('=' * 80)

if len(attn_results) > 0:
    ent_bare = np.array([r['entropy_bare'] for r in attn_results])  # (n_samples, n_layers)
    ent_trunc = np.array([r['entropy_trunc'] for r in attn_results])
    bos_bare = np.array([r['bos_mass_bare'] for r in attn_results])
    bos_trunc = np.array([r['bos_mass_trunc'] for r in attn_results])

    # Per-layer paired t-tests
    print(f"\n{'Layer':>6} {'Ent Bare':>10} {'Ent Trunc':>10} {'p(ent)':>10} {'BOS Bare':>10} {'BOS Trunc':>10} {'p(bos)':>10}")
    print('-' * 70)
    for li in range(min(NUM_LAYERS, ent_bare.shape[1])):
        t_ent, p_ent = stats.ttest_rel(ent_bare[:, li], ent_trunc[:, li])
        t_bos, p_bos = stats.ttest_rel(bos_bare[:, li], bos_trunc[:, li])
        sig_ent = '*' if p_ent < 0.05 / NUM_LAYERS else ''
        sig_bos = '*' if p_bos < 0.05 / NUM_LAYERS else ''
        print(f"{li:>6d} {ent_bare[:, li].mean():>10.4f} {ent_trunc[:, li].mean():>10.4f} {p_ent:>9.4f}{sig_ent}"
              f" {bos_bare[:, li].mean():>10.6f} {bos_trunc[:, li].mean():>10.6f} {p_bos:>9.4f}{sig_bos}")

    # Overall: is truncated more uniform?
    mean_ent_diff = (ent_trunc - ent_bare).mean()
    t_overall, p_overall = stats.ttest_rel(ent_bare.mean(axis=1), ent_trunc.mean(axis=1))
    print(f"\nOverall entropy: bare={ent_bare.mean():.4f}, trunc={ent_trunc.mean():.4f}")
    print(f"  Trunc - Bare = {mean_ent_diff:+.4f}, paired t={t_overall:.3f}, p={p_overall:.4f}")

    # BOS overall
    mean_bos_diff = (bos_trunc - bos_bare).mean()
    t_bos_overall, p_bos_overall = stats.ttest_rel(bos_bare.mean(axis=1), bos_trunc.mean(axis=1))
    print(f"\nOverall BOS mass: bare={bos_bare.mean():.6f}, trunc={bos_trunc.mean():.6f}")
    print(f"  Trunc - Bare = {mean_bos_diff:+.6f}, paired t={t_bos_overall:.3f}, p={p_bos_overall:.4f}")

    # Plots
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    layers_x = range(ent_bare.shape[1])
    ax.plot(layers_x, ent_bare.mean(axis=0), 'o-', label='Bare', color='#4c72b0', markersize=3)
    ax.plot(layers_x, ent_trunc.mean(axis=0), 's-', label='Truncated', color='#c44e52', markersize=3)
    ax.set_xlabel('Layer')
    ax.set_ylabel('Mean Attention Entropy')
    ax.set_title('Attention Entropy by Layer')
    ax.legend()

    ax = axes[1]
    ax.plot(layers_x, bos_bare.mean(axis=0), 'o-', label='Bare', color='#4c72b0', markersize=3)
    ax.plot(layers_x, bos_trunc.mean(axis=0), 's-', label='Truncated', color='#c44e52', markersize=3)
    ax.set_xlabel('Layer')
    ax.set_ylabel('Mean Attention Mass on BOS')
    ax.set_title('BOS Attention by Layer')
    ax.legend()

    plt.tight_layout()
    plt.savefig('results/exp10/10_investigation_d.png', dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved: 10_investigation_d.png')
else:
    print('No attention results to analyze.')

In [None]:
# ============================================================
# Summary: All Results, Decision Matrix, Verdicts
# ============================================================

print('=' * 80)
print('EXPERIMENT 10: COMPREHENSIVE SUMMARY')
print('=' * 80)

# --- Investigation A Summary ---
print('\n--- INVESTIGATION A: Key vs Value Separation ---')
for label, key in conditions_a:
    s = inv_a_summary[key]
    wr_str = f"{s['win_rate']*100:.1f}%" if key != 'nll_bare' else '--'
    print(f"  {label:<33} NLL={s['mean_nll']:.4f}  Win%={wr_str}")

# --- Investigation B Summary ---
print('\n--- INVESTIGATION B: Prefix Length Sensitivity ---')
for n_tok in PREFIX_LENGTHS:
    s = length_stats[n_tok]
    print(f"  {n_tok:>3d} tokens: win%={s['win_rate']:.1f}%, delta={s['mean_delta']:+.4f}")
print(f"  Spearman rho={rho:.3f}, p={p_spearman:.4f}")

# --- Investigation C Summary ---
print('\n--- INVESTIGATION C: Layer Analysis ---')
top_layers = np.argsort(per_layer_impact)[-5:][::-1]
print(f"  Top contributing layers: {list(top_layers)}")
for li in top_layers:
    sig = '*' if per_layer_pvals[li] < bonferroni_alpha else ''
    print(f"    Layer {li}: impact={per_layer_impact[li]:+.5f} {sig}")
if len(divergence_results) > 0:
    print(f"  Divergence-Impact correlation: rho={rho_div_impact:.3f}, p={p_div_impact:.4f}")

# --- Investigation D Summary ---
if len(attn_results) > 0:
    print('\n--- INVESTIGATION D: Attention Patterns ---')
    print(f"  Entropy: bare={ent_bare.mean():.4f}, trunc={ent_trunc.mean():.4f}, p={p_overall:.4f}")
    print(f"  BOS mass: bare={bos_bare.mean():.6f}, trunc={bos_trunc.mean():.6f}, p={p_bos_overall:.4f}")

# --- Decision Matrix ---
print('\n' + '=' * 80)
print('DECISION MATRIX')
print('=' * 80)

print(f"\n{'Criterion':<50} {'H1':>8} {'H2':>8} {'H3':>8}")
print('-' * 80)

# Cond 3 (value-only) wins big?
val_big = 'YES' if val_only_wr > 60 else 'NO'
val_h2 = 'NO' if val_only_wr > 60 else 'YES'
print(f"{'Cond 3 (value-only) win% > 60%':<50} {val_big:>8} {val_h2:>8} {'--':>8}")

# Cond 4 (key-only) wins big?
key_big = 'NO' if key_only_wr > 60 else 'YES'
key_h2 = 'YES' if key_only_wr > 60 else 'NO'
print(f"{'Cond 4 (key-only) win% > 60%':<50} {key_big:>8} {key_h2:>8} {'--':>8}")

# Cond 5 (RoPE noise) wins big?
rope_big = 'NO' if rope_noise_wr > 60 else 'YES'
rope_h2 = 'YES' if rope_noise_wr > 60 else 'NO'
print(f"{'Cond 5 (RoPE noise) win% > 60%':<50} {rope_big:>8} {rope_h2:>8} {'--':>8}")

# BOS swap effects
bos_drop = full_trunc_wr - trunc_bare_bos_wr
bos_h3 = 'YES' if bos_drop > 10 else 'NO'
print(f"{'Cond 6 drops >10% vs full trunc':<50} {'--':>8} {'--':>8} {bos_h3:>8}")
bos_gain = 'YES' if bare_trunc_bos_wr > 60 else 'NO'
print(f"{'Cond 7 (bare+trunc BOS) win% > 60%':<50} {'--':>8} {'--':>8} {bos_gain:>8}")

# Length trend
len_h1 = 'YES' if rho > 0.1 and p_spearman < 0.05 else 'NO'
len_h2 = 'YES' if rho > 0.1 and p_spearman < 0.05 else 'NO'
len_h3 = 'YES' if abs(rho) < 0.1 else 'NO'
print(f"{'Benefit increases with prefix length':<50} {len_h1:>8} {len_h2:>8} {len_h3:>8}")

# Divergence-impact correlation
if len(divergence_results) > 0:
    div_h1 = 'YES' if rho_div_impact > 0.3 and p_div_impact < 0.05 else 'NO'
    print(f"{'Layer divergence predicts ablation impact':<50} {div_h1:>8} {'--':>8} {'--':>8}")

print('\n--- VERDICT ---')
verdicts = []
if val_only_wr > 60 and key_only_wr < 55 and rope_noise_wr < 55:
    verdicts.append('H1 (Value Contamination): STRONGLY SUPPORTED')
elif val_only_wr > 55:
    verdicts.append('H1 (Value Contamination): PARTIALLY SUPPORTED')
else:
    verdicts.append('H1 (Value Contamination): NOT SUPPORTED')

if key_only_wr > 60 and rope_noise_wr > 60:
    verdicts.append('H2 (RoPE Noise): STRONGLY SUPPORTED')
elif key_only_wr > 55 or rope_noise_wr > 55:
    verdicts.append('H2 (RoPE Noise): PARTIALLY SUPPORTED')
else:
    verdicts.append('H2 (RoPE Noise): NOT SUPPORTED')

if bos_drop > 10 and bare_trunc_bos_wr > 60:
    verdicts.append('H3 (BOS Contamination): STRONGLY SUPPORTED')
elif bos_drop > 5 or bare_trunc_bos_wr > 55:
    verdicts.append('H3 (BOS Contamination): PARTIALLY SUPPORTED')
else:
    verdicts.append('H3 (BOS Contamination): NOT SUPPORTED')

for v in verdicts:
    print(f'  {v}')

# Summary visualization
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Bar chart of Investigation A win rates
ax = axes[0]
labels_short = ['Bare', 'Full\nTrunc', 'Value\nOnly', 'Key\nOnly', 'RoPE\nNoise',
                'Trunc\nBare BOS', 'Bare\nTrunc BOS']
win_rates_a = [50] + [inv_a_summary[k]['win_rate']*100 for _, k in conditions_a[1:]]
colors_a = ['#888888', '#4c72b0', '#c44e52', '#55a868', '#8c564b', '#e377c2', '#ff7f0e']
ax.bar(range(len(labels_short)), win_rates_a, color=colors_a)
ax.axhline(50, color='gray', linestyle='--')
ax.set_xticks(range(len(labels_short)))
ax.set_xticklabels(labels_short, fontsize=7)
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('Investigation A: Win Rates')
for i, wr in enumerate(win_rates_a):
    ax.text(i, wr + 1, f'{wr:.0f}%', ha='center', fontsize=7)

# Prefix length curve
ax = axes[1]
ax.plot([l for l in sorted(length_stats.keys())],
        [length_stats[l]['win_rate'] for l in sorted(length_stats.keys())],
        'o-', color='#4c72b0', linewidth=2)
ax.axhline(50, color='gray', linestyle='--')
ax.set_xlabel('Prefix Length (tokens)')
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title(f'Investigation B: Length Curve\n(rho={rho:.3f})')

# Layer heatmap
ax = axes[2]
ax.bar(range(NUM_LAYERS), per_layer_impact,
       color=['#c44e52' if p < bonferroni_alpha else '#4c72b0' for p in per_layer_pvals])
ax.set_xlabel('Layer')
ax.set_ylabel('Ablation Impact')
ax.set_title('Investigation C: Layer Impact')

plt.tight_layout()
plt.savefig('results/exp10/10_summary.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 10_summary.png')

In [None]:
# ============================================================
# Save All Results
# ============================================================

output = {
    'metadata': {
        'experiment': '10_truncation_mechanism_diagnostic',
        'timestamp': datetime.datetime.now().isoformat(),
        'model_name': config.model_name,
        'seed': config.seed,
    },
    'investigation_a': {
        'n_samples': len(results_a),
        'summary': inv_a_summary,
        'results': results_a,
    },
    'investigation_b': {
        'n_samples': len(results_b),
        'prefix_lengths': PREFIX_LENGTHS,
        'length_stats': {str(k): v for k, v in length_stats.items()},
        'spearman_rho': float(rho),
        'spearman_p': float(p_spearman),
        'results': results_b,
    },
    'investigation_c': {
        'n_divergence_samples': len(divergence_results),
        'n_ablation_samples': len(ablation_results),
        'per_layer_impact': per_layer_impact.tolist(),
        'per_layer_pvals': per_layer_pvals.tolist(),
        'mean_divergence': mean_div.tolist() if len(divergence_results) > 0 else [],
        'divergence_impact_rho': float(rho_div_impact) if len(divergence_results) > 0 else None,
        'divergence_impact_p': float(p_div_impact) if len(divergence_results) > 0 else None,
        'ablation_results': [
            {k: v if not isinstance(v, dict) else {str(kk): vv for kk, vv in v.items()}
             for k, v in r.items()} for r in ablation_results
        ],
    },
    'investigation_d': {
        'n_samples': len(attn_results),
        'results': attn_results,
    } if len(attn_results) > 0 else {'n_samples': 0},
    'verdicts': verdicts,
}

output_path = 'results/exp10/10_diagnostic_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Results saved to {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1e6:.1f} MB")

# Clean up checkpoints
for cp in [CHECKPOINT_PATH_A, CHECKPOINT_PATH_B, CHECKPOINT_PATH_C]:
    if os.path.exists(cp):
        print(f"  Checkpoint {cp} preserved for potential resume.")