# Exp 05: Hardness-Gated Semantic Priming + LLM Surrogates

## Motivation

Exp 01-03 consistently show priming helps MORE for hard samples (r=0.15-0.30 hardness
correlation). But the overall effect is dragged to zero by easy samples where priming
interferes. This experiment asks: **if we only prime hard samples, does oracle beat random?**

Also tests LLM-generated surrogates — the production scenario where you generate a topic
query rather than having the oracle.

## Design (two-pass)

1. **First pass**: Score bare NLL on 4000 MS MARCO samples
2. **Filter**: Keep hardest 50% (bare NLL > median) → ~2000 samples
3. **Generate**: LLM surrogates for each hard sample using Mistral-7B chat
4. **Second pass**: Full 5-condition eval on hard samples

## Conditions (5)

| # | Condition | Cache | Tests |
|---|-----------|-------|-------|
| 1 | Bare | `[BOS][doc]` | Baseline |
| 2 | Oracle-truncated | `[BOS][query\n][doc]` → truncate + RoPE | Semantic signal |
| 3 | Random-truncated | `[BOS][random\n][doc]` → truncate + RoPE | Structural control |
| 4 | Separator-only | `[BOS][doc][\n\nRelated question: ]` | Framing effect |
| 5 | LLM-surrogate-truncated | `[BOS][llm_query\n][doc]` → truncate + RoPE | Production scenario |

In [1]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
import json
import time
import numpy as np
import torch
from pathlib import Path

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

RESULTS_DIR = Path("results/exp05")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_PATH = RESULTS_DIR / "checkpoint.json"
SURROGATES_PATH = RESULTS_DIR / "surrogates.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

SEED: 42
Results directory: results/exp05
CUDA available: True
GPU: NVIDIA L4
GPU memory: 23.6 GB


In [2]:
# Cell 2: Load model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Loading {MODEL_NAME} (4-bit)...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded. dtype={model.dtype}, device={model.device}")



Loading mistralai/Mistral-7B-Instruct-v0.2 (4-bit)...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded. dtype=torch.float16, device=cuda:0


In [3]:
# Cell 3: Imports + config + templates + shared functions
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.kv_cache import (
    build_kv_cache,
    build_suffix_kv_cache,
    score_answer_with_cache,
    deepcopy_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
)
from lib.data import load_ms_marco, load_evaluation_samples
from lib.analysis import cohens_d
from lib.surrogate import generate_surrogate
from scipy import stats
from tqdm.auto import tqdm

config = ExperimentConfig(
    model_name=MODEL_NAME,
    num_samples=4000,
    min_passage_words=20,
    max_passage_words=500,
    seed=SEED,
)

SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"
SUFFIX_SEPARATOR = "\n\nRelated question: "
CHECKPOINT_EVERY = 50
N_COMPARISONS = 7
BONFERRONI_ALPHA = 0.05 / N_COMPARISONS


def generate_random_prefix_text(target_text, tokenizer, seed):
    target_ids = tokenizer.encode(target_text, add_special_tokens=False)
    target_len = len(target_ids)
    if target_len == 0:
        return ""
    rng = np.random.RandomState(seed)
    vocab_size = len(tokenizer)
    min_id = 3
    random_ids = rng.randint(min_id, vocab_size, size=target_len)
    random_text = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
    reencoded = tokenizer.encode(random_text, add_special_tokens=False)
    if len(reencoded) != target_len:
        if len(reencoded) > target_len:
            random_text = tokenizer.decode(reencoded[:target_len], skip_special_tokens=True)
        else:
            extra_needed = target_len - len(reencoded)
            extra_ids = rng.randint(min_id, vocab_size, size=extra_needed)
            extra_text = tokenizer.decode(extra_ids.tolist(), skip_special_tokens=True)
            random_text = random_text + extra_text
            reencoded2 = tokenizer.encode(random_text, add_special_tokens=False)
            if len(reencoded2) > target_len:
                random_text = tokenizer.decode(reencoded2[:target_len], skip_special_tokens=True)
    return random_text


print("Config ready")
print(f"  num_samples pool: {config.num_samples}")
print(f"  passage words: {config.min_passage_words}-{config.max_passage_words}")
print(f"  bonferroni_alpha: {BONFERRONI_ALPHA:.4f}")

