# Experiment 31: Query-Likelihood Ranking with PMI

## Motivation

Every ranking experiment so far (Exps 14, 15, 22, 23, 28) scored **NLL(answer | passage + query)**.
But in ad serving, there IS no answer at query time — there's only the user's query and the ads.

**Query-likelihood** — NLL(query | passage) — is a classic IR scoring function that measures
how well a document predicts the query. This has never been tested in our framework.

## Design

Model: Gemma 3 4B (4-bit). Data: MS MARCO v1.1 validation, 200 queries × ~8 passages each.
All caches are **bare** (no priming). This experiment is purely about scoring method comparison.

| # | Method | Score | PMI baseline |
|---|--------|-------|--------------|
| 1 | `al` | NLL(answer \| passage + query_template) | — |
| 2 | `pmi_al` | al − NLL(answer \| BOS + query_template) | BOS-only |
| 3 | `ql` | NLL(query \| passage + "\n") | — |
| 4 | `pmi_ql` | ql − NLL(query \| BOS + "\n") | BOS-only |
| 5 | `ql_search` | NLL(query \| passage + "\nSearch query: ") | — |
| 6 | `pmi_ql_search` | ql_search − NLL(query \| BOS + "\nSearch query: ") | BOS-only |

## Success Criteria

- **Primary:** Does any query-likelihood method achieve AUC > 0.80? (bare answer-likelihood = 0.828)
- **Secondary:** Does PMI improve query-likelihood as it does for answer-likelihood?
- **Informational:** How does query-likelihood correlate with answer-likelihood?

## Reference Values (from Exp 22, Gemma 3 4B)

| Method | AUC | MRR@10 |
|--------|-----|--------|
| Raw bare NLL (answer) | 0.828 | 0.860 |
| PMI bare (answer) | 0.841 | 0.860 |

In [None]:
# Cell 1: Setup & Imports
import os
os.umask(0o000)

import sys
import json
import time
import csv
import numpy as np
import torch
import gc
from pathlib import Path
from scipy import stats
from sklearn.metrics import roc_auc_score, roc_curve
from datasets import load_dataset
from tqdm.auto import tqdm

sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.model_utils import load_model
from lib.kv_cache import (
    _ensure_dynamic_cache,
    score_answer_with_cache,
    deepcopy_cache,
)
from lib.data import count_words
from lib.analysis import cohens_d

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# === Constants ===
SEED = 42
N_QUERIES = 200
MAX_PASSAGE_WORDS = 300
CHECKPOINT_EVERY = 10

RESULTS_DIR = Path("results/exp31")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_PATH = RESULTS_DIR / "checkpoint_eval.json"

# Scoring templates
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"
QL_NEWLINE_SEP = "\n"                     # minimal separator for query-likelihood
QL_SEARCH_SEP = "\nSearch query: "         # search-framed query-likelihood

# Reference values from Exp 22 (Gemma 3 4B)
REF = {
    'raw_bare_auc': 0.828,
    'pmi_bare_auc': 0.841,
    'raw_bare_mrr': 0.860,
}

np.random.seed(SEED)
torch.manual_seed(SEED)

print(f"Results dir: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: Load Model
exp_config = ExperimentConfig(
    model_name="google/gemma-3-4b-it",
    model_type="gemma3",
    use_4bit=True,
    compute_dtype="bfloat16",
    seed=SEED,
)

print("Loading Gemma 3 4B...")
model, tokenizer = load_model(exp_config)
print(f"Model loaded. dtype={model.dtype}")
print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

In [None]:
# Cell 3: Load MS MARCO Multi-Passage Validation Data
print("Loading MS MARCO v1.1 validation (multi-passage format)...")
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="validation",
                        trust_remote_code=True)
print(f"Total items: {len(dataset)}")

queries = []
np.random.seed(SEED)

