# Experiment 12: Definitive Semantic Signal Confirmation

## Motivation
Exp 11 found a suggestive quality gradient: higher-similarity real-query surrogates produced larger NLL improvements (very_low: d=0.157, high: d=0.511). But the evidence was not ironclad:
- Individual-level Pearson r = 0.054 (p=0.12) — CI may include 0
- High-similarity bin had only N=65 samples
- Single dataset (MS MARCO), single prefix template, no confound controls

## Four Investigations
- **A**: Scaled quality gradient with controls (N=2500, MS MARCO) — 11 conditions including shuffled, length-matched, and raw-query controls
- **B**: Cross-dataset replication (N=1000 each, SQuAD v2 + TriviaQA) — 5 conditions
- **C**: Prefix format ablation (N=400, MS MARCO) — 7 format conditions
- **D**: Ranking with bootstrap CIs (N=1000, MS MARCO) — MRR/Hit@1/Hit@3

## Pre-Registered Verdict Criteria
- **CONFIRMED** if: (a) Pearson r bootstrap 95% CI excludes 0 on MS MARCO (N=2500), AND (b) bin-level Spearman rho > 0.7, AND (c) sim_0.60 significantly beats shuffled_0.60 (p<0.05)
- **PARTIALLY CONFIRMED** if: (a) or (b) holds but not both, or signal on only 1 of 3 datasets
- **REFUTED** if: Pearson r CI includes 0, no monotonic trend, shuffled matches real

In [None]:
import sys, os, json, copy, time, datetime, random, re
from typing import Dict, List, Any, Optional, Tuple

import torch
import numpy as np
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 120

sys.path.insert(0, '.')

from lib import (
    ExperimentConfig,
    build_kv_cache,
    score_answer_with_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    load_evaluation_samples,
    load_ms_marco,
    _ensure_dynamic_cache,
)
from lib.kv_cache import _get_cache_keys, _get_cache_values

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# Configuration
# ============================================================

config = ExperimentConfig(
    num_samples=8000,  # Load extra for large query pool
    min_passage_words=50,
    max_passage_words=300,
    seed=42,
)

N_INV_A = 2500
N_INV_B = 1000  # per dataset
N_INV_C = 400
N_INV_D = 1000

# Similarity bins for Investigation A (finer than Exp 11)
SIM_BINS_A = [
    (0.05, 0.15, 'sim_0.10'),
    (0.25, 0.35, 'sim_0.30'),
    (0.40, 0.50, 'sim_0.45'),
    (0.55, 0.65, 'sim_0.60'),
    (0.70, 0.80, 'sim_0.75'),
    (0.80, 0.90, 'sim_0.85'),
]

# Similarity bins for Investigation B (reduced)
SIM_BINS_B = [
    (0.25, 0.35, 'sim_0.30'),
    (0.55, 0.65, 'sim_0.60'),
    (0.75, 0.85, 'sim_0.80'),
]

# Prefix formats for Investigation C
PREFIX_FORMATS = {
    'template': 'This document answers: {query}',
    'raw': '{query}',
    'question': 'Question: {query}\n\nPassage:',
    'instruction': 'Find information about: {query}\n\n',
    'shuffled_template': 'This document answers: {shuffled}',
}

SEEDS = {'msmarco': 42, 'squad': 43, 'triviaqa': 44}

torch.manual_seed(config.seed)
np.random.seed(config.seed)

print(f"Investigation A: {N_INV_A} samples x 11 conditions")
print(f"Investigation B: {N_INV_B} samples x 2 datasets x 5 conditions")
print(f"Investigation C: {N_INV_C} samples x 7 conditions")
print(f"Investigation D: {N_INV_D} samples x 5 caches x 5 queries")

In [None]:
# ============================================================
# Load Model
# ============================================================

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

print(f"Loading {config.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
)
model.eval()
print(f"Model loaded. Layers: {model.config.num_hidden_layers}")

print("Loading embedding model...")
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded.")

In [None]:
# ============================================================
# Load MS MARCO, Build Query Pool, Embed
# ============================================================

print("Loading MS MARCO dataset...")
dataset = load_dataset(
    config.dataset_name,
    config.dataset_config,
    split=config.dataset_split,
    trust_remote_code=True
)
print(f"Dataset loaded: {len(dataset)} samples")

all_samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Loaded {len(all_samples)} evaluation samples with answers")

# Build query pool from FULL dataset
print("\nBuilding query pool (full dataset)...")
query_pool = []
seen_queries = set()
for item in tqdm(dataset, desc="Extracting queries"):
    q = item.get('query', '').strip()
    if q and q not in seen_queries and len(q) > 10:
        query_pool.append(q)
        seen_queries.add(q)
print(f"Query pool size: {len(query_pool)}")

print("Embedding query pool...")
pool_embeddings = embed_model.encode(query_pool, show_progress_bar=True, batch_size=256)
print(f"Pool embeddings shape: {pool_embeddings.shape}")

In [None]:
# ============================================================
# Load SQuAD v2 + TriviaQA, Build Per-Dataset Query Pools
# ============================================================

# --- SQuAD v2 ---
print("Loading SQuAD v2...")
squad_dataset = load_dataset("rajpurkar/squad_v2", split="validation")
print(f"SQuAD v2 loaded: {len(squad_dataset)} samples")

squad_train = load_dataset("rajpurkar/squad_v2", split="train")
print(f"SQuAD v2 train (for query pool): {len(squad_train)} samples")

np.random.seed(SEEDS['squad'])
squad_samples = []
for item in squad_dataset:
    ctx = item.get('context', '').strip()
    q = item.get('question', '').strip()
    answers = item.get('answers', {})
    ans_texts = answers.get('text', [])
    if not ctx or not q or not ans_texts or len(ans_texts) == 0:
        continue
    wc = len(ctx.split())
    if 50 <= wc <= 300:
        squad_samples.append({
            'passage': ctx,
            'query': q,
            'answer': ans_texts[0],
        })
np.random.shuffle(squad_samples)
squad_samples = squad_samples[:N_INV_B]
print(f"SQuAD evaluation samples: {len(squad_samples)}")

# SQuAD query pool from train split
squad_query_pool = []
seen_sq = set()
for item in squad_train:
    q = item.get('question', '').strip()
    if q and q not in seen_sq and len(q) > 10:
        squad_query_pool.append(q)
        seen_sq.add(q)
print(f"SQuAD query pool: {len(squad_query_pool)}")
print("Embedding SQuAD query pool...")
squad_pool_embs = embed_model.encode(squad_query_pool, show_progress_bar=True, batch_size=256)

# --- TriviaQA ---
print("\nLoading TriviaQA (rc.wikipedia)...")
tqa_dataset = load_dataset("trivia_qa", "rc.wikipedia", split="validation")
print(f"TriviaQA loaded: {len(tqa_dataset)} samples")

tqa_train = load_dataset("trivia_qa", "rc.wikipedia", split="train")
print(f"TriviaQA train (for query pool): {len(tqa_train)} samples")

np.random.seed(SEEDS['triviaqa'])
tqa_samples = []
for item in tqa_dataset:
    q = item.get('question', '').strip()
    ans = item.get('answer', {})
    aliases = ans.get('aliases', [])
    if not q or not aliases:
        continue
    # Get first Wikipedia context
    entity_pages = item.get('entity_pages', {})
    wiki_contexts = entity_pages.get('wiki_context', [])
    if not wiki_contexts:
        continue
    ctx = wiki_contexts[0].strip()
    # Truncate long contexts to ~300 words
    words = ctx.split()
    if len(words) < 50:
        continue
    if len(words) > 300:
        ctx = ' '.join(words[:300])
    tqa_samples.append({
        'passage': ctx,
        'query': q,
        'answer': aliases[0],
    })

np.random.shuffle(tqa_samples)
tqa_samples = tqa_samples[:N_INV_B]
print(f"TriviaQA evaluation samples: {len(tqa_samples)}")

# TriviaQA query pool from train split
tqa_query_pool = []
seen_tq = set()
for item in tqa_train:
    q = item.get('question', '').strip()
    if q and q not in seen_tq and len(q) > 10:
        tqa_query_pool.append(q)
        seen_tq.add(q)
print(f"TriviaQA query pool: {len(tqa_query_pool)}")
print("Embedding TriviaQA query pool...")
tqa_pool_embs = embed_model.encode(tqa_query_pool, show_progress_bar=True, batch_size=256)

In [None]:
# ============================================================
# Helper Functions
# ============================================================

def build_matched_bare_and_truncated(
    prefix_text: str,
    passage: str,
    model,
    tokenizer,
    config,
) -> Tuple[int, DynamicCache, int, DynamicCache, int]:
    """Build BPE-matched bare and truncated caches.
    
    Returns: (bare_len, bare_cache, trunc_len, trunc_cache, prefix_token_len)
    """
    prefix_with_sep = prefix_text.strip() + " "
    
    prefix_encoding = tokenizer(
        prefix_with_sep, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    prefix_len = prefix_encoding['input_ids'].shape[1]
    
    full_context = prefix_with_sep + passage
    full_encoding = tokenizer(
        full_context, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    full_ids = full_encoding['input_ids'].to(config.device)
    doc_len = full_ids.shape[1] - prefix_len
    
    doc_token_ids = full_ids[:, prefix_len:]
    bos_id = full_ids[:, :1]
    bare_ids = torch.cat([bos_id, doc_token_ids], dim=1)
    bare_len = bare_ids.shape[1]
    
    with torch.no_grad():
        bare_out = model(
            input_ids=bare_ids,
            attention_mask=torch.ones_like(bare_ids),
            use_cache=True,
            return_dict=True
        )
    bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
    
    with torch.no_grad():
        full_out = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )
    
    truncated = extract_and_truncate_cache_with_bos(
        full_out.past_key_values, doc_len
    )
    keep_len = 1 + doc_len
    
    assert bare_len == keep_len, f"Length mismatch: {bare_len} vs {keep_len}"
    
    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated, surrogate_offset, model)
    
    return bare_len, bare_cache, keep_len, truncated, prefix_len


