# Exp 15: NLL Ensemble Ranking

## Motivation

Exp 14 found that combining bare + primed NLL marginally improves ranking
(global ΔMRR=+0.008), but cross-validated improvement was only +0.006 (p=0.21, ns).
The per-query oracle alpha gap (+0.050) suggests latent signal exists but a single
primed cache can't reliably extract it.

**Core hypothesis:** Diverse priming caches produce NLL estimates with partially
independent errors. Ensembling (averaging) these estimates reduces ranking noise,
just as averaging multiple measurements improves precision.

## Design

**5 Scoring Signals** (each produces a per-passage NLL):

| # | Signal | Cache | Scoring Prompt | Purpose |
|---|--------|-------|---------------|---------|
| 1 | `bare` | Bare cache | Standard prompt | Baseline |
| 2 | `rescore` | Bare cache | Alt prompt | **Control**: diversity without priming |
| 3 | `sf` | Static fact prefix | Standard prompt | Replicate Exp 14 |
| 4 | `rand` | Random text prefix | Standard prompt | Is prefix content irrelevant? |
| 5 | `intent` | Intent prefix | Standard prompt | Different semantic angle |

**Ensemble Conditions** (equal-weight NLL average, no tuning):

| Ensemble | Members | Tests |
|----------|---------|-------|
| `ens_2_sf` | bare + sf | Replicates Exp 14 |
| `ens_2_rand` | bare + rand | Random prefix ensemble |
| `ens_2_rescore` | bare + rescore | **Non-priming control** |
| `ens_3` | bare + sf + rand | 3-member ensemble |
| `ens_4` | bare + sf + rand + intent | 4-member ensemble |
| `ens_5_all` | all 5 signals | Maximum diversity |

**Critical comparison:** `ens_2_sf` vs `ens_2_rescore`. If the control matches
priming, then priming isn't special — any scoring diversity helps.

In [1]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
import json
import time
import numpy as np
import torch
import gc
from pathlib import Path

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

RESULTS_DIR = Path("results/exp15")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_PATH = RESULTS_DIR / "checkpoint.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

SEED: 42
Results directory: results/exp15
CUDA available: True
GPU: NVIDIA L4
GPU memory: 23.6 GB


In [2]:
# Cell 2: Load model (Mistral-7B 4-bit)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

print(f"Loading {MODEL_NAME} (4-bit)...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded. dtype={model.dtype}, device={model.device}")

Loading mistralai/Mistral-7B-Instruct-v0.2 (4-bit)...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded. dtype=torch.float16, device=cuda:0


In [3]:
# Cell 3: Config and library imports
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.kv_cache import (
    _get_cache_keys,
    _get_cache_values,
    _set_cache_keys,
    _set_cache_values,
    _ensure_dynamic_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    score_answer_with_cache,
    deepcopy_cache,
    build_hybrid_cache,
)
from lib.analysis import compute_ranking_metrics
from lib.surrogate import STATIC_SURROGATE_QUERIES
from lib.data import count_words
from scipy import stats
from scipy.stats import wilcoxon
from tqdm.auto import tqdm

config = ExperimentConfig(
    model_name=MODEL_NAME,
    num_samples=2000,
    seed=SEED,
)

# Templates
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"

# Alternative prompt for rescore control
ALT_QUERY_TEMPLATE = "\nQuestion: {query}\nResponse:"

# Prefix texts
STATIC_FACT = STATIC_SURROGATE_QUERIES['static_factual']['query']
RANDOM_PREFIX_TEXT = "The purple elephant danced gracefully on the frozen lake during twilight"
INTENT_PREFIX_TEXT = "What is this passage about?"

# Experiment parameters
MAX_QUERIES = 300
MAX_PASSAGE_WORDS = 300
MIN_PASSAGES_PER_QUERY = 2
CHECKPOINT_EVERY = 25

SIGNAL_NAMES = ['bare', 'rescore', 'sf', 'rand', 'intent']

print("Config ready")
print(f"  MAX_QUERIES: {MAX_QUERIES}")
print(f"  Prefixes:")
print(f"    sf:     '{STATIC_FACT}'")
print(f"    rand:   '{RANDOM_PREFIX_TEXT}'")
print(f"    intent: '{INTENT_PREFIX_TEXT}'")
print(f"  Alt prompt: '{ALT_QUERY_TEMPLATE.format(query='...')}'")
print(f"  Signals: {SIGNAL_NAMES}")

Config ready
  MAX_QUERIES: 300
  Prefixes:
    sf:     'What are the key facts I need to know?'
    rand:   'The purple elephant danced gracefully on the frozen lake during twilight'
    intent: 'What is this passage about?'
  Alt prompt: '
