# Experiment 02: Suffix vs Truncated Priming

**Goal**: Determine whether semantic content in a suffix (appended AFTER the passage) improves NLL scoring. This is the cleanest possible test of semantic signal because causal masking guarantees passage KV entries are byte-identical to bare — any benefit must come from query tokens attending to suffix KV entries.

## Motivation (from Exp 01)

Exp 01 found random prefix helps MORE than oracle prefix after truncation (d=+0.091 vs d=+0.023 ns). Oracle actually *interferes* vs random (d=-0.051, p=0.015). The truncation mechanism conflates two signals:
- **Structural value contamination** (beneficial): prefix alters passage value vectors
- **Semantic attention patterns** (harmful on average): oracle content creates specific interference

**Suffix priming** isolates the semantic signal cleanly: passage KV entries are unchanged, so any effect must come from query → suffix attention.

## Five Conditions

| # | Condition | Cache Construction | Mechanism |
|---|-----------|-------------------|----------|
| 1 | **Bare** | `[BOS] + doc_ids` (matched tokenization) | Baseline |
| 2 | **Oracle-truncated** | `[BOS][query\n][doc_ids]` → truncate + RoPE correct | Value contamination (Exp 01 replication) |
| 3 | **Random-truncated** | `[BOS][random\n][doc_ids]` → truncate + RoPE correct | Structural control (Exp 01 replication) |
| 4 | **Oracle-suffix** | `build_suffix_kv_cache(passage, query, sep)` | Clean semantic signal (NEW) |
| 5 | **Random-suffix** | `build_suffix_kv_cache(passage, random_text, sep)` | Structural control for suffix (NEW) |

## Six Comparisons

**Primary (3):**

| # | Comparison | Question |
|---|-----------|----------|
| P1 | Oracle-suffix vs Random-suffix | Is there semantic signal in suffix? (cleanest test) |
| P2 | Oracle-suffix vs Bare | Does suffix priming help at all? |
| P3 | Oracle-truncated vs Bare | Does Exp 01 replicate? (expect d~+0.023, ns) |

**Secondary (3):**

| # | Comparison | Question |
|---|-----------|----------|
| S1 | Random-suffix vs Bare | Does ANY suffix help? (structural attention benefit) |
| S2 | Oracle-suffix vs Oracle-truncated | Which mechanism is better? |
| S3 | Random-truncated vs Bare | Does Exp 01 random benefit replicate? (expect d~+0.091) |

**Multiple comparisons:** Bonferroni correction, alpha = 0.05/6 = 0.0083.

## Critical Design Details

1. **Matched tokenization for truncated conditions**: Identical to Exp 01
2. **Suffix uses `build_suffix_kv_cache()`**: Tokenizes `passage + separator + suffix` as one string
3. **Same random text** for both random-truncated and random-suffix per sample
4. **Separator**: `"\n\nRelated question: "` — identical for oracle-suffix and random-suffix
5. **Same dataset**: N=2500, SEED=42, identical samples as Exp 01
6. **All standard safeguards**: `deepcopy_cache()`, `os.umask(0o000)`, checkpoints, GPU cleanup

In [1]:
# Cell 1: Setup — permissions, seeds, output directory
import os
os.umask(0o000)  # Required: two-user environment (jupyter + CLI user)

import sys
import json
import time
import numpy as np
import torch
from pathlib import Path

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Output directory
RESULTS_DIR = Path("results/exp02")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Paths
CHECKPOINT_PATH = RESULTS_DIR / "checkpoint.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

SEED: 42
Results directory: results/exp02
CUDA available: True
GPU: NVIDIA L4
GPU memory: 23.6 GB


In [2]:
# Cell 2: Load model (Mistral-7B 4-bit) — identical to Exp 01
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Loading {MODEL_NAME} (4-bit)...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded. dtype={model.dtype}, device={model.device}")
print(f"Vocab size: {len(tokenizer)}")

Loading mistralai/Mistral-7B-Instruct-v0.2 (4-bit)...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded. dtype=torch.float16, device=cuda:0
Vocab size: 32000


In [3]:
# Cell 3: Library imports + config + templates
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.kv_cache import (
    build_kv_cache,
    build_suffix_kv_cache,
    score_answer_with_cache,
    deepcopy_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
)
from lib.data import load_ms_marco, load_evaluation_samples
from lib.analysis import cohens_d
from scipy import stats
from tqdm.auto import tqdm

config = ExperimentConfig(
    model_name=MODEL_NAME,
    num_samples=2500,
    seed=SEED,
)

# Templates: NO framing to avoid \"Document:\n\" artifact
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"

# Query/answer format
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"  # Leading space for correct BPE of first word

# Suffix separator (used for both oracle-suffix and random-suffix)
SUFFIX_SEPARATOR = "\n\nRelated question: "

# Checkpoint frequency
CHECKPOINT_EVERY = 50

# Bonferroni-corrected alpha for 6 comparisons
BONFERRONI_ALPHA = 0.05 / 6

print("Config:")
print(f"  num_samples: {config.num_samples}")
print(f"  passage words: {config.min_passage_words}-{config.max_passage_words}")
print(f"  surrogate_prefix_template: {repr(SURROGATE_PREFIX_TEMPLATE)}")
print(f"  document_template: {repr(DOCUMENT_TEMPLATE)}")
print(f"  query_template: {repr(QUERY_TEMPLATE)}")
print(f"  answer_template: {repr(ANSWER_TEMPLATE)}")
print(f"  suffix_separator: {repr(SUFFIX_SEPARATOR)}")
print(f"  checkpoint_every: {CHECKPOINT_EVERY}")
print(f"  bonferroni_alpha: {BONFERRONI_ALPHA:.4f}")

Config:
  num_samples: 2500
  passage words: 50-300
  surrogate_prefix_template: '{surrogate}\n'
  document_template: '{document}'
  query_template: '\nQuery: {query}\nAnswer:'
  answer_template: ' {answer}'
  suffix_separator: '\n\nRelated question: '
  checkpoint_every: 50
  bonferroni_alpha: 0.0083


In [4]:
# Cell 4: Load dataset (seed immediately before for determinism)
dataset = load_ms_marco(config)

# CRITICAL: Set seed immediately before load_evaluation_samples
# to ensure deterministic sample selection regardless of prior random state
np.random.seed(SEED)
samples = load_evaluation_samples(dataset, config, require_answer=True)

print(f"\nLoaded {len(samples)} samples")
print(f"\nExample sample:")
print(f"  Query: {samples[0]['query'][:100]}...")
print(f"  Passage: {samples[0]['passage'][:100]}...")
print(f"  Answer: {samples[0]['answer'][:100]}...")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'microsoft/ms_marco' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading microsoft/ms_marco dataset...
Dataset loaded: 10047 samples
Filtering samples...


Filtering:   0%|          | 0/10047 [00:00<?, ?it/s]

Selected 2500 samples

Loaded 2500 samples

Example sample:
  Query: what temperature should it be to plant grass seeds...
  Passage: Usually planted in the early fall, cool-season grass seeds prefer daytime temperatures ranging from ...
  Answer: Between 50deg and 65deg F...


