# Experiment 17: Anchor Preservation

**Hypothesis**: KV cache truncation fails on long documents (Exp 11) because RoPE
correction shifts document positions down to [1, 2, ..., D], closing the positional
gap. This disrupts the BOS token's attention-sink role and the model's internal
position-based routing (the "Attention Sink" problem from StreamingLLM).

**The fix**: After truncation, **skip RoPE correction** so document tokens retain
their original positions [P, P+1, ..., P+D-1], creating a positional "gap" where
the surrogate prefix was. BOS stays at position 0 as the attention sink anchor.

**Key difference from current code**:
- `_old`: BOS preserved + RoPE correction (positions shifted to [0, 1, ..., D])
- `_anchor`: BOS preserved + NO RoPE correction (positions are [0, gap, P, ..., P+D-1])
- At scoring time: query tokens must get positions starting at P+D (not 1+D)

**Experiments revisited**:
| Original Exp | Finding | Why Revisit |
|---|---|---|
| Exp 07 | static_fact_trunc d=+0.472 on short docs | Baseline: anchor should be neutral on short docs |
| Exp 11 | ALL conditions fail on long docs | **Primary target**: anchor should recover long-doc performance |

In [None]:
# Cell 1: Setup & Imports
import os
os.umask(0o000)  # File permission safety (two-user environment)

import json
import time
import numpy as np
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from scipy import stats
from scipy.stats import spearmanr, pearsonr
from tqdm.auto import tqdm
from datasets import load_dataset
import matplotlib.pyplot as plt

from lib.config import ExperimentConfig
from lib.kv_cache import (
    build_kv_cache,
    build_truncated_kv_cache_corrected,
    build_truncated_cache_variable_prefix,
    score_answer_with_cache,
    deepcopy_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    _get_cache_keys,
    _get_cache_values,
)
from lib.analysis import cohens_d
from lib.surrogate import STATIC_SURROGATE_QUERIES, generate_surrogate_with_template

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Paths
RESULTS_DIR = Path('results/exp17')
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR = RESULTS_DIR / 'figures'
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
SURROGATES_DIR = RESULTS_DIR / 'surrogates'
SURROGATES_DIR.mkdir(parents=True, exist_ok=True)

# Model
MODEL_NAME = 'mistralai/Mistral-7B-Instruct-v0.2'
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, quantization_config=bnb_config, device_map='auto'
)
config = ExperimentConfig(device='cuda')

print(f'Model loaded: {MODEL_NAME}')
print(f'Device: {config.device}')
print(f'Results dir: {RESULTS_DIR}')

In [None]:
# Cell 2: Configuration

# --- MS MARCO (short docs, replicating Exp 07 scale) ---
N_MSMARCO = 300
MSMARCO_CHECKPOINT_PATH = RESULTS_DIR / 'msmarco_checkpoint.json'
MSMARCO_RESULTS_PATH = RESULTS_DIR / 'msmarco_results.json'
CHECKPOINT_EVERY = 25

# --- NQ (long docs, replicating Exp 11 scale) ---
N_NQ = 300
NQ_CHECKPOINT_PATH = RESULTS_DIR / 'nq_checkpoint.json'
NQ_RESULTS_PATH = RESULTS_DIR / 'nq_results.json'
NQ_SAMPLES_CACHE_PATH = RESULTS_DIR / 'nq_samples.json'

# Length bins for NQ analysis (same as Exp 11)
LENGTH_BINS = [
    ('short',     100,  300),
    ('medium',    300,  800),
    ('long',      800,  2000),
    ('very_long', 2000, 4000),
]
SAMPLES_PER_BIN = 75  # 75 * 4 = 300 total
MAX_DOC_WORDS = 4000

# Templates (matched to Exp 07/11)
SURROGATE_PREFIX_TEMPLATE = '{surrogate}\n'
DOCUMENT_TEMPLATE = '{document}'
QUERY_TEMPLATE = '\nQuery: {query}\nAnswer:'
ANSWER_TEMPLATE = ' {answer}'

# Static factual phrase
STATIC_FACTUAL_PHRASE = STATIC_SURROGATE_QUERIES['static_factual']['query']

# LLM keyword prompt
LLM_KW_PROMPT = (
    'You are helping index a document for search. Write a search query the way '
    'real users type into Google: just keywords, no complete sentences, no question marks. '
    'Think of someone quickly typing a few relevant words. '
    'Output only the keyword query (3-6 words), nothing else.\n\n'
    'Document:'
)
LLM_KW_MAX_DOC_WORDS = 500  # Truncate long docs for generation

# Condition names
CONDITION_NAMES = [
    'bare',
    'static_fact_old',
    'static_fact_anchor',
    'random_old',
    'random_anchor',
    'oracle_old',
    'oracle_anchor',
    'llm_kw_old',
    'llm_kw_anchor',
]