for item in tqdm(dataset, desc="Filtering"):
    passages_info = item.get('passages', {})
    passage_texts = passages_info.get('passage_text', [])
    is_selected = passages_info.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])

    if not passage_texts or not query:
        continue
    # Require at least 1 relevant passage
    if not is_selected or sum(is_selected) == 0:
        continue
    # All passages must be within word limits
    word_counts = [count_words(p) for p in passage_texts]
    if any(wc > MAX_PASSAGE_WORDS for wc in word_counts):
        continue

    # Require valid answer (needed for answer-likelihood reference)
    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    else:
        continue

    passage_list = []
    for i, (ptext, sel) in enumerate(zip(passage_texts, is_selected)):
        passage_list.append({
            'passage': ptext,
            'is_relevant': bool(sel == 1),
            'word_count': word_counts[i],
            'passage_idx': i,
        })

    queries.append({
        'query': query,
        'answer': answer,
        'passages': passage_list,
        'n_passages': len(passage_list),
        'n_relevant': sum(1 for p in passage_list if p['is_relevant']),
    })

    if len(queries) >= N_QUERIES * 3:
        break

np.random.shuffle(queries)
queries = queries[:N_QUERIES]

# Stats
total_passages = sum(q['n_passages'] for q in queries)
total_relevant = sum(q['n_relevant'] for q in queries)
print(f"\nLoaded {len(queries)} queries")
print(f"Total passages: {total_passages} (mean {total_passages/len(queries):.1f}/query)")
print(f"Relevant: {total_relevant} ({100*total_relevant/total_passages:.1f}%)")
print(f"Query lengths (words): mean={np.mean([count_words(q['query']) for q in queries]):.1f}")
print(f"Answer lengths (words): mean={np.mean([count_words(q['answer']) for q in queries]):.1f}")

In [None]:
# Cell 4: Explain Experimental Conditions
print("=" * 70)
print("EXPERIMENTAL CONDITIONS EXPLAINED")
print("=" * 70)

ex_query = queries[0]['query']
ex_answer = queries[0]['answer']
ex_passage = queries[0]['passages'][0]['passage'][:100] + "..."

print(f"\nExample query:   {ex_query}")
print(f"Example answer:  {ex_answer[:80]}...")
print(f"Example passage: {ex_passage}")

print("\n" + "=" * 70)
print("ANSWER-LIKELIHOOD (standard — reference from Exp 22)")
print("=" * 70)
print(f"Cache: [BOS][passage tokens]")
print(f"Score: NLL( \"{ANSWER_TEMPLATE.format(answer=ex_answer[:40])}...\"")
print(f"       | passage + \"\\nQuery: {ex_query}\\nAnswer:\" )")
print(f"PMI baseline: NLL(answer | BOS + query_template)  [no passage]")
print(f"Key insight: Measures how much the passage helps predict the ANSWER.")

print("\n" + "=" * 70)
print("QUERY-LIKELIHOOD with newline separator (NEW)")
print("=" * 70)
print(f"Cache: [BOS][passage tokens]")
print(f"Score: NLL( \"{ex_query}\" | passage + \"\\n\" )")
print(f"PMI baseline: NLL(query | BOS + \"\\n\")  [no passage]")
print(f"Key insight: Measures how much the passage helps predict the QUERY.")
print(f"  If passage is about running shoes and query is 'best running shoes',")
print(f"  the passage should lower NLL(query) = more relevant.")

print("\n" + "=" * 70)
print("QUERY-LIKELIHOOD with search frame (NEW)")
print("=" * 70)
print(f"Cache: [BOS][passage tokens]")
print(f"Score: NLL( \"{ex_query}\" | passage + \"\\nSearch query: \" )")
print(f"PMI baseline: NLL(query | BOS + \"\\nSearch query: \")  [no passage]")
print(f"Key insight: Same as above but with explicit search framing.")
print(f"  The frame 'Search query: ' may tell the model to predict search terms.")

print("\n" + "=" * 70)
print(f"TOTAL: 6 scoring methods (3 raw + 3 PMI)")
print(f"Per passage: 1 cache build + 3 scoring calls")
print(f"Per query: 3 baseline scoring calls")
print(f"Total: {N_QUERIES} queries × ~8 passages = ~{N_QUERIES * 8} passage scores")
print("=" * 70)

In [None]:
# Cell 5: Helper Functions