In [5]:
# Cell 5: generate_random_prefix_text() — copy from Exp 01

def generate_random_prefix_text(target_text, tokenizer, seed):
    """
    Generate random text from vocabulary tokens that is length-matched
    (in tokens) to target_text.
    
    Uses a decode->re-encode verification loop to ensure the random
    prefix tokenizes to approximately the expected number of tokens.
    
    Args:
        target_text: Text to match in token length
        tokenizer: The tokenizer
        seed: Random seed for reproducibility
    
    Returns:
        Random text string with same token count as target_text
    """
    # Get target token count (no special tokens — we want content tokens only)
    target_ids = tokenizer.encode(target_text, add_special_tokens=False)
    target_len = len(target_ids)
    
    if target_len == 0:
        return ""
    
    rng = np.random.RandomState(seed)
    vocab_size = len(tokenizer)
    
    # Sample random token IDs from the full vocabulary
    # Exclude special tokens (typically IDs 0-2 for BOS, EOS, UNK)
    min_id = 3  # Skip BOS, EOS, UNK
    random_ids = rng.randint(min_id, vocab_size, size=target_len)
    
    # Decode to text
    random_text = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
    
    # Verification: re-encode and check length
    reencoded = tokenizer.encode(random_text, add_special_tokens=False)
    
    # If lengths don't match after round-trip, truncate or pad
    if len(reencoded) != target_len:
        # Truncate the text and re-decode from exactly target_len tokens
        if len(reencoded) > target_len:
            random_text = tokenizer.decode(reencoded[:target_len], skip_special_tokens=True)
        else:
            # Pad with more random tokens
            extra_needed = target_len - len(reencoded)
            extra_ids = rng.randint(min_id, vocab_size, size=extra_needed)
            extra_text = tokenizer.decode(extra_ids.tolist(), skip_special_tokens=True)
            random_text = random_text + extra_text
            reencoded2 = tokenizer.encode(random_text, add_special_tokens=False)
            if len(reencoded2) > target_len:
                random_text = tokenizer.decode(reencoded2[:target_len], skip_special_tokens=True)
    
    return random_text


# Test the function
test_query = samples[0]['query']
test_random = generate_random_prefix_text(test_query, tokenizer, seed=SEED)

oracle_tokens = tokenizer.encode(test_query, add_special_tokens=False)
random_tokens = tokenizer.encode(test_random, add_special_tokens=False)

print(f"Oracle query: {repr(test_query)}")
print(f"Oracle tokens: {len(oracle_tokens)}")
print(f"Random prefix: {repr(test_random[:80])}...")
print(f"Random tokens: {len(random_tokens)}")
print(f"Length match: {len(oracle_tokens) == len(random_tokens)}")

Oracle query: 'what temperature should it be to plant grass seeds'
Oracle tokens: 9
Random prefix: 'Restaur Mars didova少 DATA luxwalkshine'...
Random tokens: 9
Length match: True


In [6]:
# Cell 6: BPE diagnostics
# (a) Truncated mismatch confirmation (same as Exp 01)
# (b) Suffix passage token consistency check — verify passage tokens are identical
#     regardless of suffix content

print("BPE Boundary Diagnostics")
print("=" * 60)

# --- (a) Truncated mismatch confirmation ---
print("\n(a) TRUNCATED: Independent vs concatenated tokenization")
n_mismatch = 0
n_total = min(100, len(samples))

for i in range(n_total):
    passage = samples[i]['passage']
    query = samples[i]['query']
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_text = oracle_prefix + document_text
    full_ids = tokenizer.encode(full_text, add_special_tokens=True)
    prefix_ids = tokenizer.encode(oracle_prefix, add_special_tokens=True)
    prefix_len = len(prefix_ids)
    doc_from_concat = full_ids[prefix_len:]
    doc_independent = tokenizer.encode(passage, add_special_tokens=True)[1:]  # strip BOS
    if doc_from_concat != doc_independent:
        n_mismatch += 1

print(f"  Tested {n_total} samples")
print(f"  BPE mismatches: {n_mismatch}/{n_total} ({100*n_mismatch/n_total:.0f}%)")
print(f"  → Matched tokenization required for truncated conditions.")

# --- (b) Suffix passage token consistency ---
print("\n(b) SUFFIX: Passage tokens consistency across different suffixes")
n_suffix_test = min(50, len(samples))
n_passage_consistent = 0

for i in range(n_suffix_test):
    passage = samples[i]['passage']
    query = samples[i]['query']
    random_text = generate_random_prefix_text(query, tokenizer, seed=SEED + i)

    # Tokenize passage alone (bare reference)
    bare_ids = tokenizer.encode(passage, add_special_tokens=True)

    # Tokenize passage + separator + oracle suffix
    oracle_suffix_text = passage + SUFFIX_SEPARATOR + query
    oracle_suffix_ids = tokenizer.encode(oracle_suffix_text, add_special_tokens=True)

    # Tokenize passage + separator + random suffix
    random_suffix_text = passage + SUFFIX_SEPARATOR + random_text
    random_suffix_ids = tokenizer.encode(random_suffix_text, add_special_tokens=True)

    # Check: do the first len(bare_ids) tokens match across all three?
    bare_len = len(bare_ids)
    oracle_passage_part = oracle_suffix_ids[:bare_len]
    random_passage_part = random_suffix_ids[:bare_len]

    if bare_ids == oracle_passage_part == random_passage_part:
        n_passage_consistent += 1

print(f"  Tested {n_suffix_test} samples")
print(f"  Passage tokens consistent: {n_passage_consistent}/{n_suffix_test} ({100*n_passage_consistent/n_suffix_test:.0f}%)")
if n_passage_consistent == n_suffix_test:
    print(f"  → PASS: Passage tokens are identical regardless of suffix content.")
    print(f"  → Suffix has NO effect on passage KV entries (causal masking guarantees this).")
    print(f"  → Bare baseline is fair for suffix comparisons.")
else:
    n_diff = n_suffix_test - n_passage_consistent
    print(f"  → WARNING: {n_diff} samples have different passage tokens with different suffixes.")
    print(f"  → This could indicate BPE boundary effects at the passage/separator boundary.")
    print(f"  → Check if the separator creates clean BPE boundaries.")

BPE Boundary Diagnostics

(a) TRUNCATED: Independent vs concatenated tokenization
  Tested 100 samples
  BPE mismatches: 100/100 (100%)
  → Matched tokenization required for truncated conditions.

(b) SUFFIX: Passage tokens consistency across different suffixes
  Tested 50 samples
  Passage tokens consistent: 50/50 (100%)
  → PASS: Passage tokens are identical regardless of suffix content.
  → Suffix has NO effect on passage KV entries (causal masking guarantees this).
  → Bare baseline is fair for suffix comparisons.


