# Exp 03: Combined Mechanisms + Suffix Length Scaling

## Motivation
Exp 02 found random-suffix (d=0.264) is ~3x stronger than random-truncated (d=0.091), with cross-mechanism correlation r=0.47. Two open questions:
1. **Do they stack?** If partially independent, combining should yield d > 0.264.
2. **Does suffix benefit diminish with passage length?** Random tokens are a smaller fraction of longer caches.

## Conditions (5)

| # | Name | Cache | Tests |
|---|------|-------|-------|
| 1 | **Bare** | `[BOS][doc]` (matched tokenization) | Baseline |
| 2 | **Random-suffix** | `[BOS][doc][\n\nRelated question: ][random]` | Suffix alone (Exp 02 replication) |
| 3 | **Random-truncated** | `[BOS][random\n][doc]` → truncate + RoPE correct | Truncation alone (Exp 01/02 replication) |
| 4 | **Random-combined** | `[BOS][random_prefix\n][doc][\n\nRelated question: ][random_suffix]` → truncate prefix → RoPE correct | Both mechanisms stacked (NEW) |
| 5 | **Separator-only** | `[BOS][doc][\n\nRelated question: ]` (no suffix tokens) | Separator framing vs random token regularization (NEW) |

Conditions 3+4 share the same random prefix text (seed `SEED+idx`). Conditions 2+4 share the same random suffix text (seed `SEED+idx+N`). This isolates the incremental contribution of each mechanism.

## Comparisons (7, Bonferroni alpha = 0.05/7 = 0.0071)

**Combined mechanism:**
- P1: Combined vs Bare — total combined effect
- P2: Combined vs Random-suffix — does truncation add on top of suffix?
- P3: Combined vs Random-truncated — does suffix add on top of truncation?

**Separator decomposition + replication:**
- S1: Random-suffix vs Bare — replication (expect d≈0.264)
- S2: Random-truncated vs Bare — replication (expect d≈0.091)
- S3: Separator-only vs Bare — does the separator framing alone help?
- S4: Random-suffix vs Separator-only — do the random tokens add beyond the separator?

**Additivity test**: paired t-test on `delta_combined - (delta_suffix + delta_truncated)` against zero.

## Length scaling
- Wider passage range: 20-500 words (vs 50-300)
- Continuous regression: `delta ~ beta_0 + beta_1 * doc_token_len`
- With hardness covariate: `delta ~ beta_0 + beta_1 * doc_len + beta_2 * bare_nll`
- Quintile breakdown for all conditions

In [1]:
# Cell 1: Setup — permissions, seeds, output directory
import os
os.umask(0o000)  # Required: two-user environment (jupyter + CLI user)

import sys
import json
import time
import numpy as np
import torch
from pathlib import Path

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Output directory
RESULTS_DIR = Path("results/exp03")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Paths
CHECKPOINT_PATH = RESULTS_DIR / "checkpoint.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

SEED: 42
Results directory: results/exp03
CUDA available: True
GPU: NVIDIA L4
GPU memory: 23.6 GB


In [2]:
# Cell 2: Load model (Mistral-7B 4-bit) — identical to Exp 01/02
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Loading {MODEL_NAME} (4-bit)...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded. dtype={model.dtype}, device={model.device}")
print(f"Vocab size: {len(tokenizer)}")

Loading mistralai/Mistral-7B-Instruct-v0.2 (4-bit)...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]



Model loaded. dtype=torch.float16, device=cuda:0
Vocab size: 32000


In [3]:
# Cell 3: Library imports + config + templates
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.kv_cache import (
    build_kv_cache,
    build_suffix_kv_cache,
    score_answer_with_cache,
    deepcopy_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
)
from lib.data import load_ms_marco, load_evaluation_samples
from lib.analysis import cohens_d
from scipy import stats
from tqdm.auto import tqdm

config = ExperimentConfig(
    model_name=MODEL_NAME,
    num_samples=3000,
    min_passage_words=20,
    max_passage_words=500,
    seed=SEED,
)

# Templates: NO framing to avoid "Document:\n" artifact
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"

# Query/answer format
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"  # Leading space for correct BPE of first word

# Suffix separator
SUFFIX_SEPARATOR = "\n\nRelated question: "

# Checkpoint frequency
CHECKPOINT_EVERY = 50

# Bonferroni-corrected alpha for 7 comparisons
N_COMPARISONS = 7
BONFERRONI_ALPHA = 0.05 / N_COMPARISONS

print("Config:")
print(f"  num_samples: {config.num_samples}")
print(f"  passage words: {config.min_passage_words}-{config.max_passage_words}")
print(f"  surrogate_prefix_template: {repr(SURROGATE_PREFIX_TEMPLATE)}")
print(f"  document_template: {repr(DOCUMENT_TEMPLATE)}")
print(f"  query_template: {repr(QUERY_TEMPLATE)}")
print(f"  answer_template: {repr(ANSWER_TEMPLATE)}")
print(f"  suffix_separator: {repr(SUFFIX_SEPARATOR)}")
print(f"  checkpoint_every: {CHECKPOINT_EVERY}")
print(f"  n_comparisons: {N_COMPARISONS}")
print(f"  bonferroni_alpha: {BONFERRONI_ALPHA:.4f}")

Config:
  num_samples: 3000
  passage words: 20-500
  surrogate_prefix_template: '{surrogate}\n'
  document_template: '{document}'
  query_template: '\nQuery: {query}\nAnswer:'
  answer_template: ' {answer}'
  suffix_separator: '\n\nRelated question: '
  checkpoint_every: 50
  n_comparisons: 7
  bonferroni_alpha: 0.0071


In [4]:
# Cell 4: Load dataset with wider word count filter; print length distribution
dataset = load_ms_marco(config)

# CRITICAL: Set seed immediately before load_evaluation_samples
np.random.seed(SEED)
samples = load_evaluation_samples(dataset, config, require_answer=True)

N = len(samples)
print(f"\nLoaded {N} samples")
print(f"\nExample sample:")
print(f"  Query: {samples[0]['query'][:100]}...")
print(f"  Passage: {samples[0]['passage'][:100]}...")
print(f"  Answer: {samples[0]['answer'][:100]}...")

# Length distribution
word_counts = [len(s['passage'].split()) for s in samples]
word_counts_arr = np.array(word_counts)
print(f"\nPassage word count distribution:")
print(f"  Min: {word_counts_arr.min()}, Max: {word_counts_arr.max()}")
print(f"  Mean: {word_counts_arr.mean():.1f}, Median: {np.median(word_counts_arr):.1f}")
print(f"  Std: {word_counts_arr.std():.1f}")
print(f"  Quintiles: {np.percentile(word_counts_arr, [20, 40, 60, 80]).astype(int)}")

# Histogram of word counts
bins = [20, 40, 60, 80, 100, 120, 150, 200, 300, 500]
print(f"\nWord count bins:")
for i in range(len(bins) - 1):
    count = np.sum((word_counts_arr >= bins[i]) & (word_counts_arr < bins[i+1]))
    print(f"  [{bins[i]:>3}-{bins[i+1]:>3}): {count:>5} ({100*count/N:.1f}%)")
count_last = np.sum(word_counts_arr >= bins[-1])
print(f"  [{bins[-1]:>3}+   ): {count_last:>5} ({100*count_last/N:.1f}%)")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'microsoft/ms_marco' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading microsoft/ms_marco dataset...
Dataset loaded: 10047 samples
Filtering samples...


Filtering:   0%|          | 0/10047 [00:00<?, ?it/s]

Selected 3000 samples

Loaded 3000 samples

Example sample:
  Query: where is the rocky mountains located on a map...
  Passage: The Rocky Mountains, commonly known as the Rockies, are a major mountain range in western North Amer...
  Answer: Range in western North America. and also stretch from the northernmost part of British Columbia, in ...

Passage word count distribution:
  Min: 20, Max: 167
  Mean: 73.2, Median: 74.0
  Std: 26.3
  Quintiles: [47 64 82 96]

Word count bins:
  [ 20- 40):   275 (9.2%)
  [ 40- 60):   812 (27.1%)
  [ 60- 80):   633 (21.1%)
  [ 80-100):   788 (26.3%)
  [100-120):   372 (12.4%)
  [120-150):   112 (3.7%)
  [150-200):     8 (0.3%)
  [200-300):     0 (0.0%)
  [300-500):     0 (0.0%)
  [500+   ):     0 (0.0%)


In [5]:
# Cell 5: generate_random_prefix_text() — copy from Exp 01/02

