# Experiment 01: First-Principles Surrogate Priming Test

**Goal**: Determine whether surrogate priming has an effect on NLL scoring, and whether the effect is structural (any prefix) or semantic (relevant prefix).

## Three Conditions

| Condition | How cache is built | What it tests |
|-----------|-------------------|---------------|
| **Bare** | `[BOS] + doc_ids` (no prefix) | Baseline — no prefix at all |
| **Random prefix** | `[BOS][random_tokens][doc_ids]` → truncate + RoPE correct | Does *any* prefix alter values in a way that affects scoring? |
| **Oracle prefix** | `[BOS][oracle_query][doc_ids]` → truncate + RoPE correct | Does *semantically relevant* prefix content matter? |

## Key Comparisons
- **Bare vs Random**: Does the truncation/RoPE-correction process itself (with arbitrary content) change NLL?
- **Bare vs Oracle**: Does the actual query as prefix help?
- **Random vs Oracle**: Is there a semantic signal beyond structural noise?

## Critical Design Decisions (avoiding all prior bugs)
1. **No template framing**: `surrogate_prefix_template="{surrogate}\n"`, `document_template="{document}"` — avoids "Document:\n" artifact
2. **Matched tokenization**: Doc token IDs are extracted from the oracle concatenated tokenization and reused across all 3 conditions — eliminates BPE boundary mismatch entirely (`\n` alone gives 0% clean boundaries)
3. **`deepcopy_cache()` before every `score_answer_with_cache()` call**: Prevents cache mutation bug
4. **Random tokens from vocabulary**: Deterministic seed, length-matched to oracle, with decode→re-encode verification
5. **`np.random.seed(SEED)` immediately before `load_evaluation_samples()`**: Deterministic sample selection
6. **Checkpoint every 50 samples**: With full sample list saved for resume correctness
7. **Query format**: `"\nQuery: {query}\nAnswer:"` with answer `" {answer}"` (leading space for correct BPE)
8. **GPU memory management**: `del` caches + `torch.cuda.empty_cache()` after each sample

In [1]:
# Cell 1: Setup — permissions, seeds, output directory
import os
os.umask(0o000)  # Required: two-user environment (jupyter + CLI user)

import sys
import json
import time
import numpy as np
import torch
from pathlib import Path

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Output directory
RESULTS_DIR = Path("results/exp01")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Paths
CHECKPOINT_PATH = RESULTS_DIR / "checkpoint.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

SEED: 42
Results directory: results/exp01
CUDA available: True
GPU: NVIDIA L4
GPU memory: 23.6 GB


In [2]:
# Cell 2: Load model (Mistral-7B 4-bit)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Loading {MODEL_NAME} (4-bit)...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded. dtype={model.dtype}, device={model.device}")
print(f"Vocab size: {len(tokenizer)}")

Loading mistralai/Mistral-7B-Instruct-v0.2 (4-bit)...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded. dtype=torch.float16, device=cuda:0
Vocab size: 32000


In [3]:
# Cell 3: Library imports + config
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.kv_cache import (
    build_kv_cache,
    build_truncated_kv_cache_corrected,
    score_answer_with_cache,
    deepcopy_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
)
from lib.data import load_ms_marco, load_evaluation_samples
from lib.analysis import cohens_d
from scipy import stats
from tqdm.auto import tqdm

config = ExperimentConfig(
    model_name=MODEL_NAME,
    num_samples=2500,
    seed=SEED,
)

# Templates: NO framing to avoid "Document:\n" artifact
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"

# Query/answer format
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"  # Leading space for correct BPE of first word

# Checkpoint frequency
CHECKPOINT_EVERY = 50

print("Config:")
print(f"  num_samples: {config.num_samples}")
print(f"  passage words: {config.min_passage_words}-{config.max_passage_words}")
print(f"  surrogate_prefix_template: {repr(SURROGATE_PREFIX_TEMPLATE)}")
print(f"  document_template: {repr(DOCUMENT_TEMPLATE)}")
print(f"  query_template: {repr(QUERY_TEMPLATE)}")
print(f"  answer_template: {repr(ANSWER_TEMPLATE)}")
print(f"  checkpoint_every: {CHECKPOINT_EVERY}")

Config:
  num_samples: 2500
  passage words: 50-300
  surrogate_prefix_template: '{surrogate}\n'
  document_template: '{document}'
  query_template: '\nQuery: {query}\nAnswer:'
  answer_template: ' {answer}'
  checkpoint_every: 50


In [4]:
# Cell 4: Load dataset (seed immediately before for determinism)
dataset = load_ms_marco(config)

# CRITICAL: Set seed immediately before load_evaluation_samples
# to ensure deterministic sample selection regardless of prior random state
np.random.seed(SEED)
samples = load_evaluation_samples(dataset, config, require_answer=True)