In [None]:
# Cell 7: Bare baseline fairness diagnostic
# Compare matched bare [BOS]+doc_ids vs independent build_kv_cache(passage) NLLs
# to verify that using matched tokenization for bare doesn't create a systematic bias
# relative to the suffix conditions (which use independent tokenization via build_suffix_kv_cache).
#
# Why this matters: Truncated conditions use matched tokenization (doc_ids from
# concatenated encoding). Suffix conditions use build_suffix_kv_cache() which
# tokenizes passage independently. If matched bare and independent bare give
# systematically different NLLs, the bare baseline would be unfair to one mechanism.

print("Bare Baseline Fairness Diagnostic")
print("=" * 60)
print("Comparing matched bare (from truncated tokenization) vs")
print("independent bare (build_kv_cache(passage)) on 20 samples.")
print()

n_diag = 20
matched_nlls = []
independent_nlls = []

for i in range(n_diag):
    sample = samples[i]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # --- Matched bare: [BOS] + doc_ids from oracle concatenation ---
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_enc = tokenizer(oracle_prefix + document_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)
    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]
    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]

    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    matched_nll = score_answer_with_cache(
        deepcopy_cache(bare_out.past_key_values), bare_len,
        query_prompt, answer_text, model, tokenizer, config)
    matched_nlls.append(matched_nll)

    # --- Independent bare: build_kv_cache(passage) ---
    indep_len, indep_cache = build_kv_cache(passage, model, tokenizer, config)
    independent_nll = score_answer_with_cache(
        deepcopy_cache(indep_cache), indep_len,
        query_prompt, answer_text, model, tokenizer, config)
    independent_nlls.append(independent_nll)

    del bare_out, indep_cache
    torch.cuda.empty_cache()

matched_nlls = np.array(matched_nlls)
independent_nlls = np.array(independent_nlls)
diffs = matched_nlls - independent_nlls

print(f"{'Sample':>8} {'Matched':>10} {'Independent':>12} {'Diff':>10}")
print("-" * 45)
for i in range(n_diag):
    print(f"{i:>8} {matched_nlls[i]:>10.4f} {independent_nlls[i]:>12.4f} {diffs[i]:>10.6f}")

# Filter out zero-NLL samples (single-token answers are uninformative)
nonzero_mask = (matched_nlls != 0.0) & (independent_nlls != 0.0)
diffs_nonzero = diffs[nonzero_mask]
n_nonzero = np.sum(nonzero_mask)

print(f"\nNon-zero NLL samples: {n_nonzero}/{n_diag}")
print(f"Mean difference: {np.mean(diffs_nonzero):.6f}")
print(f"Mean abs difference: {np.mean(np.abs(diffs_nonzero)):.6f}")
print(f"Max abs difference: {np.max(np.abs(diffs_nonzero)):.6f}")

# Context: Exp 01 effect sizes were 0.009-0.029 in mean delta.
# A systematic bias of ~0.005 would be small relative to these effects.
mean_bias = np.mean(diffs_nonzero)
mean_abs_diff = np.mean(np.abs(diffs_nonzero))

if abs(mean_bias) < 0.01:
    print(f"\nPASS: Mean systematic bias ({mean_bias:+.4f}) is negligible.")
    print(f"Individual samples vary (max abs {np.max(np.abs(diffs_nonzero)):.3f}) due to")
    print(f"BPE boundary effects on the first passage token, but this is symmetric noise,")
    print(f"not a directional bias. Bare baseline is fair for both mechanisms.")
else:
    print(f"\nCAUTION: Mean systematic bias ({mean_bias:+.4f}) detected.")
    print(f"This is {'small' if abs(mean_bias) < 0.03 else 'non-trivial'} relative to")
    print(f"expected effect sizes (Exp 01: d~0.02-0.09, Δ~0.009-0.029).")
    print(f"Suffix comparisons (P1, P2, S1) are unaffected (both suffix conditions")
    print(f"use independent tokenization). Cross-mechanism comparisons (S2) may")
    print(f"have a small bias of ~{abs(mean_bias):.3f} in mean delta.")

In [8]:
# Cell 8: Condition explanation printout (all 5 conditions with concrete examples)
print("=" * 70)
print("EXPERIMENTAL CONDITIONS EXPLAINED")
print("=" * 70)

ex = samples[0]
ex_query = ex['query']
ex_passage = ex['passage'][:80] + "..."
ex_answer = ex['answer'][:60] + "..."
ex_random = generate_random_prefix_text(ex_query, tokenizer, seed=SEED)

print(f"\nExample passage: {repr(ex_passage)}")
print(f"Example query:   {repr(ex_query)}")
print(f"Example answer:  {repr(ex_answer)}")
print(f"Example random:  {repr(ex_random[:60])}...")

# Show matched tokenization
oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=ex_query)
doc_text = DOCUMENT_TEMPLATE.format(document=ex['passage'])
full_text = oracle_prefix + doc_text
full_ids = tokenizer.encode(full_text, add_special_tokens=True)
prefix_ids = tokenizer.encode(oracle_prefix, add_special_tokens=True)
prefix_len = len(prefix_ids)
doc_ids = full_ids[prefix_len:]

print(f"\n{'─' * 70}")
print("MATCHED TOKENIZATION (for truncated conditions)")
print(f"  Tokenize oracle_prefix + passage together → {len(full_ids)} tokens")
print(f"  Oracle prefix tokens (with BOS): {prefix_len}")
print(f"  Document tokens (shared by truncated conditions + bare): {len(doc_ids)}")

print(f"\n{'─' * 70}")
print("### CONDITION 1: BARE (baseline) ###")
print(f"  Input IDs:  [BOS] + doc_ids ({1 + len(doc_ids)} tokens)")
print(f"  Key insight: Pure baseline. Same doc tokens as truncated conditions.")
print(f"  Also fair for suffix comparisons (Cell 7 diagnostic verifies this).")

print(f"\n{'─' * 70}")
print("### CONDITION 2: ORACLE-TRUNCATED (Exp 01 replication) ###")
print(f"  Build:  [BOS]['{ex_query}\\n'][doc_ids] ({len(full_ids)} tokens)")
print(f"  After:  Truncate prefix → [BOS] + doc_ids ({1 + len(doc_ids)} tokens) + RoPE correct")
print(f"  Key insight: Value contamination from semantically relevant prefix.")

print(f"\n{'─' * 70}")
print("### CONDITION 3: RANDOM-TRUNCATED (Exp 01 replication) ###")
print(f"  Build:  [BOS]['{ex_random[:30]}...\\n'][doc_ids]")
print(f"  After:  Truncate prefix → [BOS] + doc_ids ({1 + len(doc_ids)} tokens) + RoPE correct")
print(f"  Key insight: Structural control — value contamination from random prefix.")

# Show suffix examples
suffix_oracle_text = ex['passage'] + SUFFIX_SEPARATOR + ex_query
suffix_oracle_ids = tokenizer.encode(suffix_oracle_text, add_special_tokens=True)
suffix_random_text = ex['passage'] + SUFFIX_SEPARATOR + ex_random
suffix_random_ids = tokenizer.encode(suffix_random_text, add_special_tokens=True)

