# Exp 04: Cross-Dataset Semantic Priming Battery

## Goal

Test whether semantic priming (oracle query > random prefix) emerges with longer documents
across multiple datasets. Exp 01-03 used MS MARCO (~80 words) and found random beats oracle.
Hypothesis: passages were too short for semantic signal to matter.

## Datasets (length gradient)

| Dataset | Avg doc words | N samples | Source |
|---------|--------------|-----------|--------|
| MS MARCO multi-passage | ~550 | 1500 | 10 passages concatenated per query |
| SQuAD 2.0 | ~150 | 1500 | Wikipedia paragraphs |
| Natural Questions | ~800 | 800 | Wikipedia article windows |
| TriviaQA | ~800 | 800 | Wikipedia/web evidence windows |

## Conditions (4 per dataset)

| # | Condition | Cache | Tests |
|---|-----------|-------|-------|
| 1 | Bare | `[BOS][doc]` (matched tokenization) | Baseline |
| 2 | Oracle-truncated | `[BOS][query\n][doc]` → truncate + RoPE | Semantic signal |
| 3 | Random-truncated | `[BOS][random\n][doc]` → truncate + RoPE | Structural control |
| 4 | Separator-only | `[BOS][doc][\n\nRelated question: ]` | Framing effect |

## Key comparison: Oracle vs Random (positive d = semantic priming works)

In [None]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
import json
import time
import numpy as np
import torch
from pathlib import Path

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

RESULTS_DIR = Path("results/exp04")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: Load model (Mistral-7B 4-bit)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Loading {MODEL_NAME} (4-bit)...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded. dtype={model.dtype}, device={model.device}")
print(f"Vocab size: {len(tokenizer)}")

In [None]:
# Cell 3: Imports + config + templates + shared functions
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.kv_cache import (
    build_kv_cache,
    build_suffix_kv_cache,
    score_answer_with_cache,
    deepcopy_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
)
from lib.analysis import cohens_d
from scipy import stats
from tqdm.auto import tqdm
import traceback

config = ExperimentConfig(
    model_name=MODEL_NAME,
    num_samples=5000,
    min_passage_words=20,
    max_passage_words=5000,
    seed=SEED,
)

SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"
SUFFIX_SEPARATOR = "\n\nRelated question: "
CHECKPOINT_EVERY = 50

print("Config ready")
print(f"  device: {config.device}")
print(f"  templates: prefix={repr(SURROGATE_PREFIX_TEMPLATE)}, doc={repr(DOCUMENT_TEMPLATE)}")
print(f"  suffix_separator: {repr(SUFFIX_SEPARATOR)}")


# ===================== SHARED UTILITY FUNCTIONS =====================

def generate_random_prefix_text(target_text, tokenizer, seed):
    target_ids = tokenizer.encode(target_text, add_special_tokens=False)
    target_len = len(target_ids)
    if target_len == 0:
        return ""
    rng = np.random.RandomState(seed)
    vocab_size = len(tokenizer)
    min_id = 3
    random_ids = rng.randint(min_id, vocab_size, size=target_len)
    random_text = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
    reencoded = tokenizer.encode(random_text, add_special_tokens=False)
    if len(reencoded) != target_len:
        if len(reencoded) > target_len:
            random_text = tokenizer.decode(reencoded[:target_len], skip_special_tokens=True)
        else:
            extra_needed = target_len - len(reencoded)
            extra_ids = rng.randint(min_id, vocab_size, size=extra_needed)
            extra_text = tokenizer.decode(extra_ids.tolist(), skip_special_tokens=True)
            random_text = random_text + extra_text
            reencoded2 = tokenizer.encode(random_text, add_special_tokens=False)
            if len(reencoded2) > target_len:
                random_text = tokenizer.decode(reencoded2[:target_len], skip_special_tokens=True)
    return random_text