print(f'MS MARCO: N={N_MSMARCO}, checkpoint every {CHECKPOINT_EVERY}')
print(f'NQ: N={N_NQ}, {len(LENGTH_BINS)} length bins, {SAMPLES_PER_BIN} per bin')
print(f'Conditions: {CONDITION_NAMES}')

In [None]:
# Cell 3: Explain Experimental Conditions
print('=' * 70)
print('EXPERIMENTAL CONDITIONS EXPLAINED')
print('=' * 70)

print('''
Each sample is scored under 9 conditions. For primed conditions, we compare
two truncation strategies:

  _old    = Exp 07/11 method: truncate prefix + apply RoPE correction
            Positions become: [BOS=0, doc_1=1, doc_2=2, ..., doc_D=D]
            Query starts at position D+1

  _anchor = NEW: truncate prefix + SKIP RoPE correction (anchor preservation)
            Positions become: [BOS=0, <gap>, doc_1=P, doc_2=P+1, ..., doc_D=P+D-1]
            Query starts at position P+D (preserving the positional gap)
            BOS remains at position 0 as the attention sink anchor
''')

print('### bare ###')
print('  Baseline: document cached in isolation, no priming.')
print('  Cache: [BOS][doc tokens]')
print('  Positions: [0, 1, 2, ..., D]')
print()

print('### static_fact_old ###')
print(f'  Prefix: "{STATIC_FACTUAL_PHRASE}"')
print('  Build: [BOS][prefix][doc] → truncate → RoPE correct')
print('  Cache positions: [0, 1, 2, ..., D]  (gap closed)')
print('  Replicates Exp 07 static_fact_trunc condition.')
print()

print('### static_fact_anchor ###')
print(f'  Prefix: "{STATIC_FACTUAL_PHRASE}"')
print('  Build: [BOS][prefix][doc] → truncate → NO RoPE correction')
print('  Cache positions: [0, <gap>, P, P+1, ..., P+D-1]  (gap preserved)')
print('  NEW: tests whether preserving the positional gap helps.')
print()

print('### random_old / random_anchor ###')
print('  Prefix: random token string (same length as static_fact).')
print('  Tests structural vs semantic effect with both truncation strategies.')
print()

print('### oracle_old / oracle_anchor ###')
print('  Prefix: the actual query (perfect semantic match).')
print('  Tests whether semantic content interacts with anchor preservation.')
print()

print('### llm_kw_old / llm_kw_anchor ###')
print('  Prefix: LLM-generated keyword query from the passage.')
print('  Tests practical surrogate generation with both strategies.')

In [None]:
# Cell 4: Helper Functions