def build_bare_cache(passage_text, model, tokenizer, config):
    """Build bare (no prefix) KV cache for a passage.
    Returns: (cache, context_len)
    """
    enc = tokenizer(passage_text, return_tensors="pt",
                    add_special_tokens=True, padding=False, truncation=False)
    input_ids = enc['input_ids'].to(config.device)
    with torch.no_grad():
        out = model(input_ids=input_ids,
                    attention_mask=torch.ones_like(input_ids),
                    use_cache=True, return_dict=True)
    cache = _ensure_dynamic_cache(out.past_key_values)
    context_len = input_ids.shape[1]
    del out, input_ids, enc
    return cache, context_len


def build_bos_cache(model, tokenizer, config):
    """Build BOS-only cache for PMI baselines.
    Returns: (cache, context_len=1)
    """
    bos_id = torch.tensor([[tokenizer.bos_token_id]], device=config.device)
    with torch.no_grad():
        out = model(input_ids=bos_id,
                    attention_mask=torch.ones_like(bos_id),
                    use_cache=True, return_dict=True)
    cache = _ensure_dynamic_cache(out.past_key_values)
    del out
    return cache, 1


def score_all_methods(cache, context_len, query, answer, model, tokenizer, config):
    """Score a passage with all 3 scoring methods (answer-likelihood + 2 query-likelihood).
    Requires 3 deepcopy calls (cache is mutated by scoring).
    Returns: dict with nll_al, nll_ql, nll_ql_search
    """
    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # 1. Answer-likelihood: NLL(answer | passage + query_template)
    nll_al = score_answer_with_cache(
        deepcopy_cache(cache), context_len,
        query_prompt, answer_text, model, tokenizer, config)

    # 2. Query-likelihood (newline): NLL(query | passage + "\n")
    nll_ql = score_answer_with_cache(
        deepcopy_cache(cache), context_len,
        QL_NEWLINE_SEP, query, model, tokenizer, config)

    # 3. Query-likelihood (search frame): NLL(query | passage + "\nSearch query: ")
    nll_ql_search = score_answer_with_cache(
        deepcopy_cache(cache), context_len,
        QL_SEARCH_SEP, query, model, tokenizer, config)

    return {
        'nll_al': nll_al,
        'nll_ql': nll_ql,
        'nll_ql_search': nll_ql_search,
    }


def score_baselines(query, answer, model, tokenizer, config):
    """Compute BOS-only baselines for PMI normalization (once per query).
    Returns: dict with baseline NLLs for all 3 methods
    """
    bos_cache, bos_len = build_bos_cache(model, tokenizer, config)

    query_prompt = QUERY_TEMPLATE.format(query=query)
    answer_text = ANSWER_TEMPLATE.format(answer=answer)

    # Answer baseline
    bl_al = score_answer_with_cache(
        deepcopy_cache(bos_cache), bos_len,
        query_prompt, answer_text, model, tokenizer, config)

    # Query baseline (newline)
    bl_ql = score_answer_with_cache(
        deepcopy_cache(bos_cache), bos_len,
        QL_NEWLINE_SEP, query, model, tokenizer, config)

    # Query baseline (search frame)
    bl_ql_search = score_answer_with_cache(
        bos_cache, bos_len,  # last call, no need to copy
        QL_SEARCH_SEP, query, model, tokenizer, config)

    return {
        'bl_al': bl_al,
        'bl_ql': bl_ql,
        'bl_ql_search': bl_ql_search,
    }


# === Smoke Test ===
print("Smoke test on first query, first passage...")
q0 = queries[0]
p0 = q0['passages'][0]

cache_test, ctx_test = build_bare_cache(p0['passage'], model, tokenizer, exp_config)
scores_test = score_all_methods(cache_test, ctx_test, q0['query'], q0['answer'],
                                 model, tokenizer, exp_config)
baselines_test = score_baselines(q0['query'], q0['answer'], model, tokenizer, exp_config)
del cache_test

print(f"  Passage ({p0['word_count']} words, relevant={p0['is_relevant']}):")
print(f"    Answer-likelihood:  NLL={scores_test['nll_al']:.4f}  (baseline={baselines_test['bl_al']:.4f})")
print(f"    Query-likelihood:   NLL={scores_test['nll_ql']:.4f}  (baseline={baselines_test['bl_ql']:.4f})")
print(f"    Query-lik (search): NLL={scores_test['nll_ql_search']:.4f}  (baseline={baselines_test['bl_ql_search']:.4f})")
print(f"  PMI_al = {scores_test['nll_al'] - baselines_test['bl_al']:.4f}")
print(f"  PMI_ql = {scores_test['nll_ql'] - baselines_test['bl_ql']:.4f}")
print(f"  PMI_ql_search = {scores_test['nll_ql_search'] - baselines_test['bl_ql_search']:.4f}")
print("Smoke test passed!")

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Cell 6: Main Evaluation Loop
print(f"\nEvaluating {N_QUERIES} queries...")
print(f"Checkpoint every {CHECKPOINT_EVERY} queries to {CHECKPOINT_PATH}")

