# Experiment 19: Cross-Dataset Survey

**Date:** 2026-02-04

## Purpose

Exp 18 showed that priming hurts on MS MARCO. But MS MARCO may be the wrong task:
- Short passages (~74 words)
- Extractive answers
- Model already performs well (low bare NLL)

This experiment surveys multiple datasets to find where priming helps:

| Dataset | Type | Why it might help |
|---------|------|-------------------|
| MS MARCO (hard) | Factoid QA | Filter to hard samples only |
| NarrativeQA | Long-doc QA | Long documents need focus |
| HotpotQA | Multi-hop | Reasoning chains need priming |
| PubMedQA | Scientific | Domain expertise needed |
| CNN/DailyMail | Summarization | Generative, needs focus |
| Natural Questions | Wikipedia QA | Mixed difficulty |

## Design

For each dataset:
- Sample 100 examples
- Test 3 conditions: bare, oracle_truncated, oracle_fullctx
- Report win rates and effect sizes

In [1]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
sys.path.insert(0, '/home/jupyter/research/directed_kvcache')

import json
import random
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from scipy import stats

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from datasets import load_dataset

from lib.kv_cache import (
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    score_answer_with_cache,
    deepcopy_cache,
)
from lib.config import ExperimentConfig

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTPUT_DIR = '/home/jupyter/research/directed_kvcache/results/exp19'
N_SAMPLES_PER_DATASET = 100  # Keep manageable for survey

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

PyTorch: 2.10.0+cu128
CUDA: True


In [2]:
# Cell 2: Load Model
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

config = ExperimentConfig(
    model_name=MODEL_NAME,
    device=model.device,
    seed=SEED
)

