# Experiment 13: Multi-Query Amplification + Hardness-Gated Priming (Parts A & B)

## Motivation
Exp 12 confirmed a real but tiny semantic signal (r=0.034, CI excludes 0). Every prior experiment used a single prefix query. Two key observations suggest we can amplify the effect:

1. **Multiple queries**: In production ad-serving, documents have click history with many queries. Concatenating K raw queries as prefix could amplify value contamination additively.
2. **Hard-sample targeting**: Exp 03 found surrogates hurt easy samples (low baseline NLL) and help hard ones. Exp 12's sim_0.75 bin had higher baseline NLL and the biggest benefit (d=0.226). A gated strategy could double effective d.

## Experiment 13, Part A: Multi-Query Prefix Amplification
**Hypothesis**: Concatenating K high-similarity raw queries as prefix amplifies value contamination, producing larger deltas than a single query.

11 conditions (MS MARCO N=1000, SQuAD N=1000):
1. `bare` — no prefix
2. `oracle_1q` — 1 oracle query, raw
3. `oracle_2q` — oracle + 1 query at sim>0.85, raw
4. `oracle_3q` — oracle + 2 queries at sim>0.85, raw
5. `oracle_5q` — oracle + 4 queries at sim>0.85, raw
6. `real_1q_0.70` — 1 real query at sim>=0.70, raw
7. `real_3q_0.70` — 3 real queries at sim>=0.70, raw
8. `real_5q_0.70` — 5 real queries at sim>=0.70, raw
9. `real_5q_0.50` — 5 real queries at sim>=0.50 (easier to find)
10. `random_5q` — 5 random queries (structural control)
11. `repeated_1q_5x` — same single query repeated 5x (repetition control)

## Experiment 13, Part B: Hardness-Gated Priming
**Hypothesis**: Priming only benefits hard samples. A gating strategy that primes selectively can double the effective Cohen's d.

Uses Exp 13A MS MARCO data (bare + oracle_1q already computed) extended to N=2000 with 3 conditions:
- bare, oracle_1q (raw), real_1q_0.70 (raw)
- Bins by bare NLL quartile
- Tests hardness predictors available at indexing time

In [None]:
import sys, os, json, copy, time, datetime
from typing import Dict, List, Any, Optional, Tuple

import torch
import numpy as np
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 120

sys.path.insert(0, '.')

from lib import (
    ExperimentConfig,
    build_kv_cache,
    score_answer_with_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    load_evaluation_samples,
    load_ms_marco,
    _ensure_dynamic_cache,
)

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# Configuration
# ============================================================

config = ExperimentConfig(
    num_samples=8000,  # Large pool for surrogate selection
    min_passage_words=50,
    max_passage_words=300,
    seed=42,
)

N_EXP13_MARCO = 1000
N_EXP13_SQUAD = 1000
N_EXP14 = 2000  # Extends Exp 13 MS MARCO samples

SEEDS = {'msmarco': 42, 'squad': 43}

# Multi-query separator: space between concatenated queries
# (Exp 12 showed raw queries work best, so just concatenate with spaces)
MQ_SEP = " "

torch.manual_seed(config.seed)
np.random.seed(config.seed)

print(f"Exp 13 MS MARCO: {N_EXP13_MARCO} samples x 11 conditions")
print(f"Exp 13 SQuAD:    {N_EXP13_SQUAD} samples x 11 conditions")
print(f"Exp 14 MS MARCO: {N_EXP14} samples x 3 conditions (extends Exp 13)")

In [None]:
# ============================================================
# Load Model
# ============================================================

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

print(f"Loading {config.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
)
model.eval()
print(f"Model loaded. Layers: {model.config.num_hidden_layers}")

print("Loading embedding model...")
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded.")

In [None]:
# ============================================================
# Load MS MARCO + Build Query Pool
# ============================================================

print("Loading MS MARCO dataset...")
dataset = load_dataset(
    config.dataset_name, config.dataset_config,
    split=config.dataset_split, trust_remote_code=True
)
print(f"Dataset loaded: {len(dataset)} samples")

all_samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Loaded {len(all_samples)} evaluation samples with answers")

# Build query pool from FULL dataset
print("\nBuilding query pool (full dataset)...")
query_pool = []
seen_queries = set()
for item in tqdm(dataset, desc="Extracting queries"):
    q = item.get('query', '').strip()
    if q and q not in seen_queries and len(q) > 10:
        query_pool.append(q)
        seen_queries.add(q)
print(f"Query pool size: {len(query_pool)}")

print("Embedding query pool...")
pool_embeddings = embed_model.encode(query_pool, show_progress_bar=True, batch_size=256)
print(f"Pool embeddings shape: {pool_embeddings.shape}")

In [None]:
# ============================================================
# Load SQuAD v2 + Build Query Pool
# ============================================================

print("Loading SQuAD v2...")
squad_dataset = load_dataset("rajpurkar/squad_v2", split="validation")
squad_train = load_dataset("rajpurkar/squad_v2", split="train")
print(f"SQuAD val: {len(squad_dataset)}, train: {len(squad_train)}")

np.random.seed(SEEDS['squad'])
squad_samples = []
for item in squad_dataset:
    ctx = item.get('context', '').strip()
    q = item.get('question', '').strip()
    ans_texts = item.get('answers', {}).get('text', [])
    if not ctx or not q or not ans_texts:
        continue
    wc = len(ctx.split())
    if 50 <= wc <= 300:
        squad_samples.append({'passage': ctx, 'query': q, 'answer': ans_texts[0]})
np.random.shuffle(squad_samples)
squad_samples = squad_samples[:N_EXP13_SQUAD]
print(f"SQuAD evaluation samples: {len(squad_samples)}")

# SQuAD query pool from train
squad_query_pool = []
seen_sq = set()
for item in squad_train:
    q = item.get('question', '').strip()
    if q and q not in seen_sq and len(q) > 10:
        squad_query_pool.append(q)
        seen_sq.add(q)
print(f"SQuAD query pool: {len(squad_query_pool)}")
print("Embedding SQuAD query pool...")
squad_pool_embs = embed_model.encode(squad_query_pool, show_progress_bar=True, batch_size=256)

In [None]:
# ============================================================
# Helper Functions
# ============================================================

def build_matched_bare_and_truncated(
    prefix_text: str,
    passage: str,
    model, tokenizer, config,
) -> Tuple[int, DynamicCache, int, DynamicCache, int]:
    """Build BPE-matched bare and truncated caches.
    Returns: (bare_len, bare_cache, trunc_len, trunc_cache, prefix_token_len)
    """
    prefix_with_sep = prefix_text.strip() + " "

    prefix_encoding = tokenizer(
        prefix_with_sep, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    prefix_len = prefix_encoding['input_ids'].shape[1]

    full_context = prefix_with_sep + passage
    full_encoding = tokenizer(
        full_context, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    full_ids = full_encoding['input_ids'].to(config.device)
    doc_len = full_ids.shape[1] - prefix_len

    doc_token_ids = full_ids[:, prefix_len:]
    bos_id = full_ids[:, :1]
    bare_ids = torch.cat([bos_id, doc_token_ids], dim=1)
    bare_len = bare_ids.shape[1]

    with torch.no_grad():
        bare_out = model(
            input_ids=bare_ids,
            attention_mask=torch.ones_like(bare_ids),
            use_cache=True, return_dict=True
        )
    bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)

    with torch.no_grad():
        full_out = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True, return_dict=True
        )

    truncated = extract_and_truncate_cache_with_bos(full_out.past_key_values, doc_len)
    keep_len = 1 + doc_len

    assert bare_len == keep_len, f"Length mismatch: {bare_len} vs {keep_len}"

    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated, surrogate_offset, model)

    return bare_len, bare_cache, keep_len, truncated, prefix_len