all_results = []
start_idx = 0

# Resume from checkpoint
if CHECKPOINT_PATH.exists():
    with open(CHECKPOINT_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('query_texts', [])
    current_queries = [q['query'] for q in queries]
    if ckpt_queries == current_queries:
        all_results = ckpt['results']
        start_idx = len(all_results)
        print(f"Resumed from checkpoint: {start_idx}/{N_QUERIES}")
    else:
        print("Checkpoint query mismatch — starting fresh.")

t_start = time.time()

for qidx in tqdm(range(start_idx, N_QUERIES), initial=start_idx, total=N_QUERIES,
                  desc="Queries"):
    q = queries[qidx]
    query_text = q['query']
    answer_text = q['answer']

    # Compute baselines (once per query)
    baselines = score_baselines(query_text, answer_text, model, tokenizer, exp_config)

    passage_data = []
    for pidx, p in enumerate(q['passages']):
        # Build bare cache
        cache, ctx_len = build_bare_cache(p['passage'], model, tokenizer, exp_config)

        # Score all methods
        scores = score_all_methods(cache, ctx_len, query_text, answer_text,
                                    model, tokenizer, exp_config)
        del cache

        passage_data.append({
            'passage_idx': p['passage_idx'],
            'is_relevant': p['is_relevant'],
            'word_count': p['word_count'],
            'doc_len': ctx_len - 1,  # exclude BOS
            'nll_al': scores['nll_al'],
            'nll_ql': scores['nll_ql'],
            'nll_ql_search': scores['nll_ql_search'],
            'bl_al': baselines['bl_al'],
            'bl_ql': baselines['bl_ql'],
            'bl_ql_search': baselines['bl_ql_search'],
        })

    all_results.append({
        'query_idx': qidx,
        'query': query_text,
        'answer': answer_text,
        'n_passages': q['n_passages'],
        'n_relevant': q['n_relevant'],
        'baselines': baselines,
        'passage_data': passage_data,
    })

    # Checkpoint
    if (qidx + 1) % CHECKPOINT_EVERY == 0 or qidx == N_QUERIES - 1:
        ckpt_data = {
            'results': all_results,
            'query_texts': [q['query'] for q in queries],
            'completed': len(all_results),
            'total': N_QUERIES,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_PATH, 'w') as f:
            json.dump(ckpt_data, f)

    # Periodic cleanup
    if (qidx + 1) % 20 == 0:
        gc.collect()
        torch.cuda.empty_cache()

elapsed = time.time() - t_start
print(f"\nDone! {len(all_results)} queries scored in {elapsed/60:.1f} min")
print(f"Speed: {elapsed/len(all_results):.1f} s/query")

In [None]:
# Cell 7: Analysis — AUC, MRR, Differential NLL

# === Flatten passage-level data ===
is_relevant_all = []
nll_al_all, nll_ql_all, nll_ql_search_all = [], [], []
bl_al_all, bl_ql_all, bl_ql_search_all = [], [], []

for r in all_results:
    for p in r['passage_data']:
        is_relevant_all.append(int(p['is_relevant']))
        nll_al_all.append(p['nll_al'])
        nll_ql_all.append(p['nll_ql'])
        nll_ql_search_all.append(p['nll_ql_search'])
        bl_al_all.append(p['bl_al'])
        bl_ql_all.append(p['bl_ql'])
        bl_ql_search_all.append(p['bl_ql_search'])

is_relevant = np.array(is_relevant_all)
nll_al = np.array(nll_al_all)
nll_ql = np.array(nll_ql_all)
nll_ql_search = np.array(nll_ql_search_all)
bl_al = np.array(bl_al_all)
bl_ql = np.array(bl_ql_all)
bl_ql_search = np.array(bl_ql_search_all)

# PMI scores
pmi_al = nll_al - bl_al
pmi_ql = nll_ql - bl_ql
pmi_ql_search = nll_ql_search - bl_ql_search

n_total = len(is_relevant)
n_rel = is_relevant.sum()
n_irr = n_total - n_rel
print(f"Total passages: {n_total} ({n_rel} relevant, {n_irr} irrelevant)")

# === AUC-ROC ===
scoring_methods = {
    'Raw AL (answer-likelihood)': nll_al,
    'PMI AL': pmi_al,
    'Raw QL (query-likelihood)': nll_ql,
    'PMI QL': pmi_ql,
    'Raw QL-search': nll_ql_search,
    'PMI QL-search': pmi_ql_search,
}

print("\n" + "=" * 70)
print("AUC-ROC RESULTS")
print("=" * 70)
print(f"{'Method':<35} {'AUC':>8} {'vs Exp22 ref':>15}")
print("-" * 60)

auc_results = {}
for name, scores in scoring_methods.items():
    auc = roc_auc_score(is_relevant, -scores)  # negate: lower NLL = more relevant
    auc_results[name] = float(auc)
    ref_str = ""
    if 'AL' in name and 'PMI' not in name:
        ref_str = f"(ref: {REF['raw_bare_auc']:.3f})"
    elif 'PMI AL' in name:
        ref_str = f"(ref: {REF['pmi_bare_auc']:.3f})"
    print(f"{name:<35} {auc:>8.3f} {ref_str:>15}")

# === MRR@10 ===
def compute_mrr_at_k(all_results, score_fn, k=10):
    rr_list = []
    for r in all_results:
        passages = r['passage_data']
        scored = [(score_fn(p), p['is_relevant']) for p in passages]
        scored.sort(key=lambda x: x[0])  # lower = more relevant
        rr = 0.0
        for rank, (score, rel) in enumerate(scored[:k], 1):
            if rel:
                rr = 1.0 / rank
                break
        rr_list.append(rr)
    return np.mean(rr_list), rr_list

mrr_fns = {
    'Raw AL': lambda p: p['nll_al'],
    'PMI AL': lambda p: p['nll_al'] - p['bl_al'],
    'Raw QL': lambda p: p['nll_ql'],
    'PMI QL': lambda p: p['nll_ql'] - p['bl_ql'],
    'Raw QL-search': lambda p: p['nll_ql_search'],
    'PMI QL-search': lambda p: p['nll_ql_search'] - p['bl_ql_search'],
}

print("\n" + "=" * 70)
print("MRR@10 RESULTS")
print("=" * 70)
print(f"{'Method':<35} {'MRR@10':>8} {'vs Exp22 ref':>15}")
print("-" * 60)

mrr_results = {}
mrr_per_query = {}
for name, fn in mrr_fns.items():
    mrr, rr_list = compute_mrr_at_k(all_results, fn, k=10)
    mrr_results[name] = float(mrr)
    mrr_per_query[name] = rr_list
    ref_str = ""
    if name == 'Raw AL':
        ref_str = f"(ref: {REF['raw_bare_mrr']:.3f})"
    print(f"{name:<35} {mrr:>8.3f} {ref_str:>15}")

# === Differential NLL (Cohen's d between relevant vs irrelevant) ===
print("\n" + "=" * 70)
print("DIFFERENTIAL NLL (relevant vs irrelevant)")
print("=" * 70)
print(f"{'Method':<25} {'Mean Rel':>10} {'Mean Irr':>10} {'Gap':>8} {'d':>8} {'p':>12}")
print("-" * 75)

diff_results = {}
for name, scores in scoring_methods.items():
    short_name = name.replace(' (answer-likelihood)', '').replace(' (query-likelihood)', '')
    rel_vals = scores[is_relevant == 1]
    irr_vals = scores[is_relevant == 0]
    mean_rel = np.mean(rel_vals)
    mean_irr = np.mean(irr_vals)
    gap = mean_irr - mean_rel  # positive = relevant gets lower score (good)
    pooled_std = np.sqrt(
        (np.var(rel_vals, ddof=1) * (len(rel_vals)-1) + np.var(irr_vals, ddof=1) * (len(irr_vals)-1)) /
        (len(rel_vals) + len(irr_vals) - 2)
    )
    d = gap / pooled_std if pooled_std > 0 else 0
    t_stat, p_val = stats.ttest_ind(irr_vals, rel_vals)
    diff_results[name] = {'mean_rel': mean_rel, 'mean_irr': mean_irr, 'gap': gap, 'd': d, 'p': p_val}
    print(f"{short_name:<25} {mean_rel:>10.4f} {mean_irr:>10.4f} {gap:>+8.4f} {d:>+8.3f} {p_val:>12.2e}")

# === Correlation between methods ===
print("\n" + "=" * 70)
print("INTER-METHOD CORRELATIONS (Pearson r)")
print("=" * 70)
methods_for_corr = ['Raw AL (answer-likelihood)', 'Raw QL (query-likelihood)', 'Raw QL-search',
                     'PMI AL', 'PMI QL', 'PMI QL-search']
corr_arrays = [scoring_methods[m] for m in methods_for_corr]
short_labels = ['Raw AL', 'Raw QL', 'Raw QL-s', 'PMI AL', 'PMI QL', 'PMI QL-s']

print(f"{'':>12}", end="")
for sl in short_labels:
    print(f"{sl:>10}", end="")
print()

for i, (sl_i, arr_i) in enumerate(zip(short_labels, corr_arrays)):
    print(f"{sl_i:>12}", end="")
    for j, arr_j in enumerate(corr_arrays):
        r = np.corrcoef(arr_i, arr_j)[0, 1]
        print(f"{r:>10.3f}", end="")
    print()

In [None]:
# Cell 8: Plots

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# --- Row 1: ROC curves ---
# Panel 1: Raw scores
ax = axes[0, 0]
for name, color in [('Raw AL (answer-likelihood)', 'blue'),
                      ('Raw QL (query-likelihood)', 'red'),
                      ('Raw QL-search', 'orange')]:
    fpr, tpr, _ = roc_curve(is_relevant, -scoring_methods[name])
    ax.plot(fpr, tpr, color=color, label=f"{name.split(' (')[0]} (AUC={auc_results[name]:.3f})")
ax.plot([0,1], [0,1], 'k--', alpha=0.3)
ax.set_xlabel('FPR'); ax.set_ylabel('TPR')
ax.set_title('ROC — Raw Scores')
ax.legend(fontsize=8)

# Panel 2: PMI scores
ax = axes[0, 1]
for name, color in [('PMI AL', 'blue'), ('PMI QL', 'red'), ('PMI QL-search', 'orange')]:
    fpr, tpr, _ = roc_curve(is_relevant, -scoring_methods[name])
    ax.plot(fpr, tpr, color=color, label=f"{name} (AUC={auc_results[name]:.3f})")
ax.plot([0,1], [0,1], 'k--', alpha=0.3)
ax.set_xlabel('FPR'); ax.set_ylabel('TPR')
ax.set_title('ROC — PMI Scores')
ax.legend(fontsize=8)

# Panel 3: AUC comparison bar chart
ax = axes[0, 2]
names_short = ['Raw AL', 'PMI AL', 'Raw QL', 'PMI QL', 'Raw QL-s', 'PMI QL-s']
aucs = [auc_results[m] for m in scoring_methods.keys()]
colors = ['steelblue', 'steelblue', 'coral', 'coral', 'orange', 'orange']
bars = ax.bar(range(len(names_short)), aucs, color=colors, alpha=0.8)
ax.set_xticks(range(len(names_short)))
ax.set_xticklabels(names_short, rotation=30, ha='right', fontsize=8)
ax.set_ylabel('AUC')
ax.set_title('AUC Comparison')
ax.axhline(y=REF['raw_bare_auc'], color='gray', linestyle='--', alpha=0.5, label=f"Exp22 bare={REF['raw_bare_auc']}")
ax.axhline(y=REF['pmi_bare_auc'], color='gray', linestyle=':', alpha=0.5, label=f"Exp22 PMI={REF['pmi_bare_auc']}")
ax.legend(fontsize=7)
for bar, val in zip(bars, aucs):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.003,
            f'{val:.3f}', ha='center', va='bottom', fontsize=8)