def build_matched_caches(passage, query, answer, keyword_surrogate, model, tokenizer, config):
    """Build bare + all condition caches for one sample with matched tokenization.
    
    Uses the oracle (longest) prefix to define the shared tokenization of the
    document, ensuring all conditions compare identical document token sequences.
    
    Returns dict mapping condition name -> NLL.
    """
    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    
    # --- Matched tokenization from oracle (longest) prefix ---
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    full_oracle_text = oracle_prefix + document_text
    
    full_oracle_enc = tokenizer(
        full_oracle_text, return_tensors='pt', add_special_tokens=True,
        padding=False, truncation=False
    )
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)
    
    oracle_prefix_enc = tokenizer(
        oracle_prefix, return_tensors='pt', add_special_tokens=True,
        padding=False, truncation=False
    )
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]
    
    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    doc_len = doc_ids.shape[1]
    
    nll_results = {}
    
    # Helper: build primed cache with matched tokenization
    def build_primed(prefix_text, apply_rope_correction):
        """Build a truncated cache from [BOS][prefix][doc] using matched doc tokens.
        
        Args:
            prefix_text: The prefix string
            apply_rope_correction: If True, apply RoPE correction (_old).
                If False, skip it (_anchor).
        
        Returns:
            (cache, keep_len, position_offset) where position_offset > 0 only
            when apply_rope_correction is False.
        """
        prefix_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=prefix_text)
        prefix_enc = tokenizer(
            prefix_str, return_tensors='pt', add_special_tokens=False,
            padding=False, truncation=False
        )
        prefix_ids = prefix_enc['input_ids'].to(config.device)
        prefix_token_len = 1 + prefix_ids.shape[1]  # +1 for BOS
        
        full_ids = torch.cat([bos_id, prefix_ids, doc_ids], dim=1)
        
        with torch.no_grad():
            out = model(
                input_ids=full_ids,
                attention_mask=torch.ones_like(full_ids),
                use_cache=True, return_dict=True
            )
        
        cache = extract_and_truncate_cache_with_bos(out.past_key_values, doc_len)
        keep_len = 1 + doc_len
        
        surrogate_offset = prefix_token_len - 1
        if apply_rope_correction:
            correct_rope_positions_with_bos(cache, surrogate_offset, model)
            position_offset = 0
        else:
            position_offset = surrogate_offset
        
        del out
        return cache, keep_len, position_offset
    
    def score(cache, keep_len, position_offset):
        return score_answer_with_cache(
            deepcopy_cache(cache), keep_len, query_prompt, answer_text,
            model, tokenizer, config, position_offset=position_offset
        )
    
    # --- 1. Bare ---
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    with torch.no_grad():
        bare_out = model(
            input_ids=bare_ids,
            attention_mask=torch.ones_like(bare_ids),
            use_cache=True, return_dict=True
        )
    bare_cache = bare_out.past_key_values
    nll_results['bare'] = score_answer_with_cache(
        deepcopy_cache(bare_cache), bare_ids.shape[1],
        query_prompt, answer_text, model, tokenizer, config
    )
    del bare_cache, bare_out
    
    # --- 2. Static fact ---
    for suffix, apply_rope in [('old', True), ('anchor', False)]:
        cache, kl, po = build_primed(STATIC_FACTUAL_PHRASE, apply_rope)
        nll_results[f'static_fact_{suffix}'] = score(cache, kl, po)
        del cache
    
    # --- 3. Random ---
    n_random_tokens = max(5, len(tokenizer.encode(
        STATIC_FACTUAL_PHRASE, add_special_tokens=False)))
    random_ids = torch.randint(
        100, tokenizer.vocab_size - 100, (n_random_tokens,), device='cpu'
    )
    random_text = tokenizer.decode(random_ids, skip_special_tokens=True)
    
    for suffix, apply_rope in [('old', True), ('anchor', False)]:
        cache, kl, po = build_primed(random_text, apply_rope)
        nll_results[f'random_{suffix}'] = score(cache, kl, po)
        del cache
    
    # --- 4. Oracle ---
    for suffix, apply_rope in [('old', True), ('anchor', False)]:
        cache, kl, po = build_primed(query, apply_rope)
        nll_results[f'oracle_{suffix}'] = score(cache, kl, po)
        del cache
    
    # --- 5. LLM keyword ---
    if keyword_surrogate and keyword_surrogate.strip():
        for suffix, apply_rope in [('old', True), ('anchor', False)]:
            cache, kl, po = build_primed(keyword_surrogate, apply_rope)
            nll_results[f'llm_kw_{suffix}'] = score(cache, kl, po)
            del cache
    else:
        nll_results['llm_kw_old'] = 0.0
        nll_results['llm_kw_anchor'] = 0.0
    
    torch.cuda.empty_cache()
    return nll_results, doc_len