Question: ...
Response:'
  Signals: ['bare', 'rescore', 'sf', 'rand', 'intent']


In [4]:
# Cell 4: Load MS MARCO v1.1 (same filtering as Exp 14)
from datasets import load_dataset

print("=" * 70)
print("LOADING MS MARCO v1.1 — ALL PASSAGES PER QUERY")
print("=" * 70)

dataset = load_dataset("microsoft/ms_marco", "v1.1", split="validation",
                        trust_remote_code=True)
print(f"Total items in validation: {len(dataset)}")

queries = []
np.random.seed(SEED)

for item in tqdm(dataset, desc="Filtering"):
    passages_info = item.get('passages', {})
    passage_texts = passages_info.get('passage_text', [])
    is_selected = passages_info.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])

    if not passage_texts or not query:
        continue
    if len(passage_texts) < MIN_PASSAGES_PER_QUERY:
        continue
    if not is_selected or sum(is_selected) == 0:
        continue

    word_counts = [count_words(p) for p in passage_texts]
    if any(wc > MAX_PASSAGE_WORDS for wc in word_counts):
        continue

    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    else:
        continue

    passage_list = []
    for i, (ptext, sel) in enumerate(zip(passage_texts, is_selected)):
        passage_list.append({
            'passage': ptext,
            'is_relevant': bool(sel == 1),
            'word_count': word_counts[i],
            'passage_idx': i,
        })

    queries.append({
        'query': query,
        'answer': answer,
        'passages': passage_list,
        'n_passages': len(passage_list),
        'n_relevant': sum(1 for p in passage_list if p['is_relevant']),
    })

    if len(queries) >= MAX_QUERIES * 3:
        break

np.random.shuffle(queries)
queries = queries[:MAX_QUERIES]
N = len(queries)

n_passages_list = [q['n_passages'] for q in queries]
total_passages = sum(n_passages_list)

print(f"\nSelected {N} queries ({total_passages} total passages)")
print(f"Passages per query: mean={np.mean(n_passages_list):.1f}, "
      f"min={min(n_passages_list)}, max={max(n_passages_list)}")
print(f"Word counts: mean={np.mean([p['word_count'] for q in queries for p in q['passages']]):.0f}")

del dataset
gc.collect()

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'microsoft/ms_marco' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


LOADING MS MARCO v1.1 — ALL PASSAGES PER QUERY


Total items in validation: 10047


Filtering:   0%|          | 0/10047 [00:00<?, ?it/s]


Selected 300 queries (2504 total passages)
Passages per query: mean=8.3, min=3, max=10
Word counts: mean=71


192

In [5]:
# Cell 5: Tokenize prefixes and verify BPE boundaries

print("=" * 70)
print("EXPERIMENTAL CONDITIONS — NLL ENSEMBLE RANKING")
print("=" * 70)

# Tokenize each prefix
sf_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=STATIC_FACT)
rand_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=RANDOM_PREFIX_TEXT)
intent_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=INTENT_PREFIX_TEXT)

sf_ids = tokenizer(sf_str, return_tensors="pt",
                    add_special_tokens=False)['input_ids'].to(config.device)
rand_ids = tokenizer(rand_str, return_tensors="pt",
                      add_special_tokens=False)['input_ids'].to(config.device)
intent_ids = tokenizer(intent_str, return_tensors="pt",
                        add_special_tokens=False)['input_ids'].to(config.device)

PREFIX_CONFIGS = [
    ('sf', STATIC_FACT, sf_str, sf_ids),
    ('rand', RANDOM_PREFIX_TEXT, rand_str, rand_ids),
    ('intent', INTENT_PREFIX_TEXT, intent_str, intent_ids),
]

print("\nPREFIX TOKEN LENGTHS:")
for name, text, full_str, ids in PREFIX_CONFIGS:
    print(f"  {name:<8} {ids.shape[1]:>3} tokens | '{text}'")

# Verify BPE boundary consistency across prefixes
print("\nBPE BOUNDARY CHECK (first passage):")
example_doc = queries[0]['passages'][0]['passage']
for name, text, full_str, ids in PREFIX_CONFIGS:
    concat = full_str + DOCUMENT_TEMPLATE.format(document=example_doc)
    concat_enc = tokenizer(concat, add_special_tokens=True)['input_ids']
    prefix_enc = tokenizer(full_str, add_special_tokens=True)['input_ids']
    doc_ids_from_concat = concat_enc[len(prefix_enc):]

    bare_doc_enc = tokenizer(DOCUMENT_TEMPLATE.format(document=example_doc),
                              add_special_tokens=False)['input_ids']
    match = sum(1 for a, b in zip(doc_ids_from_concat, bare_doc_enc) if a == b)
    total = max(len(bare_doc_enc), 1)
    print(f"  {name}: {match}/{total} tokens match ({100*match/total:.1f}%)")