def find_surrogate_at_similarity(
    target_query: str,
    target_embedding: np.ndarray,
    sim_low: float,
    sim_high: float,
    pool_queries: list,
    pool_embs: np.ndarray,
    rng: np.random.RandomState,
) -> Optional[Tuple[str, float]]:
    """Find a real query from the pool within the specified similarity range."""
    sims = cosine_similarity([target_embedding], pool_embs)[0]
    mask = (sims >= sim_low) & (sims < sim_high)
    for idx in np.where(mask)[0]:
        if pool_queries[idx].strip().lower() == target_query.strip().lower():
            mask[idx] = False
    
    candidates = np.where(mask)[0]
    if len(candidates) == 0:
        return None
    
    mid = (sim_low + sim_high) / 2
    distances_to_mid = np.abs(sims[candidates] - mid)
    n_pick = max(1, len(candidates) // 10)
    best_idxs = candidates[np.argsort(distances_to_mid)[:n_pick]]
    chosen = rng.choice(best_idxs)
    return pool_queries[chosen], float(sims[chosen])


def shuffle_query_words(query: str, rng: np.random.RandomState) -> str:
    """Shuffle the words of a query, preserving word set but destroying order."""
    words = query.split()
    rng.shuffle(words)
    return ' '.join(words)


def find_length_matched_random(
    target_token_len: int,
    target_query: str,
    target_embedding: np.ndarray,
    pool_queries: list,
    pool_embs: np.ndarray,
    tokenizer,
    rng: np.random.RandomState,
    max_sim: float = 0.15,
) -> Optional[Tuple[str, float]]:
    """Find a random query with similar token count but low semantic similarity."""
    sims = cosine_similarity([target_embedding], pool_embs)[0]
    low_sim_mask = sims < max_sim
    
    candidates = np.where(low_sim_mask)[0]
    if len(candidates) == 0:
        return None
    
    # Filter by token length (within ±2 tokens)
    good = []
    # Sample up to 500 candidates to check token length
    check_idxs = rng.choice(candidates, size=min(500, len(candidates)), replace=False)
    for ci in check_idxs:
        q = pool_queries[ci]
        tlen = len(tokenizer.encode(q, add_special_tokens=False))
        if abs(tlen - target_token_len) <= 2:
            good.append((ci, float(sims[ci])))
    
    if not good:
        # Relax to ±5 tokens
        for ci in check_idxs:
            q = pool_queries[ci]
            tlen = len(tokenizer.encode(q, add_special_tokens=False))
            if abs(tlen - target_token_len) <= 5:
                good.append((ci, float(sims[ci])))
    
    if not good:
        return None
    
    chosen_idx, chosen_sim = good[rng.randint(0, len(good))]
    return pool_queries[chosen_idx], chosen_sim


def bootstrap_ci(data, stat_fn=np.mean, n_boot=10000, ci=0.95, seed=42):
    """Compute bootstrap confidence interval."""
    rng = np.random.RandomState(seed)
    n = len(data)
    boot_stats = np.empty(n_boot)
    for i in range(n_boot):
        sample = rng.choice(data, size=n, replace=True)
        boot_stats[i] = stat_fn(sample)
    alpha = (1 - ci) / 2
    lo = np.percentile(boot_stats, 100 * alpha)
    hi = np.percentile(boot_stats, 100 * (1 - alpha))
    return float(lo), float(hi), boot_stats


def bootstrap_corr_ci(x, y, n_boot=10000, ci=0.95, seed=42):
    """Bootstrap CI for Pearson correlation."""
    rng = np.random.RandomState(seed)
    n = len(x)
    boot_rs = np.empty(n_boot)
    for i in range(n_boot):
        idx = rng.choice(n, size=n, replace=True)
        r, _ = stats.pearsonr(x[idx], y[idx])
        boot_rs[i] = r
    alpha = (1 - ci) / 2
    lo = np.percentile(boot_rs, 100 * alpha)
    hi = np.percentile(boot_rs, 100 * (1 - alpha))
    return float(lo), float(hi), boot_rs


def score_with_deepcopy(cache, cache_len, query_prompt, answer, model, tokenizer, config):
    """Score with a deep-copied cache to prevent mutation."""
    cache_copy = copy.deepcopy(cache)
    return score_answer_with_cache(cache_copy, cache_len, query_prompt, answer, model, tokenizer, config)


print("Helper functions defined.")

In [None]:
# ============================================================
# Investigation A: Surrogate Pre-Selection (N=2500)
# ============================================================

print("="*80)
print("INVESTIGATION A: SURROGATE PRE-SELECTION")
print("="*80)

samples_a = all_samples[:N_INV_A]
print(f"Using {len(samples_a)} samples for Investigation A")

# Embed target queries
print("Embedding target queries...")
target_queries_a = [s['query'] for s in samples_a]
target_embeddings_a = embed_model.encode(target_queries_a, show_progress_bar=True, batch_size=256)

rng_a = np.random.RandomState(SEEDS['msmarco'])

# For each sample, find surrogates at each similarity level
# Plus: shuffled version of sim_0.60, length-matched random, raw query for sim_0.60
sample_surrogates_a = []
skipped_bins_a = {b[2]: 0 for b in SIM_BINS_A}
skipped_controls = {'shuffled': 0, 'length_matched': 0}

for i in tqdm(range(len(samples_a)), desc="Selecting surrogates"):
    surr = {}
    
    # Standard similarity bins
    for sim_low, sim_high, bin_name in SIM_BINS_A:
        result = find_surrogate_at_similarity(
            target_queries_a[i], target_embeddings_a[i],
            sim_low, sim_high,
            query_pool, pool_embeddings, rng_a
        )
        if result is not None:
            surr[bin_name] = result
        else:
            skipped_bins_a[bin_name] += 1
    
    # Controls based on sim_0.60 surrogate
    if 'sim_0.60' in surr:
        surr_query_060 = surr['sim_0.60'][0]
        surr_sim_060 = surr['sim_0.60'][1]
        
        # Shuffled control
        shuffled = shuffle_query_words(surr_query_060, rng_a)
        surr['shuffled_0.60'] = (shuffled, surr_sim_060)
        
        # Length-matched random control
        tlen = len(tokenizer.encode(surr_query_060, add_special_tokens=False))
        lm_result = find_length_matched_random(
            tlen, target_queries_a[i], target_embeddings_a[i],
            query_pool, pool_embeddings, tokenizer, rng_a
        )
        if lm_result is not None:
            surr['length_matched_random'] = lm_result
        else:
            skipped_controls['length_matched'] += 1
        
        # Raw query (no template) - store the query, we'll use it differently
        surr['raw_query_0.60'] = (surr_query_060, surr_sim_060)
    else:
        skipped_controls['shuffled'] += 1
        skipped_controls['length_matched'] += 1
    
    sample_surrogates_a.append(surr)

print("\nSurrogate coverage per bin:")
for sim_low, sim_high, bin_name in SIM_BINS_A:
    n_found = N_INV_A - skipped_bins_a[bin_name]
    print(f"  {bin_name} ({sim_low:.2f}-{sim_high:.2f}): {n_found}/{N_INV_A} ({100*n_found/N_INV_A:.1f}%)")
print(f"  shuffled_0.60: {N_INV_A - skipped_controls['shuffled']}/{N_INV_A}")
print(f"  length_matched_random: {N_INV_A - skipped_controls['length_matched']}/{N_INV_A}")

In [None]:
# ============================================================
# Investigation A: Evaluation Loop with Checkpointing
# ============================================================

results_a = []
skipped_a = 0
errors_a = 0
start_a = time.time()

CHECKPOINT_PATH_A = 'results/exp12/12_checkpoint_a.json'

# Delete stale checkpoint
if os.path.exists(CHECKPOINT_PATH_A):
    # Check if it's from a previous run (stale)
    with open(CHECKPOINT_PATH_A) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') != '12_inv_a':
        os.remove(CHECKPOINT_PATH_A)
        print("Deleted stale checkpoint")
    else:
        results_a = ckpt['results']
        skipped_a = ckpt['skipped']
        errors_a = ckpt['errors']
        print(f"Resumed from checkpoint: {len(results_a)} results")

start_idx_a = len(results_a) + skipped_a + errors_a

for idx in tqdm(range(start_idx_a, len(samples_a)), desc="Inv A",
                initial=start_idx_a, total=len(samples_a)):
    sample = samples_a[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_a += 1
        continue
    
    query_prompt = config.query_template.format(query=query)
    surrogates = sample_surrogates_a[idx]
    
    try:
        result = {'idx': idx, 'query': query}
        
        # --- Oracle (also provides BPE-matched bare baseline) ---
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache, oracle_ptl = \
            build_matched_bare_and_truncated(oracle_prefix, passage, model, tokenizer, config)
        
        nll_bare = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config
        )
        nll_oracle = score_answer_with_cache(
            oracle_cache, oracle_len, query_prompt, answer, model, tokenizer, config
        )
        result['nll_bare'] = nll_bare
        result['nll_oracle'] = nll_oracle
        result['prefix_token_len_oracle'] = oracle_ptl
        
        # --- Similarity bins ---
        for sim_low, sim_high, bin_name in SIM_BINS_A:
            if bin_name in surrogates:
                surr_query, surr_sim = surrogates[bin_name]
                surr_prefix = f"This document answers: {surr_query}"
                _, _, surr_len, surr_cache, surr_ptl = \
                    build_matched_bare_and_truncated(surr_prefix, passage, model, tokenizer, config)
                nll_surr = score_answer_with_cache(
                    surr_cache, surr_len, query_prompt, answer, model, tokenizer, config
                )
                result[f'nll_{bin_name}'] = nll_surr
                result[f'sim_{bin_name}'] = surr_sim
                result[f'ptl_{bin_name}'] = surr_ptl
            else:
                result[f'nll_{bin_name}'] = None
                result[f'sim_{bin_name}'] = None
                result[f'ptl_{bin_name}'] = None
        
        # --- Shuffled control (sim_0.60 words shuffled) ---
        if 'shuffled_0.60' in surrogates:
            shuf_query = surrogates['shuffled_0.60'][0]
            shuf_prefix = f"This document answers: {shuf_query}"
            _, _, shuf_len, shuf_cache, shuf_ptl = \
                build_matched_bare_and_truncated(shuf_prefix, passage, model, tokenizer, config)
            nll_shuf = score_answer_with_cache(
                shuf_cache, shuf_len, query_prompt, answer, model, tokenizer, config
            )
            result['nll_shuffled_0.60'] = nll_shuf
            result['ptl_shuffled_0.60'] = shuf_ptl
        else:
            result['nll_shuffled_0.60'] = None
        
        # --- Length-matched random control ---
        if 'length_matched_random' in surrogates:
            lm_query, lm_sim = surrogates['length_matched_random']
            lm_prefix = f"This document answers: {lm_query}"
            _, _, lm_len, lm_cache, lm_ptl = \
                build_matched_bare_and_truncated(lm_prefix, passage, model, tokenizer, config)
            nll_lm = score_answer_with_cache(
                lm_cache, lm_len, query_prompt, answer, model, tokenizer, config
            )
            result['nll_length_matched_random'] = nll_lm
            result['sim_length_matched_random'] = lm_sim
            result['ptl_length_matched_random'] = lm_ptl
        else:
            result['nll_length_matched_random'] = None
        
        # --- Raw query (no template framing) ---
        if 'raw_query_0.60' in surrogates:
            raw_query = surrogates['raw_query_0.60'][0]
            # Use raw query text directly as prefix, no template
            _, _, raw_len, raw_cache, raw_ptl = \
                build_matched_bare_and_truncated(raw_query, passage, model, tokenizer, config)
            nll_raw = score_answer_with_cache(
                raw_cache, raw_len, query_prompt, answer, model, tokenizer, config
            )
            result['nll_raw_query_0.60'] = nll_raw
            result['ptl_raw_query_0.60'] = raw_ptl
        else:
            result['nll_raw_query_0.60'] = None
        
        results_a.append(result)
        
    except Exception as e:
        errors_a += 1
        if errors_a <= 5:
            print(f"\n  Error on sample {idx}: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    # Checkpoint every 25 samples
    if len(results_a) % 25 == 0:
        with open(CHECKPOINT_PATH_A, 'w') as f:
            json.dump({
                'experiment': '12_inv_a',
                'results': results_a, 'skipped': skipped_a,
                'errors': errors_a,
            }, f)
        elapsed = time.time() - start_a
        rate = len(results_a) / (elapsed / 60) if elapsed > 0 else 0
        print(f"\n  [{len(results_a)} done | {elapsed/60:.0f}m | {rate:.1f} samples/min]")

elapsed_a = time.time() - start_a
print(f"\nDone. {len(results_a)} evaluated, {skipped_a} skipped, {errors_a} errors.")
print(f"Time: {elapsed_a/60:.1f} min ({elapsed_a/3600:.1f} hr)")

In [None]:
# ============================================================
# Investigation A: Analysis
# ============================================================

print("="*80)
print("INVESTIGATION A RESULTS: SCALED QUALITY GRADIENT WITH CONTROLS")
print("="*80)

bare_nlls_a = np.array([r['nll_bare'] for r in results_a])
oracle_nlls_a = np.array([r['nll_oracle'] for r in results_a])
oracle_deltas = bare_nlls_a - oracle_nlls_a
oracle_wr = np.mean(oracle_deltas > 0) * 100
t_oracle, p_oracle = stats.ttest_rel(bare_nlls_a, oracle_nlls_a)
d_oracle = np.mean(oracle_deltas) / np.std(oracle_deltas, ddof=1)

print(f"\n{'Condition':<30} {'N':>5} {'Mean NLL':>10} {'Win%':>8} {'Delta':>10} {'Cohen d':>10} {'p-value':>12}")
print("-" * 90)
print(f"{'Bare (baseline)':<30} {len(results_a):>5} {np.mean(bare_nlls_a):>10.4f} {'--':>8} {'--':>10} {'--':>10} {'--':>12}")
print(f"{'Oracle':<30} {len(results_a):>5} {np.mean(oracle_nlls_a):>10.4f} {oracle_wr:>7.1f}% {np.mean(oracle_deltas):>+10.4f} {d_oracle:>10.3f} {p_oracle:>12.2e}")

# Per-bin stats
bin_stats_a = {}
for sim_low, sim_high, bin_name in SIM_BINS_A:
    valid = [r for r in results_a if r.get(f'nll_{bin_name}') is not None]
    if len(valid) < 10:
        print(f"{bin_name:<30} {len(valid):>5} -- insufficient data")
        continue
    nlls = np.array([r[f'nll_{bin_name}'] for r in valid])
    bares = np.array([r['nll_bare'] for r in valid])
    sims = np.array([r[f'sim_{bin_name}'] for r in valid])
    deltas = bares - nlls
    wr = np.mean(deltas > 0) * 100
    t, p = stats.ttest_rel(bares, nlls)
    d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
    bin_stats_a[bin_name] = {
        'n': len(valid), 'mean_nll': float(np.mean(nlls)),
        'mean_sim': float(np.mean(sims)), 'win_rate': float(wr),
        'mean_delta': float(np.mean(deltas)), 'cohens_d': float(d),
        'p_value': float(p),
    }
    label = f"{bin_name} (sim~{np.mean(sims):.2f})"
    print(f"{label:<30} {len(valid):>5} {np.mean(nlls):>10.4f} {wr:>7.1f}% {np.mean(deltas):>+10.4f} {d:>10.3f} {p:>12.2e}")

# --- Controls ---
print("\n--- Controls ---")

# Shuffled vs real sim_0.60
valid_shuf = [r for r in results_a if r.get('nll_shuffled_0.60') is not None and r.get('nll_sim_0.60') is not None]
if len(valid_shuf) >= 10:
    real_060 = np.array([r['nll_sim_0.60'] for r in valid_shuf])
    shuf_060 = np.array([r['nll_shuffled_0.60'] for r in valid_shuf])
    bares_shuf = np.array([r['nll_bare'] for r in valid_shuf])
    delta_real = bares_shuf - real_060
    delta_shuf = bares_shuf - shuf_060
    t_rs, p_rs = stats.ttest_rel(real_060, shuf_060)
    wr_real_vs_shuf = np.mean(real_060 < shuf_060) * 100
    print(f"  sim_0.60 vs shuffled_0.60: N={len(valid_shuf)}, real wins {wr_real_vs_shuf:.1f}%, t={t_rs:.3f}, p={p_rs:.6f}")
    print(f"    Real delta={np.mean(delta_real):.4f}, Shuffled delta={np.mean(delta_shuf):.4f}")

# Length-matched random vs real sim_0.60
valid_lm = [r for r in results_a if r.get('nll_length_matched_random') is not None and r.get('nll_sim_0.60') is not None]
if len(valid_lm) >= 10:
    real_060_lm = np.array([r['nll_sim_0.60'] for r in valid_lm])
    lm_nlls = np.array([r['nll_length_matched_random'] for r in valid_lm])
    bares_lm = np.array([r['nll_bare'] for r in valid_lm])
    t_rl, p_rl = stats.ttest_rel(real_060_lm, lm_nlls)
    wr_real_vs_lm = np.mean(real_060_lm < lm_nlls) * 100
    print(f"  sim_0.60 vs length_matched_random: N={len(valid_lm)}, real wins {wr_real_vs_lm:.1f}%, t={t_rl:.3f}, p={p_rl:.6f}")

# Raw query vs template sim_0.60
valid_raw = [r for r in results_a if r.get('nll_raw_query_0.60') is not None and r.get('nll_sim_0.60') is not None]
if len(valid_raw) >= 10:
    tmpl_060 = np.array([r['nll_sim_0.60'] for r in valid_raw])
    raw_060 = np.array([r['nll_raw_query_0.60'] for r in valid_raw])
    t_tr, p_tr = stats.ttest_rel(tmpl_060, raw_060)
    wr_tmpl_vs_raw = np.mean(tmpl_060 < raw_060) * 100
    print(f"  template vs raw (sim_0.60): N={len(valid_raw)}, template wins {wr_tmpl_vs_raw:.1f}%, t={t_tr:.3f}, p={p_tr:.6f}")

# --- Critical: Pearson r with bootstrap CI ---
print("\n--- Critical Test: Similarity-Delta Correlation ---")
all_sims_a = []
all_deltas_aa = []
for r in results_a:
    for _, _, bin_name in SIM_BINS_A:
        if r.get(f'nll_{bin_name}') is not None and r.get(f'sim_{bin_name}') is not None:
            all_sims_a.append(r[f'sim_{bin_name}'])
            all_deltas_aa.append(r['nll_bare'] - r[f'nll_{bin_name}'])

all_sims_a = np.array(all_sims_a)
all_deltas_aa = np.array(all_deltas_aa)

r_pearson_a, p_pearson_a = stats.pearsonr(all_sims_a, all_deltas_aa)
r_spearman_a, p_spearman_a = stats.spearmanr(all_sims_a, all_deltas_aa)

print(f"  Total (sim, delta) pairs: {len(all_sims_a)}")
print(f"  Pearson r = {r_pearson_a:.4f}, p = {p_pearson_a:.2e}")
print(f"  Spearman rho = {r_spearman_a:.4f}, p = {p_spearman_a:.2e}")

# Bootstrap CI for Pearson r
r_ci_lo, r_ci_hi, boot_rs = bootstrap_corr_ci(all_sims_a, all_deltas_aa, n_boot=10000)
print(f"  Bootstrap 95% CI for Pearson r: [{r_ci_lo:.4f}, {r_ci_hi:.4f}]")
ci_excludes_zero = r_ci_lo > 0 or r_ci_hi < 0
print(f"  CI excludes 0: {ci_excludes_zero}")

# Bin-level Spearman
if len(bin_stats_a) >= 3:
    bin_sims_ord = [bin_stats_a[b]['mean_sim'] for b in [bn for _, _, bn in SIM_BINS_A] if b in bin_stats_a]
    bin_deltas_ord = [bin_stats_a[b]['mean_delta'] for b in [bn for _, _, bn in SIM_BINS_A] if b in bin_stats_a]
    rho_bins_a, p_bins_a = stats.spearmanr(bin_sims_ord, bin_deltas_ord)
    print(f"\n  Bin-level Spearman rho = {rho_bins_a:.3f}, p = {p_bins_a:.4f}")
else:
    rho_bins_a = 0.0

# --- Partial correlation: delta ~ similarity | prefix_token_length ---
print("\n--- Confound Control: Partial Correlation ---")
all_ptls = []
all_sims_pc = []
all_deltas_pc = []
for r in results_a:
    for _, _, bin_name in SIM_BINS_A:
        if (r.get(f'nll_{bin_name}') is not None and 
            r.get(f'sim_{bin_name}') is not None and
            r.get(f'ptl_{bin_name}') is not None):
            all_sims_pc.append(r[f'sim_{bin_name}'])
            all_deltas_pc.append(r['nll_bare'] - r[f'nll_{bin_name}'])
            all_ptls.append(r[f'ptl_{bin_name}'])

all_sims_pc = np.array(all_sims_pc)
all_deltas_pc = np.array(all_deltas_pc)
all_ptls = np.array(all_ptls, dtype=float)

# Partial correlation: regress out prefix_token_length from both
from numpy.linalg import lstsq
X = np.column_stack([all_ptls, np.ones(len(all_ptls))])
res_sim = all_sims_pc - X @ lstsq(X, all_sims_pc, rcond=None)[0]
res_delta = all_deltas_pc - X @ lstsq(X, all_deltas_pc, rcond=None)[0]
r_partial, p_partial = stats.pearsonr(res_sim, res_delta)
print(f"  Partial r (sim ~ delta | prefix_len) = {r_partial:.4f}, p = {p_partial:.2e}")

# --- High-beats-oracle check ---
print("\n--- High-Beats-Oracle Distribution Check ---")
has_high = [r for r in results_a if r.get('nll_sim_0.85') is not None]
no_high = [r for r in results_a if r.get('nll_sim_0.85') is None]
if len(has_high) >= 10 and len(no_high) >= 10:
    bare_has = np.array([r['nll_bare'] for r in has_high])
    bare_no = np.array([r['nll_bare'] for r in no_high])
    t_hb, p_hb = stats.ttest_ind(bare_has, bare_no)
    print(f"  Bare NLL (has sim_0.85 match): mean={np.mean(bare_has):.4f}, N={len(has_high)}")
    print(f"  Bare NLL (no sim_0.85 match):  mean={np.mean(bare_no):.4f}, N={len(no_high)}")
    print(f"  t={t_hb:.3f}, p={p_hb:.4f}")
    if p_hb < 0.05:
        print("  WARNING: Samples with high-sim matches have different bare NLL — selection bias!")
    else:
        print("  OK: No significant difference — no selection bias detected.")

# --- Visualization ---
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot 1: Win rates by condition
ax = axes[0, 0]
conds = ['bare']
wrs = [50.0]
for _, _, bn in SIM_BINS_A:
    if bn in bin_stats_a:
        conds.append(f"{bn}\n({bin_stats_a[bn]['mean_sim']:.2f})")
        wrs.append(bin_stats_a[bn]['win_rate'])
conds.append('oracle')
wrs.append(oracle_wr)
colors = ['#888888'] + ['#4c72b0'] * (len(conds)-2) + ['#c44e52']
ax.bar(range(len(conds)), wrs, color=colors)
ax.axhline(50, color='gray', linestyle='--', linewidth=0.8)
ax.set_xticks(range(len(conds)))
ax.set_xticklabels(conds, fontsize=6, rotation=45, ha='right')
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('Win Rate by Surrogate Quality')

# Plot 2: Scatter sim vs delta
ax = axes[0, 1]
ax.scatter(all_sims_a, all_deltas_aa, alpha=0.05, s=3, color='#4c72b0')
for _, _, bn in SIM_BINS_A:
    if bn in bin_stats_a:
        ax.scatter(bin_stats_a[bn]['mean_sim'], bin_stats_a[bn]['mean_delta'],
                  s=100, color='#c44e52', zorder=5, edgecolor='black')
ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)
z = np.polyfit(all_sims_a, all_deltas_aa, 1)
x_line = np.linspace(all_sims_a.min(), all_sims_a.max(), 100)
ax.plot(x_line, np.poly1d(z)(x_line), 'r-', linewidth=2, label=f'r={r_pearson_a:.4f}')
ax.set_xlabel('Surrogate-Query Similarity')
ax.set_ylabel('Delta NLL')
ax.set_title(f'Similarity vs Improvement (r={r_pearson_a:.4f})')
ax.legend()

# Plot 3: NLL gradient
ax = axes[0, 2]
cl = ['bare']
cn = [np.mean(bare_nlls_a)]
ce = [np.std(bare_nlls_a)/np.sqrt(len(bare_nlls_a))]
for _, _, bn in SIM_BINS_A:
    if bn in bin_stats_a:
        cl.append(bn)
        cn.append(bin_stats_a[bn]['mean_nll'])
        valid = [r for r in results_a if r.get(f'nll_{bn}') is not None]
        ce.append(np.std([r[f'nll_{bn}'] for r in valid])/np.sqrt(len(valid)))
cl.append('oracle')
cn.append(np.mean(oracle_nlls_a))
ce.append(np.std(oracle_nlls_a)/np.sqrt(len(oracle_nlls_a)))
ax.errorbar(range(len(cl)), cn, yerr=ce, fmt='o-', color='#4c72b0', capsize=3)
ax.set_xticks(range(len(cl)))
ax.set_xticklabels(cl, fontsize=6, rotation=45, ha='right')
ax.set_ylabel('Mean NLL')
ax.set_title('NLL Gradient: Bare to Oracle')

# Plot 4: Controls comparison
ax = axes[1, 0]
ctrl_labels = []
ctrl_deltas_mean = []
ctrl_deltas_err = []
if len(valid_shuf) >= 10:
    ctrl_labels.extend(['sim_0.60\n(real)', 'shuffled\n_0.60'])
    ctrl_deltas_mean.extend([np.mean(delta_real), np.mean(delta_shuf)])
    ctrl_deltas_err.extend([np.std(delta_real)/np.sqrt(len(delta_real)),
                            np.std(delta_shuf)/np.sqrt(len(delta_shuf))])
if len(valid_lm) >= 10:
    bares_lm2 = np.array([r['nll_bare'] for r in valid_lm])
    ctrl_labels.append('len_match\nrandom')
    ctrl_deltas_mean.append(float(np.mean(bares_lm2 - lm_nlls)))
    ctrl_deltas_err.append(float(np.std(bares_lm2 - lm_nlls)/np.sqrt(len(lm_nlls))))
if len(valid_raw) >= 10:
    bares_raw = np.array([r['nll_bare'] for r in valid_raw])
    ctrl_labels.append('raw_query\n_0.60')
    ctrl_deltas_mean.append(float(np.mean(bares_raw - raw_060)))
    ctrl_deltas_err.append(float(np.std(bares_raw - raw_060)/np.sqrt(len(raw_060))))
if ctrl_labels:
    ax.bar(range(len(ctrl_labels)), ctrl_deltas_mean, yerr=ctrl_deltas_err,
           color=['#4c72b0', '#e377c2', '#8c564b', '#bcbd22'][:len(ctrl_labels)], capsize=3)
    ax.set_xticks(range(len(ctrl_labels)))
    ax.set_xticklabels(ctrl_labels, fontsize=7)
    ax.axhline(0, color='gray', linestyle='--')
    ax.set_ylabel('Mean Delta NLL')
    ax.set_title('Control Comparisons')

# Plot 5: Bootstrap CI for Pearson r
ax = axes[1, 1]
ax.hist(boot_rs, bins=50, color='#4c72b0', alpha=0.7, edgecolor='black', linewidth=0.3)
ax.axvline(r_pearson_a, color='red', linewidth=2, label=f'r={r_pearson_a:.4f}')
ax.axvline(r_ci_lo, color='orange', linestyle='--', label=f'95% CI: [{r_ci_lo:.4f}, {r_ci_hi:.4f}]')
ax.axvline(r_ci_hi, color='orange', linestyle='--')
ax.axvline(0, color='gray', linestyle=':', linewidth=1)
ax.set_xlabel('Pearson r')
ax.set_ylabel('Count')
ax.set_title('Bootstrap Distribution of Pearson r')
ax.legend(fontsize=7)

# Plot 6: Prefix length vs delta (confound check)
ax = axes[1, 2]
ax.scatter(all_ptls, all_deltas_pc, alpha=0.03, s=3, color='#4c72b0')
ax.set_xlabel('Prefix Token Length')
ax.set_ylabel('Delta NLL')
ax.set_title(f'Prefix Length Confound Check\n(partial r={r_partial:.4f})')
ax.axhline(0, color='gray', linestyle='--')

plt.tight_layout()
plt.savefig('results/exp12/12_investigation_a.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 12_investigation_a.png')

In [None]:
# ============================================================
# Investigation B: SQuAD v2 — Data Prep + Surrogate Selection
# ============================================================

print("="*80)
print("INVESTIGATION B: CROSS-DATASET REPLICATION")
print("="*80)

# --- SQuAD v2 ---
print(f"\nSQuAD v2: {len(squad_samples)} samples")

print("Embedding SQuAD target queries...")
squad_target_qs = [s['query'] for s in squad_samples]
squad_target_embs = embed_model.encode(squad_target_qs, show_progress_bar=True, batch_size=256)

rng_sq = np.random.RandomState(SEEDS['squad'])
squad_surrogates = []
sq_skip = {b[2]: 0 for b in SIM_BINS_B}

for i in tqdm(range(len(squad_samples)), desc="SQuAD surrogates"):
    surr = {}
    for sim_low, sim_high, bin_name in SIM_BINS_B:
        result = find_surrogate_at_similarity(
            squad_target_qs[i], squad_target_embs[i],
            sim_low, sim_high,
            squad_query_pool, squad_pool_embs, rng_sq
        )
        if result is not None:
            surr[bin_name] = result
        else:
            sq_skip[bin_name] += 1
    squad_surrogates.append(surr)

print("SQuAD surrogate coverage:")
for _, _, bn in SIM_BINS_B:
    n_found = len(squad_samples) - sq_skip[bn]
    print(f"  {bn}: {n_found}/{len(squad_samples)}")

In [None]:
# ============================================================
# Investigation B: SQuAD v2 Evaluation Loop
# ============================================================

results_squad = []
skipped_sq = 0
errors_sq = 0
start_sq = time.time()

CHECKPOINT_PATH_SQ = 'results/exp12/12_checkpoint_squad.json'
if os.path.exists(CHECKPOINT_PATH_SQ):
    with open(CHECKPOINT_PATH_SQ) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '12_squad':
        results_squad = ckpt['results']
        skipped_sq = ckpt['skipped']
        errors_sq = ckpt['errors']
        print(f"Resumed: {len(results_squad)} results")

start_idx_sq = len(results_squad) + skipped_sq + errors_sq

for idx in tqdm(range(start_idx_sq, len(squad_samples)), desc="SQuAD",
                initial=start_idx_sq, total=len(squad_samples)):
    sample = squad_samples[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_sq += 1
        continue
    
    query_prompt = config.query_template.format(query=query)
    surrogates = squad_surrogates[idx]
    
    try:
        result = {'idx': idx, 'query': query}
        
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache, _ = \
            build_matched_bare_and_truncated(oracle_prefix, passage, model, tokenizer, config)
        
        result['nll_bare'] = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_oracle'] = score_answer_with_cache(
            oracle_cache, oracle_len, query_prompt, answer, model, tokenizer, config)
        
        for sim_low, sim_high, bin_name in SIM_BINS_B:
            if bin_name in surrogates:
                sq, ss = surrogates[bin_name]
                sp = f"This document answers: {sq}"
                _, _, sl, sc, _ = build_matched_bare_and_truncated(sp, passage, model, tokenizer, config)
                result[f'nll_{bin_name}'] = score_answer_with_cache(
                    sc, sl, query_prompt, answer, model, tokenizer, config)
                result[f'sim_{bin_name}'] = ss
            else:
                result[f'nll_{bin_name}'] = None
                result[f'sim_{bin_name}'] = None
        
        results_squad.append(result)
    except Exception as e:
        errors_sq += 1
        if errors_sq <= 3:
            print(f"\n  Error: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    if len(results_squad) % 25 == 0:
        with open(CHECKPOINT_PATH_SQ, 'w') as f:
            json.dump({'experiment': '12_squad', 'results': results_squad,
                       'skipped': skipped_sq, 'errors': errors_sq}, f)
        elapsed = time.time() - start_sq
        print(f"\n  [{len(results_squad)} done | {elapsed/60:.0f}m]")

print(f"\nSQuAD done. {len(results_squad)} evaluated, {skipped_sq} skipped, {errors_sq} errors.")
print(f"Time: {(time.time()-start_sq)/60:.1f} min")

In [None]:
# ============================================================
# Investigation B: TriviaQA — Data Prep + Evaluation Loop
# ============================================================

print(f"\nTriviaQA: {len(tqa_samples)} samples")

print("Embedding TriviaQA target queries...")
tqa_target_qs = [s['query'] for s in tqa_samples]
tqa_target_embs = embed_model.encode(tqa_target_qs, show_progress_bar=True, batch_size=256)

rng_tq = np.random.RandomState(SEEDS['triviaqa'])
tqa_surrogates = []
tq_skip = {b[2]: 0 for b in SIM_BINS_B}

for i in tqdm(range(len(tqa_samples)), desc="TQA surrogates"):
    surr = {}
    for sim_low, sim_high, bin_name in SIM_BINS_B:
        result = find_surrogate_at_similarity(
            tqa_target_qs[i], tqa_target_embs[i],
            sim_low, sim_high,
            tqa_query_pool, tqa_pool_embs, rng_tq
        )
        if result is not None:
            surr[bin_name] = result
        else:
            tq_skip[bin_name] += 1
    tqa_surrogates.append(surr)

print("TriviaQA surrogate coverage:")
for _, _, bn in SIM_BINS_B:
    n_found = len(tqa_samples) - tq_skip[bn]
    print(f"  {bn}: {n_found}/{len(tqa_samples)}")

# --- TriviaQA Evaluation ---
results_tqa = []
skipped_tq = 0
errors_tq = 0
start_tq = time.time()

CHECKPOINT_PATH_TQ = 'results/exp12/12_checkpoint_tqa.json'
if os.path.exists(CHECKPOINT_PATH_TQ):
    with open(CHECKPOINT_PATH_TQ) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '12_tqa':
        results_tqa = ckpt['results']
        skipped_tq = ckpt['skipped']
        errors_tq = ckpt['errors']
        print(f"Resumed: {len(results_tqa)} results")

start_idx_tq = len(results_tqa) + skipped_tq + errors_tq

for idx in tqdm(range(start_idx_tq, len(tqa_samples)), desc="TriviaQA",
                initial=start_idx_tq, total=len(tqa_samples)):
    sample = tqa_samples[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_tq += 1
        continue
    
    query_prompt = config.query_template.format(query=query)
    surrogates = tqa_surrogates[idx]
    
    try:
        result = {'idx': idx, 'query': query}
        
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache, _ = \
            build_matched_bare_and_truncated(oracle_prefix, passage, model, tokenizer, config)
        
        result['nll_bare'] = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_oracle'] = score_answer_with_cache(
            oracle_cache, oracle_len, query_prompt, answer, model, tokenizer, config)
        
        for sim_low, sim_high, bin_name in SIM_BINS_B:
            if bin_name in surrogates:
                sq, ss = surrogates[bin_name]
                sp = f"This document answers: {sq}"
                _, _, sl, sc, _ = build_matched_bare_and_truncated(sp, passage, model, tokenizer, config)
                result[f'nll_{bin_name}'] = score_answer_with_cache(
                    sc, sl, query_prompt, answer, model, tokenizer, config)
                result[f'sim_{bin_name}'] = ss
            else:
                result[f'nll_{bin_name}'] = None
                result[f'sim_{bin_name}'] = None
        
        results_tqa.append(result)
    except Exception as e:
        errors_tq += 1
        if errors_tq <= 3:
            print(f"\n  Error: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    if len(results_tqa) % 25 == 0:
        with open(CHECKPOINT_PATH_TQ, 'w') as f:
            json.dump({'experiment': '12_tqa', 'results': results_tqa,
                       'skipped': skipped_tq, 'errors': errors_tq}, f)
        elapsed = time.time() - start_tq
        print(f"\n  [{len(results_tqa)} done | {elapsed/60:.0f}m]")

print(f"\nTriviaQA done. {len(results_tqa)} evaluated, {skipped_tq} skipped, {errors_tq} errors.")
print(f"Time: {(time.time()-start_tq)/60:.1f} min")

In [None]:
# ============================================================
# Investigation B: Cross-Dataset Analysis
# ============================================================

print("="*80)
print("INVESTIGATION B: CROSS-DATASET ANALYSIS")
print("="*80)

def analyze_dataset(results, dataset_name, bins):
    """Analyze a single dataset's results."""
    print(f"\n--- {dataset_name} (N={len(results)}) ---")
    bare = np.array([r['nll_bare'] for r in results])
    oracle = np.array([r['nll_oracle'] for r in results])
    od = bare - oracle
    owr = np.mean(od > 0) * 100
    print(f"  Oracle: win%={owr:.1f}%, delta={np.mean(od):+.4f}, d={np.mean(od)/np.std(od,ddof=1):.3f}")
    
    bstats = {}
    for sim_low, sim_high, bn in bins:
        valid = [r for r in results if r.get(f'nll_{bn}') is not None]
        if len(valid) < 10:
            continue
        nlls = np.array([r[f'nll_{bn}'] for r in valid])
        bares = np.array([r['nll_bare'] for r in valid])
        sims = np.array([r[f'sim_{bn}'] for r in valid])
        deltas = bares - nlls
        wr = np.mean(deltas > 0) * 100
        d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
        bstats[bn] = {'n': len(valid), 'mean_sim': float(np.mean(sims)),
                       'mean_delta': float(np.mean(deltas)), 'win_rate': float(wr), 'cohens_d': float(d)}
        print(f"  {bn} (sim~{np.mean(sims):.2f}): N={len(valid)}, win%={wr:.1f}%, d={d:.3f}")
    
    # Monotonicity
    if len(bstats) >= 2:
        bs = [bstats[bn]['mean_sim'] for bn in [b[2] for b in bins] if bn in bstats]
        bd = [bstats[bn]['mean_delta'] for bn in [b[2] for b in bins] if bn in bstats]
        rho, p = stats.spearmanr(bs, bd) if len(bs) >= 3 else (np.nan, np.nan)
        print(f"  Bin-level Spearman rho = {rho:.3f}, p = {p:.4f}" if not np.isnan(rho) else "  Too few bins for Spearman")
    
    return bstats

squad_stats = analyze_dataset(results_squad, "SQuAD v2", SIM_BINS_B)
tqa_stats = analyze_dataset(results_tqa, "TriviaQA", SIM_BINS_B)

# Cross-dataset comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax_idx, (results, dname, bstats) in enumerate([
    (results_a, "MS MARCO", bin_stats_a),
    (results_squad, "SQuAD v2", squad_stats),
    (results_tqa, "TriviaQA", tqa_stats),
]):
    ax = axes[ax_idx]
    sims_plot = [0.0]
    deltas_plot = [0.0]
    bins_used = SIM_BINS_A if dname == "MS MARCO" else SIM_BINS_B
    for _, _, bn in bins_used:
        if bn in bstats:
            sims_plot.append(bstats[bn]['mean_sim'])
            deltas_plot.append(bstats[bn]['mean_delta'])
    bare = np.array([r['nll_bare'] for r in results])
    oracle = np.array([r['nll_oracle'] for r in results])
    sims_plot.append(1.0)
    deltas_plot.append(float(np.mean(bare - oracle)))
    ax.plot(sims_plot, deltas_plot, 'o-', color='#4c72b0', linewidth=2, markersize=8)
    ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)
    ax.set_xlabel('Surrogate Similarity')
    ax.set_ylabel('Mean Delta NLL')
    ax.set_title(f'{dname} (N={len(results)})')