Config ready
  num_samples pool: 4000
  passage words: 20-500
  bonferroni_alpha: 0.0071


In [4]:
# Cell 4: Load MS MARCO (large pool)
dataset = load_ms_marco(config)

np.random.seed(SEED)
all_samples = load_evaluation_samples(dataset, config, require_answer=True)

N_POOL = len(all_samples)
print(f"Loaded {N_POOL} samples for hardness screening")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'microsoft/ms_marco' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading microsoft/ms_marco dataset...
Dataset loaded: 10047 samples
Filtering samples...


Filtering:   0%|          | 0/10047 [00:00<?, ?it/s]

Selected 4000 samples
Loaded 4000 samples for hardness screening


In [5]:
# Cell 5: First pass — score bare NLL on all samples
# This identifies the "hard" samples where priming is most likely to help.

print("=" * 70)
print("FIRST PASS: BARE NLL SCORING")
print("=" * 70)

bare_nlls_path = RESULTS_DIR / "bare_nlls.json"

if bare_nlls_path.exists():
    with open(bare_nlls_path, 'r') as f:
        bare_data = json.load(f)
    bare_nlls_all = bare_data['bare_nlls']
    print(f"Loaded {len(bare_nlls_all)} bare NLLs from cache")
else:
    bare_nlls_all = []
    t_start = time.time()

    for idx in tqdm(range(len(bare_nlls_all), N_POOL), initial=len(bare_nlls_all),
                     total=N_POOL, desc="Bare NLL"):
        sample = all_samples[idx]
        passage = sample['passage']
        query = sample['query']
        answer = sample['answer']
        query_prompt = QUERY_TEMPLATE.format(query=query)
        answer_text = ANSWER_TEMPLATE.format(answer=answer)

        bare_len, bare_cache = build_kv_cache(passage, model, tokenizer, config)
        bare_nll = score_answer_with_cache(
            deepcopy_cache(bare_cache), bare_len,
            query_prompt, answer_text, model, tokenizer, config)
        bare_nlls_all.append(bare_nll)

        del bare_cache
        torch.cuda.empty_cache()

        if (idx + 1) % 200 == 0:
            with open(bare_nlls_path, 'w') as f:
                json.dump({'bare_nlls': bare_nlls_all}, f)
            elapsed = time.time() - t_start
            rate = (idx + 1) / elapsed
            print(f"  {idx+1}/{N_POOL} | {rate:.1f} s/s | ETA: {(N_POOL-idx-1)/rate/60:.1f} min")

    with open(bare_nlls_path, 'w') as f:
        json.dump({'bare_nlls': bare_nlls_all}, f)

    elapsed = time.time() - t_start
    print(f"First pass: {N_POOL} samples in {elapsed/60:.1f} min")

bare_nlls_arr = np.array(bare_nlls_all)
nonzero_mask = bare_nlls_arr > 0
print(f"\nBare NLL distribution (non-zero only, N={np.sum(nonzero_mask)}):")
bnz = bare_nlls_arr[nonzero_mask]
print(f"  Mean: {bnz.mean():.3f}, Median: {np.median(bnz):.3f}")
print(f"  Q25: {np.percentile(bnz, 25):.3f}, Q75: {np.percentile(bnz, 75):.3f}")

FIRST PASS: BARE NLL SCORING


Bare NLL:   0%|          | 0/4000 [00:00<?, ?it/s]

  200/4000 | 2.3 s/s | ETA: 27.3 min
  400/4000 | 2.3 s/s | ETA: 25.8 min
  600/4000 | 2.3 s/s | ETA: 24.3 min
  800/4000 | 2.3 s/s | ETA: 22.9 min
  1000/4000 | 2.3 s/s | ETA: 21.4 min
  1200/4000 | 2.3 s/s | ETA: 20.0 min
  1400/4000 | 2.3 s/s | ETA: 18.5 min
  1600/4000 | 2.3 s/s | ETA: 17.1 min
  1800/4000 | 2.3 s/s | ETA: 15.7 min
  2000/4000 | 2.3 s/s | ETA: 14.2 min
  2200/4000 | 2.3 s/s | ETA: 12.8 min
  2400/4000 | 2.3 s/s | ETA: 11.4 min
  2600/4000 | 2.3 s/s | ETA: 10.0 min
  2800/4000 | 2.3 s/s | ETA: 8.5 min
  3000/4000 | 2.3 s/s | ETA: 7.1 min
  3200/4000 | 2.3 s/s | ETA: 5.7 min
  3400/4000 | 2.3 s/s | ETA: 4.3 min
  3600/4000 | 2.3 s/s | ETA: 2.8 min
  3800/4000 | 2.3 s/s | ETA: 1.4 min
  4000/4000 | 2.3 s/s | ETA: 0.0 min