def generate_random_prefix_text(target_text, tokenizer, seed):
    """
    Generate random text from vocabulary tokens that is length-matched
    (in tokens) to target_text.
    
    Uses a decode->re-encode verification loop to ensure the random
    prefix tokenizes to approximately the expected number of tokens.
    
    Args:
        target_text: Text to match in token length
        tokenizer: The tokenizer
        seed: Random seed for reproducibility
    
    Returns:
        Random text string with same token count as target_text
    """
    # Get target token count (no special tokens — we want content tokens only)
    target_ids = tokenizer.encode(target_text, add_special_tokens=False)
    target_len = len(target_ids)
    
    if target_len == 0:
        return ""
    
    rng = np.random.RandomState(seed)
    vocab_size = len(tokenizer)
    
    # Sample random token IDs from the full vocabulary
    # Exclude special tokens (typically IDs 0-2 for BOS, EOS, UNK)
    min_id = 3  # Skip BOS, EOS, UNK
    random_ids = rng.randint(min_id, vocab_size, size=target_len)
    
    # Decode to text
    random_text = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
    
    # Verification: re-encode and check length
    reencoded = tokenizer.encode(random_text, add_special_tokens=False)
    
    # If lengths don't match after round-trip, truncate or pad
    if len(reencoded) != target_len:
        # Truncate the text and re-decode from exactly target_len tokens
        if len(reencoded) > target_len:
            random_text = tokenizer.decode(reencoded[:target_len], skip_special_tokens=True)
        else:
            # Pad with more random tokens
            extra_needed = target_len - len(reencoded)
            extra_ids = rng.randint(min_id, vocab_size, size=extra_needed)
            extra_text = tokenizer.decode(extra_ids.tolist(), skip_special_tokens=True)
            random_text = random_text + extra_text
            reencoded2 = tokenizer.encode(random_text, add_special_tokens=False)
            if len(reencoded2) > target_len:
                random_text = tokenizer.decode(reencoded2[:target_len], skip_special_tokens=True)
    
    return random_text


# Test the function
test_query = samples[0]['query']
test_random = generate_random_prefix_text(test_query, tokenizer, seed=SEED)

oracle_tokens = tokenizer.encode(test_query, add_special_tokens=False)
random_tokens = tokenizer.encode(test_random, add_special_tokens=False)

print(f"Oracle query: {repr(test_query)}")
print(f"Oracle tokens: {len(oracle_tokens)}")
print(f"Random prefix: {repr(test_random[:80])}...")
print(f"Random tokens: {len(random_tokens)}")
print(f"Length match: {len(oracle_tokens) == len(random_tokens)}")

Oracle query: 'where is the rocky mountains located on a map'
Oracle tokens: 10
Random prefix: 'Restaur Mars didova少 DATA luxwalkshineLevel'...
Random tokens: 10
Length match: True


In [6]:
# Cell 6: BPE diagnostics
# (a) Truncated mismatch confirmation (same as Exp 01/02)
# (b) Suffix passage token consistency
# (c) Combined condition tokenization consistency
# (d) Separator tokenization consistency

print("BPE Boundary Diagnostics")
print("=" * 60)

# --- (a) Truncated mismatch confirmation ---
print("\n(a) TRUNCATED: Independent vs concatenated tokenization")
n_mismatch = 0
n_total = min(100, N)

for i in range(n_total):
    passage = samples[i]['passage']
    query = samples[i]['query']
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_text = oracle_prefix + document_text
    full_ids = tokenizer.encode(full_text, add_special_tokens=True)
    prefix_ids = tokenizer.encode(oracle_prefix, add_special_tokens=True)
    prefix_len = len(prefix_ids)
    doc_from_concat = full_ids[prefix_len:]
    doc_independent = tokenizer.encode(passage, add_special_tokens=True)[1:]  # strip BOS
    if doc_from_concat != doc_independent:
        n_mismatch += 1

print(f"  Tested {n_total} samples")
print(f"  BPE mismatches: {n_mismatch}/{n_total} ({100*n_mismatch/n_total:.0f}%)")
print(f"  -> Matched tokenization required for truncated conditions.")

# --- (b) Suffix passage token consistency ---
print("\n(b) SUFFIX: Passage tokens consistency across different suffixes")
n_suffix_test = min(50, N)
n_passage_consistent = 0

for i in range(n_suffix_test):
    passage = samples[i]['passage']
    query = samples[i]['query']
    random_text = generate_random_prefix_text(query, tokenizer, seed=SEED + i)

    bare_ids = tokenizer.encode(passage, add_special_tokens=True)
    oracle_suffix_text = passage + SUFFIX_SEPARATOR + query
    oracle_suffix_ids = tokenizer.encode(oracle_suffix_text, add_special_tokens=True)
    random_suffix_text = passage + SUFFIX_SEPARATOR + random_text
    random_suffix_ids = tokenizer.encode(random_suffix_text, add_special_tokens=True)

    bare_len = len(bare_ids)
    oracle_passage_part = oracle_suffix_ids[:bare_len]
    random_passage_part = random_suffix_ids[:bare_len]

    if bare_ids == oracle_passage_part == random_passage_part:
        n_passage_consistent += 1

print(f"  Tested {n_suffix_test} samples")
print(f"  Passage tokens consistent: {n_passage_consistent}/{n_suffix_test} ({100*n_passage_consistent/n_suffix_test:.0f}%)")
if n_passage_consistent == n_suffix_test:
    print(f"  -> PASS: Passage tokens are identical regardless of suffix content.")
else:
    n_diff = n_suffix_test - n_passage_consistent
    print(f"  -> WARNING: {n_diff} samples have different passage tokens with different suffixes.")

# --- (c) Combined condition tokenization check ---
print("\n(c) COMBINED: Tokenization consistency check")
n_combined_test = min(50, N)
n_combined_ok = 0

for i in range(n_combined_test):
    passage = samples[i]['passage']
    query = samples[i]['query']
    random_prefix_text = generate_random_prefix_text(query, tokenizer, seed=SEED + i)
    random_suffix_text = generate_random_prefix_text(query, tokenizer, seed=SEED + i + N)
    
    prefix_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_prefix_text)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    suffix_str = SUFFIX_SEPARATOR + random_suffix_text
    
    full_combined_text = prefix_str + document_text + suffix_str
    full_combined_ids = tokenizer.encode(full_combined_text, add_special_tokens=True)
    prefix_ids = tokenizer.encode(prefix_str, add_special_tokens=True)
    prefix_token_len = len(prefix_ids)
    keep_n = len(full_combined_ids) - prefix_token_len
    
    # Verify keep_n > 0 and reasonable
    doc_ids = tokenizer.encode(passage, add_special_tokens=False)
    sep_ids = tokenizer.encode(SUFFIX_SEPARATOR, add_special_tokens=False)
    suffix_ids = tokenizer.encode(random_suffix_text, add_special_tokens=False)
    expected_min = len(doc_ids) + len(sep_ids) + len(suffix_ids) - 5  # allow BPE variation
    expected_max = len(doc_ids) + len(sep_ids) + len(suffix_ids) + 5
    
    if expected_min <= keep_n <= expected_max and keep_n > 0:
        n_combined_ok += 1

print(f"  Tested {n_combined_test} samples")
print(f"  Combined keep_n in expected range: {n_combined_ok}/{n_combined_test} ({100*n_combined_ok/n_combined_test:.0f}%)")
if n_combined_ok == n_combined_test:
    print(f"  -> PASS: Combined condition tokenization is consistent.")
else:
    print(f"  -> WARNING: {n_combined_test - n_combined_ok} samples have unexpected token counts.")

# --- (d) Separator tokenization consistency ---
print("\n(d) SEPARATOR: Tokenization consistency across passages")
n_sep_test = min(50, N)
sep_only_ids_sets = []

for i in range(n_sep_test):
    passage = samples[i]['passage']
    # separator-only: passage + separator (no suffix tokens)
    sep_only_text = passage + SUFFIX_SEPARATOR
    sep_only_ids = tokenizer.encode(sep_only_text, add_special_tokens=True)
    bare_ids = tokenizer.encode(passage, add_special_tokens=True)
    # The separator should add a consistent number of tokens
    sep_only_ids_sets.append(len(sep_only_ids) - len(bare_ids))

sep_extra = np.array(sep_only_ids_sets)
print(f"  Tested {n_sep_test} samples")
print(f"  Separator adds {np.mean(sep_extra):.1f} tokens on average (min={sep_extra.min()}, max={sep_extra.max()})")
if sep_extra.min() == sep_extra.max():
    print(f"  -> PASS: Separator tokenization is perfectly consistent ({sep_extra[0]} tokens).")
else:
    print(f"  -> Note: Separator token count varies slightly due to BPE boundary effects.")
    print(f"    This is expected and does not affect the experiment.")

BPE Boundary Diagnostics

(a) TRUNCATED: Independent vs concatenated tokenization
  Tested 100 samples
  BPE mismatches: 100/100 (100%)
  -> Matched tokenization required for truncated conditions.

(b) SUFFIX: Passage tokens consistency across different suffixes
  Tested 50 samples
  Passage tokens consistent: 50/50 (100%)
  -> PASS: Passage tokens are identical regardless of suffix content.

(c) COMBINED: Tokenization consistency check
  Tested 50 samples
  Combined keep_n in expected range: 50/50 (100%)
  -> PASS: Combined condition tokenization is consistent.

(d) SEPARATOR: Tokenization consistency across passages
  Tested 50 samples
  Separator adds 7.0 tokens on average (min=7, max=7)
  -> PASS: Separator tokenization is perfectly consistent (7 tokens).


In [7]:
# Cell 7: Bare baseline fairness diagnostic
# Compare matched bare [BOS]+doc_ids vs independent build_kv_cache(passage) NLLs

print("Bare Baseline Fairness Diagnostic")
print("=" * 60)
print("Comparing matched bare (from truncated tokenization) vs")
print("independent bare (build_kv_cache(passage)) on 20 samples.")
print()

n_diag = 20
matched_nlls = []
independent_nlls = []