plt.tight_layout()
plt.savefig('results/exp12/12_investigation_b.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 12_investigation_b.png')

In [None]:
# ============================================================
# Investigation C: Prefix Format Ablation (N=400)
# ============================================================

print("="*80)
print("INVESTIGATION C: PREFIX FORMAT ABLATION")
print("="*80)

samples_c = all_samples[:N_INV_C]
print(f"Using {len(samples_c)} samples")

# Embed and find sim~0.60 surrogates
print("Embedding target queries for Inv C...")
target_qs_c = [s['query'] for s in samples_c]
target_embs_c = embed_model.encode(target_qs_c, batch_size=256)

rng_c = np.random.RandomState(SEEDS['msmarco'] + 100)
surrogates_c = []
for i in tqdm(range(len(samples_c)), desc="C surrogates"):
    result = find_surrogate_at_similarity(
        target_qs_c[i], target_embs_c[i],
        0.55, 0.65, query_pool, pool_embeddings, rng_c
    )
    surrogates_c.append(result)

n_found_c = sum(1 for s in surrogates_c if s is not None)
print(f"Found sim~0.60 surrogates: {n_found_c}/{len(samples_c)}")

# Evaluation loop
results_c = []
errors_c = 0
start_c = time.time()

CHECKPOINT_PATH_C = 'results/exp12/12_checkpoint_c.json'
if os.path.exists(CHECKPOINT_PATH_C):
    with open(CHECKPOINT_PATH_C) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '12_inv_c':
        results_c = ckpt['results']
        errors_c = ckpt['errors']
        print(f"Resumed: {len(results_c)} results")