def evaluate_sample(sample, idx, model, tokenizer, device, seed, n_total):
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # Matched tokenization
    oracle_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=query)
    document_text = DOCUMENT_TEMPLATE.format(document=passage)
    full_oracle_text = oracle_prefix + document_text

    full_oracle_enc = tokenizer(full_oracle_text, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
    full_oracle_ids = full_oracle_enc['input_ids'].to(device)

    oracle_prefix_enc = tokenizer(oracle_prefix, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
    oracle_prefix_len = oracle_prefix_enc['input_ids'].shape[1]

    bos_id = full_oracle_ids[:, :1]
    doc_ids = full_oracle_ids[:, oracle_prefix_len:]
    doc_len = doc_ids.shape[1]

    random_text = generate_random_prefix_text(query, tokenizer, seed=seed + idx)

    # Condition 1: BARE
    bare_ids = torch.cat([bos_id, doc_ids], dim=1)
    bare_len = bare_ids.shape[1]
    with torch.no_grad():
        bare_out = model(input_ids=bare_ids, attention_mask=torch.ones_like(bare_ids),
                         use_cache=True, return_dict=True)
    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_out.past_key_values), bare_len,
        query_prompt, answer_text, model, tokenizer, config)

    # Condition 2: ORACLE-TRUNCATED
    with torch.no_grad():
        oracle_out = model(input_ids=full_oracle_ids,
                           attention_mask=torch.ones_like(full_oracle_ids),
                           use_cache=True, return_dict=True)
    oracle_cache = extract_and_truncate_cache_with_bos(oracle_out.past_key_values, doc_len)
    oracle_trunc_len = 1 + doc_len
    correct_rope_positions_with_bos(oracle_cache, oracle_prefix_len - 1, model)
    oracle_trunc_nll = score_answer_with_cache(
        deepcopy_cache(oracle_cache), oracle_trunc_len,
        query_prompt, answer_text, model, tokenizer, config)

    # Condition 3: RANDOM-TRUNCATED
    random_prefix = SURROGATE_PREFIX_TEMPLATE.format(surrogate=random_text)
    random_prefix_enc = tokenizer(random_prefix, return_tensors="pt",
                                  add_special_tokens=False, padding=False, truncation=False)
    random_prefix_ids = random_prefix_enc['input_ids'].to(device)
    random_full_ids = torch.cat([bos_id, random_prefix_ids, doc_ids], dim=1)
    random_prefix_len = 1 + random_prefix_ids.shape[1]

    with torch.no_grad():
        random_out = model(input_ids=random_full_ids,
                           attention_mask=torch.ones_like(random_full_ids),
                           use_cache=True, return_dict=True)
    random_cache = extract_and_truncate_cache_with_bos(random_out.past_key_values, doc_len)
    random_trunc_len = 1 + doc_len
    correct_rope_positions_with_bos(random_cache, random_prefix_len - 1, model)
    random_trunc_nll = score_answer_with_cache(
        deepcopy_cache(random_cache), random_trunc_len,
        query_prompt, answer_text, model, tokenizer, config)

    # Condition 4: SEPARATOR-ONLY
    sep_only_len, sep_only_cache = build_suffix_kv_cache(
        passage, "", model, tokenizer, config, separator=SUFFIX_SEPARATOR)
    separator_only_nll = score_answer_with_cache(
        deepcopy_cache(sep_only_cache), sep_only_len,
        query_prompt, answer_text, model, tokenizer, config)

    del bare_out, oracle_out, oracle_cache, random_out, random_cache, sep_only_cache
    torch.cuda.empty_cache()

    return {
        'idx': idx,
        'bare_nll': bare_nll,
        'oracle_trunc_nll': oracle_trunc_nll,
        'random_trunc_nll': random_trunc_nll,
        'separator_only_nll': separator_only_nll,
        'bare_len': bare_len,
        'oracle_trunc_len': oracle_trunc_len,
        'random_trunc_len': random_trunc_len,
        'separator_only_len': sep_only_len,
        'doc_len': doc_len,
        'passage_word_count': len(passage.split()),
        'delta_oracle_vs_bare': bare_nll - oracle_trunc_nll,
        'delta_random_vs_bare': bare_nll - random_trunc_nll,
        'delta_oracle_vs_random': random_trunc_nll - oracle_trunc_nll,
        'delta_seponly_vs_bare': bare_nll - separator_only_nll,
    }