print(f"\nLoaded {len(samples)} samples")
print(f"\nExample sample:")
print(f"  Query: {samples[0]['query'][:100]}...")
print(f"  Passage: {samples[0]['passage'][:100]}...")
print(f"  Answer: {samples[0]['answer'][:100]}...")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'microsoft/ms_marco' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading microsoft/ms_marco dataset...
Dataset loaded: 10047 samples
Filtering samples...


Filtering:   0%|          | 0/10047 [00:00<?, ?it/s]

Selected 2500 samples

Loaded 2500 samples

Example sample:
  Query: what temperature should it be to plant grass seeds...
  Passage: Usually planted in the early fall, cool-season grass seeds prefer daytime temperatures ranging from ...
  Answer: Between 50deg and 65deg F...


In [5]:
# Cell 5: generate_random_prefix_text() function + test

def generate_random_prefix_text(target_text, tokenizer, seed):
    """
    Generate random text from vocabulary tokens that is length-matched
    (in tokens) to target_text.
    
    Uses a decode->re-encode verification loop to ensure the random
    prefix tokenizes to exactly the expected number of tokens.
    
    Args:
        target_text: Text to match in token length
        tokenizer: The tokenizer
        seed: Random seed for reproducibility
    
    Returns:
        Random text string with same token count as target_text
    """
    # Get target token count (no special tokens — we want content tokens only)
    target_ids = tokenizer.encode(target_text, add_special_tokens=False)
    target_len = len(target_ids)
    
    if target_len == 0:
        return ""
    
    rng = np.random.RandomState(seed)
    vocab_size = len(tokenizer)
    
    # Sample random token IDs from the full vocabulary
    # Exclude special tokens (typically IDs 0-2 for BOS, EOS, UNK)
    min_id = 3  # Skip BOS, EOS, UNK
    random_ids = rng.randint(min_id, vocab_size, size=target_len)
    
    # Decode to text
    random_text = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
    
    # Verification: re-encode and check length
    reencoded = tokenizer.encode(random_text, add_special_tokens=False)
    
    # If lengths don't match after round-trip, truncate or pad
    if len(reencoded) != target_len:
        # Truncate the text and re-decode from exactly target_len tokens
        if len(reencoded) > target_len:
            random_text = tokenizer.decode(reencoded[:target_len], skip_special_tokens=True)
        else:
            # Pad with more random tokens
            extra_needed = target_len - len(reencoded)
            extra_ids = rng.randint(min_id, vocab_size, size=extra_needed)
            extra_text = tokenizer.decode(extra_ids.tolist(), skip_special_tokens=True)
            random_text = random_text + extra_text
            reencoded2 = tokenizer.encode(random_text, add_special_tokens=False)
            if len(reencoded2) > target_len:
                random_text = tokenizer.decode(reencoded2[:target_len], skip_special_tokens=True)
    
    return random_text


# Test the function
test_query = samples[0]['query']
test_random = generate_random_prefix_text(test_query, tokenizer, seed=SEED)

oracle_tokens = tokenizer.encode(test_query, add_special_tokens=False)
random_tokens = tokenizer.encode(test_random, add_special_tokens=False)

print(f"Oracle query: {repr(test_query)}")
print(f"Oracle tokens: {len(oracle_tokens)}")
print(f"Random prefix: {repr(test_random[:80])}...")
print(f"Random tokens: {len(random_tokens)}")
print(f"Length match: {len(oracle_tokens) == len(random_tokens)}")

Oracle query: 'what temperature should it be to plant grass seeds'
Oracle tokens: 9
Random prefix: 'Restaur Mars didova少 DATA luxwalkshine'...
Random tokens: 9
Length match: True


In [6]:
# Cell 6: BPE boundary diagnostic — why matched tokenization is necessary
#
# SentencePiece adds a leading "▁" to the first token of a string. After "\n",
# the BPE merge decisions change, so tokenizing "prefix\npassage" and "passage"
# independently produces DIFFERENT token sequences for the passage.
#
# Fix: extract doc token IDs from the concatenated tokenization, then reuse
# those exact IDs for bare/oracle/random caches. This guarantees all three
# conditions operate on identical document tokens.

print("BPE Boundary Diagnostic")
print("=" * 60)

n_mismatch = 0
n_total = min(100, len(samples))

for i in range(n_total):
    passage = samples[i]['passage']
    query = samples[i]['query']

    # Concatenated tokenization (what build_truncated_kv_cache_corrected does)
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_text = oracle_prefix + document_text
    full_ids = tokenizer.encode(full_text, add_special_tokens=True)
    prefix_ids = tokenizer.encode(oracle_prefix, add_special_tokens=True)
    prefix_len = len(prefix_ids)
    doc_from_concat = full_ids[prefix_len:]

    # Independent tokenization (what build_kv_cache does)
    doc_independent = tokenizer.encode(passage, add_special_tokens=True)[1:]  # strip BOS

    if doc_from_concat != doc_independent:
        n_mismatch += 1