def run_analysis(results, condition_names, dataset_label):
    """Run statistical analysis on experiment results.
    
    Returns analysis dict suitable for JSON serialization.
    """
    cond_arrays = {cn: np.array([r[cn] for r in results]) for cn in condition_names}
    
    # Filter out failed samples (llm_kw can be 0.0 if generation failed)
    valid = np.ones(len(results), dtype=bool)
    for cn in condition_names:
        valid &= (cond_arrays[cn] != 0)
    n_valid = int(np.sum(valid))
    n_excluded = len(results) - n_valid
    
    c = {cn: cond_arrays[cn][valid] for cn in condition_names}
    
    print(f'\n{"=" * 70}')
    print(f'{dataset_label} ANALYSIS (n_valid={n_valid}, excluded={n_excluded})')
    print('=' * 70)
    
    # --- All conditions vs bare ---
    print(f'\n{"Condition":<25} {"Mean Δ":>10} {"d":>8} {"Win%":>7} {"t":>8} {"p":>12} {"sig":>5}')
    print('-' * 75)
    
    all_vs_bare = {}
    for cn in condition_names:
        if cn == 'bare':
            print(f'{cn:<25} {np.mean(c[cn]):>10.4f} {"---":>8} {"---":>7}')
            continue
        delta = c['bare'] - c[cn]  # positive = condition is better
        d = cohens_d(delta)
        win = np.mean(delta > 0) * 100
        t_stat, p_val = stats.ttest_1samp(delta, 0)
        sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
        print(f'{cn:<25} {np.mean(delta):>10.4f} {d:>+8.3f} {win:>6.1f}% {t_stat:>8.2f} {p_val:>12.2e} {sig:>5}')
        all_vs_bare[cn] = {
            'mean_delta': float(np.mean(delta)),
            'cohens_d': float(d),
            'win_pct': float(win),
            't_stat': float(t_stat),
            'p_value': float(p_val),
        }
    
    # --- Head-to-head: old vs anchor ---
    print(f'\n--- Old vs Anchor (head-to-head) ---')
    print(f'{"Comparison":<35} {"Mean Δ":>10} {"d":>8} {"Win%":>7} {"p":>12} {"sig":>5}')
    print('-' * 75)
    
    h2h = {}
    for prefix in ['static_fact', 'random', 'oracle', 'llm_kw']:
        old_cn = f'{prefix}_old'
        new_cn = f'{prefix}_anchor'
        delta = c[old_cn] - c[new_cn]  # positive = anchor is better (lower NLL)
        d = cohens_d(delta)
        win = np.mean(delta > 0) * 100
        t_stat, p_val = stats.ttest_1samp(delta, 0)
        sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
        label = f'{old_cn} vs {new_cn}'
        print(f'{label:<35} {np.mean(delta):>10.4f} {d:>+8.3f} {win:>6.1f}% {p_val:>12.2e} {sig:>5}')
        h2h[prefix] = {
            'mean_delta': float(np.mean(delta)),
            'cohens_d': float(d),
            'anchor_win_pct': float(win),
            'p_value': float(p_val),
        }
    
    # --- Hardness stratification ---
    print(f'\n--- Hardness Stratification (quintiles by bare NLL) ---')
    bare_valid = c['bare']
    quintile_bounds = np.percentile(bare_valid, [20, 40, 60, 80])
    quintile_labels = ['Q1 (easy)', 'Q2', 'Q3', 'Q4', 'Q5 (hard)']
    
    def get_quintile(nll):
        for i, b in enumerate(quintile_bounds):
            if nll <= b:
                return i
        return len(quintile_bounds)
    
    quintiles = np.array([get_quintile(nll) for nll in bare_valid])
    
    hardness = {}
    key_conds = ['static_fact_old', 'static_fact_anchor', 'oracle_old', 'oracle_anchor']
    header = f'{"Condition":<25}'
    for ql in quintile_labels:
        header += f'{ql:>14}'
    print(header)
    print('-' * (25 + 14 * 5))
    
    for cn in key_conds:
        row = f'{cn:<25}'
        cond_hardness = {}
        for q in range(5):
            mask_q = quintiles == q
            n_q = int(np.sum(mask_q))
            if n_q < 5:
                row += f'{"n/a":>14}'
                continue
            delta = bare_valid[mask_q] - c[cn][mask_q]
            d = cohens_d(delta)
            row += f'{d:>+14.3f}'
            cond_hardness[quintile_labels[q]] = float(d)
        print(row)
        hardness[cn] = cond_hardness
    
    return {
        'n_valid': n_valid,
        'n_excluded': n_excluded,
        'all_vs_bare': all_vs_bare,
        'head_to_head': h2h,
        'hardness': hardness,
    }


print('Helper functions defined.')

In [None]:
# Cell 5: MS MARCO Evaluation (Short Docs)

from lib.data import load_ms_marco, load_evaluation_samples

# Load data
msmarco_dataset = load_ms_marco(config)
all_msmarco_samples = load_evaluation_samples(msmarco_dataset, config, require_answer=True)
msmarco_samples = all_msmarco_samples[:N_MSMARCO]
print(f'MS MARCO samples: {len(msmarco_samples)}')

# Generate LLM keyword surrogates (or load from cache)
msmarco_surr_path = SURROGATES_DIR / 'msmarco_keyword_surrogates.json'
if msmarco_surr_path.exists():
    with open(msmarco_surr_path, 'r') as f:
        msmarco_kw_surrogates = json.load(f)['surrogates']
else:
    msmarco_kw_surrogates = []

start_gen = len(msmarco_kw_surrogates)
if start_gen < N_MSMARCO:
    print(f'Generating keyword surrogates {start_gen}..{N_MSMARCO}')
    for idx in tqdm(range(start_gen, N_MSMARCO), initial=start_gen, total=N_MSMARCO,
                    desc='Keyword surrogates'):
        passage = msmarco_samples[idx]['passage']
        try:
            kw = generate_surrogate_with_template(
                passage, LLM_KW_PROMPT, model, tokenizer, config)
        except Exception:
            kw = ''
        msmarco_kw_surrogates.append(kw)
        if (idx + 1) % 50 == 0 or idx == N_MSMARCO - 1:
            with open(msmarco_surr_path, 'w') as f:
                json.dump({'surrogates': msmarco_kw_surrogates}, f)
    print('Surrogates generated.')

# Evaluate
msmarco_results = []
start_idx = 0