First pass: 4000 samples in 28.4 min

Bare NLL distribution (non-zero only, N=3669):
  Mean: 1.119, Median: 0.610
  Q25: 0.221, Q75: 1.370


In [6]:
# Cell 6: Filter to hard samples + generate LLM surrogates
print("=" * 70)
print("FILTERING TO HARD SAMPLES + GENERATING SURROGATES")
print("=" * 70)

# Filter: keep samples with non-zero bare NLL above median
median_nll = np.median(bare_nlls_arr[bare_nlls_arr > 0])
print(f"Median bare NLL (non-zero): {median_nll:.3f}")

hard_indices = []
for i, nll in enumerate(bare_nlls_all):
    if nll > median_nll:
        hard_indices.append(i)

np.random.seed(SEED + 100)
np.random.shuffle(hard_indices)
hard_indices = hard_indices[:2000]
hard_indices.sort()  # sort for deterministic ordering

hard_samples = [all_samples[i] for i in hard_indices]
hard_bare_nlls = [bare_nlls_all[i] for i in hard_indices]

N_HARD = len(hard_samples)
print(f"Selected {N_HARD} hard samples")
print(f"Hard sample bare NLL: mean={np.mean(hard_bare_nlls):.3f}, "
      f"median={np.median(hard_bare_nlls):.3f}")

# Generate LLM surrogates
print(f"\nGenerating LLM surrogates for {N_HARD} samples...")

if SURROGATES_PATH.exists():
    with open(SURROGATES_PATH, 'r') as f:
        surrogates_data = json.load(f)
    surrogates = surrogates_data['surrogates']
    print(f"Loaded {len(surrogates)} surrogates from cache")
else:
    surrogates = []

start_gen = len(surrogates)
if start_gen < N_HARD:
    t_start = time.time()
    for idx in tqdm(range(start_gen, N_HARD), initial=start_gen, total=N_HARD,
                     desc="LLM Surrogates"):
        sample = hard_samples[idx]
        try:
            surrogate = generate_surrogate(sample['passage'], model, tokenizer, config)
        except Exception as e:
            surrogate = sample['query'][:30]  # fallback
            print(f"  WARNING: Generation failed for sample {idx}: {e}")
        surrogates.append(surrogate)

        if (idx + 1) % 100 == 0 or idx == N_HARD - 1:
            with open(SURROGATES_PATH, 'w') as f:
                json.dump({'surrogates': surrogates}, f)
            elapsed = time.time() - t_start
            rate = (idx - start_gen + 1) / elapsed if elapsed > 0 else 0
            remaining = (N_HARD - idx - 1) / rate if rate > 0 else 0
            tqdm.write(f"  Generated {idx+1}/{N_HARD} | {rate:.1f} s/s | ETA: {remaining/60:.1f} min")

    with open(SURROGATES_PATH, 'w') as f:
        json.dump({'surrogates': surrogates}, f)

# Show examples
print(f"\nSurrogate examples:")
for i in range(min(5, N_HARD)):
    print(f"  Oracle:    {hard_samples[i]['query'][:60]}...")
    print(f"  LLM:       {surrogates[i][:60]}...")
    print()

FILTERING TO HARD SAMPLES + GENERATING SURROGATES
Median bare NLL (non-zero): 0.610
Selected 1833 hard samples
Hard sample bare NLL: mean=1.999, median=1.377

Generating LLM surrogates for 1833 samples...