# Explain conditions
print("\n" + "=" * 70)
print("CONDITION DETAILS")
print("=" * 70)

conditions_detail = [
    ("bare", "Standard bare cache scored with standard prompt",
     "Baseline. All other conditions are compared to this."),
    ("rescore", "Bare cache scored with alt prompt ('Question:...Response:')",
     "NON-PRIMING CONTROL. Same cache, different prompt. Tests if scoring "
     "diversity alone improves ensembles, without any cache modification."),
    ("sf (static_fact)", f"Prefix: '{STATIC_FACT}'",
     "Replicates Exp 14's primed_1x. Bare keys + primed values (truncated)."),
    ("rand (random)", f"Prefix: '{RANDOM_PREFIX_TEXT}'",
     "Semantically unrelated prefix. Tests if ANY prefix content works or "
     "if semantic relevance matters."),
    ("intent", f"Prefix: '{INTENT_PREFIX_TEXT}'",
     "Different semantic angle than static_fact. Tests prefix diversity."),
]

for name, detail, purpose in conditions_detail:
    print(f"\n### {name} ###")
    print(f"  {detail}")
    print(f"  Purpose: {purpose}")

print("\n" + "=" * 70)
print("ENSEMBLE CONDITIONS (equal-weight NLL average)")
print("=" * 70)
print("  ens_2_sf:      bare + sf           (replicate Exp 14)")
print("  ens_2_rand:    bare + rand         (random prefix)")
print("  ens_2_rescore: bare + rescore      (NON-PRIMING CONTROL)")
print("  ens_3:         bare + sf + rand    (3-member)")
print("  ens_4:         bare + sf + rand + intent  (4-member)")
print("  ens_5_all:     all 5 signals       (maximum diversity)")

EXPERIMENTAL CONDITIONS — NLL ENSEMBLE RANKING

PREFIX TOKEN LENGTHS:
  sf        11 tokens | 'What are the key facts I need to know?'
  rand      17 tokens | 'The purple elephant danced gracefully on the frozen lake during twilight'
  intent     7 tokens | 'What is this passage about?'

BPE BOUNDARY CHECK (first passage):
  sf: 2/168 tokens match (1.2%)
  rand: 2/168 tokens match (1.2%)
  intent: 2/168 tokens match (1.2%)

CONDITION DETAILS

### bare ###
  Standard bare cache scored with standard prompt
  Purpose: Baseline. All other conditions are compared to this.

### rescore ###
  Bare cache scored with alt prompt ('Question:...Response:')
  Purpose: NON-PRIMING CONTROL. Same cache, different prompt. Tests if scoring diversity alone improves ensembles, without any cache modification.

### sf (static_fact) ###
  Prefix: 'What are the key facts I need to know?'
  Purpose: Replicates Exp 14's primed_1x. Bare keys + primed values (truncated).

### rand (random) ###
  Prefix: 'The purp

In [6]:
# Cell 6: Main loop — score all passages under all conditions

print("=" * 70)
print(f"MAIN EVALUATION ({N} queries, ~{total_passages} passages)")
print("=" * 70)

# Checkpoint resume
all_results = []
start_idx = 0