ax.set_ylim(0.5, 1.0)

# --- Row 2: Score distributions and MRR ---
# Panel 4: NLL distributions (relevant vs irrelevant)
ax = axes[1, 0]
for name, color, offset in [('nll_al', 'blue', -0.2), ('nll_ql', 'red', 0), ('nll_ql_search', 'orange', 0.2)]:
    arr = {'nll_al': nll_al, 'nll_ql': nll_ql, 'nll_ql_search': nll_ql_search}[name]
    label = name.replace('nll_', '')
    rel_m = np.mean(arr[is_relevant == 1])
    irr_m = np.mean(arr[is_relevant == 0])
    ax.bar([0 + offset], [rel_m], width=0.18, color=color, alpha=0.7, label=f"{label} (rel)")
    ax.bar([1 + offset], [irr_m], width=0.18, color=color, alpha=0.4)
ax.set_xticks([0, 1])
ax.set_xticklabels(['Relevant', 'Irrelevant'])
ax.set_ylabel('Mean NLL')
ax.set_title('Mean NLL by Relevance')
ax.legend(fontsize=8)

# Panel 5: MRR comparison
ax = axes[1, 1]
mrr_names = list(mrr_results.keys())
mrr_vals = [mrr_results[n] for n in mrr_names]
colors_mrr = ['steelblue', 'steelblue', 'coral', 'coral', 'orange', 'orange']
bars = ax.bar(range(len(mrr_names)), mrr_vals, color=colors_mrr, alpha=0.8)
ax.set_xticks(range(len(mrr_names)))
ax.set_xticklabels(mrr_names, rotation=30, ha='right', fontsize=8)
ax.set_ylabel('MRR@10')
ax.set_title('MRR@10 Comparison')
ax.axhline(y=REF['raw_bare_mrr'], color='gray', linestyle='--', alpha=0.5)
for bar, val in zip(bars, mrr_vals):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
            f'{val:.3f}', ha='center', va='bottom', fontsize=8)