LLM Surrogates:   0%|          | 0/1833 [00:00<?, ?it/s]

  Generated 100/1833 | 0.7 s/s | ETA: 38.8 min
  Generated 200/1833 | 0.7 s/s | ETA: 39.5 min
  Generated 300/1833 | 0.7 s/s | ETA: 37.3 min
  Generated 400/1833 | 0.7 s/s | ETA: 34.3 min
  Generated 500/1833 | 0.7 s/s | ETA: 32.6 min
  Generated 600/1833 | 0.7 s/s | ETA: 30.4 min
  Generated 700/1833 | 0.7 s/s | ETA: 27.9 min
  Generated 800/1833 | 0.7 s/s | ETA: 25.6 min
  Generated 900/1833 | 0.7 s/s | ETA: 23.2 min
  Generated 1000/1833 | 0.7 s/s | ETA: 20.9 min
  Generated 1100/1833 | 0.7 s/s | ETA: 18.4 min
  Generated 1200/1833 | 0.7 s/s | ETA: 15.9 min
  Generated 1300/1833 | 0.7 s/s | ETA: 13.3 min
  Generated 1400/1833 | 0.7 s/s | ETA: 10.8 min
  Generated 1500/1833 | 0.7 s/s | ETA: 8.4 min
  Generated 1600/1833 | 0.7 s/s | ETA: 5.9 min
  Generated 1700/1833 | 0.7 s/s | ETA: 3.4 min
  Generated 1800/1833 | 0.7 s/s | ETA: 0.8 min
  Generated 1833/1833 | 0.7 s/s | ETA: 0.0 min

Surrogate examples:
  Oracle:    what is provider based billing mean...
  LLM:       Provider Based B

In [7]:
# Cell 7: Main eval loop — 5 conditions on hard samples
print("=" * 70)
print("SECOND PASS: 5-CONDITION EVAL ON HARD SAMPLES")
print("=" * 70)

results = []
start_idx = 0