def find_queries_at_similarity(
    target_query: str,
    target_embedding: np.ndarray,
    sim_low: float,
    sim_high: float,
    pool_queries: list,
    pool_embs: np.ndarray,
    rng: np.random.RandomState,
    k: int = 1,
    diverse: bool = True,
) -> List[Tuple[str, float]]:
    """Find k real queries from the pool within the similarity range.

    If diverse=True, selects queries that are maximally spread apart
    (greedy max-min pairwise distance). Otherwise picks near bin midpoint.

    Returns list of (query, similarity) tuples, may be shorter than k.
    """
    sims = cosine_similarity([target_embedding], pool_embs)[0]
    mask = (sims >= sim_low) & (sims < sim_high)
    # Exclude exact match
    for idx in np.where(mask)[0]:
        if pool_queries[idx].strip().lower() == target_query.strip().lower():
            mask[idx] = False

    candidates = np.where(mask)[0]
    if len(candidates) == 0:
        return []

    if k == 1 or not diverse or len(candidates) <= k:
        # Simple: pick k nearest to bin midpoint
        mid = (sim_low + sim_high) / 2
        dist_to_mid = np.abs(sims[candidates] - mid)
        chosen_idxs = candidates[np.argsort(dist_to_mid)[:k]]
        return [(pool_queries[ci], float(sims[ci])) for ci in chosen_idxs]

    # Diverse selection: greedy max-min distance
    # Start with the candidate nearest to bin midpoint
    mid = (sim_low + sim_high) / 2
    dist_to_mid = np.abs(sims[candidates] - mid)
    first = candidates[np.argmin(dist_to_mid)]

    selected = [first]
    cand_embs = pool_embs[candidates]
    first_emb = pool_embs[first:first+1]

    # Precompute pairwise distances from first
    remaining = set(range(len(candidates)))
    first_local = np.where(candidates == first)[0][0]
    remaining.discard(first_local)

    # min_dist[i] = min distance from candidate i to any selected
    all_dists = 1 - cosine_similarity(cand_embs, first_emb).flatten()
    min_dists = all_dists.copy()

    for _ in range(k - 1):
        if not remaining:
            break
        rem_list = list(remaining)
        best_local = rem_list[np.argmax(min_dists[rem_list])]
        selected.append(candidates[best_local])
        remaining.discard(best_local)
        # Update min distances
        new_dists = 1 - cosine_similarity(cand_embs, cand_embs[best_local:best_local+1]).flatten()
        min_dists = np.minimum(min_dists, new_dists)

    return [(pool_queries[ci], float(sims[ci])) for ci in selected]


def build_multi_query_prefix(queries: List[str], sep: str = " ") -> str:
    """Concatenate multiple queries into a single raw prefix string."""
    return sep.join(q.strip() for q in queries)


def bootstrap_ci(data, stat_fn=np.mean, n_boot=10000, ci=0.95, seed=42):
    """Compute bootstrap confidence interval."""
    rng = np.random.RandomState(seed)
    n = len(data)
    boot_stats = np.empty(n_boot)
    for i in range(n_boot):
        sample = rng.choice(data, size=n, replace=True)
        boot_stats[i] = stat_fn(sample)
    alpha = (1 - ci) / 2
    lo = np.percentile(boot_stats, 100 * alpha)
    hi = np.percentile(boot_stats, 100 * (1 - alpha))
    return float(lo), float(hi), boot_stats


def bootstrap_corr_ci(x, y, n_boot=10000, ci=0.95, seed=42):
    """Bootstrap CI for Pearson correlation."""
    rng = np.random.RandomState(seed)
    n = len(x)
    boot_rs = np.empty(n_boot)
    for i in range(n_boot):
        idx = rng.choice(n, size=n, replace=True)
        r, _ = stats.pearsonr(x[idx], y[idx])
        boot_rs[i] = r
    alpha = (1 - ci) / 2
    lo = np.percentile(boot_rs, 100 * alpha)
    hi = np.percentile(boot_rs, 100 * (1 - alpha))
    return float(lo), float(hi), boot_rs


print("Helper functions defined.")

In [None]:
# ============================================================
# Exp 13: Surrogate Pre-Selection (MS MARCO)
# ============================================================

print("="*80)
print("EXP 13: MULTI-QUERY SURROGATE PRE-SELECTION (MS MARCO)")
print("="*80)

# Use first N_EXP14 samples (superset of Exp 13)
samples_marco = all_samples[:N_EXP14]
print(f"Using {len(samples_marco)} MS MARCO samples (first {N_EXP13_MARCO} for Exp 13, all {N_EXP14} for Exp 14)")

print("Embedding target queries...")
target_qs_marco = [s['query'] for s in samples_marco]
target_embs_marco = embed_model.encode(target_qs_marco, show_progress_bar=True, batch_size=256)

rng_m = np.random.RandomState(SEEDS['msmarco'])

# For each sample, pre-select all needed surrogates
marco_surrogates = []  # list of dicts per sample

for i in tqdm(range(len(samples_marco)), desc="Selecting surrogates"):
    surr = {}

    # Oracle companions: queries with sim>0.85 to the target (for oracle_Kq conditions)
    oracle_companions = find_queries_at_similarity(
        target_qs_marco[i], target_embs_marco[i],
        0.85, 1.0, query_pool, pool_embeddings, rng_m, k=4, diverse=True
    )
    surr['oracle_companions'] = oracle_companions

    # Real queries at sim>=0.70 (for real_Kq_0.70 conditions)
    real_070 = find_queries_at_similarity(
        target_qs_marco[i], target_embs_marco[i],
        0.70, 1.0, query_pool, pool_embeddings, rng_m, k=5, diverse=True
    )
    surr['real_070'] = real_070

    # Real queries at sim>=0.50 (for real_5q_0.50 condition)
    real_050 = find_queries_at_similarity(
        target_qs_marco[i], target_embs_marco[i],
        0.50, 1.0, query_pool, pool_embeddings, rng_m, k=5, diverse=True
    )
    surr['real_050'] = real_050

    # Random queries (for random_5q control)
    rand_idxs = rng_m.choice(len(query_pool), size=5, replace=False)
    surr['random_5'] = [(query_pool[ri], 0.0) for ri in rand_idxs]

    marco_surrogates.append(surr)

# Report coverage
def coverage_report(surrogates, key, k_needed, n_total):
    counts = [len(s[key]) for s in surrogates[:n_total]]
    have_k = sum(1 for c in counts if c >= k_needed)
    print(f"  {key} (need {k_needed}): {have_k}/{n_total} samples ({100*have_k/n_total:.1f}%)")
    if counts:
        print(f"    mean found: {np.mean(counts):.1f}, min: {min(counts)}, max: {max(counts)}")

print(f"\nCoverage for Exp 13 ({N_EXP13_MARCO} samples):")
coverage_report(marco_surrogates, 'oracle_companions', 4, N_EXP13_MARCO)
coverage_report(marco_surrogates, 'real_070', 5, N_EXP13_MARCO)
coverage_report(marco_surrogates, 'real_050', 5, N_EXP13_MARCO)

In [None]:
# ============================================================
# Exp 13: MS MARCO Evaluation Loop
# ============================================================
# All prefixes are RAW queries (no template framing).
# Multi-query prefixes: queries joined by space.
# ============================================================