ax.set_ylim(0.5, 1.0)

# Panel 6: AL vs QL scatter (per-passage)
ax = axes[1, 2]
sample_idx = np.random.choice(len(nll_al), size=min(2000, len(nll_al)), replace=False)
ax.scatter(nll_al[sample_idx][is_relevant[sample_idx] == 0],
           nll_ql[sample_idx][is_relevant[sample_idx] == 0],
           c='gray', alpha=0.2, s=10, label='Irrelevant')
ax.scatter(nll_al[sample_idx][is_relevant[sample_idx] == 1],
           nll_ql[sample_idx][is_relevant[sample_idx] == 1],
           c='red', alpha=0.6, s=20, label='Relevant')
r_corr = np.corrcoef(nll_al, nll_ql)[0, 1]
ax.set_xlabel('NLL (answer-likelihood)')
ax.set_ylabel('NLL (query-likelihood)')
ax.set_title(f'AL vs QL per passage (r={r_corr:.3f})')
ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'ranking_plots.png', dpi=150, bbox_inches='tight')
print(f"Saved to {RESULTS_DIR / 'ranking_plots.png'}")
plt.show()

In [None]:
# Cell 9: Save Results + Final Verdict

# Save passage-level CSV
csv_path = RESULTS_DIR / 'passage_scores.csv'
with open(csv_path, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['query_idx', 'passage_idx', 'is_relevant', 'word_count', 'doc_len',
                      'nll_al', 'nll_ql', 'nll_ql_search',
                      'bl_al', 'bl_ql', 'bl_ql_search',
                      'pmi_al', 'pmi_ql', 'pmi_ql_search'])
    for r in all_results:
        for p in r['passage_data']:
            writer.writerow([
                r['query_idx'], p['passage_idx'], int(p['is_relevant']),
                p['word_count'], p['doc_len'],
                f"{p['nll_al']:.6f}", f"{p['nll_ql']:.6f}", f"{p['nll_ql_search']:.6f}",
                f"{p['bl_al']:.6f}", f"{p['bl_ql']:.6f}", f"{p['bl_ql_search']:.6f}",
                f"{p['nll_al'] - p['bl_al']:.6f}",
                f"{p['nll_ql'] - p['bl_ql']:.6f}",
                f"{p['nll_ql_search'] - p['bl_ql_search']:.6f}",
            ])