if CHECKPOINT_PATH.exists():
    with open(CHECKPOINT_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('query_texts', [])
    current_queries = [q['query'] for q in queries]
    if ckpt_queries == current_queries:
        all_results = ckpt['results']
        start_idx = len(all_results)
        print(f"Resuming from checkpoint: {start_idx}/{N}")
    else:
        print("Checkpoint query mismatch. Starting fresh.")
else:
    print("No checkpoint found. Starting fresh.")

print(f"Evaluating queries {start_idx} to {N-1}")
print(f"Per passage: 4 forward passes (bare + 3 primed) + 5 scoring passes")

t_start = time.time()

for qidx in tqdm(range(start_idx, N), initial=start_idx, total=N, desc="Queries"):
    query_data = queries[qidx]
    query = query_data['query']
    answer = query_data['answer']
    query_prompt = QUERY_TEMPLATE.format(query=query)
    alt_query_prompt = ALT_QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    passage_results = []

    for pidx, pinfo in enumerate(query_data['passages']):
        passage = pinfo['passage']
        document_text = DOCUMENT_TEMPLATE.format(document=passage)

        # --- Matched tokenization (using sf prefix) ---
        full_text = sf_str + document_text
        full_enc = tokenizer(full_text, return_tensors="pt",
                              add_special_tokens=True, padding=False, truncation=False)
        full_ids = full_enc['input_ids'].to(config.device)

        sf_prefix_enc = tokenizer(sf_str, return_tensors="pt",
                                   add_special_tokens=True, padding=False, truncation=False)
        sf_prefix_len_matched = sf_prefix_enc['input_ids'].shape[1]

        bos_id = full_ids[:, :1]
        doc_ids = full_ids[:, sf_prefix_len_matched:]
        doc_len = doc_ids.shape[1]
        context_len = 1 + doc_len  # BOS + doc

        del full_enc, full_ids, sf_prefix_enc

        # === 1. Build bare cache ===
        bare_input = torch.cat([bos_id, doc_ids], dim=1)
        with torch.no_grad():
            bare_out = model(input_ids=bare_input,
                             attention_mask=torch.ones_like(bare_input),
                             use_cache=True, return_dict=True)
        bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
        del bare_out, bare_input

        # === 2. Score rescore (deepcopy bare, alt prompt) ===
        bare_copy = deepcopy_cache(bare_cache)
        rescore_nll = score_answer_with_cache(
            bare_copy, context_len, alt_query_prompt, answer_text,
            model, tokenizer, config)
        del bare_copy

        # === 3-5. For each priming prefix: build, truncate, hybrid, score ===
        primed_nlls = {}
        for p_name, p_text, p_str, p_ids in PREFIX_CONFIGS:
            primed_input = torch.cat([bos_id, p_ids, doc_ids], dim=1)
            with torch.no_grad():
                primed_out = model(input_ids=primed_input,
                                   attention_mask=torch.ones_like(primed_input),
                                   use_cache=True, return_dict=True)
            primed_full = _ensure_dynamic_cache(primed_out.past_key_values)
            del primed_out, primed_input

            # Truncate + RoPE correct
            primed_trunc = extract_and_truncate_cache_with_bos(primed_full, doc_len)
            correct_rope_positions_with_bos(primed_trunc, p_ids.shape[1], model)
            del primed_full

            # Hybrid: bare keys + primed values (pure value contamination)
            hybrid = build_hybrid_cache(bare_cache, primed_trunc)
            del primed_trunc

            primed_nlls[p_name] = score_answer_with_cache(
                hybrid, context_len, query_prompt, answer_text,
                model, tokenizer, config)
            del hybrid

        # === 6. Score bare LAST (mutates cache) ===
        bare_nll = score_answer_with_cache(
            bare_cache, context_len, query_prompt, answer_text,
            model, tokenizer, config)
        del bare_cache

        gc.collect()
        torch.cuda.empty_cache()

        passage_results.append({
            'passage_idx': pinfo['passage_idx'],
            'is_relevant': pinfo['is_relevant'],
            'word_count': pinfo['word_count'],
            'bare_nll': bare_nll,
            'rescore_nll': rescore_nll,
            'sf_nll': primed_nlls['sf'],
            'rand_nll': primed_nlls['rand'],
            'intent_nll': primed_nlls['intent'],
        })

    all_results.append({
        'query_idx': qidx,
        'query': query,
        'n_passages': len(passage_results),
        'n_relevant': query_data['n_relevant'],
        'passage_data': passage_results,
    })

    # Checkpoint
    if (qidx + 1) % CHECKPOINT_EVERY == 0 or qidx == N - 1:
        ckpt_data = {
            'results': all_results,
            'query_texts': [q['query'] for q in queries],
            'completed': len(all_results),
            'total': N,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_PATH, 'w') as f:
            json.dump(ckpt_data, f)
        elapsed = time.time() - t_start
        n_done = qidx - start_idx + 1
        rate = n_done / elapsed if elapsed > 0 else 0
        remaining = (N - qidx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint {qidx+1}/{N} | {n_done} done in {elapsed/60:.1f}m | "
                   f"ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEvaluation complete: {len(all_results)} queries in {elapsed_total/60:.1f} min")

MAIN EVALUATION (300 queries, ~2504 passages)
No checkpoint found. Starting fresh.
Evaluating queries 0 to 299
Per passage: 4 forward passes (bare + 3 primed) + 5 scoring passes


Queries:   0%|          | 0/300 [00:00<?, ?it/s]

  Checkpoint 25/300 | 25 done in 7.5m | ETA: 82.5 min


  Checkpoint 50/300 | 50 done in 15.2m | ETA: 75.8 min


  Checkpoint 75/300 | 75 done in 22.9m | ETA: 68.6 min


  Checkpoint 100/300 | 100 done in 30.7m | ETA: 61.5 min


  Checkpoint 125/300 | 125 done in 38.9m | ETA: 54.4 min


  Checkpoint 150/300 | 150 done in 46.4m | ETA: 46.4 min


  Checkpoint 175/300 | 175 done in 54.2m | ETA: 38.7 min


  Checkpoint 200/300 | 200 done in 61.9m | ETA: 30.9 min


  Checkpoint 225/300 | 225 done in 69.7m | ETA: 23.2 min


  Checkpoint 250/300 | 250 done in 77.5m | ETA: 15.5 min


  Checkpoint 275/300 | 275 done in 85.1m | ETA: 7.7 min


  Checkpoint 300/300 | 300 done in 92.5m | ETA: 0.0 min

Evaluation complete: 300 queries in 92.5 min


In [7]:
# Cell 7: Analysis — individual signals, ensembles, significance, scaling
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

print("=" * 70)
print("ANALYSIS")
print("=" * 70)

N_VALID = len(all_results)
print(f"Valid queries: {N_VALID}")

# --- Helper functions ---
def mrr_for_signal(results, sig_name):
    """Compute per-query MRR ranking by a single NLL signal."""
    mrrs = []
    for r in results:
        pd = r['passage_data']
        scores = {i: pd[i][f'{sig_name}_nll'] for i in range(len(pd))}
        rel_idx = next(i for i, p in enumerate(pd) if p['is_relevant'])
        m = compute_ranking_metrics(scores, relevant_idx=rel_idx)
        mrrs.append(m['mrr'])
    return np.array(mrrs)


def mrr_for_ensemble(results, sig_names):
    """Compute per-query MRR ranking by equal-weight NLL average."""
    mrrs = []
    for r in results:
        pd = r['passage_data']
        scores = {}
        for i in range(len(pd)):
            scores[i] = np.mean([pd[i][f'{s}_nll'] for s in sig_names])
        rel_idx = next(i for i, p in enumerate(pd) if p['is_relevant'])
        m = compute_ranking_metrics(scores, relevant_idx=rel_idx)
        mrrs.append(m['mrr'])
    return np.array(mrrs)


def sig_test(mrrs_a, mrrs_b):
    """Wilcoxon signed-rank test, returns (delta, p, sig_str)."""
    delta = float(np.mean(mrrs_a) - np.mean(mrrs_b))
    nonzero = int(np.sum(mrrs_a != mrrs_b))
    if nonzero > 10:
        _, p = wilcoxon(mrrs_a, mrrs_b)
    else:
        p = 1.0
    sig = "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "ns"
    return delta, float(p), sig, nonzero


# === 1. Individual signal MRR ===
print("\n" + "=" * 70)
print("INDIVIDUAL SIGNAL RANKING")
print("=" * 70)

individual_mrrs = {}
for sig in SIGNAL_NAMES:
    individual_mrrs[sig] = mrr_for_signal(all_results, sig)

bare_mrrs = individual_mrrs['bare']
print(f"\n{'Signal':<12} {'MRR':>8} {'ΔMRR':>8} {'p':>12} {'Sig':>5} {'Changed':>8}")
print("-" * 60)

individual_stats = {}
for sig in SIGNAL_NAMES:
    mrrs = individual_mrrs[sig]
    if sig == 'bare':
        print(f"{sig:<12} {np.mean(mrrs):>8.4f} {'--':>8} {'--':>12} {'--':>5} {'--':>8}")
    else:
        d, p, s_str, n_changed = sig_test(mrrs, bare_mrrs)
        print(f"{sig:<12} {np.mean(mrrs):>8.4f} {d:>+8.4f} {p:>12.3e} {s_str:>5} {n_changed:>8}")
        individual_stats[sig] = {'delta_mrr': d, 'p_value': p, 'significant': bool(p < 0.05)}


# === 2. Ensemble MRR ===
print("\n" + "=" * 70)
print("ENSEMBLE RANKING (EQUAL-WEIGHT NLL AVERAGE)")
print("=" * 70)

ENSEMBLE_CONFIGS = {
    'ens_2_sf':      ['bare', 'sf'],
    'ens_2_rand':    ['bare', 'rand'],
    'ens_2_intent':  ['bare', 'intent'],
    'ens_2_rescore': ['bare', 'rescore'],
    'ens_3':         ['bare', 'sf', 'rand'],
    'ens_4':         ['bare', 'sf', 'rand', 'intent'],
    'ens_5_all':     ['bare', 'sf', 'rand', 'intent', 'rescore'],
}

ensemble_mrrs = {}
ensemble_stats = {}
print(f"\n{'Ensemble':<20} {'Members':>4} {'MRR':>8} {'ΔMRR':>8} {'p':>12} {'Sig':>5} {'Changed':>8}")
print("-" * 72)

for ens_name, members in ENSEMBLE_CONFIGS.items():
    mrrs = mrr_for_ensemble(all_results, members)
    ensemble_mrrs[ens_name] = mrrs
    d, p, s_str, n_changed = sig_test(mrrs, bare_mrrs)
    print(f"{ens_name:<20} {len(members):>4} {np.mean(mrrs):>8.4f} {d:>+8.4f} "
          f"{p:>12.3e} {s_str:>5} {n_changed:>8}")
    ensemble_stats[ens_name] = {
        'members': members,
        'mrr_mean': float(np.mean(mrrs)),
        'delta_mrr': d,
        'p_value': p,
        'significant': bool(p < 0.05),
        'n_changed': n_changed,
    }


# === 3. Critical comparison: priming vs non-priming control ===
print("\n" + "=" * 70)
print("CRITICAL: PRIMING vs NON-PRIMING CONTROL")
print("=" * 70)

sf_mrr = float(np.mean(ensemble_mrrs['ens_2_sf']))
rescore_mrr = float(np.mean(ensemble_mrrs['ens_2_rescore']))
d_sf_res, p_sf_res, s_sf_res, n_sf_res = sig_test(
    ensemble_mrrs['ens_2_sf'], ensemble_mrrs['ens_2_rescore'])

print(f"  ens_2_sf (priming):      MRR={sf_mrr:.4f}")
print(f"  ens_2_rescore (control): MRR={rescore_mrr:.4f}")
print(f"  Difference:              {d_sf_res:+.4f}  (p={p_sf_res:.3e}, {s_sf_res})")
if sf_mrr > rescore_mrr + 0.001:
    print("  => Priming adds value BEYOND prompt diversity")
elif rescore_mrr > sf_mrr + 0.001:
    print("  => Prompt diversity alone BEATS priming")
else:
    print("  => Priming and prompt diversity are equivalent")


# === 4. Greedy forward selection (scaling curve) ===
print("\n" + "=" * 70)
print("GREEDY SCALING CURVE: best member to add at each step")
print("=" * 70)

available = ['rescore', 'sf', 'rand', 'intent']
selected = ['bare']
greedy_results = [{'members': list(selected), 'mrr': float(np.mean(bare_mrrs))}]

for step in range(len(available)):
    best_next = None
    best_mrr = -1
    for candidate in available:
        trial = selected + [candidate]
        trial_mrrs = mrr_for_ensemble(all_results, trial)
        mean_mrr = float(np.mean(trial_mrrs))
        if mean_mrr > best_mrr:
            best_mrr = mean_mrr
            best_next = candidate
    selected.append(best_next)
    available.remove(best_next)
    greedy_results.append({'members': list(selected), 'mrr': best_mrr})

print(f"\n{'K':<4} {'Added':>10} {'Ensemble':<35} {'MRR':>8} {'ΔMRR':>8}")
print("-" * 70)
for i, gr in enumerate(greedy_results):
    added = gr['members'][-1] if i > 0 else '--'
    delta = gr['mrr'] - greedy_results[0]['mrr']
    members_str = '+'.join(gr['members'])
    print(f"{i+1:<4} {added:>10} {members_str:<35} {gr['mrr']:>8.4f} {delta:>+8.4f}")


# === 5. NLL correlation matrix ===
print("\n" + "=" * 70)
print("NLL CORRELATION MATRIX (Pearson, across all passages)")
print("=" * 70)

all_nlls = {sig: [] for sig in SIGNAL_NAMES}
for r in all_results:
    for p in r['passage_data']:
        for sig in SIGNAL_NAMES:
            all_nlls[sig].append(p[f'{sig}_nll'])

all_nlls = {sig: np.array(vals) for sig, vals in all_nlls.items()}

print(f"\n{'':>12}", end='')
for sig in SIGNAL_NAMES:
    print(f" {sig:>10}", end='')
print()

corr_matrix = {}
for sig_a in SIGNAL_NAMES:
    print(f"{sig_a:>12}", end='')
    for sig_b in SIGNAL_NAMES:
        r, _ = stats.pearsonr(all_nlls[sig_a], all_nlls[sig_b])
        corr_matrix[f'{sig_a}_{sig_b}'] = float(r)
        print(f" {r:>10.4f}", end='')
    print()

print("\nNote: Lower correlation = more diversity = better ensembles")

ANALYSIS
Valid queries: 300

INDIVIDUAL SIGNAL RANKING

Signal            MRR     ΔMRR            p   Sig  Changed
------------------------------------------------------------
bare           0.8011       --           --    --       --
rescore        0.7900  -0.0111    1.916e-01    ns       40
sf             0.8058  +0.0047    5.795e-01    ns       39
rand           0.8034  +0.0023    7.056e-01    ns       42
intent         0.7934  -0.0077    2.875e-01    ns       43

ENSEMBLE RANKING (EQUAL-WEIGHT NLL AVERAGE)

Ensemble             Members      MRR     ΔMRR            p   Sig  Changed
------------------------------------------------------------------------
ens_2_sf                2   0.8060  +0.0049    3.629e-01    ns       21
ens_2_rand              2   0.8077  +0.0066    1.599e-01    ns       20
ens_2_intent            2   0.8013  +0.0002    9.424e-01    ns       27
ens_2_rescore           2   0.7911  -0.0101    1.057e-01    ns       25
ens_3                   3   0.8079  +0.0068    


K         Added Ensemble                                 MRR     ΔMRR
----------------------------------------------------------------------
1            -- bare                                  0.8011  +0.0000
2          rand bare+rand                             0.8077  +0.0066
3            sf bare+rand+sf                          0.8079  +0.0068
4       rescore bare+rand+sf+rescore                  0.8045  +0.0034
5        intent bare+rand+sf+rescore+intent           0.8056  +0.0044

NLL CORRELATION MATRIX (Pearson, across all passages)

                   bare    rescore         sf       rand     intent
        bare     1.0000     0.9821     0.9898     0.9891     0.9828
     rescore     0.9821     1.0000     0.9724     0.9763     0.9785
          sf     0.9898     0.9724     1.0000     0.9904     0.9872
        rand     0.9891     0.9763     0.9904     1.0000     0.9876
      intent     0.9828     0.9785     0.9872     0.9876     1.0000

Note: Lower correlation = more diversity = 

In [8]:
# Cell 8: Plots (4-panel figure)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

colors_ens = {
    'ens_2_sf': '#d62728',
    'ens_2_rand': '#ff7f0e',
    'ens_2_intent': '#9467bd',
    'ens_2_rescore': '#2ca02c',
    'ens_3': '#1f77b4',
    'ens_4': '#e377c2',
    'ens_5_all': '#17becf',
}

# --- Plot 1: Ensemble MRR bar chart ---
ax = axes[0, 0]
names = ['bare'] + list(ENSEMBLE_CONFIGS.keys())
mrr_vals = [float(np.mean(bare_mrrs))] + [float(np.mean(ensemble_mrrs[e])) for e in ENSEMBLE_CONFIGS]
bar_colors = ['#7f7f7f'] + [colors_ens.get(e, '#333') for e in ENSEMBLE_CONFIGS]
bars = ax.bar(range(len(names)), mrr_vals, color=bar_colors, edgecolor='black', linewidth=0.5)
for i, (n, m) in enumerate(zip(names, mrr_vals)):
    ax.text(i, m + 0.002, f"{m:.4f}", ha='center', fontsize=7, rotation=45)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, rotation=45, ha='right', fontsize=7)