results_13m = []
skipped_13m = 0
errors_13m = 0
start_13m = time.time()

CKPT_13M = 'results/exp13/13_checkpoint_marco.json'
if os.path.exists(CKPT_13M):
    with open(CKPT_13M) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '13_marco':
        results_13m = ckpt['results']
        skipped_13m = ckpt['skipped']
        errors_13m = ckpt['errors']
        print(f"Resumed: {len(results_13m)} results")

start_idx = len(results_13m) + skipped_13m + errors_13m

for idx in tqdm(range(start_idx, N_EXP13_MARCO), desc="Exp13 MARCO",
                initial=start_idx, total=N_EXP13_MARCO):
    sample = samples_marco[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_13m += 1
        continue

    query_prompt = config.query_template.format(query=query)
    surr = marco_surrogates[idx]

    try:
        result = {'idx': idx, 'query': query}

        # --- Condition 1: bare (use oracle prefix to get matched bare) ---
        # We use the oracle query as the reference prefix for BPE matching
        bare_len, bare_cache, _, oracle_cache, oracle_ptl = \
            build_matched_bare_and_truncated(query, passage, model, tokenizer, config)

        nll_bare = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_bare'] = nll_bare

        # --- Condition 2: oracle_1q (single oracle query, raw) ---
        nll_oracle_1q = score_answer_with_cache(
            oracle_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_oracle_1q'] = nll_oracle_1q
        result['ptl_oracle_1q'] = oracle_ptl

        # --- Conditions 3-5: oracle_2q, oracle_3q, oracle_5q ---
        companions = surr['oracle_companions']
        for k, cond_name in [(2, 'oracle_2q'), (3, 'oracle_3q'), (5, 'oracle_5q')]:
            needed = k - 1  # We already have the oracle query
            if len(companions) >= needed:
                comp_queries = [c[0] for c in companions[:needed]]
                multi_prefix = build_multi_query_prefix([query] + comp_queries, MQ_SEP)
                _, _, ml, mc, mptl = build_matched_bare_and_truncated(
                    multi_prefix, passage, model, tokenizer, config)
                nll = score_answer_with_cache(
                    mc, ml, query_prompt, answer, model, tokenizer, config)
                result[f'nll_{cond_name}'] = nll
                result[f'ptl_{cond_name}'] = mptl
            else:
                result[f'nll_{cond_name}'] = None
                result[f'ptl_{cond_name}'] = None

        # --- Condition 6: real_1q_0.70 ---
        real070 = surr['real_070']
        if len(real070) >= 1:
            r1_prefix = real070[0][0]
            _, _, r1l, r1c, r1ptl = build_matched_bare_and_truncated(
                r1_prefix, passage, model, tokenizer, config)
            result['nll_real_1q_070'] = score_answer_with_cache(
                r1c, r1l, query_prompt, answer, model, tokenizer, config)
            result['ptl_real_1q_070'] = r1ptl
            result['sim_real_1q_070'] = real070[0][1]
        else:
            result['nll_real_1q_070'] = None

        # --- Condition 7: real_3q_0.70 ---
        if len(real070) >= 3:
            r3_prefix = build_multi_query_prefix([q for q, s in real070[:3]], MQ_SEP)
            _, _, r3l, r3c, r3ptl = build_matched_bare_and_truncated(
                r3_prefix, passage, model, tokenizer, config)
            result['nll_real_3q_070'] = score_answer_with_cache(
                r3c, r3l, query_prompt, answer, model, tokenizer, config)
            result['ptl_real_3q_070'] = r3ptl
            result['mean_sim_real_3q_070'] = float(np.mean([s for _, s in real070[:3]]))
        else:
            result['nll_real_3q_070'] = None

        # --- Condition 8: real_5q_0.70 ---
        if len(real070) >= 5:
            r5_prefix = build_multi_query_prefix([q for q, s in real070[:5]], MQ_SEP)
            _, _, r5l, r5c, r5ptl = build_matched_bare_and_truncated(
                r5_prefix, passage, model, tokenizer, config)
            result['nll_real_5q_070'] = score_answer_with_cache(
                r5c, r5l, query_prompt, answer, model, tokenizer, config)
            result['ptl_real_5q_070'] = r5ptl
        else:
            result['nll_real_5q_070'] = None

        # --- Condition 9: real_5q_0.50 ---
        real050 = surr['real_050']
        if len(real050) >= 5:
            r5_050_prefix = build_multi_query_prefix([q for q, s in real050[:5]], MQ_SEP)
            _, _, r5_050l, r5_050c, r5_050ptl = build_matched_bare_and_truncated(
                r5_050_prefix, passage, model, tokenizer, config)
            result['nll_real_5q_050'] = score_answer_with_cache(
                r5_050c, r5_050l, query_prompt, answer, model, tokenizer, config)
            result['ptl_real_5q_050'] = r5_050ptl
        else:
            result['nll_real_5q_050'] = None

        # --- Condition 10: random_5q ---
        rand5 = surr['random_5']
        rand_prefix = build_multi_query_prefix([q for q, s in rand5], MQ_SEP)
        _, _, randl, randc, randptl = build_matched_bare_and_truncated(
            rand_prefix, passage, model, tokenizer, config)
        result['nll_random_5q'] = score_answer_with_cache(
            randc, randl, query_prompt, answer, model, tokenizer, config)
        result['ptl_random_5q'] = randptl

        # --- Condition 11: repeated_1q_5x ---
        if len(real070) >= 1:
            rep_query = real070[0][0]
            rep_prefix = build_multi_query_prefix([rep_query] * 5, MQ_SEP)
            _, _, repl, repc, repptl = build_matched_bare_and_truncated(
                rep_prefix, passage, model, tokenizer, config)
            result['nll_repeated_1q_5x'] = score_answer_with_cache(
                repc, repl, query_prompt, answer, model, tokenizer, config)
            result['ptl_repeated_1q_5x'] = repptl
        else:
            result['nll_repeated_1q_5x'] = None

        results_13m.append(result)

    except Exception as e:
        errors_13m += 1
        if errors_13m <= 5:
            print(f"\n  Error on sample {idx}: {e}")
        continue
    finally:
        torch.cuda.empty_cache()

    if len(results_13m) % 25 == 0:
        with open(CKPT_13M, 'w') as f:
            json.dump({'experiment': '13_marco', 'results': results_13m,
                       'skipped': skipped_13m, 'errors': errors_13m}, f)
        elapsed = time.time() - start_13m
        rate = len(results_13m) / (elapsed / 60) if elapsed > 0 else 0
        print(f"\n  [{len(results_13m)} done | {elapsed/60:.0f}m | {rate:.1f}/min]")

print(f"\nExp 13 MARCO done. {len(results_13m)} evaluated, {skipped_13m} skipped, {errors_13m} errors.")
print(f"Time: {(time.time()-start_13m)/60:.1f} min")

In [None]:
# ============================================================
# Exp 13: MS MARCO Analysis
# ============================================================

print("="*80)
print("EXP 13 RESULTS: MULTI-QUERY AMPLIFICATION (MS MARCO)")
print("="*80)

conditions_13 = [
    ('bare', 'nll_bare'),
    ('oracle_1q', 'nll_oracle_1q'),
    ('oracle_2q', 'nll_oracle_2q'),
    ('oracle_3q', 'nll_oracle_3q'),
    ('oracle_5q', 'nll_oracle_5q'),
    ('real_1q_070', 'nll_real_1q_070'),
    ('real_3q_070', 'nll_real_3q_070'),
    ('real_5q_070', 'nll_real_5q_070'),
    ('real_5q_050', 'nll_real_5q_050'),
    ('random_5q', 'nll_random_5q'),
    ('repeated_1q_5x', 'nll_repeated_1q_5x'),
]

bare_nlls = np.array([r['nll_bare'] for r in results_13m])

print(f"\n{'Condition':<20} {'N':>5} {'Mean NLL':>10} {'Win%':>8} {'Delta':>10} {'Cohen d':>10} {'p-value':>12}")
print("-" * 80)

stats_13m = {}
for cname, nll_key in conditions_13:
    valid = [r for r in results_13m if r.get(nll_key) is not None]
    if len(valid) < 10:
        print(f"{cname:<20} {len(valid):>5} -- insufficient data")
        continue
    nlls = np.array([r[nll_key] for r in valid])
    bares = np.array([r['nll_bare'] for r in valid])
    deltas = bares - nlls
    wr = np.mean(deltas > 0) * 100
    if cname == 'bare':
        print(f"{cname:<20} {len(valid):>5} {np.mean(nlls):>10.4f}")
        stats_13m[cname] = {'n': len(valid), 'mean_nll': float(np.mean(nlls))}
        continue
    t, p = stats.ttest_rel(bares, nlls)
    d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
    print(f"{cname:<20} {len(valid):>5} {np.mean(nlls):>10.4f} {wr:>7.1f}% {np.mean(deltas):>+10.4f} {d:>10.3f} {p:>12.2e}")
    stats_13m[cname] = {
        'n': len(valid), 'mean_nll': float(np.mean(nlls)),
        'win_rate': float(wr), 'mean_delta': float(np.mean(deltas)),
        'cohens_d': float(d), 'p_value': float(p),
    }

# --- Key comparisons ---
print("\n--- Key Comparisons ---")

# Does K matter? oracle_1q vs oracle_3q vs oracle_5q
for a, b in [('oracle_1q', 'oracle_3q'), ('oracle_1q', 'oracle_5q'), ('oracle_3q', 'oracle_5q')]:
    va = [r for r in results_13m if r.get(f'nll_{a}') is not None and r.get(f'nll_{b}') is not None]
    if len(va) >= 10:
        na = np.array([r[f'nll_{a}'] for r in va])
        nb = np.array([r[f'nll_{b}'] for r in va])
        t, p = stats.ttest_rel(na, nb)
        wr = np.mean(nb < na) * 100
        print(f"  {b} vs {a}: {b} wins {wr:.1f}%, p={p:.4f}")

# Diversity: real_5q_070 vs repeated_1q_5x
va = [r for r in results_13m if r.get('nll_real_5q_070') is not None and r.get('nll_repeated_1q_5x') is not None]
if len(va) >= 10:
    n_div = np.array([r['nll_real_5q_070'] for r in va])
    n_rep = np.array([r['nll_repeated_1q_5x'] for r in va])
    t, p = stats.ttest_rel(n_div, n_rep)
    print(f"  real_5q_070 vs repeated_1q_5x: diverse wins {np.mean(n_div < n_rep)*100:.1f}%, p={p:.4f}")

# Semantic: real_5q_070 vs random_5q
va = [r for r in results_13m if r.get('nll_real_5q_070') is not None and r.get('nll_random_5q') is not None]
if len(va) >= 10:
    n_real = np.array([r['nll_real_5q_070'] for r in va])
    n_rand = np.array([r['nll_random_5q'] for r in va])
    t, p = stats.ttest_rel(n_real, n_rand)
    print(f"  real_5q_070 vs random_5q: real wins {np.mean(n_real < n_rand)*100:.1f}%, p={p:.4f}")

# Practical: real_5q_050 vs real_1q_070
va = [r for r in results_13m if r.get('nll_real_5q_050') is not None and r.get('nll_real_1q_070') is not None]
if len(va) >= 10:
    n_5_050 = np.array([r['nll_real_5q_050'] for r in va])
    n_1_070 = np.array([r['nll_real_1q_070'] for r in va])
    t, p = stats.ttest_rel(n_5_050, n_1_070)
    print(f"  real_5q_050 vs real_1q_070: 5q@0.50 wins {np.mean(n_5_050 < n_1_070)*100:.1f}%, p={p:.4f}")

# --- Visualization ---
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Delta by condition (bar chart)
ax = axes[0]
cond_order = ['oracle_1q', 'oracle_2q', 'oracle_3q', 'oracle_5q',
              'real_1q_070', 'real_3q_070', 'real_5q_070', 'real_5q_050',
              'random_5q', 'repeated_1q_5x']
plot_conds = [c for c in cond_order if c in stats_13m]
plot_deltas = [stats_13m[c]['mean_delta'] for c in plot_conds]
colors_13 = ['#c44e52']*4 + ['#4c72b0']*4 + ['#888888', '#e377c2']
colors_13 = colors_13[:len(plot_conds)]
ax.bar(range(len(plot_conds)), plot_deltas, color=colors_13)
ax.set_xticks(range(len(plot_conds)))
ax.set_xticklabels(plot_conds, rotation=45, ha='right', fontsize=7)
ax.set_ylabel('Mean Delta NLL vs Bare')
ax.set_title('Multi-Query Amplification')
ax.axhline(0, color='gray', linestyle='--')

# Plot 2: Scaling curve (oracle K=1,2,3,5)
ax = axes[1]
oracle_ks = []
oracle_ds = []
for k, cname in [(1, 'oracle_1q'), (2, 'oracle_2q'), (3, 'oracle_3q'), (5, 'oracle_5q')]:
    if cname in stats_13m:
        oracle_ks.append(k)
        oracle_ds.append(stats_13m[cname]['mean_delta'])
if oracle_ks:
    ax.plot(oracle_ks, oracle_ds, 'o-', color='#c44e52', linewidth=2, markersize=8, label='Oracle K-query')
real_ks = []
real_ds = []
for k, cname in [(1, 'real_1q_070'), (3, 'real_3q_070'), (5, 'real_5q_070')]:
    if cname in stats_13m:
        real_ks.append(k)
        real_ds.append(stats_13m[cname]['mean_delta'])
if real_ks:
    ax.plot(real_ks, real_ds, 's-', color='#4c72b0', linewidth=2, markersize=8, label='Real K-query (sim>=0.70)')
ax.set_xlabel('Number of prefix queries (K)')
ax.set_ylabel('Mean Delta NLL')
ax.set_title('Scaling: Does K Help?')
ax.legend()
ax.axhline(0, color='gray', linestyle='--')

# Plot 3: Win rates
ax = axes[2]
plot_wrs = [stats_13m[c]['win_rate'] for c in plot_conds]
ax.bar(range(len(plot_conds)), plot_wrs, color=colors_13)
ax.axhline(50, color='gray', linestyle='--')
ax.set_xticks(range(len(plot_conds)))
ax.set_xticklabels(plot_conds, rotation=45, ha='right', fontsize=7)
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('Win Rates')

plt.tight_layout()
plt.savefig('results/exp13/13_exp13_marco.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 13_exp13_marco.png')

In [None]:
# ============================================================
# Exp 13: SQuAD Surrogate Selection + Evaluation
# ============================================================

print("="*80)
print("EXP 13: SQUAD REPLICATION")
print("="*80)

print("Embedding SQuAD target queries...")
squad_target_qs = [s['query'] for s in squad_samples]
squad_target_embs = embed_model.encode(squad_target_qs, show_progress_bar=True, batch_size=256)

rng_sq = np.random.RandomState(SEEDS['squad'])

squad_surrogates = []
for i in tqdm(range(len(squad_samples)), desc="SQuAD surrogates"):
    surr = {}
    surr['oracle_companions'] = find_queries_at_similarity(
        squad_target_qs[i], squad_target_embs[i],
        0.85, 1.0, squad_query_pool, squad_pool_embs, rng_sq, k=4, diverse=True)
    surr['real_070'] = find_queries_at_similarity(
        squad_target_qs[i], squad_target_embs[i],
        0.70, 1.0, squad_query_pool, squad_pool_embs, rng_sq, k=5, diverse=True)
    surr['real_050'] = find_queries_at_similarity(
        squad_target_qs[i], squad_target_embs[i],
        0.50, 1.0, squad_query_pool, squad_pool_embs, rng_sq, k=5, diverse=True)
    rand_idxs = rng_sq.choice(len(squad_query_pool), size=5, replace=False)
    surr['random_5'] = [(squad_query_pool[ri], 0.0) for ri in rand_idxs]
    squad_surrogates.append(surr)

print(f"\nSQuAD coverage:")
coverage_report(squad_surrogates, 'oracle_companions', 4, N_EXP13_SQUAD)
coverage_report(squad_surrogates, 'real_070', 5, N_EXP13_SQUAD)
coverage_report(squad_surrogates, 'real_050', 5, N_EXP13_SQUAD)

# --- Evaluation ---
results_13s = []
skipped_13s = 0
errors_13s = 0
start_13s = time.time()

CKPT_13S = 'results/exp13/13_checkpoint_squad.json'
if os.path.exists(CKPT_13S):
    with open(CKPT_13S) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '13_squad':
        results_13s = ckpt['results']
        skipped_13s = ckpt['skipped']
        errors_13s = ckpt['errors']
        print(f"Resumed: {len(results_13s)} results")

start_idx_s = len(results_13s) + skipped_13s + errors_13s

for idx in tqdm(range(start_idx_s, N_EXP13_SQUAD), desc="Exp13 SQuAD",
                initial=start_idx_s, total=N_EXP13_SQUAD):
    sample = squad_samples[idx]
    passage, query, answer = sample['passage'], sample['query'], sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_13s += 1
        continue

    query_prompt = config.query_template.format(query=query)
    surr = squad_surrogates[idx]

    try:
        result = {'idx': idx, 'query': query}

        bare_len, bare_cache, _, oracle_cache, _ = \
            build_matched_bare_and_truncated(query, passage, model, tokenizer, config)
        result['nll_bare'] = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_oracle_1q'] = score_answer_with_cache(
            oracle_cache, bare_len, query_prompt, answer, model, tokenizer, config)

        # oracle_5q
        companions = surr['oracle_companions']
        if len(companions) >= 4:
            mp = build_multi_query_prefix([query] + [c[0] for c in companions[:4]], MQ_SEP)
            _, _, ml, mc, _ = build_matched_bare_and_truncated(mp, passage, model, tokenizer, config)
            result['nll_oracle_5q'] = score_answer_with_cache(
                mc, ml, query_prompt, answer, model, tokenizer, config)
        else:
            result['nll_oracle_5q'] = None

        # real_1q_070, real_5q_070
        real070 = surr['real_070']
        if len(real070) >= 1:
            _, _, rl, rc, _ = build_matched_bare_and_truncated(
                real070[0][0], passage, model, tokenizer, config)
            result['nll_real_1q_070'] = score_answer_with_cache(
                rc, rl, query_prompt, answer, model, tokenizer, config)
        else:
            result['nll_real_1q_070'] = None

        if len(real070) >= 5:
            mp = build_multi_query_prefix([q for q, s in real070[:5]], MQ_SEP)
            _, _, rl, rc, _ = build_matched_bare_and_truncated(mp, passage, model, tokenizer, config)
            result['nll_real_5q_070'] = score_answer_with_cache(
                rc, rl, query_prompt, answer, model, tokenizer, config)
        else:
            result['nll_real_5q_070'] = None

        # random_5q
        rand_prefix = build_multi_query_prefix([q for q, s in surr['random_5']], MQ_SEP)
        _, _, randl, randc, _ = build_matched_bare_and_truncated(
            rand_prefix, passage, model, tokenizer, config)
        result['nll_random_5q'] = score_answer_with_cache(
            randc, randl, query_prompt, answer, model, tokenizer, config)

        results_13s.append(result)
    except Exception as e:
        errors_13s += 1
        if errors_13s <= 3:
            print(f"\n  Error: {e}")
        continue
    finally:
        torch.cuda.empty_cache()

    if len(results_13s) % 25 == 0:
        with open(CKPT_13S, 'w') as f:
            json.dump({'experiment': '13_squad', 'results': results_13s,
                       'skipped': skipped_13s, 'errors': errors_13s}, f)
        elapsed = time.time() - start_13s
        print(f"\n  [{len(results_13s)} done | {elapsed/60:.0f}m]")

print(f"\nSQuAD done. {len(results_13s)} evaluated, {skipped_13s} skipped, {errors_13s} errors.")
print(f"Time: {(time.time()-start_13s)/60:.1f} min")

In [None]:
# ============================================================
# Exp 13: SQuAD Analysis
# ============================================================

print("="*80)
print("EXP 13 RESULTS: SQUAD REPLICATION")
print("="*80)

squad_conds = [
    ('bare', 'nll_bare'), ('oracle_1q', 'nll_oracle_1q'),
    ('oracle_5q', 'nll_oracle_5q'),
    ('real_1q_070', 'nll_real_1q_070'), ('real_5q_070', 'nll_real_5q_070'),
    ('random_5q', 'nll_random_5q'),
]

print(f"\n{'Condition':<20} {'N':>5} {'Mean NLL':>10} {'Win%':>8} {'Delta':>10} {'Cohen d':>10}")
print("-" * 70)

stats_13s = {}
for cname, nll_key in squad_conds:
    valid = [r for r in results_13s if r.get(nll_key) is not None]
    if len(valid) < 10:
        print(f"{cname:<20} {len(valid):>5} -- insufficient")
        continue
    nlls = np.array([r[nll_key] for r in valid])
    bares = np.array([r['nll_bare'] for r in valid])
    if cname == 'bare':
        print(f"{cname:<20} {len(valid):>5} {np.mean(nlls):>10.4f}")
        stats_13s[cname] = {'n': len(valid), 'mean_nll': float(np.mean(nlls))}
        continue
    deltas = bares - nlls
    wr = np.mean(deltas > 0) * 100
    d = np.mean(deltas) / np.std(deltas, ddof=1) if np.std(deltas) > 0 else 0
    print(f"{cname:<20} {len(valid):>5} {np.mean(nlls):>10.4f} {wr:>7.1f}% {np.mean(deltas):>+10.4f} {d:>10.3f}")
    stats_13s[cname] = {'n': len(valid), 'mean_delta': float(np.mean(deltas)),
                         'win_rate': float(wr), 'cohens_d': float(d)}

# Key: does oracle_5q beat oracle_1q on SQuAD?
va = [r for r in results_13s if r.get('nll_oracle_1q') is not None and r.get('nll_oracle_5q') is not None]
if len(va) >= 10:
    o1 = np.array([r['nll_oracle_1q'] for r in va])
    o5 = np.array([r['nll_oracle_5q'] for r in va])
    t, p = stats.ttest_rel(o1, o5)
    print(f"\n  oracle_5q vs oracle_1q: 5q wins {np.mean(o5 < o1)*100:.1f}%, p={p:.4f}")

va = [r for r in results_13s if r.get('nll_real_1q_070') is not None and r.get('nll_real_5q_070') is not None]
if len(va) >= 10:
    r1 = np.array([r['nll_real_1q_070'] for r in va])
    r5 = np.array([r['nll_real_5q_070'] for r in va])
    t, p = stats.ttest_rel(r1, r5)
    print(f"  real_5q_070 vs real_1q_070: 5q wins {np.mean(r5 < r1)*100:.1f}%, p={p:.4f}")

In [None]:
# ============================================================
# Exp 13B: Hardness-Gated Priming — Extend to N=2000
# ============================================================
# Reuse Exp 13A first 1000 samples, extend with 1000 more
# Only 3 conditions: bare, oracle_1q (raw), real_1q_0.70 (raw)
# ============================================================

print("="*80)
print("EXP 13B: HARDNESS-GATED PRIMING (MS MARCO, N=2000)")
print("="*80)

# Results from Exp 13A already have bare + oracle_1q for first 1000
# Extend to samples 1000-1999 with just these 3 conditions
results_14 = []
# Copy Exp 13A results (first 1000) that have the needed fields
for r in results_13m:
    r14 = {
        'idx': r['idx'], 'query': r['query'],
        'nll_bare': r['nll_bare'],
        'nll_oracle_1q': r['nll_oracle_1q'],
        'nll_real_1q_070': r.get('nll_real_1q_070'),
    }
    results_14.append(r14)

print(f"Copied {len(results_14)} results from Exp 13A")

# Extend with samples 1000-1999
start_14 = time.time()
errors_14 = 0
skipped_14 = 0

CKPT_14 = 'results/exp13/13b_checkpoint.json'
start_idx_14 = N_EXP13_MARCO  # Start where Exp 13A ended

if os.path.exists(CKPT_14):
    with open(CKPT_14) as f:
        ckpt = json.load(f)
    if ckpt.get('experiment') == '14_extend':
        extra = ckpt['results']
        results_14 = results_14[:N_EXP13_MARCO]  # Keep only Exp 13A portion
        results_14.extend(extra)
        errors_14 = ckpt['errors']
        skipped_14 = ckpt['skipped']
        start_idx_14 = N_EXP13_MARCO + len(extra) + errors_14 + skipped_14
        print(f"Resumed Exp 13B extension: {len(extra)} additional results")

for idx in tqdm(range(start_idx_14, N_EXP14), desc="Exp13B extend",
                initial=start_idx_14 - N_EXP13_MARCO, total=N_EXP14 - N_EXP13_MARCO):
    sample = samples_marco[idx]
    passage, query, answer = sample['passage'], sample['query'], sample['answer']

    answer_ids = tokenizer(answer, return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        skipped_14 += 1
        continue

    query_prompt = config.query_template.format(query=query)
    surr = marco_surrogates[idx]

    try:
        result = {'idx': idx, 'query': query}

        # bare + oracle_1q
        bare_len, bare_cache, _, oracle_cache, _ = \
            build_matched_bare_and_truncated(query, passage, model, tokenizer, config)
        result['nll_bare'] = score_answer_with_cache(
            bare_cache, bare_len, query_prompt, answer, model, tokenizer, config)
        result['nll_oracle_1q'] = score_answer_with_cache(
            oracle_cache, bare_len, query_prompt, answer, model, tokenizer, config)

        # real_1q_0.70
        real070 = surr['real_070']
        if len(real070) >= 1:
            _, _, rl, rc, _ = build_matched_bare_and_truncated(
                real070[0][0], passage, model, tokenizer, config)
            result['nll_real_1q_070'] = score_answer_with_cache(
                rc, rl, query_prompt, answer, model, tokenizer, config)
        else:
            result['nll_real_1q_070'] = None

        results_14.append(result)
    except Exception as e:
        errors_14 += 1
        if errors_14 <= 3:
            print(f"\n  Error: {e}")
        continue
    finally:
        torch.cuda.empty_cache()

    extra_count = len(results_14) - N_EXP13_MARCO
    if extra_count > 0 and extra_count % 25 == 0:
        with open(CKPT_14, 'w') as f:
            json.dump({'experiment': '14_extend',
                       'results': results_14[N_EXP13_MARCO:],
                       'skipped': skipped_14, 'errors': errors_14}, f)
        elapsed = time.time() - start_14
        print(f"\n  [{len(results_14)} total | {elapsed/60:.0f}m]")

print(f"\nExp 13B done. {len(results_14)} total, {errors_14} errors, {skipped_14} skipped.")
print(f"Extension time: {(time.time()-start_14)/60:.1f} min")

In [None]:
# ============================================================
# Exp 13B: Hardness Analysis
# ============================================================

print("="*80)
print("EXP 13B RESULTS: HARDNESS-GATED PRIMING")
print("="*80)

bare_14 = np.array([r['nll_bare'] for r in results_14])
oracle_14 = np.array([r['nll_oracle_1q'] for r in results_14])
oracle_delta_14 = bare_14 - oracle_14

# --- Phase 1: Bin by bare NLL quartile ---
quartiles = np.percentile(bare_14, [25, 50, 75])
q_labels = [
    f'Q1 (bare<{quartiles[0]:.2f})',
    f'Q2 ({quartiles[0]:.2f}-{quartiles[1]:.2f})',
    f'Q3 ({quartiles[1]:.2f}-{quartiles[2]:.2f})',
    f'Q4 (bare>{quartiles[2]:.2f})',
]
q_masks = [
    bare_14 < quartiles[0],
    (bare_14 >= quartiles[0]) & (bare_14 < quartiles[1]),
    (bare_14 >= quartiles[1]) & (bare_14 < quartiles[2]),
    bare_14 >= quartiles[2],
]

print(f"\nBare NLL quartiles: {quartiles}")
print(f"Total samples: {len(results_14)}")

print(f"\n{'Quartile':<35} {'N':>5} {'Oracle Win%':>12} {'Oracle Delta':>14} {'Oracle d':>10}")
print("-" * 80)

quartile_stats = []
for ql, qm in zip(q_labels, q_masks):
    n = np.sum(qm)
    od = oracle_delta_14[qm]
    wr = np.mean(od > 0) * 100
    d = np.mean(od) / np.std(od, ddof=1) if np.std(od, ddof=1) > 0 else 0
    print(f"{ql:<35} {n:>5} {wr:>11.1f}% {np.mean(od):>+14.4f} {d:>10.3f}")
    quartile_stats.append({'label': ql, 'n': int(n), 'win_rate': float(wr),
                           'mean_delta': float(np.mean(od)), 'cohens_d': float(d)})

# Also for real_1q_070
valid_real = [i for i, r in enumerate(results_14) if r.get('nll_real_1q_070') is not None]
if len(valid_real) >= 40:
    bare_r = np.array([results_14[i]['nll_bare'] for i in valid_real])
    real_r = np.array([results_14[i]['nll_real_1q_070'] for i in valid_real])
    real_delta = bare_r - real_r

    rq = np.percentile(bare_r, [25, 50, 75])
    print(f"\n{'Quartile':<35} {'N':>5} {'Real070 Win%':>12} {'Real070 Delta':>14} {'Real070 d':>10}")
    print("-" * 80)
    for qi, (lo, hi) in enumerate([(0, rq[0]), (rq[0], rq[1]), (rq[1], rq[2]), (rq[2], np.inf)]):
        mask = (bare_r >= lo) & (bare_r < hi)
        n = np.sum(mask)
        if n < 5:
            continue
        rd = real_delta[mask]
        wr = np.mean(rd > 0) * 100
        d = np.mean(rd) / np.std(rd, ddof=1) if np.std(rd, ddof=1) > 0 else 0
        print(f"Q{qi+1:<34} {n:>5} {wr:>11.1f}% {np.mean(rd):>+14.4f} {d:>10.3f}")

# --- Phase 2: Hardness predictors ---
print("\n--- Hardness Predictors ---")

# Passage length
pass_lengths = np.array([len(results_14[i]['query'].split()) for i in range(len(results_14))])
# We don't have passage text in results, but we can get it from samples
pass_word_counts = np.array([len(samples_marco[r['idx']]['passage'].split()) for r in results_14])
ans_lengths = np.array([len(tokenizer.encode(samples_marco[r['idx']]['answer'], add_special_tokens=False))
                        for r in results_14])

# Correlations with bare NLL (hardness predictors)
predictors = [
    ('passage_word_count', pass_word_counts),
    ('answer_token_length', ans_lengths),
]
for pname, pvals in predictors:
    r_pred, p_pred = stats.pearsonr(pvals.astype(float), bare_14)
    print(f"  {pname} vs bare_NLL: r={r_pred:.4f}, p={p_pred:.2e}")

# Correlations with oracle delta (who benefits?)
print("\n  Predictors of benefit (oracle delta):")
for pname, pvals in predictors:
    r_pred, p_pred = stats.pearsonr(pvals.astype(float), oracle_delta_14)
    print(f"  {pname} vs oracle_delta: r={r_pred:.4f}, p={p_pred:.2e}")

# Bare NLL itself as predictor
r_bare_delta, p_bare_delta = stats.pearsonr(bare_14, oracle_delta_14)
print(f"  bare_NLL vs oracle_delta: r={r_bare_delta:.4f}, p={p_bare_delta:.2e}")

# --- Phase 3: Gated strategies ---
print("\n--- Gated Strategies ---")

# Oracle gate: only prime if bare NLL > T
for T_percentile in [25, 50, 75]:
    T = np.percentile(bare_14, T_percentile)
    prime_mask = bare_14 >= T
    n_primed = np.sum(prime_mask)
    # Gated NLL: use oracle where primed, bare where not
    gated_nlls = np.where(prime_mask, oracle_14, bare_14)
    gated_delta = bare_14 - gated_nlls
    gated_wr = np.mean(gated_delta > 0) * 100
    gated_d = np.mean(gated_delta) / np.std(gated_delta, ddof=1) if np.std(gated_delta, ddof=1) > 0 else 0
    # Compare to always-prime
    always_d = np.mean(oracle_delta_14) / np.std(oracle_delta_14, ddof=1)
    print(f"  Gate T=P{T_percentile} (NLL>{T:.2f}): prime {n_primed}/{len(bare_14)}, "
          f"win%={gated_wr:.1f}%, d={gated_d:.3f} (vs always-prime d={always_d:.3f})")

In [None]:
# ============================================================
# Exp 13B: Visualization
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Oracle delta by bare NLL quartile
ax = axes[0, 0]
qs_d = [qs['mean_delta'] for qs in quartile_stats]
qs_n = [qs['n'] for qs in quartile_stats]
colors_q = ['#55a868', '#4c72b0', '#dd8452', '#c44e52']
bars = ax.bar(range(4), qs_d, color=colors_q)
ax.set_xticks(range(4))
ax.set_xticklabels([f'Q{i+1}\n(N={n})' for i, n in enumerate(qs_n)], fontsize=8)
ax.set_ylabel('Mean Oracle Delta NLL')
ax.set_title('Oracle Benefit by Baseline Difficulty')
ax.axhline(0, color='gray', linestyle='--')
for i, d in enumerate(qs_d):
    ax.text(i, d + 0.002 if d >= 0 else d - 0.005, f'd={quartile_stats[i]["cohens_d"]:.3f}',
            ha='center', fontsize=8)

# Plot 2: Scatter bare NLL vs oracle delta
ax = axes[0, 1]
ax.scatter(bare_14, oracle_delta_14, alpha=0.05, s=3, color='#4c72b0')
ax.axhline(0, color='gray', linestyle='--')
ax.axvline(np.median(bare_14), color='orange', linestyle=':', label=f'median={np.median(bare_14):.2f}')
z = np.polyfit(bare_14, oracle_delta_14, 1)
x_line = np.linspace(bare_14.min(), min(bare_14.max(), 10), 100)
ax.plot(x_line, np.poly1d(z)(x_line), 'r-', linewidth=2)
ax.set_xlabel('Bare NLL (baseline difficulty)')
ax.set_ylabel('Oracle Delta NLL')
ax.set_title(f'Difficulty vs Benefit (r={r_bare_delta:.3f})')
ax.set_xlim(0, min(10, np.percentile(bare_14, 99)))
ax.legend()

# Plot 3: Gated strategy comparison
ax = axes[1, 0]
always_delta = np.mean(oracle_delta_14)
strat_labels = ['Never\nprime', 'Always\nprime']
strat_deltas = [0.0, always_delta]
strat_colors = ['#888888', '#4c72b0']
for T_pct in [25, 50, 75]:
    T = np.percentile(bare_14, T_pct)
    pm = bare_14 >= T
    gated = np.where(pm, oracle_14, bare_14)
    gd = np.mean(bare_14 - gated)
    strat_labels.append(f'Gate\nP{T_pct}')
    strat_deltas.append(gd)
    strat_colors.append('#c44e52')
ax.bar(range(len(strat_labels)), strat_deltas, color=strat_colors)
ax.set_xticks(range(len(strat_labels)))
ax.set_xticklabels(strat_labels, fontsize=8)
ax.set_ylabel('Mean Delta NLL')
ax.set_title('Gating Strategies (Oracle Gate)')
ax.axhline(0, color='gray', linestyle='--')

# Plot 4: Win rate by strategy
ax = axes[1, 1]
strat_wrs = [50.0, np.mean(oracle_delta_14 > 0) * 100]
for T_pct in [25, 50, 75]:
    T = np.percentile(bare_14, T_pct)
    pm = bare_14 >= T
    gated = np.where(pm, oracle_14, bare_14)
    gwr = np.mean((bare_14 - gated) > 0) * 100
    strat_wrs.append(gwr)
ax.bar(range(len(strat_labels)), strat_wrs, color=strat_colors)
ax.axhline(50, color='gray', linestyle='--')
ax.set_xticks(range(len(strat_labels)))
ax.set_xticklabels(strat_labels, fontsize=8)
ax.set_ylabel('Win Rate vs Bare (%)')
ax.set_title('Win Rate by Gating Strategy')

plt.tight_layout()
plt.savefig('results/exp13/13b_hardness_gating.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 13b_hardness_gating.png')

In [None]:
# ============================================================
# Comprehensive Summary
# ============================================================

print("="*80)
print("EXPERIMENT 13 (PARTS A & B): COMPREHENSIVE SUMMARY")
print("="*80)

# --- Exp 13A Summary ---
print("\n--- Exp 13A: Multi-Query Amplification ---")
print("\nMS MARCO:")
for cname in ['oracle_1q', 'oracle_2q', 'oracle_3q', 'oracle_5q',
              'real_1q_070', 'real_3q_070', 'real_5q_070', 'real_5q_050',
              'random_5q', 'repeated_1q_5x']:
    if cname in stats_13m:
        s = stats_13m[cname]
        print(f"  {cname:<20} d={s['cohens_d']:.3f}, win%={s['win_rate']:.1f}%, delta={s['mean_delta']:.4f}")

print("\nSQuAD:")
for cname in ['oracle_1q', 'oracle_5q', 'real_1q_070', 'real_5q_070', 'random_5q']:
    if cname in stats_13s:
        s = stats_13s[cname]
        print(f"  {cname:<20} d={s['cohens_d']:.3f}, win%={s['win_rate']:.1f}%, delta={s.get('mean_delta', 0):.4f}")

# Key verdicts
print("\n--- Key Verdicts ---")
# Multi-query amplification
if 'oracle_5q' in stats_13m and 'oracle_1q' in stats_13m:
    amp = stats_13m['oracle_5q']['mean_delta'] / stats_13m['oracle_1q']['mean_delta'] if stats_13m['oracle_1q']['mean_delta'] != 0 else 0
    print(f"  Multi-query amplification factor (oracle 5q/1q): {amp:.2f}x")
if 'real_5q_070' in stats_13m and 'real_1q_070' in stats_13m:
    amp = stats_13m['real_5q_070']['mean_delta'] / stats_13m['real_1q_070']['mean_delta'] if stats_13m['real_1q_070']['mean_delta'] != 0 else 0
    print(f"  Multi-query amplification factor (real 5q/1q @0.70): {amp:.2f}x")

# Diversity
if 'real_5q_070' in stats_13m and 'repeated_1q_5x' in stats_13m:
    div_better = stats_13m['real_5q_070']['mean_delta'] > stats_13m['repeated_1q_5x']['mean_delta']
    print(f"  Diversity helps: {div_better} (5 diverse: d={stats_13m['real_5q_070']['cohens_d']:.3f}, "
          f"1 repeated 5x: d={stats_13m['repeated_1q_5x']['cohens_d']:.3f})")

# Practical comparison
if 'real_5q_050' in stats_13m and 'real_1q_070' in stats_13m:
    prac = stats_13m['real_5q_050']['mean_delta'] > stats_13m['real_1q_070']['mean_delta']
    print(f"  5q@0.50 > 1q@0.70: {prac} (5q@0.50: d={stats_13m['real_5q_050']['cohens_d']:.3f}, "
          f"1q@0.70: d={stats_13m['real_1q_070']['cohens_d']:.3f})")

# --- Exp 13B Summary ---
print(f"\n--- Exp 13B: Hardness Gating ---")
print(f"  bare NLL vs oracle delta correlation: r={r_bare_delta:.3f}")
for qs in quartile_stats:
    print(f"  {qs['label']}: d={qs['cohens_d']:.3f}, win%={qs['win_rate']:.1f}%")

# Best gated strategy
best_gate_d = 0
best_gate_pct = 0
always_d_val = np.mean(oracle_delta_14) / np.std(oracle_delta_14, ddof=1)
for T_pct in [25, 50, 75]:
    T = np.percentile(bare_14, T_pct)
    pm = bare_14 >= T
    gated = np.where(pm, oracle_14, bare_14)
    gd = bare_14 - gated
    gd_d = np.mean(gd) / np.std(gd, ddof=1) if np.std(gd, ddof=1) > 0 else 0
    if gd_d > best_gate_d:
        best_gate_d = gd_d
        best_gate_pct = T_pct

print(f"\n  Best gate: P{best_gate_pct} (d={best_gate_d:.3f} vs always-prime d={always_d_val:.3f})")
if best_gate_d > always_d_val:
    print(f"  Gating IMPROVES over always-prime by {best_gate_d - always_d_val:.3f} d")
else:
    print(f"  Gating does NOT improve over always-prime")

In [None]:
# ============================================================
# Save Results
# ============================================================

output = {
    'metadata': {
        'experiment': '13_multi_query_and_gating',
        'timestamp': datetime.datetime.now().isoformat(),
        'model_name': config.model_name,
        'seeds': SEEDS,
        'n_exp13a_marco': N_EXP13_MARCO,
        'n_exp13a_squad': N_EXP13_SQUAD,
        'n_exp13b': N_EXP14,
    },
    'exp13a_marco': {
        'n_evaluated': len(results_13m),
        'skipped': skipped_13m, 'errors': errors_13m,
        'stats': stats_13m,
        'results': results_13m,
    },
    'exp13a_squad': {
        'n_evaluated': len(results_13s),
        'skipped': skipped_13s, 'errors': errors_13s,
        'stats': stats_13s,
        'results': results_13s,
    },
    'exp13b': {
        'n_evaluated': len(results_14),
        'quartile_stats': quartile_stats,
        'bare_delta_r': float(r_bare_delta),
        'results': results_14,
    },
}

output_path = 'results/exp13/13_results.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Results saved to {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1e6:.1f} MB")

# Final summary figure
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Scaling curves (MS MARCO)
ax = axes[0]
for series, label, marker, color in [
    ([(k, c) for k, c in [(1,'oracle_1q'),(2,'oracle_2q'),(3,'oracle_3q'),(5,'oracle_5q')]
      if c in stats_13m], 'Oracle', 'o', '#c44e52'),
    ([(k, c) for k, c in [(1,'real_1q_070'),(3,'real_3q_070'),(5,'real_5q_070')]
      if c in stats_13m], 'Real (sim>=0.70)', 's', '#4c72b0'),
]:
    if series:
        ks = [s[0] for s in series]
        ds = [stats_13m[s[1]]['mean_delta'] for s in series]
        ax.plot(ks, ds, f'{marker}-', color=color, linewidth=2, markersize=8, label=label)
ax.axhline(0, color='gray', linestyle='--')
ax.set_xlabel('Number of prefix queries (K)')
ax.set_ylabel('Mean Delta NLL')
ax.set_title('Exp 13A: Multi-Query Scaling')
ax.legend()

# Plot 2: Hardness interaction
ax = axes[1]
ax.bar(range(4), [qs['cohens_d'] for qs in quartile_stats], color=colors_q)
ax.set_xticks(range(4))
ax.set_xticklabels([f'Q{i+1}' for i in range(4)])
ax.set_ylabel("Cohen's d")
ax.set_title('Exp 13B: Effect Size by Difficulty Quartile')
ax.axhline(0, color='gray', linestyle='--')

# Plot 3: Cross-dataset comparison (1q vs 5q)
ax = axes[2]
datasets_plot = []
d1q_plot = []
d5q_plot = []
if 'oracle_1q' in stats_13m and 'oracle_5q' in stats_13m:
    datasets_plot.append('MARCO')
    d1q_plot.append(stats_13m['oracle_1q']['cohens_d'])
    d5q_plot.append(stats_13m['oracle_5q']['cohens_d'])
if 'oracle_1q' in stats_13s and 'oracle_5q' in stats_13s:
    datasets_plot.append('SQuAD')
    d1q_plot.append(stats_13s['oracle_1q']['cohens_d'])
    d5q_plot.append(stats_13s['oracle_5q']['cohens_d'])
if datasets_plot:
    x = np.arange(len(datasets_plot))
    w = 0.35
    ax.bar(x - w/2, d1q_plot, w, label='1 query', color='#4c72b0')
    ax.bar(x + w/2, d5q_plot, w, label='5 queries', color='#c44e52')
    ax.set_xticks(x)
    ax.set_xticklabels(datasets_plot)
    ax.set_ylabel("Cohen's d")
    ax.set_title('1q vs 5q: Cross-Dataset')
    ax.legend()

plt.suptitle('Experiment 13 Summary (Parts A & B)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('results/exp13/13_summary.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 13_summary.png')