def run_experiment(samples, dataset_name, model, tokenizer, config, results_dir,
                   seed=42, checkpoint_every=50):
    dataset_dir = results_dir / dataset_name
    dataset_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = dataset_dir / "checkpoint.json"
    results_path = dataset_dir / "results.json"

    N = len(samples)
    results = []
    start_idx = 0

    if checkpoint_path.exists():
        with open(checkpoint_path, 'r') as f:
            ckpt = json.load(f)
        ckpt_queries = ckpt.get('sample_queries', [])
        current_queries = [s['query'] for s in samples]
        if ckpt_queries == current_queries:
            results = ckpt['results']
            start_idx = len(results)
            print(f"  Resuming from checkpoint: {start_idx}/{N}")
        else:
            print("  Checkpoint mismatch. Starting fresh.")

    t_start = time.time()
    for idx in tqdm(range(start_idx, N), initial=start_idx, total=N,
                     desc=dataset_name):
        result = evaluate_sample(samples[idx], idx, model, tokenizer,
                                 config.device, seed, N)
        results.append(result)

        if (idx + 1) % checkpoint_every == 0 or idx == N - 1:
            ckpt_data = {
                'results': results,
                'sample_queries': [s['query'] for s in samples],
                'completed': len(results),
                'total': N,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            }
            with open(checkpoint_path, 'w') as f:
                json.dump(ckpt_data, f)
            elapsed = time.time() - t_start
            rate = (idx - start_idx + 1) / elapsed if elapsed > 0 else 0
            remaining = (N - idx - 1) / rate if rate > 0 else 0
            tqdm.write(f"  [{dataset_name}] {idx+1}/{N} | "
                       f"{rate:.2f} s/s | ETA: {remaining/60:.1f} min")

    elapsed = time.time() - t_start
    print(f"  {dataset_name}: {len(results)} samples in {elapsed/60:.1f} min")

    analysis = analyze_experiment(results, dataset_name)

    final = {
        'dataset': dataset_name,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'config': {'num_samples': N, 'seed': seed},
        'summary': analysis,
        'per_sample_results': results,
    }
    with open(results_path, 'w') as f:
        json.dump(final, f, indent=2)
    print(f"  Saved to {results_path}")

    return results, analysis


def analyze_experiment(results, dataset_name):
    bare = np.array([r['bare_nll'] for r in results])
    oracle = np.array([r['oracle_trunc_nll'] for r in results])
    random = np.array([r['random_trunc_nll'] for r in results])
    seponly = np.array([r['separator_only_nll'] for r in results])

    valid = (bare != 0) & (oracle != 0) & (random != 0) & (seponly != 0)
    n_valid = int(np.sum(valid))
    n_excluded = int(np.sum(~valid))

    b, o, r, s = bare[valid], oracle[valid], random[valid], seponly[valid]

    comparisons = [
        ('Oracle vs Bare', b - o, 'Does oracle prefix help?'),
        ('Random vs Bare', b - r, 'Does any prefix help?'),
        ('Oracle vs Random', r - o, 'THE KEY TEST: semantic signal?'),
        ('Sep-only vs Bare', b - s, 'Separator framing alone?'),
    ]

    print(f"\n  [{dataset_name}] N={n_valid} valid ({n_excluded} excluded)")
    print(f"  {'Comparison':<25} {'d':>8} {'Win%':>7} {'p':>12} {'Sig':>5}")
    print(f"  " + "-" * 60)

    comp_results = {}
    for name, delta, question in comparisons:
        d = cohens_d(delta)
        win = np.mean(delta > 0) * 100
        t_stat, p_val = stats.ttest_1samp(delta, 0)
        sig = "***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else "ns"
        print(f"  {name:<25} {d:>8.3f} {win:>6.1f}% {p_val:>12.2e} {sig:>5}")
        comp_results[name] = {
            'mean_delta': float(np.mean(delta)),
            'cohens_d': float(d),
            'win_rate': float(win / 100),
            't_stat': float(t_stat),
            'p_value': float(p_val),
            'question': question,
        }

    # Hardness interaction for oracle
    delta_oracle = b - o
    r_hard, p_hard = stats.pearsonr(b, delta_oracle)
    print(f"  Hardness x Oracle: r={r_hard:.3f}, p={p_hard:.2e}")

    # Length stats
    doc_lens = np.array([r['doc_len'] for r in results])[valid]
    word_counts = np.array([r['passage_word_count'] for r in results])[valid]

    return {
        'n_total': len(results),
        'n_valid': n_valid,
        'n_excluded': n_excluded,
        'comparisons': comp_results,
        'nll_means': {
            'bare': float(np.mean(b)), 'oracle_trunc': float(np.mean(o)),
            'random_trunc': float(np.mean(r)), 'separator_only': float(np.mean(s)),
        },
        'hardness_oracle_r': float(r_hard),
        'hardness_oracle_p': float(p_hard),
        'doc_len_mean': float(np.mean(doc_lens)),
        'doc_len_median': float(np.median(doc_lens)),
        'word_count_mean': float(np.mean(word_counts)),
        'word_count_median': float(np.median(word_counts)),
    }