if MSMARCO_CHECKPOINT_PATH.exists():
    with open(MSMARCO_CHECKPOINT_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('sample_queries', [])
    current_queries = [s['query'] for s in msmarco_samples]
    if ckpt_queries == current_queries:
        msmarco_results = ckpt['results']
        start_idx = len(msmarco_results)
        print(f'Resuming from checkpoint: {start_idx}/{N_MSMARCO}')
    else:
        print('Checkpoint sample mismatch. Starting fresh.')

for idx in tqdm(range(start_idx, N_MSMARCO), initial=start_idx, total=N_MSMARCO,
                desc='MS MARCO eval'):
    sample = msmarco_samples[idx]
    kw_surr = msmarco_kw_surrogates[idx] if idx < len(msmarco_kw_surrogates) else ''
    
    nll_dict, doc_len = build_matched_caches(
        sample['passage'], sample['query'], sample['answer'],
        kw_surr, model, tokenizer, config
    )
    
    result = {
        'idx': idx,
        'doc_len_tokens': doc_len,
        'word_count': len(sample['passage'].split()),
        **nll_dict,
    }
    msmarco_results.append(result)
    
    if (idx + 1) % CHECKPOINT_EVERY == 0 or idx == N_MSMARCO - 1:
        ckpt_data = {
            'results': msmarco_results,
            'sample_queries': [s['query'] for s in msmarco_samples],
            'completed': len(msmarco_results),
            'total': N_MSMARCO,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(MSMARCO_CHECKPOINT_PATH, 'w') as f:
            json.dump(ckpt_data, f)

print(f'MS MARCO evaluation complete: {len(msmarco_results)} samples')

In [None]:
# Cell 6: MS MARCO Analysis

msmarco_analysis = run_analysis(msmarco_results, CONDITION_NAMES, 'MS MARCO (Short Docs)')

# Save
msmarco_final = {
    'experiment': 'exp17_anchor_preservation_msmarco',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': MODEL_NAME,
        'seed': SEED,
        'n_eval': N_MSMARCO,
        'dataset': 'MS MARCO v1.1',
    },
    'condition_names': CONDITION_NAMES,
    'analysis': msmarco_analysis,
    'per_sample_results': msmarco_results,
}
with open(MSMARCO_RESULTS_PATH, 'w') as f:
    json.dump(msmarco_final, f, indent=2)
print(f'\nSaved to {MSMARCO_RESULTS_PATH}')

In [None]:
# Cell 7: NQ Evaluation (Long Docs)

# --- Load NQ samples (stratified by length) ---
if NQ_SAMPLES_CACHE_PATH.exists():
    with open(NQ_SAMPLES_CACHE_PATH, 'r') as f:
        nq_cache = json.load(f)
    nq_samples = nq_cache['samples']
    print(f'Loaded {len(nq_samples)} NQ samples from cache')
else:
    print('Loading NQ dataset (streaming)...')
    nq = load_dataset(
        'google-research-datasets/natural_questions',
        split='validation',
        streaming=True,
    )
    
    bin_samples = {name: [] for name, _, _ in LENGTH_BINS}
    n_processed = 0
    
    for example in tqdm(nq, desc='Processing NQ'):
        n_processed += 1
        
        # Extract clean document text
        doc_tokens = example['document']['tokens']
        if isinstance(doc_tokens, dict):
            token_strs = doc_tokens['token']
            is_html_flags = doc_tokens['is_html']
            clean_tokens = [t for t, h in zip(token_strs, is_html_flags) if not h]
        else:
            clean_tokens = [t['token'] for t in doc_tokens if not t['is_html']]
        
        doc_text = ' '.join(clean_tokens)
        word_count = len(doc_text.split())
        
        if word_count < LENGTH_BINS[0][1]:
            continue
        if word_count > MAX_DOC_WORDS:
            words = doc_text.split()
            doc_text = ' '.join(words[:MAX_DOC_WORDS])
            word_count = MAX_DOC_WORDS
        
        # Extract short answer
        annotations = example['annotations']
        short_answers_list = annotations['short_answers']
        
        answer_text = None
        for annotator_sa in short_answers_list:
            if not annotator_sa:
                continue
            texts = annotator_sa.get('text', [])
            if texts:
                answer_text = texts[0]
                break
            starts = annotator_sa.get('start_token', [])
            ends = annotator_sa.get('end_token', [])
            if not starts or not ends:
                continue
            start_tok = starts[0] if isinstance(starts, list) else starts
            end_tok = ends[0] if isinstance(ends, list) else ends
            if start_tok >= 0 and end_tok > start_tok:
                if isinstance(doc_tokens, dict):
                    ans_tokens = [
                        doc_tokens['token'][i]
                        for i in range(start_tok, min(end_tok, len(doc_tokens['token'])))
                        if not doc_tokens['is_html'][i]
                    ]
                else:
                    ans_tokens = [
                        doc_tokens[i]['token']
                        for i in range(start_tok, min(end_tok, len(doc_tokens)))
                        if not doc_tokens[i]['is_html']
                    ]
                if ans_tokens:
                    answer_text = ' '.join(ans_tokens)
                    break
        
        if not answer_text or len(answer_text.strip()) == 0:
            continue
        if len(answer_text.split()) > 20:
            continue
        
        question = example['question']
        query = question.get('text', '') if isinstance(question, dict) else str(question)
        
        # Assign to length bin
        for bin_name, bin_min, bin_max in LENGTH_BINS:
            if bin_min <= word_count < bin_max:
                if len(bin_samples[bin_name]) < SAMPLES_PER_BIN:
                    bin_samples[bin_name].append({
                        'passage': doc_text,
                        'query': query,
                        'answer': answer_text,
                        'word_count': word_count,
                        'length_bin': bin_name,
                    })
                break
        
        if all(len(bin_samples[name]) >= SAMPLES_PER_BIN for name, _, _ in LENGTH_BINS):
            break
    
    # Combine
    nq_samples = []
    for bin_name, _, _ in LENGTH_BINS:
        bs = bin_samples[bin_name]
        np.random.seed(SEED)
        np.random.shuffle(bs)
        nq_samples.extend(bs)
        print(f'  {bin_name}: {len(bs)} samples')
    
    with open(NQ_SAMPLES_CACHE_PATH, 'w') as f:
        json.dump({'samples': nq_samples, 'n_processed': n_processed}, f)
    print(f'Saved {len(nq_samples)} NQ samples')

N_NQ = min(N_NQ, len(nq_samples))
print(f'NQ samples to evaluate: {N_NQ}')

# Generate NQ keyword surrogates
nq_surr_path = SURROGATES_DIR / 'nq_keyword_surrogates.json'
if nq_surr_path.exists():
    with open(nq_surr_path, 'r') as f:
        nq_kw_surrogates = json.load(f)['surrogates']
else:
    nq_kw_surrogates = []

start_gen = len(nq_kw_surrogates)
if start_gen < N_NQ:
    print(f'Generating NQ keyword surrogates {start_gen}..{N_NQ}')
    for idx in tqdm(range(start_gen, N_NQ), initial=start_gen, total=N_NQ,
                    desc='NQ keyword surrogates'):
        passage = nq_samples[idx]['passage']
        words = passage.split()
        if len(words) > LLM_KW_MAX_DOC_WORDS:
            passage_for_gen = ' '.join(words[:LLM_KW_MAX_DOC_WORDS])
        else:
            passage_for_gen = passage
        try:
            kw = generate_surrogate_with_template(
                passage_for_gen, LLM_KW_PROMPT, model, tokenizer, config)
        except Exception:
            kw = ''
        nq_kw_surrogates.append(kw)
        if (idx + 1) % 50 == 0 or idx == N_NQ - 1:
            with open(nq_surr_path, 'w') as f:
                json.dump({'surrogates': nq_kw_surrogates}, f)
    print('NQ surrogates generated.')

# Evaluate
nq_results = []
start_idx = 0

if NQ_CHECKPOINT_PATH.exists():
    with open(NQ_CHECKPOINT_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('sample_queries', [])
    current_queries = [s['query'] for s in nq_samples[:N_NQ]]
    if ckpt_queries == current_queries:
        nq_results = ckpt['results']
        start_idx = len(nq_results)
        print(f'Resuming from checkpoint: {start_idx}/{N_NQ}')
    else:
        print('Checkpoint sample mismatch. Starting fresh.')

for idx in tqdm(range(start_idx, N_NQ), initial=start_idx, total=N_NQ,
                desc='NQ eval'):
    sample = nq_samples[idx]
    kw_surr = nq_kw_surrogates[idx] if idx < len(nq_kw_surrogates) else ''
    
    nll_dict, doc_len = build_matched_caches(
        sample['passage'], sample['query'], sample['answer'],
        kw_surr, model, tokenizer, config
    )
    
    result = {
        'idx': idx,
        'doc_len_tokens': doc_len,
        'word_count': sample['word_count'],
        'length_bin': sample['length_bin'],
        **nll_dict,
    }
    nq_results.append(result)
    
    if (idx + 1) % CHECKPOINT_EVERY == 0 or idx == N_NQ - 1:
        ckpt_data = {
            'results': nq_results,
            'sample_queries': [s['query'] for s in nq_samples[:N_NQ]],
            'completed': len(nq_results),
            'total': N_NQ,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(NQ_CHECKPOINT_PATH, 'w') as f:
            json.dump(ckpt_data, f)

print(f'NQ evaluation complete: {len(nq_results)} samples')

In [None]:
# Cell 8: NQ Analysis

nq_analysis = run_analysis(nq_results, CONDITION_NAMES, 'NQ (Long Docs)')

# --- Per Length Bin Analysis ---
print(f'\n{"=" * 70}')
print('NQ: PER LENGTH BIN ANALYSIS')
print('=' * 70)

cond_arrays = {cn: np.array([r[cn] for r in nq_results]) for cn in CONDITION_NAMES}
valid = np.ones(len(nq_results), dtype=bool)
for cn in CONDITION_NAMES:
    valid &= (cond_arrays[cn] != 0)
c = {cn: cond_arrays[cn][valid] for cn in CONDITION_NAMES}

length_bins_arr = np.array([r['length_bin'] for r in nq_results])[valid]
word_counts_arr = np.array([r['word_count'] for r in nq_results])[valid]

bin_names_ordered = [name for name, _, _ in LENGTH_BINS]
per_bin_results = {}

for cn in CONDITION_NAMES:
    if cn == 'bare':
        continue
    print(f'\n  {cn}:')
    bin_ds = {}
    for bin_name in bin_names_ordered:
        mask = length_bins_arr == bin_name
        n_bin = int(np.sum(mask))
        if n_bin < 5:
            continue
        delta = c['bare'][mask] - c[cn][mask]
        d = cohens_d(delta)
        win = np.mean(delta > 0) * 100
        _, p_val = stats.ttest_1samp(delta, 0)
        sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
        print(f'    {bin_name}: n={n_bin}, d={d:+.3f}, win={win:.1f}%, p={p_val:.2e} {sig}')
        bin_ds[bin_name] = {'d': float(d), 'win': float(win), 'n': n_bin, 'p': float(p_val)}
    per_bin_results[cn] = bin_ds

# --- Old vs Anchor by length bin ---
print(f'\n--- Old vs Anchor by Length Bin (CRITICAL TEST) ---')
print(f'{"":<15}', end='')
for bn in bin_names_ordered:
    print(f'{bn:>15}', end='')
print()
print('-' * (15 + 15 * len(bin_names_ordered)))

h2h_by_bin = {}
for prefix in ['static_fact', 'random', 'oracle', 'llm_kw']:
    old_cn = f'{prefix}_old'
    new_cn = f'{prefix}_anchor'
    row = f'{prefix:<15}'
    h2h_bin = {}
    for bn in bin_names_ordered:
        mask = length_bins_arr == bn
        n_bin = int(np.sum(mask))
        if n_bin < 5:
            row += f'{"n/a":>15}'
            continue
        delta = c[old_cn][mask] - c[new_cn][mask]  # positive = anchor wins
        d = cohens_d(delta)
        win = np.mean(delta > 0) * 100
        row += f'{d:>+7.3f} ({win:4.0f}%)'
        h2h_bin[bn] = {'d': float(d), 'win': float(win)}
    print(row)
    h2h_by_bin[prefix] = h2h_bin

# --- Length correlation ---
print(f'\n--- Length × Effect Size Correlation ---')
interaction_results = {}
for cn in CONDITION_NAMES:
    if cn == 'bare':
        continue
    delta = c['bare'] - c[cn]
    r_spear, p_spear = spearmanr(word_counts_arr, delta)
    r_pears, p_pears = pearsonr(word_counts_arr, delta)
    print(f'  {cn}: Spearman r={r_spear:+.3f} (p={p_spear:.3f}), '
          f'Pearson r={r_pears:+.3f} (p={p_pears:.3f})')
    interaction_results[cn] = {
        'spearman_r': float(r_spear), 'spearman_p': float(p_spear),
        'pearson_r': float(r_pears), 'pearson_p': float(p_pears),
    }

nq_analysis['per_bin'] = per_bin_results
nq_analysis['h2h_by_bin'] = h2h_by_bin
nq_analysis['length_interaction'] = interaction_results

# Save
nq_final = {
    'experiment': 'exp17_anchor_preservation_nq',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': MODEL_NAME,
        'seed': SEED,
        'n_eval': N_NQ,
        'dataset': 'google-research-datasets/natural_questions',
        'length_bins': LENGTH_BINS,
    },
    'condition_names': CONDITION_NAMES,
    'analysis': nq_analysis,
    'per_sample_results': nq_results,
}
with open(NQ_RESULTS_PATH, 'w') as f:
    json.dump(nq_final, f, indent=2)
print(f'\nSaved to {NQ_RESULTS_PATH}')

In [None]:
# Cell 9: Summary & Visualization

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# --- Plot 1: Effect sizes comparison (old vs anchor) ---
ax = axes[0]
prefixes = ['static_fact', 'random', 'oracle', 'llm_kw']
x = np.arange(len(prefixes))
width = 0.35

# MS MARCO data
msmarco_c = {}
msmarco_arr = {cn: np.array([r[cn] for r in msmarco_results]) for cn in CONDITION_NAMES}
msmarco_valid = np.ones(len(msmarco_results), dtype=bool)
for cn in CONDITION_NAMES:
    msmarco_valid &= (msmarco_arr[cn] != 0)
msmarco_c = {cn: msmarco_arr[cn][msmarco_valid] for cn in CONDITION_NAMES}

old_d_msmarco = [cohens_d(msmarco_c['bare'] - msmarco_c[f'{p}_old']) for p in prefixes]
anchor_d_msmarco = [cohens_d(msmarco_c['bare'] - msmarco_c[f'{p}_anchor']) for p in prefixes]

bars1 = ax.bar(x - width/2, old_d_msmarco, width, label='Old (RoPE corrected)', color='#2196F3', alpha=0.8)
bars2 = ax.bar(x + width/2, anchor_d_msmarco, width, label='Anchor (gap preserved)', color='#FF9800', alpha=0.8)

ax.set_ylabel("Cohen's d vs bare")
ax.set_title('MS MARCO (Short Docs)')
ax.set_xticks(x)
ax.set_xticklabels(prefixes, rotation=15)
ax.legend()
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax.set_ylim(min(min(old_d_msmarco), min(anchor_d_msmarco)) - 0.1,
            max(max(old_d_msmarco), max(anchor_d_msmarco)) + 0.1)

# --- Plot 2: NQ effect sizes by length bin ---
ax = axes[1]

# For NQ, show static_fact old vs anchor across length bins
nq_c = {}
nq_arr = {cn: np.array([r[cn] for r in nq_results]) for cn in CONDITION_NAMES}
nq_valid = np.ones(len(nq_results), dtype=bool)
for cn in CONDITION_NAMES:
    nq_valid &= (nq_arr[cn] != 0)
nq_c = {cn: nq_arr[cn][nq_valid] for cn in CONDITION_NAMES}
nq_bins_arr = np.array([r['length_bin'] for r in nq_results])[nq_valid]

old_d_bins = []
anchor_d_bins = []
for bn in bin_names_ordered:
    mask = nq_bins_arr == bn
    if np.sum(mask) < 5:
        old_d_bins.append(0)
        anchor_d_bins.append(0)
    else:
        old_d_bins.append(cohens_d(nq_c['bare'][mask] - nq_c['static_fact_old'][mask]))
        anchor_d_bins.append(cohens_d(nq_c['bare'][mask] - nq_c['static_fact_anchor'][mask]))

x2 = np.arange(len(bin_names_ordered))
bars3 = ax.bar(x2 - width/2, old_d_bins, width, label='Old (RoPE corrected)', color='#2196F3', alpha=0.8)
bars4 = ax.bar(x2 + width/2, anchor_d_bins, width, label='Anchor (gap preserved)', color='#FF9800', alpha=0.8)

ax.set_ylabel("Cohen's d vs bare")
ax.set_title('NQ: static_fact by Length Bin')
ax.set_xticks(x2)
ax.set_xticklabels(bin_names_ordered, rotation=15)
ax.legend()
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'effect_sizes_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'Saved to {FIGURES_DIR / "effect_sizes_comparison.png"}')

# --- Length interaction plot ---
fig, ax = plt.subplots(figsize=(10, 6))

for cn, color, marker in [
    ('static_fact_old', '#2196F3', 'o'),
    ('static_fact_anchor', '#FF9800', 's'),
    ('oracle_old', '#4CAF50', '^'),
    ('oracle_anchor', '#F44336', 'D'),
]:
    ds = []
    bin_centers = []
    for bn, bmin, bmax in LENGTH_BINS:
        mask = nq_bins_arr == bn
        if np.sum(mask) < 5:
            continue
        d = cohens_d(nq_c['bare'][mask] - nq_c[cn][mask])
        ds.append(d)
        bin_centers.append((bmin + bmax) / 2)
    ax.plot(bin_centers, ds, marker=marker, label=cn, color=color, linewidth=2, markersize=8)

ax.set_xlabel('Document Length (words, bin center)')
ax.set_ylabel("Cohen's d vs bare")
ax.set_title('NQ: Length × Method Interaction')
ax.legend()
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax.set_xscale('log')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'length_interaction.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'Saved to {FIGURES_DIR / "length_interaction.png"}')

# --- Summary ---
print('\n' + '=' * 70)
print('EXPERIMENT 17 SUMMARY')
print('=' * 70)
print()
print('Key question: Does anchor preservation (skipping RoPE correction)')
print('recover priming performance on long documents?')
print()

# Print summary comparison
for dataset_name, analysis in [('MS MARCO', msmarco_analysis), ('NQ', nq_analysis)]:
    print(f'\n--- {dataset_name} ---')
    h2h = analysis.get('head_to_head', {})
    for prefix in prefixes:
        if prefix in h2h:
            d = h2h[prefix]['cohens_d']
            win = h2h[prefix]['anchor_win_pct']
            p = h2h[prefix]['p_value']
            verdict = 'ANCHOR BETTER' if d > 0 and p < 0.05 else 'OLD BETTER' if d < 0 and p < 0.05 else 'NO DIFFERENCE'
            print(f'  {prefix}: anchor d={d:+.3f}, win={win:.0f}%, p={p:.3e} → {verdict}')

# Save combined analysis
combined = {
    'experiment': 'exp17_anchor_preservation',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'msmarco': msmarco_analysis,
    'nq': nq_analysis,
}
with open(RESULTS_DIR / 'analysis_summary.json', 'w') as f:
    json.dump(combined, f, indent=2)
print(f'\nAll results saved to {RESULTS_DIR}/')