for i in range(n_diag):
    sample = samples[i]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # --- Matched bare: [BOS] + doc_ids from oracle concatenation ---
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_enc = tokenizer(oracle_prefix + document_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)
    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]
    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]

    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    matched_nll = score_answer_with_cache(
        deepcopy_cache(bare_out.past_key_values), bare_len,
        query_prompt, answer_text, model, tokenizer, config)
    matched_nlls.append(matched_nll)

    # --- Independent bare: build_kv_cache(passage) ---
    indep_len, indep_cache = build_kv_cache(passage, model, tokenizer, config)
    independent_nll = score_answer_with_cache(
        deepcopy_cache(indep_cache), indep_len,
        query_prompt, answer_text, model, tokenizer, config)
    independent_nlls.append(independent_nll)

    del bare_out, indep_cache
    torch.cuda.empty_cache()

matched_nlls = np.array(matched_nlls)
independent_nlls = np.array(independent_nlls)
diffs = matched_nlls - independent_nlls

print(f"{'Sample':>8} {'Matched':>10} {'Independent':>12} {'Diff':>10}")
print("-" * 45)
for i in range(n_diag):
    print(f"{i:>8} {matched_nlls[i]:>10.4f} {independent_nlls[i]:>12.4f} {diffs[i]:>10.6f}")

nonzero_mask = (matched_nlls != 0.0) & (independent_nlls != 0.0)
diffs_nonzero = diffs[nonzero_mask]
n_nonzero = np.sum(nonzero_mask)

print(f"\nNon-zero NLL samples: {n_nonzero}/{n_diag}")
print(f"Mean difference: {np.mean(diffs_nonzero):.6f}")
print(f"Mean abs difference: {np.mean(np.abs(diffs_nonzero)):.6f}")
print(f"Max abs difference: {np.max(np.abs(diffs_nonzero)):.6f}")

mean_bias = np.mean(diffs_nonzero)
if abs(mean_bias) < 0.01:
    print(f"\nPASS: Mean systematic bias ({mean_bias:+.4f}) is negligible.")
    print(f"Bare baseline is fair for both mechanisms.")
else:
    print(f"\nCAUTION: Mean systematic bias ({mean_bias:+.4f}) detected.")

Bare Baseline Fairness Diagnostic
Comparing matched bare (from truncated tokenization) vs
independent bare (build_kv_cache(passage)) on 20 samples.

  Sample    Matched  Independent       Diff
---------------------------------------------
       0     1.4131       1.4258  -0.012695
       1     2.2940       2.3565  -0.062500
       2     0.6567       0.6581  -0.001379
       3     0.5826       0.5772   0.005409
       4     2.1058       2.1370  -0.031250
       5     0.1321       0.1327  -0.000579
       6     0.3503       0.3316   0.018663
       7     0.0000       0.0000   0.000000
       8     5.2344       5.1172   0.117188
       9     0.7922       0.8000  -0.007812
      10     1.3027       1.3009   0.001786
      11     2.0703       2.0234   0.046875
      12     2.4172       2.3969   0.020312
      13     0.7026       0.6910   0.011593
      14     1.4946       1.5009  -0.006250
      15     0.2955       0.3209  -0.025391
      16     0.7670       0.7622   0.004755
      17     

In [8]:
# Cell 8: Condition explanation printout (all 5 conditions with concrete examples)
print("=" * 70)
print("EXPERIMENTAL CONDITIONS EXPLAINED")
print("=" * 70)

ex = samples[0]
ex_query = ex['query']
ex_passage = ex['passage'][:80] + "..."
ex_answer = ex['answer'][:60] + "..."
ex_random_prefix = generate_random_prefix_text(ex_query, tokenizer, seed=SEED)
ex_random_suffix = generate_random_prefix_text(ex_query, tokenizer, seed=SEED + N)

print(f"\nExample passage: {repr(ex_passage)}")
print(f"Example query:   {repr(ex_query)}")
print(f"Example answer:  {repr(ex_answer)}")
print(f"Example random prefix: {repr(ex_random_prefix[:60])}...")
print(f"Example random suffix: {repr(ex_random_suffix[:60])}...")

# Show matched tokenization
oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=ex_query)
doc_text = DOCUMENT_TEMPLATE.format(document=ex['passage'])
full_text = oracle_prefix + doc_text
full_ids = tokenizer.encode(full_text, add_special_tokens=True)
prefix_ids = tokenizer.encode(oracle_prefix, add_special_tokens=True)
prefix_len = len(prefix_ids)
doc_ids = full_ids[prefix_len:]

print(f"\n{'_' * 70}")
print("MATCHED TOKENIZATION (for truncated conditions)")
print(f"  Tokenize oracle_prefix + passage together -> {len(full_ids)} tokens")
print(f"  Oracle prefix tokens (with BOS): {prefix_len}")
print(f"  Document tokens (shared by truncated + bare): {len(doc_ids)}")

print(f"\n{'_' * 70}")
print("### CONDITION 1: BARE (baseline) ###")
print(f"  Input IDs:  [BOS] + doc_ids ({1 + len(doc_ids)} tokens)")
print(f"  Key insight: Pure baseline. Same doc tokens as truncated conditions.")

print(f"\n{'_' * 70}")
print("### CONDITION 2: RANDOM-SUFFIX ###")
suffix_random_text = ex['passage'] + SUFFIX_SEPARATOR + ex_random_suffix
suffix_random_ids = tokenizer.encode(suffix_random_text, add_special_tokens=True)
print(f"  Build:  [BOS][passage][separator][random_suffix] ({len(suffix_random_ids)} tokens)")
print(f"  Scoring: Query attends to passage + separator + suffix KV entries")
print(f"  Key insight: Passage KV entries IDENTICAL to bare (causal masking).")
print(f"  Replicates Exp 02 (expected d~0.264).")

print(f"\n{'_' * 70}")
print("### CONDITION 3: RANDOM-TRUNCATED ###")
print(f"  Build:  [BOS][random_prefix\\n][doc_ids]")
print(f"  After:  Truncate prefix -> [BOS] + doc_ids ({1 + len(doc_ids)} tokens) + RoPE correct")
print(f"  Key insight: Value contamination from random prefix.")
print(f"  Replicates Exp 01/02 (expected d~0.091).")

print(f"\n{'_' * 70}")
print("### CONDITION 4: RANDOM-COMBINED (NEW) ###")
prefix_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=ex_random_prefix)
suffix_str = SUFFIX_SEPARATOR + ex_random_suffix
combined_text = prefix_str + doc_text + suffix_str
combined_ids = tokenizer.encode(combined_text, add_special_tokens=True)
prefix_enc = tokenizer.encode(prefix_str, add_special_tokens=True)
keep_n_ex = len(combined_ids) - len(prefix_enc)
print(f"  Build:  [BOS][random_prefix\\n][doc][separator][random_suffix] ({len(combined_ids)} tokens)")
print(f"  After:  Truncate prefix -> [BOS] + doc + sep + suffix ({1 + keep_n_ex} tokens) + RoPE correct")
print(f"  Key insight: Both truncation (value contamination) AND suffix (attention target).")
print(f"  Tests whether the two mechanisms stack.")

print(f"\n{'_' * 70}")
print("### CONDITION 5: SEPARATOR-ONLY (NEW) ###")
sep_only_text = ex['passage'] + SUFFIX_SEPARATOR
sep_only_ids = tokenizer.encode(sep_only_text, add_special_tokens=True)
print(f"  Build:  [BOS][passage][separator] ({len(sep_only_ids)} tokens)")
print(f"  Scoring: Query attends to passage + separator KV entries (no suffix tokens)")
print(f"  Key insight: Isolates separator framing from random token effect.")
print(f"  If S3 sig: separator framing alone helps. If S4 sig: random tokens add beyond separator.")

print(f"\n{'_' * 70}")
print("DESIGN NOTES:")
print(f"  Random PREFIX (cond 3+4): seed SEED+idx (shared)")
print(f"  Random SUFFIX (cond 2+4): seed SEED+idx+N (shared, different from prefix)")
print(f"  Suffix separator: {repr(SUFFIX_SEPARATOR)}")
print(f"  CACHE SAFETY: deepcopy_cache() before every score call.")

EXPERIMENTAL CONDITIONS EXPLAINED

Example passage: 'The Rocky Mountains, commonly known as the Rockies, are a major mountain range i...'
Example query:   'where is the rocky mountains located on a map'
Example answer:  'Range in western North America. and also stretch from the no...'
Example random prefix: 'Restaur Mars didova少 DATA luxwalkshineLevel'...
Example random suffix: 'Douglas algoactualTCদanel associate […] civ keeps'...

______________________________________________________________________
MATCHED TOKENIZATION (for truncated conditions)
  Tokenize oracle_prefix + passage together -> 123 tokens
  Oracle prefix tokens (with BOS): 12
  Document tokens (shared by truncated + bare): 111

______________________________________________________________________
### CONDITION 1: BARE (baseline) ###
  Input IDs:  [BOS] + doc_ids (112 tokens)
  Key insight: Pure baseline. Same doc tokens as truncated conditions.

______________________________________________________________________
#

In [9]:
# Cell 9: Evaluation parameters + checkpoint loading

print(f"Total samples: {N}")

# Check for existing checkpoint
results = []
start_idx = 0