print(f"\n{'─' * 70}")
print("### CONDITION 4: ORACLE-SUFFIX (NEW — cleanest semantic test) ###")
print(f"  Build:  [BOS][passage]['{SUFFIX_SEPARATOR}']['{ex_query}'] ({len(suffix_oracle_ids)} tokens)")
print(f"  Scoring: Query attends to passage + separator + suffix KV entries")
print(f"  Key insight: Passage KV entries are BYTE-IDENTICAL to bare (causal masking).")
print(f"  Any benefit must come from query → suffix attention.")

print(f"\n{'─' * 70}")
print("### CONDITION 5: RANDOM-SUFFIX (structural control for suffix) ###")
print(f"  Build:  [BOS][passage]['{SUFFIX_SEPARATOR}']['{ex_random[:30]}...'] ({len(suffix_random_ids)} tokens)")
print(f"  Scoring: Query attends to passage + separator + suffix KV entries")
print(f"  Key insight: Same structure as oracle-suffix but random content.")
print(f"  P1 (oracle-suffix vs random-suffix) isolates semantic signal.")

print(f"\n{'─' * 70}")
print("DESIGN NOTES:")
print(f"  Same random text used for BOTH random-truncated and random-suffix per sample.")
print(f"  Suffix separator: {repr(SUFFIX_SEPARATOR)} (identical for oracle/random suffix).")
print(f"  CACHE SAFETY: deepcopy_cache() before every score call.")

EXPERIMENTAL CONDITIONS EXPLAINED

Example passage: 'Usually planted in the early fall, cool-season grass seeds prefer daytime temper...'
Example query:   'what temperature should it be to plant grass seeds'
Example answer:  'Between 50deg and 65deg F...'
Example random:  'Restaur Mars didova少 DATA luxwalkshine'...

──────────────────────────────────────────────────────────────────────
MATCHED TOKENIZATION (for truncated conditions)
  Tokenize oracle_prefix + passage together → 134 tokens
  Oracle prefix tokens (with BOS): 11
  Document tokens (shared by truncated conditions + bare): 123

──────────────────────────────────────────────────────────────────────
### CONDITION 1: BARE (baseline) ###
  Input IDs:  [BOS] + doc_ids (124 tokens)
  Key insight: Pure baseline. Same doc tokens as truncated conditions.
  Also fair for suffix comparisons (Cell 7 diagnostic verifies this).

──────────────────────────────────────────────────────────────────────
### CONDITION 2: ORACLE-TRUNCATED (Exp 0

In [9]:
# Cell 9: Evaluation parameters + checkpoint loading

N = len(samples)
print(f"Total samples: {N}")

# Check for existing checkpoint
results = []
start_idx = 0

if CHECKPOINT_PATH.exists():
    with open(CHECKPOINT_PATH, 'r') as f:
        checkpoint = json.load(f)

    # Verify checkpoint matches current sample list (resume correctness)
    ckpt_sample_ids = checkpoint.get('sample_queries', [])
    current_sample_ids = [s['query'] for s in samples]

    if ckpt_sample_ids == current_sample_ids:
        results = checkpoint['results']
        start_idx = len(results)
        print(f"Resuming from checkpoint: {start_idx}/{N} samples completed")
    else:
        print("WARNING: Checkpoint sample list doesn't match current samples.")
        print("Starting from scratch.")
        results = []
        start_idx = 0
else:
    print("No checkpoint found. Starting from scratch.")

print(f"Will evaluate samples {start_idx} to {N-1}")
print(f"\nConditions per sample: 5")
print(f"Total condition evaluations: {(N - start_idx) * 5}")

Total samples: 2500
No checkpoint found. Starting from scratch.
Will evaluate samples 0 to 2499

Conditions per sample: 5
Total condition evaluations: 12500


In [19]:
# Cell 10: Main evaluation loop (5 conditions × N samples, with checkpointing)
#
# Conditions 1-3: Matched tokenization (identical to Exp 01)
# Conditions 4-5: Suffix via build_suffix_kv_cache()
#
# Same random text used for both random-truncated (cond 3) and random-suffix (cond 5).

t_start = time.time()