if CHECKPOINT_PATH.exists():
    with open(CHECKPOINT_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('sample_queries', [])
    current_queries = [s['query'] for s in hard_samples]
    if ckpt_queries == current_queries:
        results = ckpt['results']
        start_idx = len(results)
        print(f"Resuming from checkpoint: {start_idx}/{N_HARD}")
    else:
        print("Checkpoint mismatch. Starting fresh.")
else:
    print("No checkpoint found. Starting fresh.")

print(f"Evaluating samples {start_idx} to {N_HARD-1}")
print(f"Conditions: 5 (bare, oracle, random, separator, LLM-surrogate)")

t_start = time.time()

for idx in tqdm(range(start_idx, N_HARD), initial=start_idx, total=N_HARD,
                 desc="Evaluating"):
    sample = hard_samples[idx]
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    llm_surrogate = surrogates[idx]

    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # Matched tokenization
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_text = oracle_prefix + document_text

    full_oracle_enc = tokenizer(full_oracle_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(config.device)

    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]

    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    doc_len = doc_ids.shape[1]

    random_text = generate_random_prefix_text(query, tokenizer, seed=SEED + idx)

    # === Condition 1: BARE ===
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]
    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_out.past_key_values), bare_len,
        query_prompt, answer_text, model, tokenizer, config)

    # === Condition 2: ORACLE-TRUNCATED ===
    with torch.no_grad():
        oracle_out = model(input_ids=full_oracle_ids,
                           attention_mask=torch.ones_like(full_oracle_ids),
                           use_cache=True, return_dict=True)
    oracle_cache = extract_and_truncate_cache_with_bos(oracle_out.past_key_values, doc_len)
    correct_rope_positions_with_bos(oracle_cache, oracle_prefix_len - 1, model)
    oracle_trunc_nll = score_answer_with_cache(
        deepcopy_cache(oracle_cache), 1 + doc_len,
        query_prompt, answer_text, model, tokenizer, config)

    # === Condition 3: RANDOM-TRUNCATED ===
    random_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_text)
    random_prefix_enc = tokenizer(random_prefix, return_tensors="pt",
                                  add_special_tokens=False, padding=False, truncation=False)
    random_prefix_ids = random_prefix_enc['input_ids'].to(config.device)
    random_full_ids = torch.cat([bos_id, random_prefix_ids, doc_ids], dim=1)
    random_prefix_len = 1 + random_prefix_ids.shape[1]

    with torch.no_grad():
        random_out = model(input_ids=random_full_ids,
                           attention_mask=torch.ones_like(random_full_ids),
                           use_cache=True, return_dict=True)
    random_cache = extract_and_truncate_cache_with_bos(random_out.past_key_values, doc_len)
    correct_rope_positions_with_bos(random_cache, random_prefix_len - 1, model)
    random_trunc_nll = score_answer_with_cache(
        deepcopy_cache(random_cache), 1 + doc_len,
        query_prompt, answer_text, model, tokenizer, config)

    # === Condition 4: SEPARATOR-ONLY ===
    sep_only_len, sep_only_cache = build_suffix_kv_cache(
        passage, "", model, tokenizer, config, separator=SUFFIX_SEPARATOR)
    separator_only_nll = score_answer_with_cache(
        deepcopy_cache(sep_only_cache), sep_only_len,
        query_prompt, answer_text, model, tokenizer, config)

    # === Condition 5: LLM-SURROGATE-TRUNCATED ===
    llm_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=llm_surrogate)
    llm_prefix_enc = tokenizer(llm_prefix, return_tensors="pt",
                                add_special_tokens=False, padding=False, truncation=False)
    llm_prefix_ids = llm_prefix_enc['input_ids'].to(config.device)
    llm_full_ids = torch.cat([bos_id, llm_prefix_ids, doc_ids], dim=1)
    llm_prefix_len = 1 + llm_prefix_ids.shape[1]

    with torch.no_grad():
        llm_out = model(input_ids=llm_full_ids,
                        attention_mask=torch.ones_like(llm_full_ids),
                        use_cache=True, return_dict=True)
    llm_cache = extract_and_truncate_cache_with_bos(llm_out.past_key_values, doc_len)
    correct_rope_positions_with_bos(llm_cache, llm_prefix_len - 1, model)
    llm_trunc_nll = score_answer_with_cache(
        deepcopy_cache(llm_cache), 1 + doc_len,
        query_prompt, answer_text, model, tokenizer, config)

    result = {
        'idx': idx,
        'bare_nll': bare_nll,
        'oracle_trunc_nll': oracle_trunc_nll,
        'random_trunc_nll': random_trunc_nll,
        'separator_only_nll': separator_only_nll,
        'llm_trunc_nll': llm_trunc_nll,
        'doc_len': doc_len,
        'passage_word_count': len(passage.split()),
        'pre_filter_bare_nll': hard_bare_nlls[idx],
        'llm_surrogate': llm_surrogate,
        'delta_oracle_vs_bare': bare_nll - oracle_trunc_nll,
        'delta_random_vs_bare': bare_nll - random_trunc_nll,
        'delta_oracle_vs_random': random_trunc_nll - oracle_trunc_nll,
        'delta_seponly_vs_bare': bare_nll - separator_only_nll,
        'delta_llm_vs_bare': bare_nll - llm_trunc_nll,
        'delta_llm_vs_random': random_trunc_nll - llm_trunc_nll,
        'delta_llm_vs_oracle': oracle_trunc_nll - llm_trunc_nll,
    }
    results.append(result)

    del bare_out, oracle_out, oracle_cache, random_out, random_cache
    del sep_only_cache, llm_out, llm_cache
    torch.cuda.empty_cache()

    if (idx + 1) % CHECKPOINT_EVERY == 0 or idx == N_HARD - 1:
        ckpt_data = {
            'results': results,
            'sample_queries': [s['query'] for s in hard_samples],
            'completed': len(results),
            'total': N_HARD,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_PATH, 'w') as f:
            json.dump(ckpt_data, f)
        elapsed = time.time() - t_start
        rate = (idx - start_idx + 1) / elapsed if elapsed > 0 else 0
        remaining = (N_HARD - idx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint {idx+1}/{N_HARD} | {rate:.2f} s/s | ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEvaluation complete: {len(results)} samples in {elapsed_total/60:.1f} min")

SECOND PASS: 5-CONDITION EVAL ON HARD SAMPLES
No checkpoint found. Starting fresh.
Evaluating samples 0 to 1832
Conditions: 5 (bare, oracle, random, separator, LLM-surrogate)


Evaluating:   0%|          | 0/1833 [00:00<?, ?it/s]

  Checkpoint 50/1833 | 0.46 s/s | ETA: 64.3 min
  Checkpoint 100/1833 | 0.46 s/s | ETA: 62.7 min
  Checkpoint 150/1833 | 0.46 s/s | ETA: 61.1 min
  Checkpoint 200/1833 | 0.46 s/s | ETA: 59.4 min
  Checkpoint 250/1833 | 0.46 s/s | ETA: 57.7 min
  Checkpoint 300/1833 | 0.46 s/s | ETA: 55.8 min
  Checkpoint 350/1833 | 0.46 s/s | ETA: 54.0 min
  Checkpoint 400/1833 | 0.46 s/s | ETA: 52.1 min
  Checkpoint 450/1833 | 0.46 s/s | ETA: 50.3 min
  Checkpoint 500/1833 | 0.46 s/s | ETA: 48.5 min
  Checkpoint 550/1833 | 0.46 s/s | ETA: 46.6 min
  Checkpoint 600/1833 | 0.46 s/s | ETA: 44.8 min
  Checkpoint 650/1833 | 0.46 s/s | ETA: 43.0 min
  Checkpoint 700/1833 | 0.46 s/s | ETA: 41.2 min
  Checkpoint 750/1833 | 0.46 s/s | ETA: 39.4 min
  Checkpoint 800/1833 | 0.46 s/s | ETA: 37.6 min
  Checkpoint 850/1833 | 0.46 s/s | ETA: 35.8 min
  Checkpoint 900/1833 | 0.46 s/s | ETA: 33.9 min
  Checkpoint 950/1833 | 0.46 s/s | ETA: 32.1 min
  Checkpoint 1000/1833 | 0.46 s/s | ETA: 30.3 min
  Checkpoint 1050/18

In [8]:
# Cell 8: Analysis
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

print("=" * 70)
print("ANALYSIS — HARDNESS-GATED SEMANTIC PRIMING")
print("=" * 70)

bare_raw = np.array([r['bare_nll'] for r in results])
oracle_raw = np.array([r['oracle_trunc_nll'] for r in results])
random_raw = np.array([r['random_trunc_nll'] for r in results])
seponly_raw = np.array([r['separator_only_nll'] for r in results])
llm_raw = np.array([r['llm_trunc_nll'] for r in results])

valid = (bare_raw != 0) & (oracle_raw != 0) & (random_raw != 0) & (seponly_raw != 0) & (llm_raw != 0)
n_valid = int(np.sum(valid))
n_excluded = int(np.sum(~valid))

bare = bare_raw[valid]
oracle = oracle_raw[valid]
random = random_raw[valid]
sep_only = seponly_raw[valid]
llm = llm_raw[valid]

print(f"Total: {len(results)}, Valid: {n_valid}, Excluded: {n_excluded}")
print(f"Bonferroni alpha: {BONFERRONI_ALPHA:.4f} ({N_COMPARISONS} comparisons)")

# NLL summary
print(f"\n{'Condition':<25} {'Mean NLL':>10} {'Std':>10}")
print("-" * 50)
for name, arr in [('Bare', bare), ('Oracle-truncated', oracle),
                   ('Random-truncated', random), ('Separator-only', sep_only),
                   ('LLM-surrogate-trunc', llm)]:
    print(f"{name:<25} {np.mean(arr):>10.4f} {np.std(arr):>10.4f}")

# Comparisons
comparisons = [
    ('Oracle vs Bare', bare - oracle, 'Does oracle help hard samples?'),
    ('Random vs Bare', bare - random, 'Does random help hard samples?'),
    ('Oracle vs Random', random - oracle, 'KEY: semantic priming in hard?'),
    ('Sep-only vs Bare', bare - sep_only, 'Separator on hard samples?'),
    ('LLM vs Bare', bare - llm, 'LLM surrogate overall?'),
    ('LLM vs Random', random - llm, 'LLM better than random?'),
    ('LLM vs Oracle', oracle - llm, 'LLM better than oracle?'),
]

print(f"\n{'Comparison':<25} {'Mean D':>8} {'d':>8} {'Win%':>7} {'t':>8} {'p':>12} {'Sig':>5}")
print("-" * 80)

comparison_results = {}
for name, delta, question in comparisons:
    d = cohens_d(delta)
    win = np.mean(delta > 0) * 100
    t_stat, p_val = stats.ttest_1samp(delta, 0)
    sig = "***" if p_val < 0.001 else "**" if p_val < BONFERRONI_ALPHA else "*" if p_val < 0.05 else "ns"
    print(f"{name:<25} {np.mean(delta):>8.4f} {d:>8.3f} {win:>6.1f}% {t_stat:>8.2f} {p_val:>11.2e} {sig:>5}")
    comparison_results[name] = {
        'mean_delta': float(np.mean(delta)),
        'cohens_d': float(d),
        'win_rate': float(win / 100),
        't_stat': float(t_stat),
        'p_value': float(p_val),
        'bonferroni_significant': bool(p_val < BONFERRONI_ALPHA),
        'question': question,
    }

# Hardness interaction within hard samples
print(f"\nHardness interaction (within hard subset):")
for name, delta in [('Oracle', bare - oracle), ('Random', bare - random),
                     ('LLM', bare - llm), ('Sep-only', bare - sep_only)]:
    r, p = stats.pearsonr(bare, delta)
    print(f"  {name:<15}: r={r:.3f}, p={p:.2e}")

# Summary verdict
print(f"\n{'_' * 70}")
print("VERDICTS:")
d_or = comparison_results['Oracle vs Random']['cohens_d']
p_or = comparison_results['Oracle vs Random']['p_value']
d_lr = comparison_results['LLM vs Random']['cohens_d']
p_lr = comparison_results['LLM vs Random']['p_value']

if p_or < 0.05 and d_or > 0:
    print(f"  Oracle vs Random: SEMANTIC PRIMING DETECTED in hard samples (d={d_or:+.3f})")
elif p_or < 0.05 and d_or < 0:
    print(f"  Oracle vs Random: Oracle still INTERFERES even for hard samples (d={d_or:+.3f})")
else:
    print(f"  Oracle vs Random: No semantic signal even in hard samples (d={d_or:+.3f}, ns)")

if p_lr < 0.05 and d_lr > 0:
    print(f"  LLM vs Random: LLM surrogates ADD value beyond random (d={d_lr:+.3f})")
else:
    print(f"  LLM vs Random: LLM surrogates no better than random (d={d_lr:+.3f})")

ANALYSIS — HARDNESS-GATED SEMANTIC PRIMING
Total: 1833, Valid: 1833, Excluded: 0
Bonferroni alpha: 0.0071 (7 comparisons)

Condition                   Mean NLL        Std
--------------------------------------------------
Bare                          1.9998     1.8515
Oracle-truncated              1.9305     1.8396
Random-truncated              1.9290     1.8609
Separator-only                1.7865     1.6843
LLM-surrogate-trunc           1.8241     1.7409

Comparison                  Mean D        d    Win%        t            p   Sig
--------------------------------------------------------------------------------
Oracle vs Bare              0.0693    0.131   58.6%     5.59    2.59e-08   ***
Random vs Bare              0.0708    0.153   60.2%     6.55    7.65e-11   ***
Oracle vs Random           -0.0015   -0.003   51.1%    -0.12    9.06e-01    ns
Sep-only vs Bare            0.2133    0.372   81.0%    15.93    1.39e-53   ***
LLM vs Bare                 0.1757    0.374   74.7%    16.01

In [9]:
# Cell 9: Plots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot 1: Delta distributions — Oracle vs Random (THE key test)
delta_or = random - oracle
axes[0, 0].hist(delta_or, bins=60, color='steelblue', alpha=0.7, edgecolor='black', linewidth=0.3)
axes[0, 0].axvline(x=0, color='red', linestyle='--')
axes[0, 0].axvline(x=np.mean(delta_or), color='black', linestyle='-',
                    label=f'd={cohens_d(delta_or):+.3f}')
axes[0, 0].set_title('Oracle vs Random (hard samples)')
axes[0, 0].legend()

# Plot 2: Delta distributions — LLM vs Random
delta_lr = random - llm
axes[0, 1].hist(delta_lr, bins=60, color='forestgreen', alpha=0.7, edgecolor='black', linewidth=0.3)
axes[0, 1].axvline(x=0, color='red', linestyle='--')
axes[0, 1].axvline(x=np.mean(delta_lr), color='black', linestyle='-',
                    label=f'd={cohens_d(delta_lr):+.3f}')
axes[0, 1].set_title('LLM Surrogate vs Random (hard samples)')
axes[0, 1].legend()

# Plot 3: All conditions bar chart
cond_names = ['Oracle', 'Random', 'Sep-only', 'LLM']
cond_d = [cohens_d(bare - oracle), cohens_d(bare - random),
          cohens_d(bare - sep_only), cohens_d(bare - llm)]
colors = ['steelblue', 'darkorange', 'crimson', 'forestgreen']
axes[0, 2].bar(range(4), cond_d, color=colors, alpha=0.8, edgecolor='black')
axes[0, 2].set_xticks(range(4))
axes[0, 2].set_xticklabels(cond_names)
axes[0, 2].axhline(y=0, color='gray', linestyle='--')
axes[0, 2].set_ylabel("Cohen's d vs Bare")
axes[0, 2].set_title('All Conditions vs Bare (hard samples)')

# Plot 4: Hardness scatter — Oracle benefit
axes[1, 0].scatter(bare, bare - oracle, alpha=0.15, s=5, c='steelblue')
axes[1, 0].axhline(y=0, color='red', linestyle='--')
z = np.polyfit(bare, bare - oracle, 1)
x_range = np.linspace(bare.min(), bare.max(), 100)
axes[1, 0].plot(x_range, np.polyval(z, x_range), 'r-', alpha=0.8)
axes[1, 0].set_xlabel('Bare NLL')
axes[1, 0].set_ylabel('Oracle benefit')
axes[1, 0].set_title('Hardness vs Oracle Benefit')

# Plot 5: Hardness scatter — LLM benefit
axes[1, 1].scatter(bare, bare - llm, alpha=0.15, s=5, c='forestgreen')
axes[1, 1].axhline(y=0, color='red', linestyle='--')
z2 = np.polyfit(bare, bare - llm, 1)
axes[1, 1].plot(x_range, np.polyval(z2, x_range), 'r-', alpha=0.8)
axes[1, 1].set_xlabel('Bare NLL')
axes[1, 1].set_ylabel('LLM benefit')
axes[1, 1].set_title('Hardness vs LLM Surrogate Benefit')

# Plot 6: LLM surrogate quality
# Token overlap between LLM surrogate and oracle query
from lib.analysis import compute_token_overlap
overlaps = []
for r in results[:min(500, len(results))]:
    s = hard_samples[r['idx']]
    overlap = compute_token_overlap(s['query'], r['llm_surrogate'], tokenizer)
    overlaps.append(overlap)
axes[1, 2].hist(overlaps, bins=40, color='mediumpurple', alpha=0.7, edgecolor='black', linewidth=0.3)
axes[1, 2].set_xlabel('Token Jaccard similarity')
axes[1, 2].set_ylabel('Count')
axes[1, 2].set_title(f'LLM vs Oracle Query Similarity (mean={np.mean(overlaps):.3f})')

plt.suptitle('Exp 05: Hardness-Gated Semantic Priming + LLM Surrogates', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'analysis_plots.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plot saved to {RESULTS_DIR / 'analysis_plots.png'}")

Plot saved to results/exp05/analysis_plots.png


In [10]:
# Cell 10: Save results
final = {
    'experiment': 'exp05_hardness_gated_surrogates',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': config.model_name,
        'seed': SEED,
        'n_pool': N_POOL,
        'n_hard': N_HARD,
        'median_nll_threshold': float(median_nll),
        'min_passage_words': config.min_passage_words,
        'max_passage_words': config.max_passage_words,
        'bonferroni_alpha': BONFERRONI_ALPHA,
    },
    'summary': {
        'n_total': len(results),
        'n_valid': n_valid,
        'n_excluded': n_excluded,
        'nll_means': {
            'bare': float(np.mean(bare)),
            'oracle_trunc': float(np.mean(oracle)),
            'random_trunc': float(np.mean(random)),
            'separator_only': float(np.mean(sep_only)),
            'llm_trunc': float(np.mean(llm)),
        },
        'comparisons': comparison_results,
    },
    'per_sample_results': results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final, f, indent=2)

print(f"Results saved to {FINAL_RESULTS_PATH}")
print(f"File size: {FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB")
print(f"\nDone!")

Results saved to results/exp05/results.json
File size: 1335.8 KB

Done!


In [11]:
# Cell 11: GPU cleanup — free all VRAM
import gc

print("Cleaning up GPU memory...")
mem_before = torch.cuda.memory_allocated() / 1e9

del model
del tokenizer

gc.collect()
torch.cuda.empty_cache()
gc.collect()

mem_after = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {mem_before:.2f} GB -> {mem_after:.2f} GB")
print("Cleanup complete.")

Cleaning up GPU memory...
GPU memory: 4.14 GB -> 0.01 GB
Cleanup complete.