if CHECKPOINT_PATH.exists():
    with open(CHECKPOINT_PATH, 'r') as f:
        checkpoint = json.load(f)

    # Verify checkpoint matches current sample list (resume correctness)
    ckpt_sample_ids = checkpoint.get('sample_queries', [])
    current_sample_ids = [s['query'] for s in samples]

    if ckpt_sample_ids == current_sample_ids:
        results = checkpoint['results']
        start_idx = len(results)
        print(f"Resuming from checkpoint: {start_idx}/{N} samples completed")
    else:
        print("WARNING: Checkpoint sample list doesn't match current samples.")
        print("Starting from scratch.")
        results = []
        start_idx = 0
else:
    print("No checkpoint found. Starting from scratch.")

print(f"Will evaluate samples {start_idx} to {N-1}")
print(f"\nConditions per sample: 5")
print(f"Total condition evaluations: {(N - start_idx) * 5}")

Total samples: 3000
No checkpoint found. Starting from scratch.
Will evaluate samples 0 to 2999

Conditions per sample: 5
Total condition evaluations: 15000


In [10]:
# Cell 10: Main evaluation loop (5 conditions x N samples, with checkpointing)
#
# Conditions:
#   1. Bare — [BOS] + doc_ids (matched tokenization)
#   2. Random-suffix — build_suffix_kv_cache(passage, random_suffix)
#   3. Random-truncated — [BOS][random_prefix\n][doc_ids] -> truncate + RoPE correct
#   4. Random-combined — [BOS][random_prefix\n][doc][sep][random_suffix] -> truncate prefix -> RoPE correct
#   5. Separator-only — build_suffix_kv_cache(passage, "", sep=SUFFIX_SEPARATOR)
#
# Random prefix (cond 3+4): seed SEED+idx
# Random suffix (cond 2+4): seed SEED+idx+N

t_start = time.time()