print("All shared functions defined.")

In [None]:
# Cell 4: Sub-experiment A — MS MARCO Multi-Passage
# Concatenate all 10 passages per query into one ~550-word document.
# Oracle query must find the answer in ONE passage among 10.

try:
    print("=" * 70)
    print("SUB-EXPERIMENT A: MS MARCO MULTI-PASSAGE")
    print("=" * 70)

    from datasets import load_dataset as hf_load_dataset

    print("  Loading MS MARCO v2.1 (all passages)...")
    ds_marco = hf_load_dataset('microsoft/ms_marco', 'v2.1', split='validation')

    msmarco_mp_candidates = []
    for item in ds_marco:
        answers = item['answers']
        if not answers or not answers[0] or answers[0] == 'No Answer Present.':
            continue
        passages = item['passages']['passage_text']
        full_doc = '\n\n'.join(passages)
        word_count = len(full_doc.split())
        if word_count < 100:
            continue
        query_text = item['query'].strip('. ')
        msmarco_mp_candidates.append({
            'passage': full_doc,
            'query': query_text,
            'answer': answers[0],
        })

    np.random.seed(SEED)
    np.random.shuffle(msmarco_mp_candidates)
    msmarco_mp_samples = msmarco_mp_candidates[:1500]

    wc = np.array([len(s['passage'].split()) for s in msmarco_mp_samples])
    print(f"  Loaded {len(msmarco_mp_samples)} samples")
    print(f"  Word counts: mean={wc.mean():.0f}, median={np.median(wc):.0f}, "
          f"min={wc.min()}, max={wc.max()}")
    print(f"  Example query: {msmarco_mp_samples[0]['query'][:80]}...")
    print(f"  Example answer: {msmarco_mp_samples[0]['answer'][:80]}...")

    msmarco_mp_results, msmarco_mp_analysis = run_experiment(
        msmarco_mp_samples, "msmarco_mp", model, tokenizer, config,
        RESULTS_DIR, seed=SEED)

    del ds_marco, msmarco_mp_candidates
except Exception as e:
    print(f"MS MARCO multi-passage FAILED: {e}")
    traceback.print_exc()
    msmarco_mp_results, msmarco_mp_analysis = None, None

In [None]:
# Cell 5: Sub-experiment B — SQuAD 2.0
# Wikipedia paragraphs (~150 words). Intermediate length.

try:
    print("=" * 70)
    print("SUB-EXPERIMENT B: SQuAD 2.0")
    print("=" * 70)

    from datasets import load_dataset as hf_load_dataset

    print("  Loading SQuAD 2.0...")
    ds_squad = hf_load_dataset('rajpurkar/squad_v2', split='validation')

    squad_candidates = []
    for item in ds_squad:
        answers = item['answers']
        if not answers['text']:
            continue  # unanswerable
        answer_text = answers['text'][0]
        if len(answer_text.split()) < 2:
            continue  # skip single-word (likely single-token NLL=0)
        squad_candidates.append({
            'passage': item['context'],
            'query': item['question'],
            'answer': answer_text,
        })

    np.random.seed(SEED + 1)
    np.random.shuffle(squad_candidates)
    squad_samples = squad_candidates[:1500]

    wc = np.array([len(s['passage'].split()) for s in squad_samples])
    print(f"  Loaded {len(squad_samples)} samples (from {len(squad_candidates)} candidates)")
    print(f"  Word counts: mean={wc.mean():.0f}, median={np.median(wc):.0f}, "
          f"min={wc.min()}, max={wc.max()}")
    print(f"  Example query: {squad_samples[0]['query'][:80]}...")
    print(f"  Example answer: {squad_samples[0]['answer'][:80]}...")

    squad_results, squad_analysis = run_experiment(
        squad_samples, "squad", model, tokenizer, config,
        RESULTS_DIR, seed=SEED)

    del ds_squad, squad_candidates