print(f"CSV saved: {csv_path} ({n_total} rows)")

# Save full results JSON
results_json = {
    'experiment': 'Exp 31: Query-Likelihood Ranking with PMI',
    'model': 'google/gemma-3-4b-it',
    'n_queries': len(all_results),
    'n_passages': n_total,
    'n_relevant': int(n_rel),
    'seed': SEED,
    'auc_results': auc_results,
    'mrr_results': mrr_results,
    'diff_results': {k: {kk: float(vv) for kk, vv in v.items()} for k, v in diff_results.items()},
    'reference_exp22': REF,
    'all_results': all_results,
}
json_path = RESULTS_DIR / 'results.json'
with open(json_path, 'w') as f:
    json.dump(results_json, f, indent=2)
print(f"Results saved: {json_path} ({os.path.getsize(json_path)/1024:.1f} KB)")

# === FINAL VERDICT ===
print("\n" + "=" * 70)
print("FINAL VERDICT — Exp 31: Query-Likelihood Ranking")
print("=" * 70)

best_ql_auc = max(auc_results['Raw QL (query-likelihood)'],
                   auc_results['PMI QL'],
                   auc_results['Raw QL-search'],
                   auc_results['PMI QL-search'])
best_al_auc = max(auc_results['Raw AL (answer-likelihood)'],
                   auc_results['PMI AL'])