for idx in tqdm(range(start_idx, N), initial=start_idx, total=N, desc="Evaluating"):
    sample = samples[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # === Step 1: Determine canonical document token IDs (for truncated conditions) ===
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_text = oracle_prefix + document_text

    full_oracle_enc = tokenizer(full_oracle_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)

    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]  # includes BOS

    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    doc_len = doc_ids.shape[1]

    # Generate random text — two independent seeds
    random_prefix_text = generate_random_prefix_text(query, tokenizer, seed=SEED + idx)
    random_suffix_text = generate_random_prefix_text(query, tokenizer, seed=SEED + idx + N)

    # === Condition 1: BARE — [BOS] + doc_ids ===
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]  # = 1 + doc_len
    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    bare_cache = bare_out.past_key_values
    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_cache), bare_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 2: RANDOM-SUFFIX — build_suffix_kv_cache(passage, random_suffix) ===
    random_sfx_len, random_sfx_cache = build_suffix_kv_cache(
        passage, random_suffix_text, model, tokenizer, config, separator=SUFFIX_SEPARATOR)
    random_suffix_nll = score_answer_with_cache(
        deepcopy_cache(random_sfx_cache), random_sfx_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 3: RANDOM-TRUNCATED — [BOS][random_prefix\n][doc_ids] -> truncate ===
    random_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_prefix_text)
    random_prefix_enc = tokenizer(random_prefix, return_tensors="pt",
                                  add_special_tokens=False, padding=False, truncation=False)
    random_prefix_ids = random_prefix_enc['input_ids'].to(config.device)
    random_full_ids = torch.cat([bos_id, random_prefix_ids, doc_ids], dim=1)
    random_prefix_len = 1 + random_prefix_ids.shape[1]  # BOS + prefix tokens

    with torch.no_grad():
        random_trunc_out = model(input_ids=random_full_ids,
                                 attention_mask=torch.ones_like(random_full_ids),
                                 use_cache=True, return_dict=True)
    random_trunc_cache = extract_and_truncate_cache_with_bos(random_trunc_out.past_key_values, doc_len)
    random_trunc_len = 1 + doc_len
    correct_rope_positions_with_bos(random_trunc_cache, random_prefix_len - 1, model)
    random_trunc_nll = score_answer_with_cache(
        deepcopy_cache(random_trunc_cache), random_trunc_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 4: RANDOM-COMBINED ===
    # Build full text: prefix\n + passage + separator + suffix
    prefix_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_prefix_text)
    suffix_str = SUFFIX_SEPARATOR + random_suffix_text
    full_combined_text = prefix_str + document_text + suffix_str

    full_combined_enc = tokenizer(full_combined_text, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    full_combined_ids = full_combined_enc['input_ids'].to(config.device)
    full_combined_len = full_combined_ids.shape[1]

    # Determine prefix length for truncation boundary
    prefix_enc = tokenizer(prefix_str, return_tensors="pt",
                           add_special_tokens=True, padding=False, truncation=False)
    prefix_token_len = prefix_enc['input_ids'].shape[1]  # includes BOS

    # keep_n = everything after the prefix (doc + sep + suffix tokens)
    keep_n = full_combined_len - prefix_token_len

    with torch.no_grad():
        combined_out = model(input_ids=full_combined_ids,
                             attention_mask=torch.ones_like(full_combined_ids),
                             use_cache=True, return_dict=True)

    # Truncate: keep [BOS] + last keep_n positions
    combined_cache = extract_and_truncate_cache_with_bos(combined_out.past_key_values, keep_n)
    combined_len = 1 + keep_n  # BOS + doc + sep + suffix

    # RoPE correct: shift = prefix_token_len - 1 (prefix tokens minus BOS)
    correct_rope_positions_with_bos(combined_cache, prefix_token_len - 1, model)

    combined_nll = score_answer_with_cache(
        deepcopy_cache(combined_cache), combined_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 5: SEPARATOR-ONLY — build_suffix_kv_cache(passage, "") ===
    sep_only_len, sep_only_cache = build_suffix_kv_cache(
        passage, "", model, tokenizer, config, separator=SUFFIX_SEPARATOR)
    separator_only_nll = score_answer_with_cache(
        deepcopy_cache(sep_only_cache), sep_only_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # Record all 5 NLLs + cache lengths + precomputed deltas
    result = {
        'idx': idx,
        'bare_nll': bare_nll,
        'random_suffix_nll': random_suffix_nll,
        'random_trunc_nll': random_trunc_nll,
        'combined_nll': combined_nll,
        'separator_only_nll': separator_only_nll,
        'bare_len': bare_len,
        'random_suffix_len': random_sfx_len,
        'random_trunc_len': random_trunc_len,
        'combined_len': combined_len,
        'separator_only_len': sep_only_len,
        'doc_len': doc_len,
        'passage_word_count': len(passage.split()),
        # Precomputed deltas (positive = first condition has LOWER NLL = better)
        # P1: Combined vs Bare
        'delta_combined_vs_bare': bare_nll - combined_nll,
        # P2: Combined vs Random-suffix
        'delta_combined_vs_suffix': random_suffix_nll - combined_nll,
        # P3: Combined vs Random-truncated
        'delta_combined_vs_trunc': random_trunc_nll - combined_nll,
        # S1: Random-suffix vs Bare
        'delta_suffix_vs_bare': bare_nll - random_suffix_nll,
        # S2: Random-truncated vs Bare
        'delta_trunc_vs_bare': bare_nll - random_trunc_nll,
        # S3: Separator-only vs Bare
        'delta_seponly_vs_bare': bare_nll - separator_only_nll,
        # S4: Random-suffix vs Separator-only
        'delta_suffix_vs_seponly': separator_only_nll - random_suffix_nll,
    }
    results.append(result)

    # GPU memory management
    del bare_cache, bare_out, random_sfx_cache, random_trunc_cache, random_trunc_out
    del combined_cache, combined_out, sep_only_cache
    torch.cuda.empty_cache()

    # Checkpoint
    if (idx + 1) % CHECKPOINT_EVERY == 0 or idx == N - 1:
        checkpoint_data = {
            'results': results,
            'sample_queries': [s['query'] for s in samples],
            'completed': len(results),
            'total': N,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_PATH, 'w') as f:
            json.dump(checkpoint_data, f)

        elapsed = time.time() - t_start
        samples_done = idx - start_idx + 1
        rate = samples_done / elapsed if elapsed > 0 else 0
        remaining = (N - idx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint at {idx+1}/{N} | Rate: {rate:.1f} samples/s | ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEvaluation complete: {len(results)} samples in {elapsed_total/60:.1f} minutes")

Evaluating:   0%|          | 0/3000 [00:00<?, ?it/s]

  Checkpoint at 50/3000 | Rate: 0.5 samples/s | ETA: 103.9 min
  Checkpoint at 100/3000 | Rate: 0.5 samples/s | ETA: 103.1 min
  Checkpoint at 150/3000 | Rate: 0.5 samples/s | ETA: 101.5 min
  Checkpoint at 200/3000 | Rate: 0.5 samples/s | ETA: 100.3 min
  Checkpoint at 250/3000 | Rate: 0.5 samples/s | ETA: 98.8 min
  Checkpoint at 300/3000 | Rate: 0.5 samples/s | ETA: 96.8 min
  Checkpoint at 350/3000 | Rate: 0.5 samples/s | ETA: 95.2 min
  Checkpoint at 400/3000 | Rate: 0.5 samples/s | ETA: 93.4 min
  Checkpoint at 450/3000 | Rate: 0.5 samples/s | ETA: 91.6 min
  Checkpoint at 500/3000 | Rate: 0.5 samples/s | ETA: 89.8 min
  Checkpoint at 550/3000 | Rate: 0.5 samples/s | ETA: 88.1 min
  Checkpoint at 600/3000 | Rate: 0.5 samples/s | ETA: 86.4 min
  Checkpoint at 650/3000 | Rate: 0.5 samples/s | ETA: 84.7 min
  Checkpoint at 700/3000 | Rate: 0.5 samples/s | ETA: 82.9 min
  Checkpoint at 750/3000 | Rate: 0.5 samples/s | ETA: 81.1 min
  Checkpoint at 800/3000 | Rate: 0.5 samples/s | ETA

In [11]:
# Cell 11: Cache length diagnostics
print("Cache Length Diagnostics")
print("=" * 60)

# --- Truncated match check ---
print("\n(a) TRUNCATED: All cache lengths match bare?")
bare_lens = [r['bare_len'] for r in results]
random_trunc_lens = [r['random_trunc_len'] for r in results]
doc_lens = [r['doc_len'] for r in results]

all_trunc_match = all(
    b == rt == 1 + d
    for b, rt, d in zip(bare_lens, random_trunc_lens, doc_lens)
)
print(f"  All truncated cache lengths match (bare == random-trunc == 1+doc): {all_trunc_match}")
if all_trunc_match:
    print("  PASS: Matched tokenization guarantees identical cache lengths.")
else:
    mismatches = sum(1 for b, rt in zip(bare_lens, random_trunc_lens) if b != rt)
    print(f"  WARNING: {mismatches} samples have mismatched truncated lengths!")

# --- Suffix length distribution ---
print("\n(b) SUFFIX: Cache length distribution")
random_sfx_lens = np.array([r['random_suffix_len'] for r in results])
sep_only_lens = np.array([r['separator_only_len'] for r in results])
combined_lens = np.array([r['combined_len'] for r in results])
bare_lens_arr = np.array(bare_lens)

random_sfx_extra = random_sfx_lens - bare_lens_arr
sep_only_extra = sep_only_lens - bare_lens_arr
combined_extra = combined_lens - bare_lens_arr

print(f"  Random-suffix extra tokens (vs bare):  mean={np.mean(random_sfx_extra):.1f}, min={np.min(random_sfx_extra)}, max={np.max(random_sfx_extra)}")
print(f"  Separator-only extra tokens (vs bare):  mean={np.mean(sep_only_extra):.1f}, min={np.min(sep_only_extra)}, max={np.max(sep_only_extra)}")
print(f"  Combined extra tokens (vs bare):  mean={np.mean(combined_extra):.1f}, min={np.min(combined_extra)}, max={np.max(combined_extra)}")

print(f"\n  Expected: combined_extra ~= suffix_extra (both have sep + suffix tokens)")
combined_vs_sfx_diff = combined_extra - random_sfx_extra
print(f"  Combined - Suffix extra diff: mean={np.mean(combined_vs_sfx_diff):.2f}, min={np.min(combined_vs_sfx_diff)}, max={np.max(combined_vs_sfx_diff)}")

Cache Length Diagnostics

(a) TRUNCATED: All cache lengths match bare?
  All truncated cache lengths match (bare == random-trunc == 1+doc): True
  PASS: Matched tokenization guarantees identical cache lengths.

(b) SUFFIX: Cache length distribution
  Random-suffix extra tokens (vs bare):  mean=13.7, min=7, max=29
  Separator-only extra tokens (vs bare):  mean=6.9, min=5, max=9
  Combined extra tokens (vs bare):  mean=13.9, min=8, max=29

  Expected: combined_extra ~= suffix_extra (both have sep + suffix tokens)
  Combined - Suffix extra diff: mean=0.12, min=-2, max=2


In [12]:
# Cell 12: Primary analysis — filter zeros, compute all 7 comparisons with Bonferroni
print("=" * 70)
print("PRIMARY ANALYSIS — 7 COMPARISONS WITH BONFERRONI CORRECTION")
print("=" * 70)

bare_nlls_raw = np.array([r['bare_nll'] for r in results])
random_suffix_nlls_raw = np.array([r['random_suffix_nll'] for r in results])
random_trunc_nlls_raw = np.array([r['random_trunc_nll'] for r in results])
combined_nlls_raw = np.array([r['combined_nll'] for r in results])
seponly_nlls_raw = np.array([r['separator_only_nll'] for r in results])

# Sanity checks
for name, arr in [('bare', bare_nlls_raw), ('random_suffix', random_suffix_nlls_raw),
                   ('random_trunc', random_trunc_nlls_raw), ('combined', combined_nlls_raw),
                   ('separator_only', seponly_nlls_raw)]:
    assert not np.any(np.isnan(arr)), f"NaN in {name} NLLs!"

# Filter out degenerate samples where ANY NLL is 0.0
valid_mask = (
    (bare_nlls_raw != 0.0) &
    (random_suffix_nlls_raw != 0.0) &
    (random_trunc_nlls_raw != 0.0) &
    (combined_nlls_raw != 0.0) &
    (seponly_nlls_raw != 0.0)
)
n_invalid = np.sum(~valid_mask)
n_valid = np.sum(valid_mask)
print(f"Total samples: {len(results)}")
print(f"Excluded (zero NLL from single-token answers): {n_invalid}")
print(f"Valid samples for analysis: {n_valid}")
print(f"Bonferroni-corrected alpha: {BONFERRONI_ALPHA:.4f} (0.05 / {N_COMPARISONS})\n")

bare = bare_nlls_raw[valid_mask]
random_suffix = random_suffix_nlls_raw[valid_mask]
random_trunc = random_trunc_nlls_raw[valid_mask]
combined = combined_nlls_raw[valid_mask]
sep_only = seponly_nlls_raw[valid_mask]

# NLL summary
print(f"{'Condition':<20} {'Mean NLL':>10} {'Std':>10} {'Median':>10}")
print("-" * 55)
for name, arr in [('Bare', bare), ('Random-suffix', random_suffix),
                   ('Random-truncated', random_trunc), ('Combined', combined),
                   ('Separator-only', sep_only)]:
    print(f"{name:<20} {np.mean(arr):>10.4f} {np.std(arr):>10.4f} {np.median(arr):>10.4f}")

# All 7 comparisons
# Convention: positive delta = first condition is BETTER (lower NLL)
comparisons = [
    # Primary — Combined mechanism
    ('P1: Combined vs Bare', bare - combined, 'Total combined effect?'),
    ('P2: Combined vs Random-suffix', random_suffix - combined, 'Does truncation add on top of suffix?'),
    ('P3: Combined vs Random-trunc', random_trunc - combined, 'Does suffix add on top of truncation?'),
    # Secondary — Separator decomposition + replication
    ('S1: Random-suffix vs Bare', bare - random_suffix, 'Replication (expect d~0.264)?'),
    ('S2: Random-trunc vs Bare', bare - random_trunc, 'Replication (expect d~0.091)?'),
    ('S3: Separator-only vs Bare', bare - sep_only, 'Separator framing alone?'),
    ('S4: Random-suffix vs Sep-only', sep_only - random_suffix, 'Random tokens beyond separator?'),
]

print(f"\n{'_' * 90}")
print("PAIRED COMPARISONS (positive delta = first condition has LOWER NLL = better)")
print(f"{'_' * 90}")
print(f"\n{'Comparison':<35} {'Mean D':>8} {'d':>8} {'Win%':>7} {'t':>8} {'p':>12} {'Sig':>5}")
print("-" * 88)

comparison_results = {}
for name, delta, question in comparisons:
    d = cohens_d(delta)
    win_rate = np.mean(delta > 0) * 100
    t_stat, p_val = stats.ttest_1samp(delta, 0)
    sig = "***" if p_val < 0.001 else "**" if p_val < BONFERRONI_ALPHA else "*" if p_val < 0.05 else "ns"
    bonf_sig = p_val < BONFERRONI_ALPHA
    print(f"{name:<35} {np.mean(delta):>8.4f} {d:>8.3f} {win_rate:>6.1f}% {t_stat:>8.2f} {p_val:>11.2e} {sig:>5}")
    comparison_results[name] = {
        'mean_delta': float(np.mean(delta)),
        'cohens_d': float(d),
        'win_rate': float(win_rate / 100),
        't_stat': float(t_stat),
        'p_value': float(p_val),
        'bonferroni_significant': bool(bonf_sig),
        'question': question,
    }

print(f"\nSignificance levels: *** p<0.001, ** p<{BONFERRONI_ALPHA:.4f} (Bonferroni), * p<0.05, ns = not significant")
print(f"Cohen's d interpretation: |d|<0.2 = negligible, 0.2-0.5 = small, 0.5-0.8 = medium, >0.8 = large")

PRIMARY ANALYSIS — 7 COMPARISONS WITH BONFERRONI CORRECTION
Total samples: 3000
Excluded (zero NLL from single-token answers): 248
Valid samples for analysis: 2752
Bonferroni-corrected alpha: 0.0071 (0.05 / 7)

Condition              Mean NLL        Std     Median
-------------------------------------------------------
Bare                     1.1258     1.5961     0.6097
Random-suffix            1.0344     1.5293     0.5359
Random-truncated         1.0753     1.5520     0.5795
Combined                 1.0453     1.5569     0.5357
Separator-only           1.0255     1.4248     0.5520

__________________________________________________________________________________________
PAIRED COMPARISONS (positive delta = first condition has LOWER NLL = better)
__________________________________________________________________________________________

Comparison                            Mean D        d    Win%        t            p   Sig
----------------------------------------------------------

In [13]:
# Cell 13: Summary table with verdicts
print("=" * 70)
print("SUMMARY TABLE WITH VERDICTS")
print("=" * 70)
print(f"\nN = {n_valid} valid samples (of {len(results)} total)")
print(f"Bonferroni-corrected alpha = {BONFERRONI_ALPHA:.4f}\n")

for name, delta, question in comparisons:
    cr = comparison_results[name]
    d = cr['cohens_d']
    p = cr['p_value']
    bonf = cr['bonferroni_significant']
    win = cr['win_rate'] * 100

    if bonf:
        if cr['mean_delta'] > 0:
            verdict = "YES — significant benefit (survives Bonferroni)"
        else:
            verdict = "YES — significant HARM (survives Bonferroni)"
    elif p < 0.05:
        if cr['mean_delta'] > 0:
            verdict = "Suggestive benefit (p<0.05 but fails Bonferroni)"
        else:
            verdict = "Suggestive harm (p<0.05 but fails Bonferroni)"
    else:
        verdict = "NO — not significant"

    print(f"Q: {question}")
    print(f"  {name}")
    print(f"  D = {cr['mean_delta']:+.4f}, d = {d:+.3f}, Win% = {win:.1f}%, p = {p:.2e}")
    print(f"  -> {verdict}")
    print()

SUMMARY TABLE WITH VERDICTS

N = 2752 valid samples (of 3000 total)
Bonferroni-corrected alpha = 0.0071

Q: Total combined effect?
  P1: Combined vs Bare
  D = +0.0806, d = +0.209, Win% = 63.6%, p = 2.31e-27
  -> YES — significant benefit (survives Bonferroni)

Q: Does truncation add on top of suffix?
  P2: Combined vs Random-suffix
  D = -0.0108, d = -0.047, Win% = 50.1%, p = 1.39e-02
  -> Suggestive harm (p<0.05 but fails Bonferroni)

Q: Does suffix add on top of truncation?
  P3: Combined vs Random-trunc
  D = +0.0300, d = +0.104, Win% = 57.7%, p = 5.24e-08
  -> YES — significant benefit (survives Bonferroni)

Q: Replication (expect d~0.264)?
  S1: Random-suffix vs Bare
  D = +0.0914, d = +0.270, Win% = 66.9%, p = 6.14e-44
  -> YES — significant benefit (survives Bonferroni)

Q: Replication (expect d~0.091)?
  S2: Random-trunc vs Bare
  D = +0.0505, d = +0.160, Win% = 61.7%, p = 7.13e-17
  -> YES — significant benefit (survives Bonferroni)

Q: Separator framing alone?
  S3: Separato

In [14]:
# Cell 14: Additivity test
# Test whether combined effect = suffix effect + truncation effect
# If sub-additive (expected given r=0.47), the remainder is negative

print("=" * 70)
print("ADDITIVITY TEST")
print("=" * 70)
print("\nH0: delta_combined = delta_suffix + delta_truncated")
print("(positive remainder = super-additive, negative = sub-additive)\n")

delta_combined = bare - combined
delta_suffix = bare - random_suffix
delta_trunc = bare - random_trunc

# Remainder: how much the combined effect exceeds the sum of individual effects
remainder = delta_combined - (delta_suffix + delta_trunc)

t_stat_add, p_val_add = stats.ttest_1samp(remainder, 0)
d_add = cohens_d(remainder)

print(f"Mean delta_combined:   {np.mean(delta_combined):+.4f}")
print(f"Mean delta_suffix:     {np.mean(delta_suffix):+.4f}")
print(f"Mean delta_truncated:  {np.mean(delta_trunc):+.4f}")
print(f"Sum of individuals:    {np.mean(delta_suffix) + np.mean(delta_trunc):+.4f}")
print(f"Remainder (combined - sum): {np.mean(remainder):+.4f}")
print(f"\nCohen's d of remainder: {d_add:+.3f}")
print(f"t = {t_stat_add:.2f}, p = {p_val_add:.2e}")

if p_val_add < 0.05:
    if np.mean(remainder) > 0:
        print("\n-> SUPER-ADDITIVE: Combined effect exceeds sum of parts.")
    else:
        print("\n-> SUB-ADDITIVE: Combined effect is less than sum of parts.")
        print("   (Expected given r=0.47 cross-mechanism correlation from Exp 02)")
else:
    print("\n-> ADDITIVE: Combined effect ≈ sum of parts (within noise).")

# Cross-mechanism correlation in this dataset
r_cross, p_cross = stats.pearsonr(delta_suffix, delta_trunc)
print(f"\nCross-mechanism correlation: r = {r_cross:.3f}, p = {p_cross:.2e}")
print(f"  (Exp 02 found r = 0.47)")

ADDITIVITY TEST

H0: delta_combined = delta_suffix + delta_truncated
(positive remainder = super-additive, negative = sub-additive)

Mean delta_combined:   +0.0806
Mean delta_suffix:     +0.0914
Mean delta_truncated:  +0.0505
Sum of individuals:    +0.1419
Remainder (combined - sum): -0.0613

Cohen's d of remainder: -0.223
t = -11.72, p = 5.34e-31

-> SUB-ADDITIVE: Combined effect is less than sum of parts.
   (Expected given r=0.47 cross-mechanism correlation from Exp 02)

Cross-mechanism correlation: r = 0.410, p = 3.82e-112
  (Exp 02 found r = 0.47)


In [15]:
# Cell 15: Length scaling — regression + quintile breakdown
import statsmodels.api as sm

print("=" * 70)
print("LENGTH SCALING ANALYSIS")
print("=" * 70)

# Get per-sample data (aligned with valid_mask)
doc_lens_valid = np.array([r['doc_len'] for r in results])[valid_mask]
word_counts_valid = np.array([r['passage_word_count'] for r in results])[valid_mask]

# ========== Continuous regression: delta_suffix ~ doc_token_len ==========
print("\n--- (a) Regression: delta_suffix ~ doc_token_len ---")
X = sm.add_constant(doc_lens_valid)
model_ols = sm.OLS(delta_suffix, X).fit()
print(f"  beta_0 (intercept) = {model_ols.params[0]:.4f} (p = {model_ols.pvalues[0]:.2e})")
print(f"  beta_1 (doc_len)   = {model_ols.params[1]:.6f} (p = {model_ols.pvalues[1]:.2e})")
print(f"  R^2 = {model_ols.rsquared:.4f}")
if model_ols.pvalues[1] < 0.05:
    direction = "DECREASING" if model_ols.params[1] < 0 else "INCREASING"
    print(f"  -> Significant: suffix benefit {direction} with passage length")
else:
    print(f"  -> Not significant: no length dependence detected")

# ========== With hardness covariate ==========
print("\n--- (b) Regression: delta_suffix ~ doc_len + bare_nll ---")
X2 = sm.add_constant(np.column_stack([doc_lens_valid, bare]))
model_ols2 = sm.OLS(delta_suffix, X2).fit()
print(f"  beta_0 (intercept) = {model_ols2.params[0]:.4f} (p = {model_ols2.pvalues[0]:.2e})")
print(f"  beta_1 (doc_len)   = {model_ols2.params[1]:.6f} (p = {model_ols2.pvalues[1]:.2e})")
print(f"  beta_2 (bare_nll)  = {model_ols2.params[2]:.4f} (p = {model_ols2.pvalues[2]:.2e})")
print(f"  R^2 = {model_ols2.rsquared:.4f}")

# ========== Suffix fraction regression ==========
print("\n--- (c) Regression: delta_suffix ~ suffix_fraction ---")
random_sfx_lens_valid = np.array([r['random_suffix_len'] for r in results])[valid_mask]
suffix_fraction = (random_sfx_lens_valid - np.array(bare_lens)[valid_mask]) / random_sfx_lens_valid
X3 = sm.add_constant(suffix_fraction)
model_ols3 = sm.OLS(delta_suffix, X3).fit()
print(f"  beta_0 (intercept)        = {model_ols3.params[0]:.4f} (p = {model_ols3.pvalues[0]:.2e})")
print(f"  beta_1 (suffix_fraction)  = {model_ols3.params[1]:.4f} (p = {model_ols3.pvalues[1]:.2e})")
print(f"  R^2 = {model_ols3.rsquared:.4f}")

# ========== Quintile breakdown (by doc token length) ==========
print("\n--- (d) Quintile breakdown by passage length ---")
quintile_edges = np.percentile(doc_lens_valid, [0, 20, 40, 60, 80, 100])
q_labels = ['Q1 (short)', 'Q2', 'Q3', 'Q4', 'Q5 (long)']

# Header
print(f"\n{'Bin':<14} {'N':>5} {'Tokens':>12} {'Words':>12}", end='')
for cname in ['Suffix', 'Trunc', 'Combined', 'Sep-only']:
    print(f" {'d_'+cname:>10} {'Win%':>6}", end='')
print()
print("-" * 120)

quintile_results = {}
for qi in range(5):
    lo = quintile_edges[qi]
    hi = quintile_edges[qi + 1]
    if qi == 4:  # last bin inclusive
        mask = (doc_lens_valid >= lo) & (doc_lens_valid <= hi)
    else:
        mask = (doc_lens_valid >= lo) & (doc_lens_valid < hi)
    n_q = np.sum(mask)
    mean_tokens = np.mean(doc_lens_valid[mask])
    mean_words = np.mean(word_counts_valid[mask])

    row = f"{q_labels[qi]:<14} {n_q:>5} {mean_tokens:>8.0f}tok {mean_words:>8.0f}wds"
    
    q_data = {}
    for cname, delta_arr in [('Suffix', delta_suffix), ('Trunc', delta_trunc),
                              ('Combined', delta_combined), ('Sep-only', bare - sep_only)]:
        dq = delta_arr[mask]
        d_q = cohens_d(dq)
        win_q = np.mean(dq > 0) * 100
        row += f" {d_q:>10.3f} {win_q:>5.1f}%"
        q_data[cname] = {'d': float(d_q), 'win': float(win_q / 100), 'n': int(n_q),
                         'mean_delta': float(np.mean(dq))}
    
    print(row)
    quintile_results[q_labels[qi]] = q_data

# Monotonicity: is suffix d decreasing across quintiles?
suffix_ds = [quintile_results[ql]['Suffix']['d'] for ql in q_labels]
print(f"\nSuffix d across quintiles: {[f'{d:.3f}' for d in suffix_ds]}")
if all(suffix_ds[i] >= suffix_ds[i+1] for i in range(4)):
    print("-> MONOTONICALLY DECREASING (consistent with length dilution hypothesis)")
else:
    print("-> NOT monotonically decreasing")

LENGTH SCALING ANALYSIS

--- (a) Regression: delta_suffix ~ doc_token_len ---
  beta_0 (intercept) = 0.1965 (p = 4.19e-25)
  beta_1 (doc_len)   = -0.000949 (p = 3.02e-09)
  R^2 = 0.0127
  -> Significant: suffix benefit DECREASING with passage length

--- (b) Regression: delta_suffix ~ doc_len + bare_nll ---
  beta_0 (intercept) = 0.1094 (p = 6.10e-09)
  beta_1 (doc_len)   = -0.000796 (p = 2.03e-07)
  beta_2 (bare_nll)  = 0.0623 (p = 2.49e-56)
  R^2 = 0.0986

--- (c) Regression: delta_suffix ~ suffix_fraction ---
  beta_0 (intercept)        = -0.0105 (p = 5.67e-01)
  beta_1 (suffix_fraction)  = 0.8453 (p = 3.36e-09)
  R^2 = 0.0126

--- (d) Quintile breakdown by passage length ---

Bin                N       Tokens        Words   d_Suffix   Win%    d_Trunc   Win% d_Combined   Win% d_Sep-only   Win%
------------------------------------------------------------------------------------------------------------------------
Q1 (short)       539       59tok       41wds      0.369  75.3%      0.2

In [16]:
# Cell 16: Separator decomposition by length
print("=" * 70)
print("SEPARATOR DECOMPOSITION BY LENGTH")
print("=" * 70)
print("\nTwo competing hypotheses:")
print("  A (Separator framing): sep-only captures most benefit, constant with length")
print("  B (Random token regularization): sep-only ~0, suffix decreases with length")

delta_seponly = bare - sep_only
delta_suffix_vs_sep = sep_only - random_suffix  # positive = suffix better than sep-only

# Regression: separator-only benefit vs length
print("\n--- Regression: delta_sep_only ~ doc_len ---")
X_sep = sm.add_constant(doc_lens_valid)
model_sep = sm.OLS(delta_seponly, X_sep).fit()
print(f"  beta_1 (doc_len) = {model_sep.params[1]:.6f} (p = {model_sep.pvalues[1]:.2e})")
print(f"  R^2 = {model_sep.rsquared:.4f}")

# Regression: random tokens beyond separator vs length
print("\n--- Regression: delta_suffix_vs_sep ~ doc_len ---")
model_sfx_sep = sm.OLS(delta_suffix_vs_sep, X_sep).fit()
print(f"  beta_1 (doc_len) = {model_sfx_sep.params[1]:.6f} (p = {model_sfx_sep.pvalues[1]:.2e})")
print(f"  R^2 = {model_sfx_sep.rsquared:.4f}")

# Quintile breakdown
print("\n--- Quintile breakdown ---")
print(f"{'Bin':<14} {'N':>5} {'d_SepOnly':>10} {'Win%':>7} {'d_Sfx-Sep':>10} {'Win%':>7}")
print("-" * 60)

for qi in range(5):
    lo = quintile_edges[qi]
    hi = quintile_edges[qi + 1]
    if qi == 4:
        mask = (doc_lens_valid >= lo) & (doc_lens_valid <= hi)
    else:
        mask = (doc_lens_valid >= lo) & (doc_lens_valid < hi)
    n_q = np.sum(mask)
    
    d_sep = cohens_d(delta_seponly[mask])
    win_sep = np.mean(delta_seponly[mask] > 0) * 100
    d_sfx_sep = cohens_d(delta_suffix_vs_sep[mask])
    win_sfx_sep = np.mean(delta_suffix_vs_sep[mask] > 0) * 100
    
    print(f"{q_labels[qi]:<14} {n_q:>5} {d_sep:>10.3f} {win_sep:>6.1f}% {d_sfx_sep:>10.3f} {win_sfx_sep:>6.1f}%")

# Verdict
cr_s3 = comparison_results.get('S3: Separator-only vs Bare', {})
cr_s4 = comparison_results.get('S4: Random-suffix vs Sep-only', {})
s3_sig = cr_s3.get('bonferroni_significant', False)
s4_sig = cr_s4.get('bonferroni_significant', False)

print(f"\nInterpretation matrix:")
print(f"  S3 significant (separator alone)? {s3_sig}")
print(f"  S4 significant (random beyond sep)? {s4_sig}")
if s3_sig and not s4_sig:
    print("  -> Hypothesis A: SEPARATOR FRAMING explains the effect")
elif not s3_sig and s4_sig:
    print("  -> Hypothesis B: RANDOM TOKEN REGULARIZATION is the mechanism")
elif s3_sig and s4_sig:
    print("  -> BOTH contribute: separator framing AND random tokens")
else:
    print("  -> Neither alone significant: full suffix needed (interaction effect)")

SEPARATOR DECOMPOSITION BY LENGTH

Two competing hypotheses:
  A (Separator framing): sep-only captures most benefit, constant with length
  B (Random token regularization): sep-only ~0, suffix decreases with length

--- Regression: delta_sep_only ~ doc_len ---
  beta_1 (doc_len) = -0.001070 (p = 4.27e-07)
  R^2 = 0.0093

--- Regression: delta_suffix_vs_sep ~ doc_len ---
  beta_1 (doc_len) = 0.000121 (p = 5.66e-01)
  R^2 = 0.0001

--- Quintile breakdown ---
Bin                N  d_SepOnly    Win%  d_Sfx-Sep    Win%
------------------------------------------------------------
Q1 (short)       539      0.303   69.0%     -0.012   46.6%
Q2               542      0.299   65.9%     -0.051   48.5%
Q3               550      0.188   66.4%     -0.016   45.8%
Q4               562      0.189   63.9%      0.006   49.3%
Q5 (long)        559      0.118   66.2%     -0.017   45.1%

Interpretation matrix:
  S3 significant (separator alone)? True
  S4 significant (random beyond sep)? False
  -> Hypothesi

In [17]:
# Cell 17: Hardness interaction (all 4 non-bare conditions)
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt

print("=" * 70)
print("HARDNESS INTERACTION ANALYSIS")
print("=" * 70)

conditions_hardness = [
    ('Random-suffix', delta_suffix),
    ('Random-truncated', delta_trunc),
    ('Combined', delta_combined),
    ('Separator-only', delta_seponly),
]

print(f"\nPearson r of bare NLL (hardness) with condition benefit:")
print(f"{'Condition':<25} {'r':>8} {'p':>12}")
print("-" * 48)

hardness_results = {}
for name, delta in conditions_hardness:
    r, p = stats.pearsonr(bare, delta)
    print(f"{name:<25} {r:>8.4f} {p:>12.2e}")
    hardness_results[name] = {'r': float(r), 'p': float(p)}

# Quartile breakdown
quartiles = np.percentile(bare, [25, 50, 75])
q_labels_h = ['Q1 (easy)', 'Q2', 'Q3', 'Q4 (hard)']
q_bounds = [(-np.inf, quartiles[0]), (quartiles[0], quartiles[1]),
            (quartiles[1], quartiles[2]), (quartiles[2], np.inf)]

print(f"\nQuartile breakdown (by bare NLL hardness):")
header = f"{'Quartile':<12} {'N':>5} {'Bare':>8}"
for name, _ in conditions_hardness:
    short = name[:8]
    header += f" {'D_'+short:>10} {'d_'+short:>8}"
print(header)
print("-" * len(header))

for label, (lo, hi) in zip(q_labels_h, q_bounds):
    mask = (bare > lo) & (bare <= hi)
    n_q = np.sum(mask)
    mean_bare = np.mean(bare[mask])
    row = f"{label:<12} {n_q:>5} {mean_bare:>8.3f}"
    for name, delta in conditions_hardness:
        dq = delta[mask]
        row += f" {np.mean(dq):>10.4f} {cohens_d(dq):>8.3f}"
    print(row)

HARDNESS INTERACTION ANALYSIS

Pearson r of bare NLL (hardness) with condition benefit:
Condition                        r            p
------------------------------------------------
Random-suffix               0.2995     3.97e-58
Random-truncated            0.2369     2.11e-36
Combined                    0.2213     7.27e-32
Separator-only              0.5026    4.38e-176

Quartile breakdown (by bare NLL hardness):
Quartile         N     Bare D_Random-s d_Random-s D_Random-t d_Random-t D_Combined d_Combined D_Separato d_Separato
-------------------------------------------------------------------------------------------------------------------
Q1 (easy)      688    0.079     0.0002    0.003     0.0055    0.099     0.0010    0.016    -0.0369   -0.258
Q2             688    0.408     0.0397    0.321     0.0257    0.202     0.0384    0.240     0.0004    0.002
Q3             688    0.941     0.0978    0.477     0.0518    0.267     0.0905    0.364     0.0840    0.239
Q4 (hard)      688    3

In [18]:
# Cell 18: Plots — delta distributions, length scaling, hardness scatter

# === Plot 1: Delta distributions (all 7 comparisons) ===
fig, axes = plt.subplots(2, 4, figsize=(22, 10))
axes_flat = axes.flat

plot_configs = [
    ('P1: Combined vs Bare', bare - combined, 'steelblue'),
    ('P2: Combined vs Suffix', random_suffix - combined, 'forestgreen'),
    ('P3: Combined vs Trunc', random_trunc - combined, 'darkorange'),
    ('S1: Suffix vs Bare', bare - random_suffix, 'mediumpurple'),
    ('S2: Trunc vs Bare', bare - random_trunc, 'teal'),
    ('S3: Sep-only vs Bare', bare - sep_only, 'crimson'),
    ('S4: Suffix vs Sep-only', sep_only - random_suffix, 'goldenrod'),
]

for i, (title, delta, color) in enumerate(plot_configs):
    ax = axes_flat[i]
    cr = comparison_results[comparisons[i][0]]
    ax.hist(delta, bins=80, color=color, alpha=0.7, edgecolor='black', linewidth=0.3)
    ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
    ax.axvline(x=np.mean(delta), color='black', linestyle='-', alpha=0.8,
               label=f'mean={np.mean(delta):.4f}, d={cr["cohens_d"]:+.3f}')
    ax.set_xlabel('Delta NLL')
    ax.set_ylabel('Count')
    ax.set_title(title, fontsize=10)
    ax.legend(fontsize=7)

# Hide last subplot
axes_flat[7].set_visible(False)

plt.suptitle('Delta NLL Distributions — All 7 Comparisons', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'delta_distributions.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'delta_distributions.png'}")

# === Plot 2: Length scaling ===
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# (a) Suffix benefit vs doc length
axes[0].scatter(doc_lens_valid, delta_suffix, alpha=0.1, s=5, c='mediumpurple')
axes[0].axhline(y=0, color='red', linestyle='--', alpha=0.5)
z = np.polyfit(doc_lens_valid, delta_suffix, 1)
x_range = np.linspace(doc_lens_valid.min(), doc_lens_valid.max(), 100)
axes[0].plot(x_range, np.polyval(z, x_range), 'r-', alpha=0.8,
             label=f'beta={z[0]:.5f}, p={model_ols.pvalues[1]:.2e}')
axes[0].set_xlabel('Document token length')
axes[0].set_ylabel('Suffix benefit (+ = helps)')
axes[0].set_title('Suffix Benefit vs Passage Length')
axes[0].legend()

# (b) Separator-only benefit vs doc length
axes[1].scatter(doc_lens_valid, delta_seponly, alpha=0.1, s=5, c='crimson')
axes[1].axhline(y=0, color='red', linestyle='--', alpha=0.5)
z2 = np.polyfit(doc_lens_valid, delta_seponly, 1)
axes[1].plot(x_range, np.polyval(z2, x_range), 'r-', alpha=0.8,
             label=f'beta={z2[0]:.5f}')
axes[1].set_xlabel('Document token length')
axes[1].set_ylabel('Sep-only benefit (+ = helps)')
axes[1].set_title('Separator-Only Benefit vs Passage Length')
axes[1].legend()

# (c) Quintile d values for all conditions
x_pos = np.arange(5)
width = 0.2
for ci, cname in enumerate(['Suffix', 'Trunc', 'Combined', 'Sep-only']):
    ds = [quintile_results[ql][cname]['d'] for ql in q_labels]
    colors = ['mediumpurple', 'teal', 'steelblue', 'crimson']
    axes[2].bar(x_pos + ci * width, ds, width, label=cname, alpha=0.8, color=colors[ci])
axes[2].set_xticks(x_pos + 1.5 * width)
axes[2].set_xticklabels(q_labels, rotation=15)
axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[2].set_ylabel("Cohen's d")
axes[2].set_title('Effect Size by Length Quintile')
axes[2].legend(fontsize=8)

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'length_scaling.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'length_scaling.png'}")

# === Plot 3: Hardness scatter (2x2 grid) ===
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

scatter_configs = [
    ('Random-suffix', delta_suffix, 'mediumpurple'),
    ('Random-truncated', delta_trunc, 'teal'),
    ('Combined', delta_combined, 'steelblue'),
    ('Separator-only', delta_seponly, 'crimson'),
]

for ax, (name, delta, color) in zip(axes.flat, scatter_configs):
    r, p = hardness_results[name]['r'], hardness_results[name]['p']
    ax.scatter(bare, delta, alpha=0.12, s=5, c=color)
    ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
    z = np.polyfit(bare, delta, 1)
    x_range = np.linspace(bare.min(), bare.max(), 100)
    ax.plot(x_range, np.polyval(z, x_range), 'r-', alpha=0.8, label=f'r={r:.3f}, p={p:.1e}')
    ax.set_xlabel('Bare NLL (hardness)')
    ax.set_ylabel(f'{name} benefit (+ = helps)')
    ax.set_title(f'Hardness vs {name} Benefit')
    ax.legend()

plt.suptitle('Hardness Interaction — All 4 Non-Bare Conditions', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'hardness_interaction.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'hardness_interaction.png'}")

Plot saved to results/exp03/delta_distributions.png
Plot saved to results/exp03/length_scaling.png
Plot saved to results/exp03/hardness_interaction.png


In [19]:
# Cell 19: Save results JSON

final_results = {
    'experiment': 'exp03_combined_and_length',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': config.model_name,
        'num_samples': config.num_samples,
        'seed': SEED,
        'min_passage_words': config.min_passage_words,
        'max_passage_words': config.max_passage_words,
        'surrogate_prefix_template': SURROGATE_PREFIX_TEMPLATE,
        'document_template': DOCUMENT_TEMPLATE,
        'query_template': QUERY_TEMPLATE,
        'answer_template': ANSWER_TEMPLATE,
        'suffix_separator': SUFFIX_SEPARATOR,
        'n_comparisons': N_COMPARISONS,
        'bonferroni_alpha': BONFERRONI_ALPHA,
    },
    'summary': {
        'n_total': len(results),
        'n_valid': int(n_valid),
        'n_excluded_zero_nll': int(n_invalid),
        'nll_means': {
            'bare': float(np.mean(bare)),
            'random_suffix': float(np.mean(random_suffix)),
            'random_truncated': float(np.mean(random_trunc)),
            'combined': float(np.mean(combined)),
            'separator_only': float(np.mean(sep_only)),
        },
        'nll_stds': {
            'bare': float(np.std(bare)),
            'random_suffix': float(np.std(random_suffix)),
            'random_truncated': float(np.std(random_trunc)),
            'combined': float(np.std(combined)),
            'separator_only': float(np.std(sep_only)),
        },
        'comparisons': comparison_results,
        'additivity_test': {
            'mean_remainder': float(np.mean(remainder)),
            'cohens_d': float(d_add),
            't_stat': float(t_stat_add),
            'p_value': float(p_val_add),
            'cross_mechanism_r': float(r_cross),
            'cross_mechanism_p': float(p_cross),
        },
        'hardness_interaction': hardness_results,
        'length_scaling': {
            'suffix_vs_doclen': {
                'beta_1': float(model_ols.params[1]),
                'p_value': float(model_ols.pvalues[1]),
                'r_squared': float(model_ols.rsquared),
            },
            'suffix_vs_doclen_with_hardness': {
                'beta_1_doclen': float(model_ols2.params[1]),
                'p_doclen': float(model_ols2.pvalues[1]),
                'beta_2_bare_nll': float(model_ols2.params[2]),
                'p_bare_nll': float(model_ols2.pvalues[2]),
                'r_squared': float(model_ols2.rsquared),
            },
            'suffix_vs_fraction': {
                'beta_1': float(model_ols3.params[1]),
                'p_value': float(model_ols3.pvalues[1]),
                'r_squared': float(model_ols3.rsquared),
            },
            'quintiles': quintile_results,
        },
    },
    'per_sample_results': results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"Final results saved to {FINAL_RESULTS_PATH}")
print(f"File size: {FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB")
print(f"Total samples: {len(results)}")
print(f"Valid samples: {n_valid}")
print(f"\nDone!")

Final results saved to results/exp03/results.json
File size: 2309.5 KB
Total samples: 3000
Valid samples: 2752

Done!