except Exception as e:
    print(f"SQuAD FAILED: {e}")
    traceback.print_exc()
    squad_results, squad_analysis = None, None

In [None]:
# Cell 6: Sub-experiment C — Natural Questions
# Wikipedia articles windowed to ~800 words around the short answer.

try:
    print("=" * 70)
    print("SUB-EXPERIMENT C: NATURAL QUESTIONS")
    print("=" * 70)

    from datasets import load_dataset as hf_load_dataset

    print("  Loading Natural Questions (streaming)...")
    ds_nq = hf_load_dataset('google-research-datasets/natural_questions',
                             'default', split='validation', streaming=True)

    WINDOW_WORDS = 800
    nq_candidates = []
    n_scanned = 0

    for item in ds_nq:
        n_scanned += 1
        if len(nq_candidates) >= 2400:
            break

        # Find short answer
        short_answers = item['annotations']['short_answers']
        answer_text = None
        for sa in short_answers:
            if sa['text'] and len(sa['text']) > 0:
                candidate = sa['text'][0]
                if len(candidate.split()) >= 2:
                    answer_text = candidate
                    break
        if answer_text is None:
            continue

        # Extract plain text from document tokens
        tokens = item['document']['tokens']
        plain_tokens = [t for t, h in zip(tokens['token'], tokens['is_html']) if not h]
        full_text = ' '.join(plain_tokens)

        # Find answer in text
        answer_pos = full_text.lower().find(answer_text.lower())
        if answer_pos == -1:
            continue

        # Window around answer
        words = full_text.split()
        char_count = 0
        answer_word_pos = 0
        for wi, w in enumerate(words):
            if char_count >= answer_pos:
                answer_word_pos = wi
                break
            char_count += len(w) + 1

        half = WINDOW_WORDS // 2
        start = max(0, answer_word_pos - half)
        end = min(len(words), start + WINDOW_WORDS)
        start = max(0, end - WINDOW_WORDS)

        passage = ' '.join(words[start:end])
        if len(passage.split()) < 100:
            continue

        question = item['question']
        if isinstance(question, dict):
            question = question.get('text', str(question))

        nq_candidates.append({
            'passage': passage,
            'query': question,
            'answer': answer_text,
        })

    print(f"  Scanned {n_scanned} items, found {len(nq_candidates)} candidates")

    np.random.seed(SEED + 2)
    np.random.shuffle(nq_candidates)
    nq_samples = nq_candidates[:800]

    wc = np.array([len(s['passage'].split()) for s in nq_samples])
    print(f"  Selected {len(nq_samples)} samples")
    print(f"  Word counts: mean={wc.mean():.0f}, median={np.median(wc):.0f}, "
          f"min={wc.min()}, max={wc.max()}")
    print(f"  Example query: {nq_samples[0]['query'][:80]}...")
    print(f"  Example answer: {nq_samples[0]['answer'][:80]}...")

    nq_results, nq_analysis = run_experiment(
        nq_samples, "nq", model, tokenizer, config,
        RESULTS_DIR, seed=SEED)

    del nq_candidates
except Exception as e:
    print(f"Natural Questions FAILED: {e}")
    traceback.print_exc()
    nq_results, nq_analysis = None, None

In [None]:
# Cell 7: Sub-experiment D — TriviaQA
# Wikipedia/web evidence documents windowed to ~800 words around the answer.