print(f"Tested {n_total} samples")
print(f"BPE mismatches: {n_mismatch}/{n_total} ({100*n_mismatch/n_total:.0f}%)")
print()
if n_mismatch > 0:
    print("As expected, independent tokenization does NOT match concatenated tokenization.")
    print("The main loop uses MATCHED tokenization: doc token IDs are extracted from")
    print("the oracle concatenation, then reused for bare and random conditions.")
    print("This eliminates BPE mismatch entirely.")
else:
    print("Surprisingly, all boundaries are clean. Matched tokenization is still used")
    print("for safety, but independent tokenization would also work here.")

BPE Boundary Diagnostic
Tested 100 samples
BPE mismatches: 100/100 (100%)

As expected, independent tokenization does NOT match concatenated tokenization.
The main loop uses MATCHED tokenization: doc token IDs are extracted from
the oracle concatenation, then reused for bare and random conditions.
This eliminates BPE mismatch entirely.


In [7]:
# Cell 7: Condition explanation printout
print("=" * 70)
print("EXPERIMENTAL CONDITIONS EXPLAINED")
print("=" * 70)

ex = samples[0]
ex_query = ex['query']
ex_passage = ex['passage'][:80] + "..."
ex_answer = ex['answer'][:60] + "..."
ex_random = generate_random_prefix_text(ex_query, tokenizer, seed=SEED)

print(f"\nExample passage: {repr(ex_passage)}")
print(f"Example query:   {repr(ex_query)}")
print(f"Example answer:  {repr(ex_answer)}")

# Show the matched tokenization approach
oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=ex_query)
doc_text = DOCUMENT_TEMPLATE.format(document=ex['passage'])
full_text = oracle_prefix + doc_text
full_ids = tokenizer.encode(full_text, add_special_tokens=True)
prefix_ids = tokenizer.encode(oracle_prefix, add_special_tokens=True)
prefix_len = len(prefix_ids)
doc_ids = full_ids[prefix_len:]

print(f"\n{'─' * 70}")
print("MATCHED TOKENIZATION (avoids BPE boundary mismatch)")
print(f"  Tokenize oracle_prefix + passage together → {len(full_ids)} tokens")
print(f"  Oracle prefix tokens (with BOS): {prefix_len}")
print(f"  Document tokens (shared by ALL conditions): {len(doc_ids)}")

print(f"\n{'─' * 70}")
print("### CONDITION 1: BARE (baseline) ###")
print(f"  Input IDs:  [BOS] + doc_ids ({1 + len(doc_ids)} tokens)")
print(f"  Key insight: Pure baseline. Same doc tokens, no prefix, no RoPE correction.")

print(f"\n{'─' * 70}")
print("### CONDITION 2: RANDOM PREFIX (structural control) ###")
random_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=ex_random)
random_prefix_ids = tokenizer.encode(random_prefix, add_special_tokens=False)
print(f"  Random text:  {repr(ex_random[:60])}...")
print(f"  Input IDs:  [BOS] + random_prefix_ids + doc_ids ({1 + len(random_prefix_ids) + len(doc_ids)} tokens)")
print(f"  After truncation: [BOS] + doc_ids ({1 + len(doc_ids)} tokens) + RoPE correction")
print(f"  Key insight: Tests if ANY prefix affects NLL via value contamination.")

print(f"\n{'─' * 70}")
print("### CONDITION 3: ORACLE PREFIX (semantic signal) ###")
print(f"  Oracle query: {repr(ex_query)}")
print(f"  Input IDs:  [BOS] + oracle_prefix_ids + doc_ids ({len(full_ids)} tokens)")
print(f"  After truncation: [BOS] + doc_ids ({1 + len(doc_ids)} tokens) + RoPE correction")
print(f"  Key insight: Tests if RELEVANT prefix adds semantic signal beyond structural noise.")

print(f"\n{'─' * 70}")
print("ALL conditions use IDENTICAL doc_ids → differences are purely from prefix contamination.")
print("CACHE SAFETY: deepcopy_cache() before every score call.")

EXPERIMENTAL CONDITIONS EXPLAINED

Example passage: 'Usually planted in the early fall, cool-season grass seeds prefer daytime temper...'
Example query:   'what temperature should it be to plant grass seeds'
Example answer:  'Between 50deg and 65deg F...'

──────────────────────────────────────────────────────────────────────
MATCHED TOKENIZATION (avoids BPE boundary mismatch)
  Tokenize oracle_prefix + passage together → 134 tokens
  Oracle prefix tokens (with BOS): 11
  Document tokens (shared by ALL conditions): 123

──────────────────────────────────────────────────────────────────────
### CONDITION 1: BARE (baseline) ###
  Input IDs:  [BOS] + doc_ids (124 tokens)
  Key insight: Pure baseline. Same doc tokens, no prefix, no RoPE correction.