for idx in tqdm(range(start_idx, N), initial=start_idx, total=N, desc="Evaluating"):
    sample = samples[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # === Step 1: Determine canonical document token IDs (for truncated conditions) ===
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_text = oracle_prefix + document_text

    full_oracle_enc = tokenizer(full_oracle_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)

    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]  # includes BOS

    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    doc_len = doc_ids.shape[1]

    # Generate random text ONCE — used for BOTH random-truncated and random-suffix
    random_text = generate_random_prefix_text(query, tokenizer, seed=SEED + idx)

    # === Condition 1: BARE — [BOS] + doc_ids ===
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]  # = 1 + doc_len
    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    bare_cache = bare_out.past_key_values
    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_cache), bare_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 2: ORACLE-TRUNCATED — [BOS][oracle_prefix][doc_ids] → truncate ===
    with torch.no_grad():
        oracle_trunc_out = model(input_ids=full_oracle_ids,
                                 attention_mask=torch.ones_like(full_oracle_ids),
                                 use_cache=True, return_dict=True)
    oracle_trunc_cache = extract_and_truncate_cache_with_bos(oracle_trunc_out.past_key_values, doc_len)
    oracle_trunc_len = 1 + doc_len
    correct_rope_positions_with_bos(oracle_trunc_cache, oracle_prefix_len - 1, model)
    oracle_trunc_nll = score_answer_with_cache(
        deepcopy_cache(oracle_trunc_cache), oracle_trunc_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 3: RANDOM-TRUNCATED — [BOS][random_prefix][doc_ids] → truncate ===
    random_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_text)
    random_prefix_enc = tokenizer(random_prefix, return_tensors="pt",
                                  add_special_tokens=False, padding=False, truncation=False)
    random_prefix_ids = random_prefix_enc['input_ids'].to(config.device)
    random_full_ids = torch.cat([bos_id, random_prefix_ids, doc_ids], dim=1)
    random_prefix_len = 1 + random_prefix_ids.shape[1]  # BOS + prefix tokens

    with torch.no_grad():
        random_trunc_out = model(input_ids=random_full_ids,
                                 attention_mask=torch.ones_like(random_full_ids),
                                 use_cache=True, return_dict=True)
    random_trunc_cache = extract_and_truncate_cache_with_bos(random_trunc_out.past_key_values, doc_len)
    random_trunc_len = 1 + doc_len
    correct_rope_positions_with_bos(random_trunc_cache, random_prefix_len - 1, model)
    random_trunc_nll = score_answer_with_cache(
        deepcopy_cache(random_trunc_cache), random_trunc_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 4: ORACLE-SUFFIX — build_suffix_kv_cache(passage, query) ===
    oracle_sfx_len, oracle_sfx_cache = build_suffix_kv_cache(
        passage, query, model, tokenizer, config, separator=SUFFIX_SEPARATOR)
    oracle_suffix_nll = score_answer_with_cache(
        deepcopy_cache(oracle_sfx_cache), oracle_sfx_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 5: RANDOM-SUFFIX — build_suffix_kv_cache(passage, random_text) ===
    random_sfx_len, random_sfx_cache = build_suffix_kv_cache(
        passage, random_text, model, tokenizer, config, separator=SUFFIX_SEPARATOR)
    random_suffix_nll = score_answer_with_cache(
        deepcopy_cache(random_sfx_cache), random_sfx_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # Record all 5 NLLs + cache lengths + precomputed deltas
    result = {
        'idx': idx,
        'bare_nll': bare_nll,
        'oracle_trunc_nll': oracle_trunc_nll,
        'random_trunc_nll': random_trunc_nll,
        'oracle_suffix_nll': oracle_suffix_nll,
        'random_suffix_nll': random_suffix_nll,
        'bare_len': bare_len,
        'oracle_trunc_len': oracle_trunc_len,
        'random_trunc_len': random_trunc_len,
        'oracle_suffix_len': oracle_sfx_len,
        'random_suffix_len': random_sfx_len,
        'doc_len': doc_len,
        # Precomputed deltas (positive = first condition has LOWER NLL = better)
        # P1: oracle-suffix vs random-suffix
        'delta_p1_oracle_sfx_vs_random_sfx': random_suffix_nll - oracle_suffix_nll,
        # P2: oracle-suffix vs bare
        'delta_p2_oracle_sfx_vs_bare': bare_nll - oracle_suffix_nll,
        # P3: oracle-truncated vs bare
        'delta_p3_oracle_trunc_vs_bare': bare_nll - oracle_trunc_nll,
        # S1: random-suffix vs bare
        'delta_s1_random_sfx_vs_bare': bare_nll - random_suffix_nll,
        # S2: oracle-suffix vs oracle-truncated
        'delta_s2_oracle_sfx_vs_oracle_trunc': oracle_trunc_nll - oracle_suffix_nll,
        # S3: random-truncated vs bare
        'delta_s3_random_trunc_vs_bare': bare_nll - random_trunc_nll,
    }
    results.append(result)

    # GPU memory management
    del bare_cache, bare_out, oracle_trunc_cache, oracle_trunc_out
    del random_trunc_cache, random_trunc_out, oracle_sfx_cache, random_sfx_cache
    torch.cuda.empty_cache()

    # Checkpoint
    if (idx + 1) % CHECKPOINT_EVERY == 0 or idx == N - 1:
        checkpoint_data = {
            'results': results,
            'sample_queries': [s['query'] for s in samples],
            'completed': len(results),
            'total': N,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_PATH, 'w') as f:
            json.dump(checkpoint_data, f)

        elapsed = time.time() - t_start
        samples_done = idx - start_idx + 1
        rate = samples_done / elapsed if elapsed > 0 else 0
        remaining = (N - idx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint at {idx+1}/{N} | Rate: {rate:.1f} samples/s | ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEvaluation complete: {len(results)} samples in {elapsed_total/60:.1f} minutes")

Evaluating:   0%|          | 0/2500 [00:00<?, ?it/s]

  Checkpoint at 50/2500 | Rate: 0.5 samples/s | ETA: 86.6 min
  Checkpoint at 100/2500 | Rate: 0.5 samples/s | ETA: 86.0 min
  Checkpoint at 150/2500 | Rate: 0.5 samples/s | ETA: 84.9 min
  Checkpoint at 200/2500 | Rate: 0.5 samples/s | ETA: 83.4 min
  Checkpoint at 250/2500 | Rate: 0.5 samples/s | ETA: 81.9 min
  Checkpoint at 300/2500 | Rate: 0.5 samples/s | ETA: 80.3 min
  Checkpoint at 350/2500 | Rate: 0.5 samples/s | ETA: 78.5 min
  Checkpoint at 400/2500 | Rate: 0.5 samples/s | ETA: 76.6 min
  Checkpoint at 450/2500 | Rate: 0.5 samples/s | ETA: 74.9 min
  Checkpoint at 500/2500 | Rate: 0.5 samples/s | ETA: 73.1 min
  Checkpoint at 550/2500 | Rate: 0.5 samples/s | ETA: 71.3 min
  Checkpoint at 600/2500 | Rate: 0.5 samples/s | ETA: 69.5 min
  Checkpoint at 650/2500 | Rate: 0.5 samples/s | ETA: 67.7 min
  Checkpoint at 700/2500 | Rate: 0.5 samples/s | ETA: 65.8 min
  Checkpoint at 750/2500 | Rate: 0.5 samples/s | ETA: 64.0 min
  Checkpoint at 800/2500 | Rate: 0.5 samples/s | ETA: 62

In [20]:
# Cell 11: Cache length diagnostics
print("Cache Length Diagnostics")
print("=" * 60)

# --- Truncated match check ---
print("\n(a) TRUNCATED: All cache lengths match bare?")
bare_lens = [r['bare_len'] for r in results]
oracle_trunc_lens = [r['oracle_trunc_len'] for r in results]
random_trunc_lens = [r['random_trunc_len'] for r in results]
doc_lens = [r['doc_len'] for r in results]

all_trunc_match = all(
    b == ot == rt == 1 + d
    for b, ot, rt, d in zip(bare_lens, oracle_trunc_lens, random_trunc_lens, doc_lens)
)
print(f"  All truncated cache lengths match (bare == oracle-trunc == random-trunc == 1+doc): {all_trunc_match}")
if all_trunc_match:
    print("  PASS: Matched tokenization guarantees identical cache lengths.")
else:
    mismatches = sum(1 for b, ot, rt in zip(bare_lens, oracle_trunc_lens, random_trunc_lens) if not (b == ot == rt))
    print(f"  WARNING: {mismatches} samples have mismatched truncated lengths!")

# --- Suffix length distribution ---
print("\n(b) SUFFIX: Cache length distribution")
oracle_sfx_lens = np.array([r['oracle_suffix_len'] for r in results])
random_sfx_lens = np.array([r['random_suffix_len'] for r in results])
bare_lens_arr = np.array(bare_lens)

oracle_sfx_extra = oracle_sfx_lens - bare_lens_arr
random_sfx_extra = random_sfx_lens - bare_lens_arr

print(f"  Oracle-suffix extra tokens (vs bare):  mean={np.mean(oracle_sfx_extra):.1f}, min={np.min(oracle_sfx_extra)}, max={np.max(oracle_sfx_extra)}")
print(f"  Random-suffix extra tokens (vs bare):  mean={np.mean(random_sfx_extra):.1f}, min={np.min(random_sfx_extra)}, max={np.max(random_sfx_extra)}")
print(f"  Oracle-suffix total length:  mean={np.mean(oracle_sfx_lens):.1f}, min={np.min(oracle_sfx_lens)}, max={np.max(oracle_sfx_lens)}")
print(f"  Random-suffix total length:  mean={np.mean(random_sfx_lens):.1f}, min={np.min(random_sfx_lens)}, max={np.max(random_sfx_lens)}")

# Note: suffix length = passage tokens + separator tokens + suffix tokens
# This won't exactly match bare_len because bare uses matched tokenization
# while suffix uses independent tokenization via build_suffix_kv_cache.
# Cell 7 verified this doesn't cause meaningful NLL differences.

Cache Length Diagnostics

(a) TRUNCATED: All cache lengths match bare?
  All truncated cache lengths match (bare == oracle-trunc == random-trunc == 1+doc): True
  PASS: Matched tokenization guarantees identical cache lengths.

(b) SUFFIX: Cache length distribution
  Oracle-suffix extra tokens (vs bare):  mean=13.8, min=7, max=28
  Random-suffix extra tokens (vs bare):  mean=13.8, min=7, max=28
  Oracle-suffix total length:  mean=140.4, min=68, max=295
  Random-suffix total length:  mean=140.4, min=68, max=295


In [21]:
# Cell 12: Primary analysis — filter zeros, compute all 6 comparisons with Bonferroni
print("=" * 70)
print("PRIMARY ANALYSIS — 6 COMPARISONS WITH BONFERRONI CORRECTION")
print("=" * 70)

bare_nlls_raw = np.array([r['bare_nll'] for r in results])
oracle_trunc_nlls_raw = np.array([r['oracle_trunc_nll'] for r in results])
random_trunc_nlls_raw = np.array([r['random_trunc_nll'] for r in results])
oracle_suffix_nlls_raw = np.array([r['oracle_suffix_nll'] for r in results])
random_suffix_nlls_raw = np.array([r['random_suffix_nll'] for r in results])

# Sanity checks
for name, arr in [('bare', bare_nlls_raw), ('oracle_trunc', oracle_trunc_nlls_raw),
                   ('random_trunc', random_trunc_nlls_raw), ('oracle_suffix', oracle_suffix_nlls_raw),
                   ('random_suffix', random_suffix_nlls_raw)]:
    assert not np.any(np.isnan(arr)), f"NaN in {name} NLLs!"

# Filter out degenerate samples where ANY NLL is 0.0
valid_mask = (
    (bare_nlls_raw != 0.0) &
    (oracle_trunc_nlls_raw != 0.0) &
    (random_trunc_nlls_raw != 0.0) &
    (oracle_suffix_nlls_raw != 0.0) &
    (random_suffix_nlls_raw != 0.0)
)
n_invalid = np.sum(~valid_mask)
n_valid = np.sum(valid_mask)
print(f"Total samples: {len(results)}")
print(f"Excluded (zero NLL from single-token answers): {n_invalid}")
print(f"Valid samples for analysis: {n_valid}")
print(f"Bonferroni-corrected alpha: {BONFERRONI_ALPHA:.4f} (0.05 / 6)\n")

bare = bare_nlls_raw[valid_mask]
oracle_trunc = oracle_trunc_nlls_raw[valid_mask]
random_trunc = random_trunc_nlls_raw[valid_mask]
oracle_suffix = oracle_suffix_nlls_raw[valid_mask]
random_suffix = random_suffix_nlls_raw[valid_mask]

# NLL summary
print(f"{'Condition':<20} {'Mean NLL':>10} {'Std':>10} {'Median':>10}")
print("-" * 55)
for name, arr in [('Bare', bare), ('Oracle-truncated', oracle_trunc),
                   ('Random-truncated', random_trunc), ('Oracle-suffix', oracle_suffix),
                   ('Random-suffix', random_suffix)]:
    print(f"{name:<20} {np.mean(arr):>10.4f} {np.std(arr):>10.4f} {np.median(arr):>10.4f}")

# All 6 comparisons
# Convention: positive delta = first condition is BETTER (lower NLL)
comparisons = [
    # Primary
    ('P1: Oracle-sfx vs Random-sfx', random_suffix - oracle_suffix, 'Semantic signal in suffix?'),
    ('P2: Oracle-sfx vs Bare', bare - oracle_suffix, 'Does suffix priming help?'),
    ('P3: Oracle-trunc vs Bare', bare - oracle_trunc, 'Exp 01 replication (d~0.023 ns)?'),
    # Secondary
    ('S1: Random-sfx vs Bare', bare - random_suffix, 'Any suffix helps?'),
    ('S2: Oracle-sfx vs Oracle-trunc', oracle_trunc - oracle_suffix, 'Suffix vs truncated?'),
    ('S3: Random-trunc vs Bare', bare - random_trunc, 'Exp 01 random replication (d~0.091)?'),
]

print(f"\n{'─' * 80}")
print("PAIRED COMPARISONS (positive delta = first condition has LOWER NLL = better)")
print(f"{'─' * 80}")
print(f"\n{'Comparison':<35} {'Mean Δ':>8} {'d':>8} {'Win%':>7} {'t':>8} {'p':>12} {'Sig':>5}")
print("-" * 88)

comparison_results = {}
for name, delta, question in comparisons:
    d = cohens_d(delta)
    win_rate = np.mean(delta > 0) * 100
    t_stat, p_val = stats.ttest_1samp(delta, 0)
    sig = "***" if p_val < 0.001 else "**" if p_val < BONFERRONI_ALPHA else "*" if p_val < 0.05 else "ns"
    bonf_sig = p_val < BONFERRONI_ALPHA
    print(f"{name:<35} {np.mean(delta):>8.4f} {d:>8.3f} {win_rate:>6.1f}% {t_stat:>8.2f} {p_val:>11.2e} {sig:>5}")
    comparison_results[name] = {
        'mean_delta': float(np.mean(delta)),
        'cohens_d': float(d),
        'win_rate': float(win_rate / 100),
        't_stat': float(t_stat),
        'p_value': float(p_val),
        'bonferroni_significant': bool(bonf_sig),
        'question': question,
    }

print(f"\nSignificance levels: *** p<0.001, ** p<{BONFERRONI_ALPHA:.4f} (Bonferroni), * p<0.05, ns = not significant")
print(f"Cohen's d interpretation: |d|<0.2 = negligible, 0.2-0.5 = small, 0.5-0.8 = medium, >0.8 = large")

PRIMARY ANALYSIS — 6 COMPARISONS WITH BONFERRONI CORRECTION
Total samples: 5000
Excluded (zero NLL from single-token answers): 394
Valid samples for analysis: 4606
Bonferroni-corrected alpha: 0.0083 (0.05 / 6)

Condition              Mean NLL        Std     Median
-------------------------------------------------------
Bare                     1.1455     1.5698     0.6206
Oracle-truncated         1.1369     1.5553     0.6270
Random-truncated         1.1171     1.5405     0.5923
Oracle-suffix            1.1168     1.4874     0.6098
Random-suffix            1.0513     1.4846     0.5464

────────────────────────────────────────────────────────────────────────────────
PAIRED COMPARISONS (positive delta = first condition has LOWER NLL = better)
────────────────────────────────────────────────────────────────────────────────

Comparison                            Mean Δ        d    Win%        t            p   Sig
------------------------------------------------------------------------------

In [22]:
# Cell 13: Summary table with verdicts
print("=" * 70)
print("SUMMARY TABLE WITH VERDICTS")
print("=" * 70)
print(f"\nN = {n_valid} valid samples (of {len(results)} total)")
print(f"Bonferroni-corrected alpha = {BONFERRONI_ALPHA:.4f}\n")

for name, delta, question in comparisons:
    cr = comparison_results[name]
    d = cr['cohens_d']
    p = cr['p_value']
    bonf = cr['bonferroni_significant']
    win = cr['win_rate'] * 100

    if bonf:
        if cr['mean_delta'] > 0:
            verdict = "YES — significant benefit (survives Bonferroni)"
        else:
            verdict = "YES — significant HARM (survives Bonferroni)"
    elif p < 0.05:
        if cr['mean_delta'] > 0:
            verdict = "Suggestive benefit (p<0.05 but fails Bonferroni)"
        else:
            verdict = "Suggestive harm (p<0.05 but fails Bonferroni)"
    else:
        verdict = "NO — not significant"

    print(f"Q: {question}")
    print(f"  {name}")
    print(f"  Δ = {cr['mean_delta']:+.4f}, d = {d:+.3f}, Win% = {win:.1f}%, p = {p:.2e}")
    print(f"  → {verdict}")
    print()

SUMMARY TABLE WITH VERDICTS

N = 4606 valid samples (of 5000 total)
Bonferroni-corrected alpha = 0.0083

Q: Semantic signal in suffix?
  P1: Oracle-sfx vs Random-sfx
  Δ = -0.0655, d = -0.192, Win% = 37.6%, p = 2.99e-38
  → YES — significant HARM (survives Bonferroni)

Q: Does suffix priming help?
  P2: Oracle-sfx vs Bare
  Δ = +0.0287, d = +0.071, Win% = 52.9%, p = 1.29e-06
  → YES — significant benefit (survives Bonferroni)

Q: Exp 01 replication (d~0.023 ns)?
  P3: Oracle-trunc vs Bare
  Δ = +0.0086, d = +0.023, Win% = 50.0%, p = 1.13e-01
  → NO — not significant

Q: Any suffix helps?
  S1: Random-sfx vs Bare
  Δ = +0.0942, d = +0.264, Win% = 65.2%, p = 1.14e-69
  → YES — significant benefit (survives Bonferroni)

Q: Suffix vs truncated?
  S2: Oracle-sfx vs Oracle-trunc
  Δ = +0.0201, d = +0.048, Win% = 52.8%, p = 1.09e-03
  → YES — significant benefit (survives Bonferroni)

Q: Exp 01 random replication (d~0.091)?
  S3: Random-trunc vs Bare
  Δ = +0.0285, d = +0.091, Win% = 59.5%, p

In [23]:
# Cell 14: Hardness interaction (all 4 non-bare conditions)
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt

print("=" * 70)
print("HARDNESS INTERACTION ANALYSIS")
print("=" * 70)

# Deltas for all 4 non-bare conditions (positive = condition better than bare)
delta_oracle_trunc = bare - oracle_trunc
delta_random_trunc = bare - random_trunc
delta_oracle_sfx = bare - oracle_suffix
delta_random_sfx = bare - random_suffix

conditions = [
    ('Oracle-truncated', delta_oracle_trunc),
    ('Random-truncated', delta_random_trunc),
    ('Oracle-suffix', delta_oracle_sfx),
    ('Random-suffix', delta_random_sfx),
]

print(f"\nPearson r of bare NLL (hardness) with condition benefit:")
print(f"{'Condition':<25} {'r':>8} {'p':>12}")
print("-" * 48)

hardness_results = {}
for name, delta in conditions:
    r, p = stats.pearsonr(bare, delta)
    print(f"{name:<25} {r:>8.4f} {p:>12.2e}")
    hardness_results[name] = {'r': float(r), 'p': float(p)}

# Quartile breakdown
quartiles = np.percentile(bare, [25, 50, 75])
q_labels = ['Q1 (easy)', 'Q2', 'Q3', 'Q4 (hard)']
q_bounds = [(-np.inf, quartiles[0]), (quartiles[0], quartiles[1]),
            (quartiles[1], quartiles[2]), (quartiles[2], np.inf)]

print(f"\nQuartile breakdown (by bare NLL hardness):")
header = f"{'Quartile':<12} {'N':>5} {'Bare':>8}"
for name, _ in conditions:
    short = name.split('-')[0][:3] + '-' + name.split('-')[1][:3]
    header += f" {'Δ_'+short:>10} {'d_'+short:>8}"
print(header)
print("-" * len(header))

for label, (lo, hi) in zip(q_labels, q_bounds):
    mask = (bare > lo) & (bare <= hi)
    n_q = np.sum(mask)
    mean_bare = np.mean(bare[mask])
    row = f"{label:<12} {n_q:>5} {mean_bare:>8.3f}"
    for name, delta in conditions:
        dq = delta[mask]
        row += f" {np.mean(dq):>10.4f} {cohens_d(dq):>8.3f}"
    print(row)

HARDNESS INTERACTION ANALYSIS

Pearson r of bare NLL (hardness) with condition benefit:
Condition                        r            p
------------------------------------------------
Oracle-truncated            0.1569     8.77e-27
Random-truncated            0.1927     9.29e-40
Oracle-suffix               0.3278    7.54e-116
Random-suffix               0.3462    7.47e-130

Quartile breakdown (by bare NLL hardness):
Quartile         N     Bare  Δ_Ora-tru d_Ora-tru  Δ_Ran-tru d_Ran-tru  Δ_Ora-suf d_Ora-suf  Δ_Ran-suf d_Ran-suf
---------------------------------------------------------------------------------------------------------------
Q1 (easy)     1152    0.086    -0.0331   -0.195     0.0014    0.023    -0.0501   -0.377    -0.0023   -0.043
Q2            1152    0.417    -0.0213   -0.127     0.0132    0.105    -0.0270   -0.132     0.0348    0.280
Q3            1150    0.943     0.0114    0.044     0.0303    0.154     0.0056    0.019     0.0847    0.375
Q4 (hard)     1152    3.136    

In [24]:
# Cell 15: Cross-mechanism comparison (suffix vs truncated scatter plot)
print("Cross-Mechanism Comparison: Suffix vs Truncated")
print("=" * 60)

# Oracle: suffix benefit vs truncated benefit
oracle_sfx_benefit = bare - oracle_suffix  # positive = suffix better than bare
oracle_trunc_benefit = bare - oracle_trunc  # positive = trunc better than bare

# Random: suffix benefit vs truncated benefit
random_sfx_benefit = bare - random_suffix
random_trunc_benefit = bare - random_trunc

# Correlation
r_oracle_cross, p_oracle_cross = stats.pearsonr(oracle_trunc_benefit, oracle_sfx_benefit)
r_random_cross, p_random_cross = stats.pearsonr(random_trunc_benefit, random_sfx_benefit)

print(f"\nCorrelation of truncated benefit vs suffix benefit:")
print(f"  Oracle: r = {r_oracle_cross:.4f}, p = {p_oracle_cross:.2e}")
print(f"  Random: r = {r_random_cross:.4f}, p = {p_random_cross:.2e}")

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Oracle scatter
axes[0].scatter(oracle_trunc_benefit, oracle_sfx_benefit, alpha=0.15, s=5, c='steelblue')
axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[0].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
# Regression line
z = np.polyfit(oracle_trunc_benefit, oracle_sfx_benefit, 1)
p_fit = np.poly1d(z)
x_range = np.linspace(oracle_trunc_benefit.min(), oracle_trunc_benefit.max(), 100)
axes[0].plot(x_range, p_fit(x_range), 'r-', alpha=0.8, label=f'r={r_oracle_cross:.3f}')
axes[0].set_xlabel('Oracle-truncated benefit (vs bare)')
axes[0].set_ylabel('Oracle-suffix benefit (vs bare)')
axes[0].set_title('Oracle: Truncated vs Suffix Benefit')
axes[0].legend()

# Random scatter
axes[1].scatter(random_trunc_benefit, random_sfx_benefit, alpha=0.15, s=5, c='darkorange')
axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[1].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
z2 = np.polyfit(random_trunc_benefit, random_sfx_benefit, 1)
p_fit2 = np.poly1d(z2)
x_range2 = np.linspace(random_trunc_benefit.min(), random_trunc_benefit.max(), 100)
axes[1].plot(x_range2, p_fit2(x_range2), 'r-', alpha=0.8, label=f'r={r_random_cross:.3f}')
axes[1].set_xlabel('Random-truncated benefit (vs bare)')
axes[1].set_ylabel('Random-suffix benefit (vs bare)')
axes[1].set_title('Random: Truncated vs Suffix Benefit')
axes[1].legend()

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'cross_mechanism_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\nPlot saved to {RESULTS_DIR / 'cross_mechanism_comparison.png'}")

Cross-Mechanism Comparison: Suffix vs Truncated

Correlation of truncated benefit vs suffix benefit:
  Oracle: r = 0.4182, p = 2.01e-194
  Random: r = 0.4651, p = 5.58e-246

Plot saved to results/exp02/cross_mechanism_comparison.png


In [25]:
# Cell 16: Delta distribution plots (2x3 grid)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

plot_configs = [
    # Row 1: Primary comparisons
    ('P1: Oracle-sfx vs Random-sfx', random_suffix - oracle_suffix, 'steelblue'),
    ('P2: Oracle-sfx vs Bare', bare - oracle_suffix, 'forestgreen'),
    ('P3: Oracle-trunc vs Bare', bare - oracle_trunc, 'darkorange'),
    # Row 2: Secondary comparisons
    ('S1: Random-sfx vs Bare', bare - random_suffix, 'mediumpurple'),
    ('S2: Oracle-sfx vs Oracle-trunc', oracle_trunc - oracle_suffix, 'crimson'),
    ('S3: Random-trunc vs Bare', bare - random_trunc, 'teal'),
]

for ax, (title, delta, color) in zip(axes.flat, plot_configs):
    cr = comparison_results[title]
    ax.hist(delta, bins=80, color=color, alpha=0.7, edgecolor='black', linewidth=0.3)
    ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
    ax.axvline(x=np.mean(delta), color='black', linestyle='-', alpha=0.8,
               label=f'mean={np.mean(delta):.4f}, d={cr["cohens_d"]:+.3f}')
    ax.set_xlabel('Delta NLL (+ = first better)')
    ax.set_ylabel('Count')
    ax.set_title(title, fontsize=10)
    ax.legend(fontsize=8)

plt.suptitle('Delta NLL Distributions — All 6 Comparisons', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'delta_distributions.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'delta_distributions.png'}")

Plot saved to results/exp02/delta_distributions.png


In [26]:
# Cell 17: Hardness scatter plots (2x2 grid — all 4 non-bare conditions)

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

scatter_configs = [
    ('Oracle-truncated', delta_oracle_trunc, 'steelblue'),
    ('Random-truncated', delta_random_trunc, 'darkorange'),
    ('Oracle-suffix', delta_oracle_sfx, 'forestgreen'),
    ('Random-suffix', delta_random_sfx, 'mediumpurple'),
]

for ax, (name, delta, color) in zip(axes.flat, scatter_configs):
    r, p = hardness_results[name]['r'], hardness_results[name]['p']
    ax.scatter(bare, delta, alpha=0.12, s=5, c=color)
    ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
    # Regression line
    z = np.polyfit(bare, delta, 1)
    p_fit = np.poly1d(z)
    x_range = np.linspace(bare.min(), bare.max(), 100)
    ax.plot(x_range, p_fit(x_range), 'r-', alpha=0.8, label=f'r={r:.3f}, p={p:.1e}')
    ax.set_xlabel('Bare NLL (hardness)')
    ax.set_ylabel(f'{name} benefit (+ = helps)')
    ax.set_title(f'Hardness vs {name} Benefit')
    ax.legend()

plt.suptitle('Hardness Interaction — All 4 Non-Bare Conditions', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'hardness_interaction.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'hardness_interaction.png'}")

Plot saved to results/exp02/hardness_interaction.png


In [27]:
# Cell 18: Save results JSON

final_results = {
    'experiment': 'exp02_suffix_vs_truncated',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': config.model_name,
        'num_samples': config.num_samples,
        'seed': SEED,
        'min_passage_words': config.min_passage_words,
        'max_passage_words': config.max_passage_words,
        'surrogate_prefix_template': SURROGATE_PREFIX_TEMPLATE,
        'document_template': DOCUMENT_TEMPLATE,
        'query_template': QUERY_TEMPLATE,
        'answer_template': ANSWER_TEMPLATE,
        'suffix_separator': SUFFIX_SEPARATOR,
        'bonferroni_alpha': BONFERRONI_ALPHA,
    },
    'summary': {
        'n_total': len(results),
        'n_valid': int(n_valid),
        'n_excluded_zero_nll': int(n_invalid),
        'nll_means': {
            'bare': float(np.mean(bare)),
            'oracle_truncated': float(np.mean(oracle_trunc)),
            'random_truncated': float(np.mean(random_trunc)),
            'oracle_suffix': float(np.mean(oracle_suffix)),
            'random_suffix': float(np.mean(random_suffix)),
        },
        'nll_stds': {
            'bare': float(np.std(bare)),
            'oracle_truncated': float(np.std(oracle_trunc)),
            'random_truncated': float(np.std(random_trunc)),
            'oracle_suffix': float(np.std(oracle_suffix)),
            'random_suffix': float(np.std(random_suffix)),
        },
        'comparisons': comparison_results,
        'hardness_interaction': hardness_results,
        'cross_mechanism_correlation': {
            'oracle_r': float(r_oracle_cross),
            'oracle_p': float(p_oracle_cross),
            'random_r': float(r_random_cross),
            'random_p': float(p_random_cross),
        },
    },
    'per_sample_results': results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"Final results saved to {FINAL_RESULTS_PATH}")
print(f"File size: {FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB")
print(f"Total samples: {len(results)}")
print(f"Valid samples: {n_valid}")
print(f"\nDone!")

Final results saved to results/exp02/results.json
File size: 3715.0 KB
Total samples: 5000
Valid samples: 4606

Done!