try:
    print("=" * 70)
    print("SUB-EXPERIMENT D: TriviaQA")
    print("=" * 70)

    from datasets import load_dataset as hf_load_dataset

    print("  Loading TriviaQA RC (streaming)...")
    ds_tqa = hf_load_dataset('trivia_qa', 'rc', split='validation', streaming=True)

    WINDOW_WORDS = 800
    tqa_candidates = []
    n_scanned = 0

    for item in ds_tqa:
        n_scanned += 1
        if len(tqa_candidates) >= 2400:
            break

        answer_text = item['answer']['value']
        if len(answer_text.split()) < 2:
            continue

        # Find evidence containing the answer
        wiki_contexts = item['entity_pages'].get('wiki_context', [])
        search_contexts = item['search_results'].get('search_context', [])

        context = None
        for ctx in wiki_contexts + search_contexts:
            if answer_text.lower() in ctx.lower():
                context = ctx
                break
        if context is None:
            continue

        # Window around answer
        answer_pos = context.lower().find(answer_text.lower())
        words = context.split()
        char_count = 0
        answer_word_pos = 0
        for wi, w in enumerate(words):
            if char_count >= answer_pos:
                answer_word_pos = wi
                break
            char_count += len(w) + 1

        half = WINDOW_WORDS // 2
        start = max(0, answer_word_pos - half)
        end = min(len(words), start + WINDOW_WORDS)
        start = max(0, end - WINDOW_WORDS)

        passage = ' '.join(words[start:end])
        if len(passage.split()) < 100:
            continue

        tqa_candidates.append({
            'passage': passage,
            'query': item['question'],
            'answer': answer_text,
        })

    print(f"  Scanned {n_scanned} items, found {len(tqa_candidates)} candidates")

    np.random.seed(SEED + 3)
    np.random.shuffle(tqa_candidates)
    tqa_samples = tqa_candidates[:800]

    wc = np.array([len(s['passage'].split()) for s in tqa_samples])
    print(f"  Selected {len(tqa_samples)} samples")
    print(f"  Word counts: mean={wc.mean():.0f}, median={np.median(wc):.0f}, "
          f"min={wc.min()}, max={wc.max()}")
    print(f"  Example query: {tqa_samples[0]['query'][:80]}...")
    print(f"  Example answer: {tqa_samples[0]['answer'][:80]}...")

    tqa_results, tqa_analysis = run_experiment(
        tqa_samples, "triviaqa", model, tokenizer, config,
        RESULTS_DIR, seed=SEED)

    del tqa_candidates
except Exception as e:
    print(f"TriviaQA FAILED: {e}")
    traceback.print_exc()
    tqa_results, tqa_analysis = None, None

In [None]:
# Cell 8: Cross-dataset comparison table
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

print("=" * 70)
print("CROSS-DATASET COMPARISON")
print("=" * 70)

all_analyses = {}
for name, analysis in [('msmarco_mp', msmarco_mp_analysis),
                        ('squad', squad_analysis),
                        ('nq', nq_analysis),
                        ('triviaqa', tqa_analysis)]:
    if analysis is not None:
        all_analyses[name] = analysis

if not all_analyses:
    print("No experiments completed successfully!")
else:
    print(f"\n{'Dataset':<18} {'N':>6} {'Words':>8} {'Tokens':>8} "
          f"{'d(Orc-Bare)':>12} {'d(Rnd-Bare)':>12} {'d(Orc-Rnd)':>12} {'d(Sep-Bare)':>12} "
          f"{'Hard r':>8}")
    print("-" * 110)

    for name, a in all_analyses.items():
        c = a['comparisons']
        d_ob = c['Oracle vs Bare']['cohens_d']
        d_rb = c['Random vs Bare']['cohens_d']
        d_or = c['Oracle vs Random']['cohens_d']
        d_sb = c['Sep-only vs Bare']['cohens_d']
        p_or = c['Oracle vs Random']['p_value']
        sig = "***" if p_or < 0.001 else "**" if p_or < 0.01 else "*" if p_or < 0.05 else "ns"

        print(f"{name:<18} {a['n_valid']:>6} {a['word_count_mean']:>7.0f} "
              f"{a['doc_len_mean']:>7.0f} "
              f"{d_ob:>+12.3f} {d_rb:>+12.3f} {d_or:>+12.3f}{sig:<3} "
              f"{d_sb:>+12.3f} {a['hardness_oracle_r']:>8.3f}")

    print(f"\n{'_' * 110}")
    print("KEY: d(Orc-Rnd) is THE critical test. Positive = semantic priming works.")
    print("     Negative = oracle interferes (same as Exp 01).")
    print(f"\nInterpretation:")
    for name, a in all_analyses.items():
        d_or = a['comparisons']['Oracle vs Random']['cohens_d']
        p_or = a['comparisons']['Oracle vs Random']['p_value']
        if p_or < 0.05 and d_or > 0:
            print(f"  {name}: SEMANTIC PRIMING DETECTED (d={d_or:+.3f}, p={p_or:.2e})")
        elif p_or < 0.05 and d_or < 0:
            print(f"  {name}: Oracle INTERFERES (d={d_or:+.3f}, p={p_or:.2e})")
        else:
            print(f"  {name}: No semantic signal (d={d_or:+.3f}, p={p_or:.2e}, ns)")