print(f"\nModel: Gemma 3 4B | N={len(all_results)} queries, {n_total} passages")
print(f"\nAnswer-Likelihood (reference):")
print(f"  Raw AL AUC:  {auc_results['Raw AL (answer-likelihood)']:.3f}  (Exp22 ref: {REF['raw_bare_auc']:.3f})")
print(f"  PMI AL AUC:  {auc_results['PMI AL']:.3f}  (Exp22 ref: {REF['pmi_bare_auc']:.3f})")
print(f"\nQuery-Likelihood (NEW):")
print(f"  Raw QL AUC:       {auc_results['Raw QL (query-likelihood)']:.3f}")
print(f"  PMI QL AUC:       {auc_results['PMI QL']:.3f}")
print(f"  Raw QL-search AUC: {auc_results['Raw QL-search']:.3f}")
print(f"  PMI QL-search AUC: {auc_results['PMI QL-search']:.3f}")

print(f"\nMRR@10:")
for name, val in mrr_results.items():
    print(f"  {name:<20} {val:.3f}")

ql_passes_primary = best_ql_auc > 0.80
ql_beats_al = best_ql_auc > best_al_auc
pmi_helps_ql = (auc_results['PMI QL'] > auc_results['Raw QL (query-likelihood)'] or
                auc_results['PMI QL-search'] > auc_results['Raw QL-search'])

print(f"\nPrimary:   QL AUC > 0.80?  {'YES' if ql_passes_primary else 'NO'} (best={best_ql_auc:.3f})")
print(f"Secondary: PMI helps QL?    {'YES' if pmi_helps_ql else 'NO'}")
print(f"Compare:   QL beats AL?     {'YES' if ql_beats_al else 'NO'} (best QL={best_ql_auc:.3f}, best AL={best_al_auc:.3f})")

r_al_ql = np.corrcoef(nll_al, nll_ql)[0, 1]
print(f"\nAL-QL correlation: r={r_al_ql:.3f}")

if ql_passes_primary:
    print("\nVERDICT: Query-likelihood IS a viable ranking signal.")
    if ql_beats_al:
        print("  Query-likelihood BEATS answer-likelihood — pursue for ad ranking.")
    else:
        print("  Query-likelihood works but doesn't beat answer-likelihood.")
        print("  May still be valuable for ad serving where answers are unavailable.")
else:
    print("\nVERDICT: Query-likelihood FAILS as a ranking signal (AUC < 0.80).")
    print("  The model cannot reliably predict queries from passages.")

print("\nDone!")

In [None]:
# Cell 10: GPU Cleanup
print("Cleaning up GPU memory...")
mem_before = torch.cuda.memory_allocated() / 1e9

del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()
gc.collect()

mem_after = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {mem_before:.2f} GB -> {mem_after:.2f} GB")
print("Cleanup complete.")