start_idx_c = len(results_c) + errors_c

for idx in tqdm(range(start_idx_c, len(samples_c)), desc="Inv C",
                initial=start_idx_c, total=len(samples_c)):
    sample = samples_c[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        continue
    
    if surrogates_c[idx] is None:
        continue
    
    surr_query, surr_sim = surrogates_c[idx]
    query_prompt = config.query_template.format(query=query)
    
    try:
        result = {'idx': idx, 'surr_sim': surr_sim}
        
        # Bare + Oracle
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache, _ = \
            build_matched_bare_and_truncated(oracle_prefix, passage, model, tokenizer, config)
        result['nll_bare'] = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_oracle'] = score_answer_with_cache(
            oracle_cache, oracle_len, query_prompt, answer, model, tokenizer, config)
        
        # Template format
        prefix_tmpl = f"This document answers: {surr_query}"
        _, _, tl, tc, _ = build_matched_bare_and_truncated(prefix_tmpl, passage, model, tokenizer, config)
        result['nll_template'] = score_answer_with_cache(tc, tl, query_prompt, answer, model, tokenizer, config)
        
        # Raw format (just the query)
        _, _, rl, rc, _ = build_matched_bare_and_truncated(surr_query, passage, model, tokenizer, config)
        result['nll_raw'] = score_answer_with_cache(rc, rl, query_prompt, answer, model, tokenizer, config)
        
        # Question format
        prefix_q = f"Question: {surr_query}\n\nPassage:"
        _, _, ql, qc, _ = build_matched_bare_and_truncated(prefix_q, passage, model, tokenizer, config)
        result['nll_question'] = score_answer_with_cache(qc, ql, query_prompt, answer, model, tokenizer, config)
        
        # Instruction format
        prefix_i = f"Find information about: {surr_query}\n\n"
        _, _, il, ic, _ = build_matched_bare_and_truncated(prefix_i, passage, model, tokenizer, config)
        result['nll_instruction'] = score_answer_with_cache(ic, il, query_prompt, answer, model, tokenizer, config)
        
        # Shuffled template
        shuffled_q = shuffle_query_words(surr_query, rng_c)
        prefix_st = f"This document answers: {shuffled_q}"
        _, _, stl, stc, _ = build_matched_bare_and_truncated(prefix_st, passage, model, tokenizer, config)
        result['nll_shuffled_template'] = score_answer_with_cache(stc, stl, query_prompt, answer, model, tokenizer, config)
        
        results_c.append(result)
    except Exception as e:
        errors_c += 1
        if errors_c <= 3:
            print(f"\n  Error: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    if len(results_c) % 25 == 0:
        with open(CHECKPOINT_PATH_C, 'w') as f:
            json.dump({'experiment': '12_inv_c', 'results': results_c, 'errors': errors_c}, f)
        elapsed = time.time() - start_c
        print(f"\n  [{len(results_c)} done | {elapsed/60:.0f}m]")

print(f"\nInv C done. {len(results_c)} evaluated, {errors_c} errors.")
print(f"Time: {(time.time()-start_c)/60:.1f} min")

In [None]:
# ============================================================
# Investigation C: Analysis
# ============================================================

print("="*80)
print("INVESTIGATION C RESULTS: PREFIX FORMAT ABLATION")
print("="*80)

n_c = len(results_c)
bare_c = np.array([r['nll_bare'] for r in results_c])

formats = ['oracle', 'template', 'raw', 'question', 'instruction', 'shuffled_template']
format_labels = ['Oracle', 'Template', 'Raw query', 'Question:', 'Instruction:', 'Shuffled template']

print(f"\n{'Format':<25} {'N':>5} {'Mean NLL':>10} {'Win%':>8} {'Delta':>10} {'Cohen d':>10} {'p-value':>12}")
print("-" * 85)
print(f"{'Bare':<25} {n_c:>5} {np.mean(bare_c):>10.4f}")

format_stats = {}
for fmt, label in zip(formats, format_labels):
    nlls = np.array([r[f'nll_{fmt}'] for r in results_c])
    deltas = bare_c - nlls
    wr = np.mean(deltas > 0) * 100
    t, p = stats.ttest_rel(bare_c, nlls)
    d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
    format_stats[fmt] = {'mean_nll': float(np.mean(nlls)), 'win_rate': float(wr),
                         'mean_delta': float(np.mean(deltas)), 'cohens_d': float(d), 'p_value': float(p)}
    print(f"{label:<25} {n_c:>5} {np.mean(nlls):>10.4f} {wr:>7.1f}% {np.mean(deltas):>+10.4f} {d:>10.3f} {p:>12.2e}")

# Key comparisons
print("\n--- Key Comparisons ---")
tmpl_nlls = np.array([r['nll_template'] for r in results_c])
raw_nlls = np.array([r['nll_raw'] for r in results_c])
shuf_nlls = np.array([r['nll_shuffled_template'] for r in results_c])

t_tr, p_tr = stats.ttest_rel(tmpl_nlls, raw_nlls)
print(f"  Template vs Raw: t={t_tr:.3f}, p={p_tr:.4f}, template wins {np.mean(tmpl_nlls < raw_nlls)*100:.1f}%")

t_ts, p_ts = stats.ttest_rel(tmpl_nlls, shuf_nlls)
print(f"  Template vs Shuffled: t={t_ts:.3f}, p={p_ts:.4f}, template wins {np.mean(tmpl_nlls < shuf_nlls)*100:.1f}%")

# Visualization
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
x = range(len(format_labels))
deltas_plot = [format_stats[f]['mean_delta'] for f in formats]
colors_fmt = ['#c44e52', '#4c72b0', '#55a868', '#e377c2', '#8c564b', '#bcbd22']
ax.bar(x, deltas_plot, color=colors_fmt)
ax.set_xticks(x)
ax.set_xticklabels(format_labels, rotation=30, ha='right')
ax.set_ylabel('Mean Delta NLL vs Bare')
ax.set_title('Prefix Format Comparison (sim~0.60 surrogates)')
ax.axhline(0, color='gray', linestyle='--')
for i, d in enumerate(deltas_plot):
    ax.text(i, d + 0.001 if d >= 0 else d - 0.003, f'{d:.4f}', ha='center', fontsize=8)

plt.tight_layout()
plt.savefig('results/exp12/12_investigation_c.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 12_investigation_c.png')

In [None]:
# ============================================================
# Investigation D: Ranking with Bootstrap CIs (N=1000)
# ============================================================

print("="*80)
print("INVESTIGATION D: RANKING WITH BOOTSTRAP CIs")
print("="*80)

samples_d = all_samples[:N_INV_D]
print(f"Using {len(samples_d)} samples")

# Embed and find surrogates
print("Embedding target queries for Inv D...")
target_qs_d = [s['query'] for s in samples_d]
target_embs_d = embed_model.encode(target_qs_d, batch_size=256)

rng_d = np.random.RandomState(SEEDS['msmarco'] + 200)

# Pre-select surrogates + distractors
d_surrogates = []  # {sim_0.30: (q, sim), sim_0.60: ..., sim_0.80: ...}
d_distractors = []  # [q1, q2, q3, q4]

for i in tqdm(range(len(samples_d)), desc="D surrogates"):
    surr = {}
    for sim_low, sim_high, bin_name in SIM_BINS_B:
        result = find_surrogate_at_similarity(
            target_qs_d[i], target_embs_d[i],
            sim_low, sim_high, query_pool, pool_embeddings, rng_d
        )
        if result is not None:
            surr[bin_name] = result
    d_surrogates.append(surr)
    
    # Select 4 distractors: 2 low-sim + 2 medium-sim
    sims_all = cosine_similarity([target_embs_d[i]], pool_embeddings)[0]
    dists = []
    for slo, shi, npick in [(0.1, 0.3, 2), (0.3, 0.5, 2)]:
        mask = (sims_all >= slo) & (sims_all < shi)
        cands = np.where(mask)[0]
        if len(cands) >= npick:
            chosen = rng_d.choice(cands, size=npick, replace=False)
            dists.extend([query_pool[c] for c in chosen])
        else:
            dists.extend([query_pool[c] for c in cands[:npick]])
    while len(dists) < 4:
        dists.append(query_pool[rng_d.randint(0, len(query_pool))])
    d_distractors.append(dists[:4])

print("Surrogate and distractor selection done.")

# Evaluation loop
results_d = []
errors_d = 0
start_d = time.time()

CHECKPOINT_PATH_D = 'results/exp12/12_checkpoint_d.json'
if os.path.exists(CHECKPOINT_PATH_D):
    with open(CHECKPOINT_PATH_D) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '12_inv_d':
        results_d = ckpt['results']
        errors_d = ckpt['errors']
        print(f"Resumed: {len(results_d)} results")

start_idx_d = len(results_d) + errors_d

for idx in tqdm(range(start_idx_d, len(samples_d)), desc="Inv D",
                initial=start_idx_d, total=len(samples_d)):
    sample = samples_d[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    
    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        continue
    
    distractors = d_distractors[idx]
    all_queries = [query] + distractors  # correct is index 0
    surrogates = d_surrogates[idx]
    
    try:
        result = {'idx': idx}
        
        # Build caches for each condition
        # bare + oracle + 3 sim levels
        oracle_prefix = f"This document answers: {query}"
        bare_len, bare_cache, oracle_len, oracle_cache, _ = \
            build_matched_bare_and_truncated(oracle_prefix, passage, model, tokenizer, config)
        
        cache_conditions = {
            'bare': (bare_len, bare_cache),
            'oracle': (oracle_len, oracle_cache),
        }
        
        for _, _, bn in SIM_BINS_B:
            if bn in surrogates:
                sq, ss = surrogates[bn]
                sp = f"This document answers: {sq}"
                _, _, sl, sc, _ = build_matched_bare_and_truncated(sp, passage, model, tokenizer, config)
                cache_conditions[bn] = (sl, sc)
        
        # Score all 5 queries under each cache condition
        for cond_name, (clen, ccache) in cache_conditions.items():
            scores = []
            for q in all_queries:
                qp = config.query_template.format(query=q)
                nll = score_with_deepcopy(ccache, clen, qp, answer, model, tokenizer, config)
                scores.append(nll)
            # Rank: lower NLL = better match, correct is index 0
            rank = int(np.argsort(scores).tolist().index(0)) + 1
            result[f'scores_{cond_name}'] = scores
            result[f'rank_{cond_name}'] = rank
        
        results_d.append(result)
    except Exception as e:
        errors_d += 1
        if errors_d <= 3:
            print(f"\n  Error: {e}")
        continue
    finally:
        torch.cuda.empty_cache()
    
    if len(results_d) % 25 == 0:
        with open(CHECKPOINT_PATH_D, 'w') as f:
            json.dump({'experiment': '12_inv_d', 'results': results_d, 'errors': errors_d}, f)
        elapsed = time.time() - start_d
        print(f"\n  [{len(results_d)} done | {elapsed/60:.0f}m]")

print(f"\nInv D done. {len(results_d)} evaluated, {errors_d} errors.")
print(f"Time: {(time.time()-start_d)/60:.1f} min")

In [None]:
# ============================================================
# Investigation D: Analysis with Bootstrap CIs
# ============================================================

print("="*80)
print("INVESTIGATION D RESULTS: RANKING WITH BOOTSTRAP CIs")
print("="*80)

n_d = len(results_d)

# Compute MRR, Hit@1, Hit@3 for each condition
cond_names_d = ['bare', 'oracle'] + [bn for _, _, bn in SIM_BINS_B]

print(f"\n{'Condition':<15} {'N':>5} {'MRR':>10} {'MRR 95% CI':>22} {'Hit@1':>8} {'Hit@3':>8}")
print("-" * 75)

ranking_stats = {}
for cond in cond_names_d:
    valid = [r for r in results_d if f'rank_{cond}' in r]
    if len(valid) < 10:
        print(f"{cond:<15} {len(valid):>5} -- insufficient data")
        continue
    
    ranks = np.array([r[f'rank_{cond}'] for r in valid])
    rrs = 1.0 / ranks
    mrr = float(np.mean(rrs))
    hit1 = float(np.mean(ranks == 1))
    hit3 = float(np.mean(ranks <= 3))
    
    # Bootstrap CIs
    mrr_lo, mrr_hi, _ = bootstrap_ci(rrs, np.mean, n_boot=10000)
    hit1_lo, hit1_hi, _ = bootstrap_ci((ranks == 1).astype(float), np.mean, n_boot=10000)
    hit3_lo, hit3_hi, _ = bootstrap_ci((ranks <= 3).astype(float), np.mean, n_boot=10000)
    
    ranking_stats[cond] = {
        'n': len(valid), 'mrr': mrr, 'mrr_ci': (mrr_lo, mrr_hi),
        'hit1': hit1, 'hit1_ci': (hit1_lo, hit1_hi),
        'hit3': hit3, 'hit3_ci': (hit3_lo, hit3_hi),
    }
    
    print(f"{cond:<15} {len(valid):>5} {mrr:>10.3f} [{mrr_lo:.3f}, {mrr_hi:.3f}]  {hit1:>7.1%} {hit3:>7.1%}")

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax_idx, (metric, mlabel) in enumerate([('mrr', 'MRR'), ('hit1', 'Hit@1'), ('hit3', 'Hit@3')]):
    ax = axes[ax_idx]
    conds_plot = [c for c in cond_names_d if c in ranking_stats]
    vals = [ranking_stats[c][metric] for c in conds_plot]
    ci_los = [ranking_stats[c][f'{metric}_ci'][0] for c in conds_plot]
    ci_his = [ranking_stats[c][f'{metric}_ci'][1] for c in conds_plot]
    errs = [[v - lo for v, lo in zip(vals, ci_los)],
            [hi - v for v, hi in zip(vals, ci_his)]]
    
    colors_d = ['#888888', '#c44e52'] + ['#4c72b0'] * len(SIM_BINS_B)
    ax.bar(range(len(conds_plot)), vals, yerr=errs, color=colors_d[:len(conds_plot)],
           capsize=4, edgecolor='black', linewidth=0.5)
    ax.set_xticks(range(len(conds_plot)))
    ax.set_xticklabels(conds_plot, rotation=30, ha='right')
    ax.set_ylabel(mlabel)
    ax.set_title(f'{mlabel} with 95% Bootstrap CIs')
    ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig('results/exp12/12_investigation_d.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 12_investigation_d.png')

In [None]:
# ============================================================
# Comprehensive Summary + Pre-Registered Verdicts
# ============================================================

print("="*80)
print("EXPERIMENT 12: COMPREHENSIVE SUMMARY")
print("="*80)

# --- Criterion (a): Pearson r bootstrap 95% CI excludes 0 ---
criterion_a = ci_excludes_zero
print(f"\n(a) Pearson r = {r_pearson_a:.4f}, 95% CI = [{r_ci_lo:.4f}, {r_ci_hi:.4f}]")
print(f"    CI excludes 0: {criterion_a}")

# --- Criterion (b): Bin-level Spearman rho > 0.7 ---
criterion_b = rho_bins_a > 0.7
print(f"\n(b) Bin-level Spearman rho = {rho_bins_a:.3f}")
print(f"    rho > 0.7: {criterion_b}")

# --- Criterion (c): sim_0.60 beats shuffled_0.60 (p<0.05) ---
criterion_c = False
if len(valid_shuf) >= 10:
    criterion_c = p_rs < 0.05 and wr_real_vs_shuf > 50
    print(f"\n(c) sim_0.60 vs shuffled_0.60: p={p_rs:.6f}, real wins {wr_real_vs_shuf:.1f}%")
    print(f"    Significant (p<0.05) and real wins: {criterion_c}")

# --- Overall verdict ---
print("\n" + "="*80)
if criterion_a and criterion_b and criterion_c:
    verdict = "CONFIRMED"
    print("VERDICT: CONFIRMED")
    print("All three pre-registered criteria met.")
elif criterion_a or criterion_b:
    verdict = "PARTIALLY CONFIRMED"
    print("VERDICT: PARTIALLY CONFIRMED")
    print("Some but not all criteria met.")
else:
    verdict = "REFUTED"
    print("VERDICT: REFUTED")
    print("No criteria met. Semantic quality gradient not confirmed at scale.")
print("="*80)

# --- Cross-dataset summary ---
print("\n--- Cross-Dataset Replication ---")
for dname, bstats in [("MS MARCO", bin_stats_a), ("SQuAD v2", squad_stats), ("TriviaQA", tqa_stats)]:
    bins_used = SIM_BINS_A if dname == "MS MARCO" else SIM_BINS_B
    if len(bstats) >= 2:
        bs = [bstats[bn]['mean_sim'] for bn in [b[2] for b in bins_used] if bn in bstats]
        bd = [bstats[bn]['mean_delta'] for bn in [b[2] for b in bins_used] if bn in bstats]
        if len(bs) >= 3:
            rho, p = stats.spearmanr(bs, bd)
            print(f"  {dname}: bin-level rho={rho:.3f}, p={p:.4f}")
        else:
            print(f"  {dname}: too few bins ({len(bs)})")

# --- Format ablation summary ---
print("\n--- Format Ablation ---")
best_fmt = max(format_stats.items(), key=lambda x: x[1]['mean_delta'])
print(f"  Best format: {best_fmt[0]} (delta={best_fmt[1]['mean_delta']:.4f})")
worst_fmt = min(format_stats.items(), key=lambda x: x[1]['mean_delta'])
print(f"  Worst format: {worst_fmt[0]} (delta={worst_fmt[1]['mean_delta']:.4f})")

# --- Ranking summary ---
print("\n--- Ranking ---")
if 'bare' in ranking_stats and 'oracle' in ranking_stats:
    print(f"  Bare MRR: {ranking_stats['bare']['mrr']:.3f} {ranking_stats['bare']['mrr_ci']}")
    print(f"  Oracle MRR: {ranking_stats['oracle']['mrr']:.3f} {ranking_stats['oracle']['mrr_ci']}")
    for _, _, bn in SIM_BINS_B:
        if bn in ranking_stats:
            print(f"  {bn} MRR: {ranking_stats[bn]['mrr']:.3f} {ranking_stats[bn]['mrr_ci']}")

# --- Control comparisons ---
print("\n--- Control Comparisons ---")
if len(valid_shuf) >= 10:
    print(f"  sim_0.60 vs shuffled: p={p_rs:.6f} (word order {'matters' if p_rs < 0.05 else 'does NOT matter'})")
if len(valid_lm) >= 10:
    print(f"  sim_0.60 vs length-matched random: p={p_rl:.6f} (similarity {'matters' if p_rl < 0.05 else 'does NOT matter'} beyond length)")
if len(valid_raw) >= 10:
    print(f"  template vs raw: p={p_tr:.4f} (template framing {'matters' if p_tr < 0.05 else 'does NOT matter'})")
print(f"  Partial r (controlling prefix length): {r_partial:.4f}, p={p_partial:.2e}")

In [None]:
# ============================================================
# Save Results
# ============================================================

output = {
    'metadata': {
        'experiment': '12_definitive_semantic_signal',
        'timestamp': datetime.datetime.now().isoformat(),
        'model_name': config.model_name,
        'seeds': SEEDS,
        'n_inv_a': N_INV_A, 'n_inv_b': N_INV_B,
        'n_inv_c': N_INV_C, 'n_inv_d': N_INV_D,
    },
    'investigation_a': {
        'n_evaluated': len(results_a),
        'skipped': skipped_a, 'errors': errors_a,
        'bin_stats': bin_stats_a,
        'oracle_win_rate': float(oracle_wr),
        'pearson_r': float(r_pearson_a),
        'pearson_p': float(p_pearson_a),
        'pearson_r_ci': [r_ci_lo, r_ci_hi],
        'spearman_rho': float(r_spearman_a),
        'bin_spearman_rho': float(rho_bins_a),
        'partial_r': float(r_partial),
        'partial_p': float(p_partial),
        'shuffled_p': float(p_rs) if len(valid_shuf) >= 10 else None,
        'shuffled_real_wins': float(wr_real_vs_shuf) if len(valid_shuf) >= 10 else None,
        'length_matched_p': float(p_rl) if len(valid_lm) >= 10 else None,
        'raw_vs_template_p': float(p_tr) if len(valid_raw) >= 10 else None,
        'results': results_a,
    },
    'investigation_b': {
        'squad': {
            'n_evaluated': len(results_squad),
            'bin_stats': squad_stats,
            'results': results_squad,
        },
        'triviaqa': {
            'n_evaluated': len(results_tqa),
            'bin_stats': tqa_stats,
            'results': results_tqa,
        },
    },
    'investigation_c': {
        'n_evaluated': len(results_c),
        'format_stats': format_stats,
        'results': results_c,
    },
    'investigation_d': {
        'n_evaluated': len(results_d),
        'ranking_stats': {k: {kk: vv for kk, vv in v.items()} for k, v in ranking_stats.items()},
        'results': results_d,
    },
    'verdict': {
        'criterion_a': bool(criterion_a),
        'criterion_b': bool(criterion_b),
        'criterion_c': bool(criterion_c),
        'overall': verdict,
    },
}

output_path = 'results/exp12/12_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Results saved to {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1e6:.1f} MB")

# Final summary figure
fig, axes = plt.subplots(1, 4, figsize=(22, 5))

# A: Quality gradient
ax = axes[0]
sp = [0.0]
dp = [0.0]
for _, _, bn in SIM_BINS_A:
    if bn in bin_stats_a:
        sp.append(bin_stats_a[bn]['mean_sim'])
        dp.append(bin_stats_a[bn]['mean_delta'])
sp.append(1.0)
dp.append(float(np.mean(oracle_deltas)))
ax.plot(sp, dp, 'o-', color='#4c72b0', linewidth=2, markersize=8)
ax.axhline(0, color='gray', linestyle='--')
ax.set_xlabel('Surrogate Similarity')
ax.set_ylabel('Mean Delta NLL')
ax.set_title(f'A: Quality Gradient (N={len(results_a)})\nr={r_pearson_a:.4f} [{r_ci_lo:.4f},{r_ci_hi:.4f}]')

# B: Cross-dataset
ax = axes[1]
for dname, bstats, marker, color in [
    ("MARCO", bin_stats_a, 'o', '#4c72b0'),
    ("SQuAD", squad_stats, 's', '#c44e52'),
    ("TQA", tqa_stats, '^', '#55a868'),
]:
    bins_used = SIM_BINS_A if dname == "MARCO" else SIM_BINS_B
    ss = [bstats[bn]['mean_sim'] for bn in [b[2] for b in bins_used] if bn in bstats]
    dd = [bstats[bn]['mean_delta'] for bn in [b[2] for b in bins_used] if bn in bstats]
    ax.plot(ss, dd, f'{marker}-', color=color, label=dname, markersize=8)
ax.axhline(0, color='gray', linestyle='--')
ax.set_xlabel('Surrogate Similarity')
ax.set_ylabel('Mean Delta NLL')
ax.set_title('B: Cross-Dataset')
ax.legend()

# C: Format ablation
ax = axes[2]
fmts_sorted = sorted(format_stats.items(), key=lambda x: -x[1]['mean_delta'])
ax.barh(range(len(fmts_sorted)), [f[1]['mean_delta'] for f in fmts_sorted],
        color='#4c72b0')
ax.set_yticks(range(len(fmts_sorted)))
ax.set_yticklabels([f[0] for f in fmts_sorted], fontsize=8)
ax.set_xlabel('Mean Delta NLL')
ax.set_title('C: Format Ablation')
ax.axvline(0, color='gray', linestyle='--')

# D: Ranking
ax = axes[3]
conds_d = [c for c in ['bare'] + [bn for _, _, bn in SIM_BINS_B] + ['oracle'] if c in ranking_stats]
mrrs_d = [ranking_stats[c]['mrr'] for c in conds_d]
errs_d = [[ranking_stats[c]['mrr'] - ranking_stats[c]['mrr_ci'][0] for c in conds_d],
           [ranking_stats[c]['mrr_ci'][1] - ranking_stats[c]['mrr'] for c in conds_d]]
ax.bar(range(len(conds_d)), mrrs_d, yerr=errs_d,
       color=['#888888'] + ['#4c72b0']*(len(conds_d)-2) + ['#c44e52'], capsize=3)
ax.set_xticks(range(len(conds_d)))
ax.set_xticklabels(conds_d, rotation=30, ha='right', fontsize=8)
ax.set_ylabel('MRR')
ax.set_title('D: Ranking (95% CI)')
ax.set_ylim(0, 1)

plt.suptitle(f'Experiment 12: Verdict = {verdict}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('results/exp12/12_summary.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 12_summary.png')
print(f"\n{'='*80}")
print(f"EXPERIMENT 12 COMPLETE. VERDICT: {verdict}")
print(f"{'='*80}")