In [None]:
# Cell 9: Cross-dataset plots

if all_analyses:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    datasets = list(all_analyses.keys())
    x = np.arange(len(datasets))

    # Plot 1: Oracle vs Random d (THE KEY TEST)
    d_vals = [all_analyses[d]['comparisons']['Oracle vs Random']['cohens_d'] for d in datasets]
    colors = ['green' if d > 0 else 'red' for d in d_vals]
    axes[0].bar(x, d_vals, color=colors, alpha=0.7, edgecolor='black')
    axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(datasets, rotation=15)
    axes[0].set_ylabel("Cohen's d")
    axes[0].set_title("Oracle vs Random (+ = semantic priming works)")

    # Plot 2: All conditions vs Bare
    width = 0.2
    for ci, (cname, key) in enumerate([
        ('Oracle', 'Oracle vs Bare'), ('Random', 'Random vs Bare'),
        ('Sep-only', 'Sep-only vs Bare')]):
        d_vals = [all_analyses[d]['comparisons'][key]['cohens_d'] for d in datasets]
        axes[1].bar(x + ci * width, d_vals, width, label=cname, alpha=0.8)
    axes[1].axhline(y=0, color='gray', linestyle='--')
    axes[1].set_xticks(x + width)
    axes[1].set_xticklabels(datasets, rotation=15)
    axes[1].set_ylabel("Cohen's d vs Bare")
    axes[1].set_title("All Conditions vs Bare")
    axes[1].legend()

    # Plot 3: Effect size vs document length
    doc_lens = [all_analyses[d]['doc_len_mean'] for d in datasets]
    d_or = [all_analyses[d]['comparisons']['Oracle vs Random']['cohens_d'] for d in datasets]
    d_rb = [all_analyses[d]['comparisons']['Random vs Bare']['cohens_d'] for d in datasets]
    axes[2].scatter(doc_lens, d_or, s=100, c='steelblue', label='Oracle vs Random', zorder=3)
    axes[2].scatter(doc_lens, d_rb, s=100, c='darkorange', marker='^', label='Random vs Bare', zorder=3)
    axes[2].axhline(y=0, color='gray', linestyle='--')
    for i, name in enumerate(datasets):
        axes[2].annotate(name, (doc_lens[i], d_or[i]), textcoords="offset points",
                          xytext=(5, 5), fontsize=8)
    axes[2].set_xlabel('Mean document token length')
    axes[2].set_ylabel("Cohen's d")
    axes[2].set_title("Effect Size vs Document Length")
    axes[2].legend()

    plt.tight_layout()
    plt.savefig(RESULTS_DIR / 'cross_dataset_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Plot saved to {RESULTS_DIR / 'cross_dataset_comparison.png'}")

In [None]:
# Cell 10: Save combined results

combined = {
    'experiment': 'exp04_semantic_priming_battery',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'datasets': {},
}

for name, analysis in all_analyses.items():
    combined['datasets'][name] = analysis

combined_path = RESULTS_DIR / 'combined_results.json'
with open(combined_path, 'w') as f:
    json.dump(combined, f, indent=2)

print(f"Combined results saved to {combined_path}")
print(f"Individual dataset results in {RESULTS_DIR}/{{dataset_name}}/results.json")
print(f"\nDone! Total datasets completed: {len(all_analyses)}")

In [None]:
# Cell 11: GPU cleanup — free all VRAM before next notebook
import gc

print("Cleaning up GPU memory...")
mem_before = torch.cuda.memory_allocated() / 1e9

# Delete model and tokenizer
del model
del tokenizer

# Clear all remaining tensors
gc.collect()
torch.cuda.empty_cache()
gc.collect()

mem_after = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {mem_before:.2f} GB -> {mem_after:.2f} GB")
print("Cleanup complete. Safe to start next notebook.")