──────────────────────────────────────────────────────────────────────
### CONDITION 2: RANDOM PREFIX (structural control) ###
  Random text:  'Restaur Mars didova少 DATA luxwalkshine'...
  Input IDs:  [BOS] + random_prefix_ids + doc_ids (134 to

In [8]:
# Cell 8: Evaluation parameters + checkpoint loading

N = len(samples)
print(f"Total samples: {N}")

# Check for existing checkpoint
results = []
start_idx = 0

if CHECKPOINT_PATH.exists():
    with open(CHECKPOINT_PATH, 'r') as f:
        checkpoint = json.load(f)
    
    # Verify checkpoint matches current sample list (resume correctness)
    ckpt_sample_ids = checkpoint.get('sample_queries', [])
    current_sample_ids = [s['query'] for s in samples]
    
    if ckpt_sample_ids == current_sample_ids:
        results = checkpoint['results']
        start_idx = len(results)
        print(f"Resuming from checkpoint: {start_idx}/{N} samples completed")
    else:
        print("WARNING: Checkpoint sample list doesn't match current samples.")
        print("Starting from scratch.")
        results = []
        start_idx = 0
else:
    print("No checkpoint found. Starting from scratch.")

print(f"Will evaluate samples {start_idx} to {N-1}")

Total samples: 2500
No checkpoint found. Starting from scratch.
Will evaluate samples 0 to 2499


In [9]:
# Cell 9: Main evaluation loop (3 conditions × N samples, with checkpointing)
#
# KEY FIX: Matched tokenization. For each sample we:
# 1. Tokenize oracle_prefix + passage together to get the canonical doc_ids
# 2. Build bare cache from [BOS] + doc_ids
# 3. Build oracle full cache from [BOS][oracle_prefix_ids][doc_ids], truncate + RoPE correct
# 4. Build random full cache from [BOS][random_prefix_ids][doc_ids], truncate + RoPE correct
# All three caches use IDENTICAL document tokens — differences are purely from prefix.

t_start = time.time()

for idx in tqdm(range(start_idx, N), initial=start_idx, total=N, desc="Evaluating"):
    sample = samples[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # === Step 1: Determine canonical document token IDs ===
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_text = oracle_prefix + document_text

    full_oracle_enc = tokenizer(full_oracle_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)

    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]  # includes BOS

    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    doc_len = doc_ids.shape[1]

    # === Condition 1: BARE — [BOS] + doc_ids ===
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]  # = 1 + doc_len
    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    bare_cache = bare_out.past_key_values
    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_cache), bare_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 2: RANDOM PREFIX — [BOS][random_prefix][doc_ids] → truncate ===
    random_text = generate_random_prefix_text(query, tokenizer, seed=SEED + idx)
    random_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_text)
    random_prefix_enc = tokenizer(random_prefix, return_tensors="pt",
                                  add_special_tokens=False, padding=False, truncation=False)
    random_prefix_ids = random_prefix_enc['input_ids'].to(config.device)
    random_full_ids = torch.cat([bos_id, random_prefix_ids, doc_ids], dim=1)
    random_prefix_len = 1 + random_prefix_ids.shape[1]  # BOS + prefix tokens

    with torch.no_grad():
        random_out = model(input_ids=random_full_ids,
                           attention_mask=torch.ones_like(random_full_ids),
                           use_cache=True, return_dict=True)
    random_cache = extract_and_truncate_cache_with_bos(random_out.past_key_values, doc_len)
    random_len = 1 + doc_len
    correct_rope_positions_with_bos(random_cache, random_prefix_len - 1, model)

    random_nll = score_answer_with_cache(
        deepcopy_cache(random_cache), random_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # === Condition 3: ORACLE PREFIX — already have full_oracle_ids ===
    with torch.no_grad():
        oracle_out = model(input_ids=full_oracle_ids,
                           attention_mask=torch.ones_like(full_oracle_ids),
                           use_cache=True, return_dict=True)
    oracle_cache = extract_and_truncate_cache_with_bos(oracle_out.past_key_values, doc_len)
    oracle_len = 1 + doc_len
    correct_rope_positions_with_bos(oracle_cache, oracle_prefix_len - 1, model)

    oracle_nll = score_answer_with_cache(
        deepcopy_cache(oracle_cache), oracle_len, query_prompt, answer_text,
        model, tokenizer, config
    )

    # Record result
    result = {
        'idx': idx,
        'bare_nll': bare_nll,
        'random_nll': random_nll,
        'oracle_nll': oracle_nll,
        'bare_len': bare_len,
        'random_len': random_len,
        'oracle_len': oracle_len,
        'doc_len': doc_len,
        'delta_random_vs_bare': bare_nll - random_nll,     # positive = random better
        'delta_oracle_vs_bare': bare_nll - oracle_nll,     # positive = oracle better
        'delta_oracle_vs_random': random_nll - oracle_nll, # positive = oracle better
    }
    results.append(result)

    # GPU memory management
    del bare_cache, random_cache, oracle_cache, bare_out, random_out, oracle_out
    torch.cuda.empty_cache()

    # Checkpoint
    if (idx + 1) % CHECKPOINT_EVERY == 0 or idx == N - 1:
        checkpoint_data = {
            'results': results,
            'sample_queries': [s['query'] for s in samples],
            'completed': len(results),
            'total': N,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_PATH, 'w') as f:
            json.dump(checkpoint_data, f)

        elapsed = time.time() - t_start
        samples_done = idx - start_idx + 1
        rate = samples_done / elapsed if elapsed > 0 else 0
        remaining = (N - idx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint at {idx+1}/{N} | Rate: {rate:.1f} samples/s | ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEvaluation complete: {len(results)} samples in {elapsed_total/60:.1f} minutes")

Evaluating:   0%|          | 0/2500 [00:00<?, ?it/s]

  Checkpoint at 50/2500 | Rate: 0.8 samples/s | ETA: 53.0 min
  Checkpoint at 100/2500 | Rate: 0.8 samples/s | ETA: 52.0 min
  Checkpoint at 150/2500 | Rate: 0.8 samples/s | ETA: 51.1 min
  Checkpoint at 200/2500 | Rate: 0.8 samples/s | ETA: 50.1 min
  Checkpoint at 250/2500 | Rate: 0.8 samples/s | ETA: 49.2 min
  Checkpoint at 300/2500 | Rate: 0.8 samples/s | ETA: 48.2 min
  Checkpoint at 350/2500 | Rate: 0.8 samples/s | ETA: 47.1 min
  Checkpoint at 400/2500 | Rate: 0.8 samples/s | ETA: 46.0 min
  Checkpoint at 450/2500 | Rate: 0.8 samples/s | ETA: 45.0 min
  Checkpoint at 500/2500 | Rate: 0.8 samples/s | ETA: 43.9 min
  Checkpoint at 550/2500 | Rate: 0.8 samples/s | ETA: 42.9 min
  Checkpoint at 600/2500 | Rate: 0.8 samples/s | ETA: 41.8 min
  Checkpoint at 650/2500 | Rate: 0.8 samples/s | ETA: 40.7 min
  Checkpoint at 700/2500 | Rate: 0.8 samples/s | ETA: 39.6 min
  Checkpoint at 750/2500 | Rate: 0.8 samples/s | ETA: 38.5 min
  Checkpoint at 800/2500 | Rate: 0.8 samples/s | ETA: 37

In [10]:
# Cell 10: Cache length diagnostic
print("Cache Length Diagnostic")
print("=" * 50)

bare_lens = [r['bare_len'] for r in results]
random_lens = [r['random_len'] for r in results]
oracle_lens = [r['oracle_len'] for r in results]

n = len(results)

# Check if results have doc_len (new matched code) or not (old code)
has_doc_len = 'doc_len' in results[0]
if has_doc_len:
    doc_lens = [r['doc_len'] for r in results]
    all_match = all(b == r == o == 1 + d
                    for b, r, o, d in zip(bare_lens, random_lens, oracle_lens, doc_lens))
    print(f"\nSamples: {n}")
    print(f"All cache lengths match (bare == random == oracle == 1+doc): {all_match}")
    if all_match:
        print("PASS: Matched tokenization guarantees identical cache lengths.")
    else:
        mismatches = sum(1 for b, r, o in zip(bare_lens, random_lens, oracle_lens) if not (b == r == o))
        print(f"WARNING: {mismatches} samples have mismatched lengths!")
else:
    # Old results — check how close they are
    exact_bare_rand = sum(1 for b, r in zip(bare_lens, random_lens) if b == r)
    exact_bare_orac = sum(1 for b, o in zip(bare_lens, oracle_lens) if b == o)
    exact_rand_orac = sum(1 for r, o in zip(random_lens, oracle_lens) if r == o)
    max_diff = max(max(abs(b - r) for b, r in zip(bare_lens, random_lens)),
                   max(abs(b - o) for b, o in zip(bare_lens, oracle_lens)))

    print(f"\nSamples: {n}")
    print(f"NOTE: Results from pre-matched-tokenization run (no doc_len key).")
    print(f"  Bare vs Random exact match:  {exact_bare_rand}/{n} ({100*exact_bare_rand/n:.1f}%)")
    print(f"  Bare vs Oracle exact match:  {exact_bare_orac}/{n} ({100*exact_bare_orac/n:.1f}%)")
    print(f"  Random vs Oracle exact match: {exact_rand_orac}/{n} ({100*exact_rand_orac/n:.1f}%)")
    print(f"  Max length difference: {max_diff} token(s)")
    print()
    print("Random and Oracle always match (both from concatenated tokenization).")
    print("Bare differs by ±1 token due to BPE boundary mismatch.")
    print("To get perfectly matched results, re-run cells 8 → 9 with the new code.")

Cache Length Diagnostic

Samples: 2500
All cache lengths match (bare == random == oracle == 1+doc): True
PASS: Matched tokenization guarantees identical cache lengths.


In [11]:
# Cell 11: Primary analysis — NLL means, paired comparisons
print("=" * 70)
print("PRIMARY ANALYSIS")
print("=" * 70)

bare_nlls_raw = np.array([r['bare_nll'] for r in results])
random_nlls_raw = np.array([r['random_nll'] for r in results])
oracle_nlls_raw = np.array([r['oracle_nll'] for r in results])

# Sanity checks
assert not np.any(np.isnan(bare_nlls_raw)), "NaN in bare NLLs!"
assert not np.any(np.isnan(random_nlls_raw)), "NaN in random NLLs!"
assert not np.any(np.isnan(oracle_nlls_raw)), "NaN in oracle NLLs!"

# Filter out degenerate samples where any NLL is 0.0
# This happens when the answer is a single token (num_scored = answer_len - 1 = 0)
valid_mask = (bare_nlls_raw != 0.0) & (random_nlls_raw != 0.0) & (oracle_nlls_raw != 0.0)
n_invalid = np.sum(~valid_mask)
print(f"Total samples: {len(results)}")
print(f"Excluded (zero NLL from single-token answers): {n_invalid}")
print(f"Valid samples for analysis: {np.sum(valid_mask)}\n")

bare_nlls = bare_nlls_raw[valid_mask]
random_nlls = random_nlls_raw[valid_mask]
oracle_nlls = oracle_nlls_raw[valid_mask]

# NLL summary
print(f"{'Condition':<20} {'Mean NLL':>10} {'Std':>10} {'Median':>10}")
print("-" * 55)
print(f"{'Bare':<20} {np.mean(bare_nlls):>10.4f} {np.std(bare_nlls):>10.4f} {np.median(bare_nlls):>10.4f}")
print(f"{'Random prefix':<20} {np.mean(random_nlls):>10.4f} {np.std(random_nlls):>10.4f} {np.median(random_nlls):>10.4f}")
print(f"{'Oracle prefix':<20} {np.mean(oracle_nlls):>10.4f} {np.std(oracle_nlls):>10.4f} {np.median(oracle_nlls):>10.4f}")

# Paired differences
delta_random_bare = bare_nlls - random_nlls   # positive = random better
delta_oracle_bare = bare_nlls - oracle_nlls   # positive = oracle better
delta_oracle_random = random_nlls - oracle_nlls  # positive = oracle better

print(f"\n{'─' * 70}")
print("PAIRED COMPARISONS (positive delta = first condition has HIGHER NLL)")
print(f"{'─' * 70}")

comparisons = [
    ("Bare vs Random", delta_random_bare, "Random better?"),
    ("Bare vs Oracle", delta_oracle_bare, "Oracle better?"),
    ("Random vs Oracle", delta_oracle_random, "Oracle better than random?"),
]

print(f"\n{'Comparison':<25} {'Mean Δ':>8} {'d':>8} {'Win%':>7} {'t':>8} {'p':>12}")
print("-" * 72)

for name, delta, question in comparisons:
    d = cohens_d(delta)
    win_rate = np.mean(delta > 0) * 100
    t_stat, p_val = stats.ttest_1samp(delta, 0)
    sig = "***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else "ns"
    print(f"{name:<25} {np.mean(delta):>8.4f} {d:>8.3f} {win_rate:>6.1f}% {t_stat:>8.2f} {p_val:>11.2e} {sig}")

print(f"\nNote: Cohen's d interpretation: |d|<0.2 = negligible, 0.2-0.5 = small, 0.5-0.8 = medium, >0.8 = large")

PRIMARY ANALYSIS
Total samples: 2500
Excluded (zero NLL from single-token answers): 197
Valid samples for analysis: 2303

Condition              Mean NLL        Std     Median
-------------------------------------------------------
Bare                     1.1455     1.5698     0.6206
Random prefix            1.1171     1.5405     0.5923
Oracle prefix            1.1369     1.5553     0.6270

──────────────────────────────────────────────────────────────────────
PAIRED COMPARISONS (positive delta = first condition has HIGHER NLL)
──────────────────────────────────────────────────────────────────────

Comparison                  Mean Δ        d    Win%        t            p
------------------------------------------------------------------------
Bare vs Random              0.0285    0.091   59.5%     4.37    1.31e-05 ***
Bare vs Oracle              0.0086    0.023   50.0%     1.12    2.63e-01 ns
Random vs Oracle           -0.0198   -0.051   44.9%    -2.42    1.54e-02 *

Note: Cohen's d i

In [12]:
# Cell 12: Summary table
print("=" * 70)
print("SUMMARY TABLE")
print("=" * 70)

print(f"\nN = {len(results)} samples\n")

# Build summary
summary_rows = [
    {
        'comparison': 'Bare vs Random',
        'question': 'Does truncation/RoPE process itself change NLL?',
        'mean_delta': np.mean(delta_random_bare),
        'cohens_d': cohens_d(delta_random_bare),
        'win_pct': np.mean(delta_random_bare > 0) * 100,
        'p_value': stats.ttest_1samp(delta_random_bare, 0)[1],
    },
    {
        'comparison': 'Bare vs Oracle',
        'question': 'Does the actual query as prefix help?',
        'mean_delta': np.mean(delta_oracle_bare),
        'cohens_d': cohens_d(delta_oracle_bare),
        'win_pct': np.mean(delta_oracle_bare > 0) * 100,
        'p_value': stats.ttest_1samp(delta_oracle_bare, 0)[1],
    },
    {
        'comparison': 'Random vs Oracle',
        'question': 'Is there semantic signal beyond structural noise?',
        'mean_delta': np.mean(delta_oracle_random),
        'cohens_d': cohens_d(delta_oracle_random),
        'win_pct': np.mean(delta_oracle_random > 0) * 100,
        'p_value': stats.ttest_1samp(delta_oracle_random, 0)[1],
    },
]

for row in summary_rows:
    sig = "***" if row['p_value'] < 0.001 else "**" if row['p_value'] < 0.01 else "*" if row['p_value'] < 0.05 else "ns"
    verdict = "YES" if row['p_value'] < 0.05 else "NO"
    direction = "(helps)" if row['mean_delta'] > 0 else "(hurts)"
    print(f"Q: {row['question']}")
    print(f"  Δ = {row['mean_delta']:+.4f}, d = {row['cohens_d']:+.3f}, Win% = {row['win_pct']:.1f}%, p = {row['p_value']:.2e} {sig}")
    print(f"  Answer: {verdict} {direction if row['p_value'] < 0.05 else ''}")
    print()

SUMMARY TABLE

N = 2500 samples

Q: Does truncation/RoPE process itself change NLL?
  Δ = +0.0285, d = +0.091, Win% = 59.5%, p = 1.31e-05 ***
  Answer: YES (helps)

Q: Does the actual query as prefix help?
  Δ = +0.0086, d = +0.023, Win% = 50.0%, p = 2.63e-01 ns
  Answer: NO 

Q: Is there semantic signal beyond structural noise?
  Δ = -0.0198, d = -0.051, Win% = 44.9%, p = 1.54e-02 *
  Answer: YES (hurts)



In [13]:
# Cell 13: Hardness interaction analysis
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt

print("=" * 70)
print("HARDNESS INTERACTION ANALYSIS")
print("=" * 70)

# Correlation of bare NLL (hardness) with oracle benefit
r_oracle, p_r_oracle = stats.pearsonr(bare_nlls, delta_oracle_bare)
r_random, p_r_random = stats.pearsonr(bare_nlls, delta_random_bare)

print(f"\nCorrelation of bare NLL (hardness) with:")
print(f"  Oracle benefit:  r = {r_oracle:.4f}, p = {p_r_oracle:.2e}")
print(f"  Random benefit:  r = {r_random:.4f}, p = {p_r_random:.2e}")

# Quartile breakdown
quartiles = np.percentile(bare_nlls, [25, 50, 75])
q_labels = ['Q1 (easy)', 'Q2', 'Q3', 'Q4 (hard)']
q_bounds = [(-np.inf, quartiles[0]), (quartiles[0], quartiles[1]),
            (quartiles[1], quartiles[2]), (quartiles[2], np.inf)]

print(f"\nQuartile breakdown (by bare NLL hardness):")
print(f"{'Quartile':<15} {'N':>5} {'Bare NLL':>10} {'Oracle Δ':>10} {'Oracle d':>10} {'Oracle Win%':>12} {'Random Δ':>10}")
print("-" * 80)

for label, (lo, hi) in zip(q_labels, q_bounds):
    mask = (bare_nlls > lo) & (bare_nlls <= hi)
    n_q = np.sum(mask)
    mean_bare = np.mean(bare_nlls[mask])
    oracle_delta_q = delta_oracle_bare[mask]
    random_delta_q = delta_random_bare[mask]
    d_q = cohens_d(oracle_delta_q) if n_q > 1 else 0
    win_q = np.mean(oracle_delta_q > 0) * 100
    print(f"{label:<15} {n_q:>5} {mean_bare:>10.3f} {np.mean(oracle_delta_q):>10.4f} {d_q:>10.3f} {win_q:>11.1f}% {np.mean(random_delta_q):>10.4f}")

# Scatter plot: hardness vs oracle benefit
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].scatter(bare_nlls, delta_oracle_bare, alpha=0.15, s=5, c='steelblue')
axes[0].axhline(y=0, color='red', linestyle='--', alpha=0.5)
z = np.polyfit(bare_nlls, delta_oracle_bare, 1)
p = np.poly1d(z)
x_range = np.linspace(bare_nlls.min(), bare_nlls.max(), 100)
axes[0].plot(x_range, p(x_range), 'r-', alpha=0.8, label=f'r={r_oracle:.3f}')
axes[0].set_xlabel('Bare NLL (hardness)')
axes[0].set_ylabel('Oracle benefit (positive = oracle helps)')
axes[0].set_title('Hardness vs Oracle Benefit')
axes[0].legend()

axes[1].scatter(bare_nlls, delta_random_bare, alpha=0.15, s=5, c='orange')
axes[1].axhline(y=0, color='red', linestyle='--', alpha=0.5)
z2 = np.polyfit(bare_nlls, delta_random_bare, 1)
p2 = np.poly1d(z2)
axes[1].plot(x_range, p2(x_range), 'r-', alpha=0.8, label=f'r={r_random:.3f}')
axes[1].set_xlabel('Bare NLL (hardness)')
axes[1].set_ylabel('Random benefit (positive = random helps)')
axes[1].set_title('Hardness vs Random Prefix Benefit')
axes[1].legend()

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'hardness_interaction.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\nPlot saved to {RESULTS_DIR / 'hardness_interaction.png'}")

HARDNESS INTERACTION ANALYSIS

Correlation of bare NLL (hardness) with:
  Oracle benefit:  r = 0.1569, p = 3.63e-14
  Random benefit:  r = 0.1927, p = 1.08e-20

Quartile breakdown (by bare NLL hardness):
Quartile            N   Bare NLL   Oracle Δ   Oracle d  Oracle Win%   Random Δ
--------------------------------------------------------------------------------
Q1 (easy)         576      0.086    -0.0331     -0.195        37.8%     0.0014
Q2                576      0.417    -0.0213     -0.127        45.7%     0.0132
Q3                575      0.943     0.0114      0.044        56.7%     0.0303
Q4 (hard)         576      3.136     0.0775      0.121        59.7%     0.0690

Plot saved to results/exp01/hardness_interaction.png


In [14]:
# Cell 14: Delta distribution plots

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

plot_configs = [
    ('Bare - Random (+ = random helps)', delta_random_bare, 'steelblue'),
    ('Bare - Oracle (+ = oracle helps)', delta_oracle_bare, 'forestgreen'),
    ('Random - Oracle (+ = oracle > random)', delta_oracle_random, 'darkorange'),
]

for ax, (title, delta, color) in zip(axes, plot_configs):
    ax.hist(delta, bins=80, color=color, alpha=0.7, edgecolor='black', linewidth=0.3)
    ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
    ax.axvline(x=np.mean(delta), color='black', linestyle='-', alpha=0.8,
               label=f'mean={np.mean(delta):.4f}')
    ax.set_xlabel('Delta NLL')
    ax.set_ylabel('Count')
    ax.set_title(title)
    ax.legend()

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'delta_distributions.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'delta_distributions.png'}")

Plot saved to results/exp01/delta_distributions.png


In [15]:
# Cell 15: Save final results JSON

final_results = {
    'experiment': 'exp01_first_principles_priming',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': config.model_name,
        'num_samples': config.num_samples,
        'seed': SEED,
        'min_passage_words': config.min_passage_words,
        'max_passage_words': config.max_passage_words,
        'surrogate_prefix_template': SURROGATE_PREFIX_TEMPLATE,
        'document_template': DOCUMENT_TEMPLATE,
        'query_template': QUERY_TEMPLATE,
        'answer_template': ANSWER_TEMPLATE,
    },
    'summary': {
        'n_samples': len(results),
        'bare_nll_mean': float(np.mean(bare_nlls)),
        'bare_nll_std': float(np.std(bare_nlls)),
        'random_nll_mean': float(np.mean(random_nlls)),
        'random_nll_std': float(np.std(random_nlls)),
        'oracle_nll_mean': float(np.mean(oracle_nlls)),
        'oracle_nll_std': float(np.std(oracle_nlls)),
        'comparisons': {
            'bare_vs_random': {
                'mean_delta': float(np.mean(delta_random_bare)),
                'cohens_d': float(cohens_d(delta_random_bare)),
                'win_rate': float(np.mean(delta_random_bare > 0)),
                't_stat': float(stats.ttest_1samp(delta_random_bare, 0)[0]),
                'p_value': float(stats.ttest_1samp(delta_random_bare, 0)[1]),
            },
            'bare_vs_oracle': {
                'mean_delta': float(np.mean(delta_oracle_bare)),
                'cohens_d': float(cohens_d(delta_oracle_bare)),
                'win_rate': float(np.mean(delta_oracle_bare > 0)),
                't_stat': float(stats.ttest_1samp(delta_oracle_bare, 0)[0]),
                'p_value': float(stats.ttest_1samp(delta_oracle_bare, 0)[1]),
            },
            'random_vs_oracle': {
                'mean_delta': float(np.mean(delta_oracle_random)),
                'cohens_d': float(cohens_d(delta_oracle_random)),
                'win_rate': float(np.mean(delta_oracle_random > 0)),
                't_stat': float(stats.ttest_1samp(delta_oracle_random, 0)[0]),
                'p_value': float(stats.ttest_1samp(delta_oracle_random, 0)[1]),
            },
        },
        'hardness_interaction': {
            'oracle_r': float(r_oracle),
            'oracle_p': float(p_r_oracle),
            'random_r': float(r_random),
            'random_p': float(p_r_random),
        },
    },
    'per_sample_results': results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"Final results saved to {FINAL_RESULTS_PATH}")
print(f"File size: {FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB")
print(f"Total samples: {len(results)}")
print(f"\nDone!")

Final results saved to results/exp01/results.json
File size: 931.7 KB
Total samples: 2500

Done!