ax.set_ylabel("MRR")
ax.set_title("MRR by Condition")
ax.axhline(y=float(np.mean(bare_mrrs)), color='gray', linestyle='--', alpha=0.5, label='bare')
ax.legend(fontsize=8)

# --- Plot 2: Scaling curve ---
ax = axes[0, 1]
k_vals = list(range(1, len(greedy_results) + 1))
mrr_curve = [gr['mrr'] for gr in greedy_results]
ax.plot(k_vals, mrr_curve, 'o-', color='#1f77b4', linewidth=2, markersize=8)
for i, gr in enumerate(greedy_results):
    label = gr['members'][-1] if i > 0 else 'bare'
    ax.annotate(f"+{label}" if i > 0 else label,
                (k_vals[i], mrr_curve[i]),
                textcoords="offset points", xytext=(5, 8), fontsize=8)
ax.axhline(y=float(np.mean(bare_mrrs)), color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel("Ensemble Size K")
ax.set_ylabel("MRR")
ax.set_title("Greedy Scaling Curve")
ax.set_xticks(k_vals)

# --- Plot 3: Correlation heatmap ---
ax = axes[1, 0]
corr_data = np.zeros((len(SIGNAL_NAMES), len(SIGNAL_NAMES)))
for i, sa in enumerate(SIGNAL_NAMES):
    for j, sb in enumerate(SIGNAL_NAMES):
        corr_data[i, j] = corr_matrix[f'{sa}_{sb}']
im = ax.imshow(corr_data, vmin=0.9, vmax=1.0, cmap='YlOrRd', aspect='auto')
ax.set_xticks(range(len(SIGNAL_NAMES)))
ax.set_xticklabels(SIGNAL_NAMES, fontsize=8)
ax.set_yticks(range(len(SIGNAL_NAMES)))
ax.set_yticklabels(SIGNAL_NAMES, fontsize=8)
for i in range(len(SIGNAL_NAMES)):
    for j in range(len(SIGNAL_NAMES)):
        ax.text(j, i, f"{corr_data[i,j]:.3f}", ha='center', va='center', fontsize=7)
plt.colorbar(im, ax=ax, label="Pearson r")
ax.set_title("NLL Correlation Matrix")

# --- Plot 4: Per-query ΔMRR distributions ---
ax = axes[1, 1]
for ens_name in ['ens_2_sf', 'ens_2_rescore', 'ens_4']:
    deltas = ensemble_mrrs[ens_name] - bare_mrrs
    ax.hist(deltas, bins=30, alpha=0.5, label=ens_name,
            color=colors_ens.get(ens_name, 'gray'))
ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel("ΔMRR (ensemble - bare)")
ax.set_ylabel("Count")
ax.set_title("Per-Query MRR Change Distribution")
ax.legend(fontsize=8)

plt.suptitle('Exp 15: NLL Ensemble Ranking', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'analysis_plots.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Plots saved to {RESULTS_DIR / 'analysis_plots.png'}")

Plots saved to results/exp15/analysis_plots.png


In [9]:
# Cell 9: Save results JSON
final = {
    'experiment': 'exp15_nll_ensemble_ranking',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': config.model_name,
        'seed': SEED,
        'n_queries': N,
        'n_valid': N_VALID,
        'max_passage_words': MAX_PASSAGE_WORDS,
        'min_passages_per_query': MIN_PASSAGES_PER_QUERY,
        'dataset': 'MS MARCO v1.1 validation',
        'prefixes': {
            'sf': STATIC_FACT,
            'rand': RANDOM_PREFIX_TEXT,
            'intent': INTENT_PREFIX_TEXT,
        },
        'alt_query_template': ALT_QUERY_TEMPLATE,
    },
    'signal_names': SIGNAL_NAMES,
    'individual_mrrs': {sig: float(np.mean(individual_mrrs[sig])) for sig in SIGNAL_NAMES},
    'individual_stats': individual_stats,
    'ensemble_configs': {k: v for k, v in ENSEMBLE_CONFIGS.items()},
    'ensemble_stats': ensemble_stats,
    'priming_vs_control': {
        'ens_2_sf_mrr': sf_mrr,
        'ens_2_rescore_mrr': rescore_mrr,
        'difference': float(d_sf_res),
        'p_value': float(p_sf_res),
        'priming_is_special': bool(sf_mrr > rescore_mrr + 0.001),
    },
    'greedy_scaling': greedy_results,
    'correlation_matrix': corr_matrix,
    'per_query_results': [
        {k: v for k, v in r.items() if k != 'passage_data'}
        for r in all_results
    ],
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final, f, indent=2)

print(f"Results saved to {FINAL_RESULTS_PATH}")
print(f"File size: {FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB")

# Print summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"Bare MRR:           {float(np.mean(bare_mrrs)):.4f}")
best_ens = max(ensemble_stats.items(), key=lambda x: x[1]['mrr_mean'])
print(f"Best ensemble:      {best_ens[0]} (MRR={best_ens[1]['mrr_mean']:.4f}, "
      f"ΔMRR={best_ens[1]['delta_mrr']:+.4f}, p={best_ens[1]['p_value']:.3e})")
print(f"Priming vs control: {d_sf_res:+.4f} (p={p_sf_res:.3e})")
print(f"Scaling saturates:  K={len(greedy_results[-1]['members'])} members, "
      f"MRR={greedy_results[-1]['mrr']:.4f}")
print("\nDone!")

Results saved to results/exp15/results.json
File size: 44.9 KB

SUMMARY
Bare MRR:           0.8011
Best ensemble:      ens_3 (MRR=0.8079, ΔMRR=+0.0068, p=2.323e-01)
Priming vs control: +0.0149 (p=3.890e-02)
Scaling saturates:  K=5 members, MRR=0.8056

Done!