print(f"Model loaded on {model.device}, dtype={model.dtype}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded on cuda:0, dtype=torch.bfloat16


In [3]:
# Cell 3: Core Evaluation Functions

@dataclass
class Sample:
    """Unified sample format across datasets."""
    passage: str
    query: str
    answer: str
    dataset: str
    metadata: dict = None

def build_bare_cache(passage: str) -> Tuple[DynamicCache, int]:
    """Build baseline cache from passage only."""
    ids = tokenizer.encode(passage, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return out.past_key_values, ids.shape[1]

def build_primed_cache_fullcontext(prefix: str, passage: str) -> Tuple[DynamicCache, int]:
    """Build full-context cache: prefix stays visible."""
    text = prefix + " " + passage
    ids = tokenizer.encode(text, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return out.past_key_values, ids.shape[1]

def build_primed_cache_truncated(prefix: str, passage: str) -> Tuple[DynamicCache, int]:
    """Build truncated cache: prefix removed after forward pass."""
    prefix_with_sep = prefix + " "
    prefix_ids = tokenizer.encode(prefix_with_sep, return_tensors='pt', add_special_tokens=True)
    prefix_len = prefix_ids.shape[1]
    
    full_text = prefix_with_sep + passage
    full_ids = tokenizer.encode(full_text, return_tensors='pt', add_special_tokens=True).to(model.device)
    full_len = full_ids.shape[1]
    doc_len = full_len - prefix_len
    
    with torch.no_grad():
        out = model(full_ids, use_cache=True)
    
    truncated_cache = extract_and_truncate_cache_with_bos(out.past_key_values, doc_len)
    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated_cache, surrogate_offset, model)
    
    return truncated_cache, 1 + doc_len

def score_sample(cache: DynamicCache, cache_len: int, query: str, answer: str) -> float:
    """Score P(answer | cache, query)."""
    return score_answer_with_cache(
        cache, cache_len,
        f"\n\nQuery: {query}\nAnswer:",
        " " + answer,
        model, tokenizer, config
    )

def evaluate_sample(sample: Sample) -> dict:
    """Evaluate a single sample across all conditions."""
    passage = sample.passage[:4000]  # Truncate very long passages
    query = sample.query
    answer = sample.answer[:500]  # Truncate very long answers
    
    # Oracle prefix: query repeated 5x
    oracle_prefix = " ".join([query] * 5)
    
    # Build caches
    bare_cache, bare_len = build_bare_cache(passage)
    trunc_cache, trunc_len = build_primed_cache_truncated(oracle_prefix, passage)
    full_cache, full_len = build_primed_cache_fullcontext(oracle_prefix, passage)
    
    # Score
    nll_bare = score_sample(deepcopy_cache(bare_cache), bare_len, query, answer)
    nll_trunc = score_sample(deepcopy_cache(trunc_cache), trunc_len, query, answer)
    nll_full = score_sample(deepcopy_cache(full_cache), full_len, query, answer)
    
    return {
        'dataset': sample.dataset,
        'passage_words': len(passage.split()),
        'answer_words': len(answer.split()),
        'nll_bare': nll_bare,
        'nll_truncated': nll_trunc,
        'nll_fullctx': nll_full,
        'delta_truncated': nll_bare - nll_trunc,
        'delta_fullctx': nll_bare - nll_full,
    }

print("Evaluation functions defined.")

Evaluation functions defined.


In [4]:
# Cell 4: Dataset Loaders

def load_msmarco_hard(n: int) -> List[Sample]:
    """Load MS MARCO, filtered to hard samples (will filter after scoring bare)."""
    print("Loading MS MARCO...")
    ds = load_dataset("ms_marco", "v1.1", split="train")
    samples = []
    for item in ds:
        if item['answers'] and item['answers'][0] and item['passages']['passage_text']:
            samples.append(Sample(
                passage=item['passages']['passage_text'][0],
                query=item['query'],
                answer=item['answers'][0],
                dataset='msmarco_hard'
            ))
        if len(samples) >= n * 3:  # Get extra to filter
            break
    random.shuffle(samples)
    return samples[:n * 3]  # Will filter to hardest n later

def load_narrativeqa(n: int) -> List[Sample]:
    """Load NarrativeQA - long document QA."""
    print("Loading NarrativeQA...")
    try:
        ds = load_dataset("narrativeqa", split="test", trust_remote_code=True)
        samples = []
        for item in ds:
            if item['document']['summary']['text']:
                samples.append(Sample(
                    passage=item['document']['summary']['text'][:6000],
                    query=item['question']['text'],
                    answer=item['answers'][0]['text'] if item['answers'] else "",
                    dataset='narrativeqa',
                    metadata={'doc_type': item['document'].get('kind', 'unknown')}
                ))
            if len(samples) >= n:
                break
        random.shuffle(samples)
        return samples[:n]
    except Exception as e:
        print(f"  Error loading NarrativeQA: {e}")
        return []

def load_hotpotqa(n: int) -> List[Sample]:
    """Load HotpotQA - multi-hop reasoning."""
    print("Loading HotpotQA...")
    try:
        ds = load_dataset("hotpot_qa", "fullwiki", split="validation")
        samples = []
        for item in ds:
            # Combine supporting paragraphs
            context_parts = []
            for title, sentences in zip(item['context']['title'], item['context']['sentences']):
                context_parts.append(f"{title}: {' '.join(sentences)}")
            passage = "\n\n".join(context_parts[:3])  # Limit to 3 paragraphs
            
            if passage and item['answer']:
                samples.append(Sample(
                    passage=passage,
                    query=item['question'],
                    answer=item['answer'],
                    dataset='hotpotqa',
                    metadata={'type': item['type'], 'level': item['level']}
                ))
            if len(samples) >= n:
                break
        random.shuffle(samples)
        return samples[:n]
    except Exception as e:
        print(f"  Error loading HotpotQA: {e}")
        return []

def load_pubmedqa(n: int) -> List[Sample]:
    """Load PubMedQA - scientific/medical QA."""
    print("Loading PubMedQA...")
    try:
        ds = load_dataset("pubmed_qa", "pqa_labeled", split="train")
        samples = []
        for item in ds:
            context = " ".join(item['context']['contexts']) if item['context']['contexts'] else ""
            if context and item['long_answer']:
                samples.append(Sample(
                    passage=context,
                    query=item['question'],
                    answer=item['long_answer'],
                    dataset='pubmedqa',
                    metadata={'final_decision': item['final_decision']}
                ))
            if len(samples) >= n:
                break
        random.shuffle(samples)
        return samples[:n]
    except Exception as e:
        print(f"  Error loading PubMedQA: {e}")
        return []

def load_cnn_dailymail(n: int) -> List[Sample]:
    """Load CNN/DailyMail - summarization as QA."""
    print("Loading CNN/DailyMail...")
    try:
        ds = load_dataset("cnn_dailymail", "3.0.0", split="test")
        samples = []
        for item in ds:
            if item['article'] and item['highlights']:
                samples.append(Sample(
                    passage=item['article'][:5000],
                    query="Summarize the key points of this article.",
                    answer=item['highlights'],
                    dataset='cnn_dailymail'
                ))
            if len(samples) >= n:
                break
        random.shuffle(samples)
        return samples[:n]
    except Exception as e:
        print(f"  Error loading CNN/DailyMail: {e}")
        return []

def load_natural_questions(n: int) -> List[Sample]:
    """Load Natural Questions - Wikipedia QA."""
    print("Loading Natural Questions...")
    try:
        ds = load_dataset("natural_questions", "default", split="validation")
        samples = []
        for item in ds:
            # Get long answer context
            doc_tokens = item['document']['tokens']
            doc_text = ' '.join([t['token'] for t in doc_tokens['token'][:1000]])  # Limit tokens
            
            # Get short answer if available
            short_answers = item['annotations']['short_answers']
            if short_answers and short_answers[0]['start_token'] >= 0:
                ans_start = short_answers[0]['start_token'][0]
                ans_end = short_answers[0]['end_token'][0]
                answer = ' '.join([t['token'] for t in doc_tokens['token'][ans_start:ans_end]])
            else:
                continue  # Skip if no short answer
            
            if doc_text and answer:
                samples.append(Sample(
                    passage=doc_text,
                    query=item['question']['text'],
                    answer=answer,
                    dataset='natural_questions'
                ))
            if len(samples) >= n:
                break
        random.shuffle(samples)
        return samples[:n]
    except Exception as e:
        print(f"  Error loading Natural Questions: {e}")
        return []

def load_squad_v2(n: int) -> List[Sample]:
    """Load SQuAD v2 - reading comprehension with unanswerable questions."""
    print("Loading SQuAD v2...")
    try:
        ds = load_dataset("squad_v2", split="validation")
        samples = []
        for item in ds:
            # Only use answerable questions
            if item['answers']['text']:
                samples.append(Sample(
                    passage=item['context'],
                    query=item['question'],
                    answer=item['answers']['text'][0],
                    dataset='squad_v2'
                ))
            if len(samples) >= n:
                break
        random.shuffle(samples)
        return samples[:n]
    except Exception as e:
        print(f"  Error loading SQuAD v2: {e}")
        return []

print("Dataset loaders defined.")

Dataset loaders defined.


In [5]:
# Cell 5: Load All Datasets

print("="*70)
print("LOADING DATASETS")
print("="*70)

all_samples = {}

# Load each dataset
loaders = [
    ('msmarco_hard', load_msmarco_hard),
    ('squad_v2', load_squad_v2),
    ('hotpotqa', load_hotpotqa),
    ('pubmedqa', load_pubmedqa),
    ('cnn_dailymail', load_cnn_dailymail),
    ('narrativeqa', load_narrativeqa),
    # ('natural_questions', load_natural_questions),  # Often slow/large
]

for name, loader in loaders:
    try:
        samples = loader(N_SAMPLES_PER_DATASET)
        if samples:
            all_samples[name] = samples
            print(f"  {name}: {len(samples)} samples loaded")
        else:
            print(f"  {name}: No samples loaded")
    except Exception as e:
        print(f"  {name}: Error - {e}")

print(f"\nTotal datasets loaded: {len(all_samples)}")

LOADING DATASETS
Loading MS MARCO...
  msmarco_hard: 300 samples loaded
Loading SQuAD v2...
  squad_v2: 100 samples loaded
Loading HotpotQA...
  hotpotqa: 100 samples loaded
Loading PubMedQA...
  pubmedqa: 100 samples loaded
Loading CNN/DailyMail...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'narrativeqa' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


  cnn_dailymail: 100 samples loaded
Loading NarrativeQA...


Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

  narrativeqa: 100 samples loaded

Total datasets loaded: 6


In [6]:
# Cell 6: Show Dataset Statistics

print("="*70)
print("DATASET STATISTICS")
print("="*70)

print(f"\n{'Dataset':<20} {'N':>6} {'Passage Words':>15} {'Answer Words':>15}")
print("-"*60)

for name, samples in all_samples.items():
    passage_words = [len(s.passage.split()) for s in samples[:50]]
    answer_words = [len(s.answer.split()) for s in samples[:50]]
    print(f"{name:<20} {len(samples):>6} {np.mean(passage_words):>15.1f} {np.mean(answer_words):>15.1f}")

print("\n" + "="*70)
print("SAMPLE EXAMPLES")
print("="*70)

for name, samples in all_samples.items():
    s = samples[0]
    print(f"\n### {name} ###")
    print(f"Query: {s.query[:100]}..." if len(s.query) > 100 else f"Query: {s.query}")
    print(f"Answer: {s.answer[:100]}..." if len(s.answer) > 100 else f"Answer: {s.answer}")
    print(f"Passage: {s.passage[:150]}...")

DATASET STATISTICS

Dataset                   N   Passage Words    Answer Words
------------------------------------------------------------
msmarco_hard            300            64.3            12.9
squad_v2                100           105.4             1.6
hotpotqa                100           303.2             2.1
pubmedqa                100           187.5            44.6
cnn_dailymail           100           521.6            34.0
narrativeqa             100           301.6             3.4

SAMPLE EXAMPLES

### msmarco_hard ###
Query: what was true of the desegregation of the armed forces under president truman
Answer: An official government policy.
Passage: 1 Full text of Executive Order 9981 from the Harry S. Truman Presidential Library and Museum. 2  Integration of the Armed Forces, 1940-1965 (Defense S...

### squad_v2 ###
Query: In what century did important classical music developments occur in Normandy?
Answer: 11th
Passage: Normandy was the site of several important devel

In [7]:
# Cell 7: Run Evaluation

print("="*70)
print("RUNNING EVALUATION")
print("="*70)

all_results = {}
CHECKPOINT_PATH = f'{OUTPUT_DIR}/checkpoint.json'

# Load checkpoint if exists
if os.path.exists(CHECKPOINT_PATH):
    with open(CHECKPOINT_PATH, 'r') as f:
        all_results = json.load(f)
    print(f"Loaded checkpoint with {len(all_results)} datasets")

for dataset_name, samples in all_samples.items():
    if dataset_name in all_results and len(all_results[dataset_name]) >= N_SAMPLES_PER_DATASET:
        print(f"\n{dataset_name}: Already complete ({len(all_results[dataset_name])} samples)")
        continue
    
    print(f"\nProcessing {dataset_name}...")
    
    # For MS MARCO hard, we need to evaluate bare first then filter
    if dataset_name == 'msmarco_hard':
        # First pass: get bare NLL for all
        bare_nlls = []
        print("  First pass: computing bare NLL to filter hard samples...")
        for s in tqdm(samples[:N_SAMPLES_PER_DATASET * 2], desc="  Bare NLL"):
            try:
                cache, cache_len = build_bare_cache(s.passage[:4000])
                nll = score_sample(deepcopy_cache(cache), cache_len, s.query, s.answer[:500])
                bare_nlls.append((nll, s))
            except:
                continue
        
        # Sort by NLL (descending) and take hardest N
        bare_nlls.sort(key=lambda x: x[0], reverse=True)
        samples = [s for _, s in bare_nlls[:N_SAMPLES_PER_DATASET]]
        print(f"  Filtered to {len(samples)} hardest samples (NLL range: {bare_nlls[0][0]:.2f} - {bare_nlls[N_SAMPLES_PER_DATASET-1][0]:.2f})")
    
    results = []
    for sample in tqdm(samples[:N_SAMPLES_PER_DATASET], desc=f"  {dataset_name}"):
        try:
            result = evaluate_sample(sample)
            results.append(result)
        except Exception as e:
            continue
        
        # Checkpoint every 25 samples
        if len(results) % 25 == 0:
            all_results[dataset_name] = results
            with open(CHECKPOINT_PATH, 'w') as f:
                json.dump(all_results, f)
    
    all_results[dataset_name] = results
    print(f"  Completed: {len(results)} samples")

# Final save
with open(f'{OUTPUT_DIR}/results.json', 'w') as f:
    json.dump(all_results, f, indent=2)
print(f"\nSaved all results to {OUTPUT_DIR}/results.json")

RUNNING EVALUATION
Loaded checkpoint with 6 datasets

msmarco_hard: Already complete (100 samples)

squad_v2: Already complete (100 samples)

hotpotqa: Already complete (100 samples)

pubmedqa: Already complete (100 samples)

cnn_dailymail: Already complete (100 samples)

narrativeqa: Already complete (100 samples)

Saved all results to /home/jupyter/research/directed_kvcache/results/exp19/results.json


In [8]:
# Cell 8: Analysis - Summary Table

def cohens_d(x):
    return np.mean(x) / np.std(x, ddof=1) if np.std(x) > 0 else 0

print("="*70)
print("CROSS-DATASET SURVEY RESULTS")
print("="*70)

summary = []

print(f"\n{'Dataset':<18} {'N':>5} {'Bare NLL':>10} {'Trunc Win%':>12} {'Trunc d':>10} {'Full Win%':>12} {'Full d':>10}")
print("-"*85)

for dataset_name, results in all_results.items():
    if not results:
        continue
    
    n = len(results)
    bare = np.array([r['nll_bare'] for r in results])
    delta_trunc = np.array([r['delta_truncated'] for r in results])
    delta_full = np.array([r['delta_fullctx'] for r in results])
    
    win_trunc = np.mean(delta_trunc > 0) * 100
    win_full = np.mean(delta_full > 0) * 100
    d_trunc = cohens_d(delta_trunc)
    d_full = cohens_d(delta_full)
    
    print(f"{dataset_name:<18} {n:>5} {np.mean(bare):>10.2f} {win_trunc:>11.1f}% {d_trunc:>+10.3f} {win_full:>11.1f}% {d_full:>+10.3f}")
    
    summary.append({
        'dataset': dataset_name,
        'n': n,
        'mean_bare_nll': np.mean(bare),
        'mean_passage_words': np.mean([r['passage_words'] for r in results]),
        'mean_answer_words': np.mean([r['answer_words'] for r in results]),
        'truncated_win_rate': win_trunc,
        'truncated_cohens_d': d_trunc,
        'fullctx_win_rate': win_full,
        'fullctx_cohens_d': d_full,
    })

CROSS-DATASET SURVEY RESULTS

Dataset                N   Bare NLL   Trunc Win%    Trunc d    Full Win%     Full d
-------------------------------------------------------------------------------------
msmarco_hard         100       3.59        59.0%     +0.190        32.0%     -0.358
squad_v2             100       0.14        52.0%     +0.003        47.0%     -0.026
hotpotqa             100       1.68        34.0%     -0.348        29.0%     -0.381
pubmedqa             100       1.96        16.0%     -0.728         8.0%     -1.132
cnn_dailymail        100       2.77         8.0%     -1.307         8.0%     -1.330
narrativeqa          100       1.26        45.0%     -0.348        30.0%     -0.599


In [9]:
# Cell 9: Statistical Significance Tests

print("\n" + "="*70)
print("STATISTICAL TESTS: Does priming help? (one-sample t-test, H0: delta=0)")
print("="*70)

print(f"\n{'Dataset':<18} {'Truncated':>25} {'Full-Context':>25}")
print(f"{'':18} {'t-stat':>12} {'p-value':>12} {'t-stat':>12} {'p-value':>12}")
print("-"*70)

for dataset_name, results in all_results.items():
    if not results or len(results) < 10:
        continue
    
    delta_trunc = np.array([r['delta_truncated'] for r in results])
    delta_full = np.array([r['delta_fullctx'] for r in results])
    
    t_trunc, p_trunc = stats.ttest_1samp(delta_trunc, 0)
    t_full, p_full = stats.ttest_1samp(delta_full, 0)
    
    # Mark significant results
    sig_trunc = "*" if p_trunc < 0.05 else ""
    sig_full = "*" if p_full < 0.05 else ""
    
    print(f"{dataset_name:<18} {t_trunc:>+12.2f} {p_trunc:>11.4f}{sig_trunc} {t_full:>+12.2f} {p_full:>11.4f}{sig_full}")

print("\n* = p < 0.05 (significant)")
print("Positive t-stat = priming HELPS, Negative t-stat = priming HURTS")


STATISTICAL TESTS: Does priming help? (one-sample t-test, H0: delta=0)

Dataset                            Truncated              Full-Context
                         t-stat      p-value       t-stat      p-value
----------------------------------------------------------------------
msmarco_hard              +1.90      0.0603        -3.58      0.0005*
squad_v2                  +0.03      0.9723        -0.26      0.7963
hotpotqa                  -3.48      0.0008*        -3.81      0.0002*
pubmedqa                  -7.28      0.0000*       -11.32      0.0000*
cnn_dailymail            -13.07      0.0000*       -13.30      0.0000*
narrativeqa               -3.48      0.0007*        -5.99      0.0000*

* = p < 0.05 (significant)
Positive t-stat = priming HELPS, Negative t-stat = priming HURTS


In [10]:
# Cell 10: Identify Best Datasets for Priming

print("\n" + "="*70)
print("RANKING: Which datasets benefit most from priming?")
print("="*70)

# Sort by truncated Cohen's d
summary_sorted = sorted(summary, key=lambda x: x['truncated_cohens_d'], reverse=True)

print("\n### Ranked by TRUNCATED effect size (Cohen's d) ###")
print(f"{'Rank':<6} {'Dataset':<18} {'d':>10} {'Win%':>10} {'Bare NLL':>12} {'Passage Words':>15}")
print("-"*75)
for i, s in enumerate(summary_sorted, 1):
    verdict = "HELPS" if s['truncated_cohens_d'] > 0.1 else "HURTS" if s['truncated_cohens_d'] < -0.1 else "NEUTRAL"
    print(f"{i:<6} {s['dataset']:<18} {s['truncated_cohens_d']:>+10.3f} {s['truncated_win_rate']:>9.1f}% {s['mean_bare_nll']:>12.2f} {s['mean_passage_words']:>15.0f}  [{verdict}]")

print("\n### Ranked by FULL-CONTEXT effect size (Cohen's d) ###")
summary_sorted_full = sorted(summary, key=lambda x: x['fullctx_cohens_d'], reverse=True)
print(f"{'Rank':<6} {'Dataset':<18} {'d':>10} {'Win%':>10}")
print("-"*50)
for i, s in enumerate(summary_sorted_full, 1):
    verdict = "HELPS" if s['fullctx_cohens_d'] > 0.1 else "HURTS" if s['fullctx_cohens_d'] < -0.1 else "NEUTRAL"
    print(f"{i:<6} {s['dataset']:<18} {s['fullctx_cohens_d']:>+10.3f} {s['fullctx_win_rate']:>9.1f}%  [{verdict}]")


RANKING: Which datasets benefit most from priming?

### Ranked by TRUNCATED effect size (Cohen's d) ###
Rank   Dataset                     d       Win%     Bare NLL   Passage Words
---------------------------------------------------------------------------
1      msmarco_hard           +0.190      59.0%         3.59              69  [HELPS]
2      squad_v2               +0.003      52.0%         0.14             105  [NEUTRAL]
3      hotpotqa               -0.348      34.0%         1.68             296  [HURTS]
4      narrativeqa            -0.348      45.0%         1.26             323  [HURTS]
5      pubmedqa               -0.728      16.0%         1.96             191  [HURTS]
6      cnn_dailymail          -1.307       8.0%         2.77             446  [HURTS]

### Ranked by FULL-CONTEXT effect size (Cohen's d) ###
Rank   Dataset                     d       Win%
--------------------------------------------------
1      squad_v2               -0.026      47.0%  [NEUTRAL]
2      msm

In [11]:
# Cell 11: Correlation Analysis - What predicts priming benefit?

print("\n" + "="*70)
print("CORRELATION: What predicts priming benefit?")
print("="*70)

# Combine all results
all_flat = []
for dataset_name, results in all_results.items():
    for r in results:
        all_flat.append(r)

if len(all_flat) > 50:
    bare_nll = np.array([r['nll_bare'] for r in all_flat])
    passage_words = np.array([r['passage_words'] for r in all_flat])
    answer_words = np.array([r['answer_words'] for r in all_flat])
    delta_trunc = np.array([r['delta_truncated'] for r in all_flat])
    delta_full = np.array([r['delta_fullctx'] for r in all_flat])
    
    print(f"\nTotal samples across all datasets: {len(all_flat)}")
    
    print(f"\n{'Predictor':<20} {'Truncated Delta':>20} {'Full-Context Delta':>20}")
    print(f"{'':20} {'r':>10} {'p':>9} {'r':>10} {'p':>9}")
    print("-"*62)
    
    for name, predictor in [('Bare NLL', bare_nll), ('Passage Words', passage_words), ('Answer Words', answer_words)]:
        r1, p1 = stats.pearsonr(predictor, delta_trunc)
        r2, p2 = stats.pearsonr(predictor, delta_full)
        sig1 = "*" if p1 < 0.05 else ""
        sig2 = "*" if p2 < 0.05 else ""
        print(f"{name:<20} {r1:>+10.3f} {p1:>8.4f}{sig1} {r2:>+10.3f} {p2:>8.4f}{sig2}")
    
    print("\n* = p < 0.05")


CORRELATION: What predicts priming benefit?

Total samples across all datasets: 600

Predictor                 Truncated Delta   Full-Context Delta
                              r         p          r         p
--------------------------------------------------------------
Bare NLL                 -0.175   0.0000*     -0.211   0.0000*
Passage Words            -0.197   0.0000*     -0.139   0.0007*
Answer Words             -0.027   0.5027     +0.079   0.0521

* = p < 0.05


In [12]:
# Cell 12: Save Final Analysis

analysis = {
    'summary': summary,
    'total_samples': len(all_flat) if 'all_flat' in dir() else sum(len(r) for r in all_results.values()),
    'datasets_evaluated': list(all_results.keys()),
}

with open(f'{OUTPUT_DIR}/analysis.json', 'w') as f:
    json.dump(analysis, f, indent=2)

print(f"Analysis saved to {OUTPUT_DIR}/analysis.json")

print("\n" + "="*70)
print("CONCLUSIONS")
print("="*70)

print("\nDatasets where TRUNCATED priming helps (d > 0.1):")
for s in summary:
    if s['truncated_cohens_d'] > 0.1:
        print(f"  - {s['dataset']}: d={s['truncated_cohens_d']:+.3f}, win={s['truncated_win_rate']:.1f}%")

print("\nDatasets where TRUNCATED priming hurts (d < -0.1):")
for s in summary:
    if s['truncated_cohens_d'] < -0.1:
        print(f"  - {s['dataset']}: d={s['truncated_cohens_d']:+.3f}, win={s['truncated_win_rate']:.1f}%")

print("\nDatasets where FULL-CONTEXT priming helps (d > 0.1):")
for s in summary:
    if s['fullctx_cohens_d'] > 0.1:
        print(f"  - {s['dataset']}: d={s['fullctx_cohens_d']:+.3f}, win={s['fullctx_win_rate']:.1f}%")

if not any(s['truncated_cohens_d'] > 0.1 for s in summary):
    print("  (none)")

Analysis saved to /home/jupyter/research/directed_kvcache/results/exp19/analysis.json

CONCLUSIONS

Datasets where TRUNCATED priming helps (d > 0.1):
  - msmarco_hard: d=+0.190, win=59.0%

Datasets where TRUNCATED priming hurts (d < -0.1):
  - hotpotqa: d=-0.348, win=34.0%
  - pubmedqa: d=-0.728, win=16.0%
  - cnn_dailymail: d=-1.307, win=8.0%
  - narrativeqa: d=-0.348, win=45.0%

Datasets where FULL-CONTEXT priming helps (d > 0.1):


In [13]:
# Cell 13: Per-Dataset Deep Dive

print("\n" + "="*70)
print("PER-DATASET DETAILS")
print("="*70)

for dataset_name, results in all_results.items():
    if not results or len(results) < 10:
        continue
    
    print(f"\n### {dataset_name.upper()} ###")
    
    bare = np.array([r['nll_bare'] for r in results])
    trunc = np.array([r['nll_truncated'] for r in results])
    full = np.array([r['nll_fullctx'] for r in results])
    delta_trunc = np.array([r['delta_truncated'] for r in results])
    delta_full = np.array([r['delta_fullctx'] for r in results])
    
    print(f"  N = {len(results)}")
    print(f"  Mean passage words: {np.mean([r['passage_words'] for r in results]):.0f}")
    print(f"  Mean answer words: {np.mean([r['answer_words'] for r in results]):.0f}")
    print(f"  Mean bare NLL: {np.mean(bare):.3f}")
    print(f"  ")
    print(f"  TRUNCATED: win={np.mean(delta_trunc > 0)*100:.1f}%, mean_delta={np.mean(delta_trunc):+.3f}, d={cohens_d(delta_trunc):+.3f}")
    print(f"  FULL-CTX:  win={np.mean(delta_full > 0)*100:.1f}%, mean_delta={np.mean(delta_full):+.3f}, d={cohens_d(delta_full):+.3f}")
    
    # Hardness interaction within dataset
    r_trunc, p_trunc = stats.pearsonr(bare, delta_trunc)
    print(f"  Hardness interaction (bare NLL vs delta_trunc): r={r_trunc:+.3f}, p={p_trunc:.4f}")


PER-DATASET DETAILS

### MSMARCO_HARD ###
  N = 100
  Mean passage words: 69
  Mean answer words: 15
  Mean bare NLL: 3.594
  
  TRUNCATED: win=59.0%, mean_delta=+0.068, d=+0.190
  FULL-CTX:  win=32.0%, mean_delta=-0.177, d=-0.358
  Hardness interaction (bare NLL vs delta_trunc): r=+0.255, p=0.0105

### SQUAD_V2 ###
  N = 100
  Mean passage words: 105
  Mean answer words: 2
  Mean bare NLL: 0.139
  
  TRUNCATED: win=52.0%, mean_delta=+0.000, d=+0.003
  FULL-CTX:  win=47.0%, mean_delta=-0.006, d=-0.026
  Hardness interaction (bare NLL vs delta_trunc): r=+0.443, p=0.0000

### HOTPOTQA ###
  N = 100
  Mean passage words: 296
  Mean answer words: 2
  Mean bare NLL: 1.678
  
  TRUNCATED: win=34.0%, mean_delta=-0.130, d=-0.348
  FULL-CTX:  win=29.0%, mean_delta=-0.471, d=-0.381
  Hardness interaction (bare NLL vs delta_trunc): r=-0.403, p=0.0000

### PUBMEDQA ###
  N = 100
  Mean passage words: 191
  Mean answer words: 42
  Mean bare NLL: 1.963
  
  TRUNCATED: win=16.0%, mean_delta=-0.072, 