# Exp 28: Contrastive Ranking Soft Prefix

## Motivation

Exps 22-23 proved that NLL-trained priming cannot improve document ranking: value contamination
from a document-independent prefix lowers NLL equally for relevant and irrelevant passages.
But those experiments used either discrete prefixes (static_fact) or NLL-optimized soft prefixes.
**What if we train the soft prefix with a ranking loss instead?**

## Hypothesis

Training the soft prefix with a hinge-based ranking loss can create differential value
contamination — reducing NLL more for relevant passages than irrelevant ones. The prefix
values might learn to "amplify" answer-predictive tokens more than other tokens.

## Core Mechanism

MS MARCO provides ~8-10 candidate passages per query with relevance labels. For each step:
1. Pick 1 relevant + 1 irrelevant passage for the same query
2. Score both through hybrid cache (same soft prefix)
3. Hinge loss: `max(0, margin + NLL_relevant - NLL_irrelevant)`
4. Gradient pushes prefix to make relevant passages predict the answer better

## Why This Might Fail

The soft prefix is the same for all documents. It produces identical value contamination
regardless of document content. The contrastive gradient tells it "help relevant passages more"
but it has no mechanism to distinguish relevant from irrelevant at cache-build time (it doesn't
see the query). The only hope: the prefix values create an "amplifier" that happens to boost
answer-predictive tokens more than other tokens. Theoretically possible, practically unlikely.

## Success Criteria

- **Primary**: Contrastive prefix AUC > 0.835 (bare=0.828) or PMI AUC > 0.845 (bare PMI=0.841)
- **Secondary**: Still helps average NLL (d > 0 vs bare)
- **Failure is informative**: Confirms value contamination from document-independent prefix
  fundamentally cannot create ranking signal

In [None]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
import json
import time
import csv
import numpy as np
import torch
import gc
from pathlib import Path
from scipy import stats

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

RESULTS_DIR = Path("results/exp28")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

EXP25_SOFT_FACT = Path("results/exp25/soft_prefix_fact.pt")

CHECKPOINT_TRAIN_WARM_PATH = RESULTS_DIR / "checkpoint_train_warm.json"
CHECKPOINT_TRAIN_COLD_PATH = RESULTS_DIR / "checkpoint_train_cold.json"
CHECKPOINT_EVAL_PATH = RESULTS_DIR / "checkpoint_eval.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"
CSV_EVAL_PATH = RESULTS_DIR / "passage_scores.csv"
SOFT_WARM_PATH = RESULTS_DIR / "soft_prefix_contrastive_warm.pt"
SOFT_COLD_PATH = RESULTS_DIR / "soft_prefix_contrastive_cold.pt"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"Exp 25 soft_prefix_fact: {EXP25_SOFT_FACT} (exists: {EXP25_SOFT_FACT.exists()})")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: Load Gemma 3 4B

sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.model_utils import load_model

MODEL_NAME = "google/gemma-3-4b-it"

exp_config = ExperimentConfig(
    model_name=MODEL_NAME,
    model_type="gemma3",
    compute_dtype="auto",
    use_4bit=True,
    num_samples=500,
    seed=SEED,
)

print(f"Loading {MODEL_NAME} (4-bit, bfloat16)...")
model, tokenizer = load_model(exp_config)

from lib.kv_cache import (
    _get_text_config, _get_head_dim,
    _get_cache_keys, _get_cache_values,
    _set_cache_keys, _set_cache_values,
    _ensure_dynamic_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    score_answer_with_cache,
    deepcopy_cache,
    replace_values_at_layers,
)

text_config = _get_text_config(model.config)
NUM_LAYERS = text_config.num_hidden_layers
HIDDEN_SIZE = text_config.hidden_size
HEAD_DIM = _get_head_dim(model.config)

print(f"Model loaded.")
print(f"  Layers: {NUM_LAYERS}, hidden: {HIDDEN_SIZE}, head_dim: {HEAD_DIM}")
print(f"  KV heads: {text_config.num_key_value_heads}")
print(f"  BOS token: {tokenizer.bos_token_id}")

# Verify cache dtype
sample_ids = tokenizer("test", return_tensors="pt")['input_ids'].to(exp_config.device)
with torch.no_grad():
    out = model(sample_ids, use_cache=True)
    cache_check = _ensure_dynamic_cache(out.past_key_values)
    k0 = _get_cache_keys(cache_check, 0)
    print(f"  Cache dtype: {k0.dtype}")
del out, sample_ids, cache_check
torch.cuda.empty_cache()

In [None]:
# Cell 3: Constants

from lib.analysis import cohens_d
from lib.data import count_words
from lib.surrogate import STATIC_SURROGATE_QUERIES
from tqdm.auto import tqdm

# Templates (same as Exp 25)
STATIC_FACT = STATIC_SURROGATE_QUERIES['static_factual']['query']
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"

# Architecture
CUTOFF = 16  # layers 0-15
MAX_PASSAGE_WORDS = 300

# Training hyperparameters
MARGIN = 0.1        # Hinge loss margin
LR = 0.05           # Lower than Exp 25 (0.1) for warm-start stability
N_EPOCHS = 5
GRAD_ACCUM = 4
WARMUP_STEPS = 30
N_TRAIN = 500       # queries (each with ~8 passages)
N_EVAL = 200        # queries for ranking eval
CHECKPOINT_EVERY = 50

# Tokenize static fact prefix for matched tokenization reference
sf_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=STATIC_FACT)
sf_ids = tokenizer(sf_str, return_tensors="pt",
                    add_special_tokens=False)['input_ids'].to(exp_config.device)
PREFIX_LEN = sf_ids.shape[1]

# Get embedding layer
embed_fn = model.get_input_embeddings()

# Reference values from Exp 22
EXP22_REF = {
    'raw_bare_auc': 0.828,
    'raw_primed_auc': 0.829,
    'pmi_bare_auc': 0.841,
    'pmi_primed_auc': 0.832,
    'raw_bare_mrr': 0.860,
}

print("Config:")
print(f"  MARGIN={MARGIN}, LR={LR}, N_EPOCHS={N_EPOCHS}, GRAD_ACCUM={GRAD_ACCUM}")
print(f"  N_TRAIN={N_TRAIN} queries, N_EVAL={N_EVAL} queries")
print(f"  CUTOFF={CUTOFF} (layers 0-{CUTOFF-1})")
print(f"  PREFIX_LEN={PREFIX_LEN} tokens")
print(f"  Trainable params: {PREFIX_LEN * HIDDEN_SIZE:,}")
print(f"\nReference (Exp 22):")
for k, v in EXP22_REF.items():
    print(f"  {k}: {v:.3f}")

In [None]:
# Cell 4: Load MS MARCO train — multi-passage format

from datasets import load_dataset

print("=" * 70)
print("LOADING MS MARCO v1.1 TRAIN — MULTI-PASSAGE FORMAT")
print("=" * 70)

dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")
print(f"Total items: {len(dataset)}")

train_queries = []
np.random.seed(SEED)

for item in tqdm(dataset, desc="Filtering train"):
    passages_info = item.get('passages', {})
    passage_texts = passages_info.get('passage_text', [])
    is_selected = passages_info.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])

    if not passage_texts or not query:
        continue
    if not is_selected or sum(is_selected) == 0:
        continue

    # Check word counts
    word_counts = [count_words(p) for p in passage_texts]
    if any(wc > MAX_PASSAGE_WORDS for wc in word_counts):
        continue

    # Require answer
    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    else:
        continue

    # Need at least 1 relevant and 1 irrelevant
    n_rel = sum(1 for s in is_selected if s == 1)
    n_irr = len(is_selected) - n_rel
    if n_rel == 0 or n_irr == 0:
        continue

    passage_list = []
    for i, (ptext, sel) in enumerate(zip(passage_texts, is_selected)):
        passage_list.append({
            'passage': ptext,
            'is_relevant': bool(sel == 1),
            'word_count': word_counts[i],
            'passage_idx': i,
        })

    train_queries.append({
        'query': query,
        'answer': answer,
        'passages': passage_list,
        'n_passages': len(passage_list),
        'n_relevant': n_rel,
    })

    if len(train_queries) >= N_TRAIN * 3:
        break

np.random.shuffle(train_queries)
train_queries = train_queries[:N_TRAIN]

n_passages_train = [q['n_passages'] for q in train_queries]
total_train_passages = sum(n_passages_train)
total_train_rel = sum(q['n_relevant'] for q in train_queries)

print(f"\nSelected {len(train_queries)} training queries ({total_train_passages} passages)")
print(f"Passages per query: mean={np.mean(n_passages_train):.1f}, "
      f"min={min(n_passages_train)}, max={max(n_passages_train)}")
print(f"Relevant: {total_train_rel} ({100*total_train_rel/total_train_passages:.1f}%)")

del dataset
gc.collect()

In [None]:
# Cell 5: Explain experimental conditions

print("=" * 70)
print("EXPERIMENTAL CONDITIONS")
print("=" * 70)

print("""
### Training: Contrastive (Hinge) Loss ###

For each training query:
  1. Sample 1 relevant passage (R) and 1 irrelevant passage (I)
  2. Build hybrid cache with soft prefix for R -> compute NLL_R
  3. Build hybrid cache with soft prefix for I -> compute NLL_I
  4. Loss = max(0, MARGIN + NLL_R - NLL_I)
     - If NLL_R < NLL_I - MARGIN: loss = 0 (satisfied)
     - If NLL_R > NLL_I: loss > MARGIN (violated)
  5. Gradient pushes soft prefix to reduce NLL_R relative to NLL_I

Both forward passes MUST be in the same computation graph
(both NLLs share the same soft_prefix) for loss.backward() to work.

### Two Init Conditions ###

  warm: Initialize from results/exp25/soft_prefix_fact.pt
        Tests: can we add ranking signal to an already-useful NLL prefix?

  cold: Initialize from random N(0, 0.02)
        Tests: can contrastive loss learn ranking signal from scratch?

### Evaluation Conditions (5 total) ###

  bare:             No prefix, no cache modification
  exp25_fact:       Exp 25 soft_prefix_fact (NLL-optimized)
  contrastive_warm: This experiment's warm-start prefix
  contrastive_cold: This experiment's cold-start prefix
  baseline:         BOS-only cache (no document) for PMI computation

### Ranking Metrics ###

  AUC-ROC: Can the NLL scores separate relevant from irrelevant?
  MRR@10:  Is the first relevant passage ranked near the top?
  Both computed for raw NLL and PMI (NLL - baseline) scoring.
""")

In [None]:
# Cell 6: differentiable_hybrid_score() — reuse from Exp 25 Cell 7 verbatim

from lib.kv_cache import _get_rope_theta_for_layer, _build_rope_correction, _rotate_half

print(f"Embedding layer: {type(embed_fn).__name__}, shape={embed_fn.weight.shape}")


def differentiable_hybrid_score(
    soft_prefix: torch.Tensor,       # (1, prefix_len, hidden_size), requires_grad
    doc_ids: torch.Tensor,            # (1, doc_len)
    bos_id: torch.Tensor,             # (1, 1)
    query_prompt: str,
    answer_text: str,
    model,
    tokenizer,
    config,
    cutoff: int,
):
    """
    Compute answer NLL through a hybrid cache built from soft prefix embeddings.
    Returns a scalar loss tensor with gradients flowing back to soft_prefix.
    """
    device = config.device
    doc_len = doc_ids.shape[1]
    prefix_len = soft_prefix.shape[1]
    context_len = 1 + doc_len  # BOS + doc

    # --- Step 1: Bare cache (no gradients needed) ---
    bare_input = torch.cat([bos_id, doc_ids], dim=1)
    with torch.no_grad():
        bare_out = model(input_ids=bare_input,
                         attention_mask=torch.ones_like(bare_input),
                         use_cache=True, return_dict=True)
    bare_cache = bare_out.past_key_values
    del bare_out

    # --- Step 2: Primed cache via inputs_embeds (gradients enabled) ---
    with torch.no_grad():
        bos_emb = embed_fn(bos_id)
        doc_emb = embed_fn(doc_ids)

    soft_cast = soft_prefix.to(dtype=bos_emb.dtype)
    inputs_embeds = torch.cat([bos_emb.detach(), soft_cast, doc_emb.detach()], dim=1)
    total_len = inputs_embeds.shape[1]
    attn_mask = torch.ones((1, total_len), device=device, dtype=torch.long)

    primed_out = model(inputs_embeds=inputs_embeds,
                       attention_mask=attn_mask,
                       use_cache=True, return_dict=True)
    primed_cache = primed_out.past_key_values
    del primed_out

    # --- Step 3+4: Build hybrid cache ---
    primed_cache_dc = _ensure_dynamic_cache(primed_cache)
    bare_cache_dc = _ensure_dynamic_cache(bare_cache)

    from transformers import DynamicCache
    from transformers.cache_utils import DynamicSlidingWindowLayer, DynamicLayer

    hybrid_cache = DynamicCache()
    for layer_idx in range(NUM_LAYERS):
        k = _get_cache_keys(bare_cache_dc, layer_idx)

        if layer_idx < cutoff:
            primed_v = _get_cache_values(primed_cache_dc, layer_idx)
            bos_v = primed_v[:, :, :1, :]
            doc_v = primed_v[:, :, -doc_len:, :]
            v = torch.cat([bos_v, doc_v], dim=2)
        else:
            v = _get_cache_values(bare_cache_dc, layer_idx)

        src_layer = bare_cache_dc.layers[layer_idx]
        if isinstance(src_layer, DynamicSlidingWindowLayer):
            new_layer = DynamicSlidingWindowLayer(sliding_window=src_layer.sliding_window)
            new_layer.dtype = k.dtype
            new_layer.device = k.device
            new_layer.keys = k
            new_layer.values = v
            new_layer.is_initialized = True
            new_layer.cumulative_length = src_layer.cumulative_length
            new_layer._sliding_window_tensor = new_layer._sliding_window_tensor.to(k.device)
        else:
            new_layer = DynamicLayer()
            new_layer.dtype = k.dtype
            new_layer.device = k.device
            new_layer.keys = k
            new_layer.values = v
            new_layer.is_initialized = True
        hybrid_cache.layers.append(new_layer)

    # --- Step 5: Score answer through hybrid cache ---
    query_ids = tokenizer(query_prompt, return_tensors="pt",
                          add_special_tokens=False)['input_ids'].to(device)
    answer_ids = tokenizer(answer_text, return_tensors="pt",
                           add_special_tokens=False)['input_ids'].to(device)
    query_len = query_ids.shape[1]
    answer_len = answer_ids.shape[1]

    qa_ids = torch.cat([query_ids, answer_ids], dim=1)
    qa_len = qa_ids.shape[1]
    qa_attn_full = torch.ones((1, context_len + qa_len), device=device)

    qa_out = model(input_ids=qa_ids,
                   attention_mask=qa_attn_full,
                   past_key_values=hybrid_cache,
                   use_cache=False, return_dict=True)

    logits = qa_out.logits
    answer_logits = logits[:, query_len - 1 : query_len + answer_len - 1, :]

    loss = torch.nn.functional.cross_entropy(
        answer_logits.reshape(-1, answer_logits.shape[-1]),
        answer_ids.reshape(-1),
        reduction='mean'
    )

    del qa_out, logits, bare_cache, bare_cache_dc, primed_cache, primed_cache_dc

    return loss


print("differentiable_hybrid_score() defined")
print("  Input: soft_prefix (requires_grad), doc_ids, query, answer")
print("  Output: scalar NLL loss with gradients to soft_prefix")

In [None]:
# Cell 7: Gradient flow sanity check with contrastive loss

print("=" * 70)
print("GRADIENT FLOW SANITY CHECK — CONTRASTIVE LOSS")
print("=" * 70)

# Create test prefix
test_prefix = torch.randn(1, PREFIX_LEN, HIDDEN_SIZE,
                           device=exp_config.device,
                           dtype=torch.float32,
                           requires_grad=True)

# Pick a training query with both relevant and irrelevant passages
test_q = train_queries[0]
query_prompt = QUERY_TEMPLATE.format(query=test_q['query'])
answer_text = ANSWER_TEMPLATE.format(answer=test_q['answer'])

rel_passage = next(p for p in test_q['passages'] if p['is_relevant'])
irr_passage = next(p for p in test_q['passages'] if not p['is_relevant'])

print(f"Test query: '{test_q['query'][:60]}...'")
print(f"Relevant passage: '{rel_passage['passage'][:60]}...'")
print(f"Irrelevant passage: '{irr_passage['passage'][:60]}...'")

# Helper to get doc_ids via matched tokenization
def get_matched_doc_ids(passage_text):
    doc_text = DOCUMENT_TEMPLATE.format(document=passage_text)
    full_text = sf_str + doc_text
    full_enc = tokenizer(full_text, return_tensors="pt",
                          add_special_tokens=True, padding=False, truncation=False)
    full_ids = full_enc['input_ids'].to(exp_config.device)
    sf_enc = tokenizer(sf_str, return_tensors="pt",
                        add_special_tokens=True, padding=False, truncation=False)
    sf_len = sf_enc['input_ids'].shape[1]
    bos_id = full_ids[:, :1]
    doc_ids = full_ids[:, sf_len:]
    return bos_id, doc_ids

try:
    # Score relevant passage
    bos_rel, doc_rel = get_matched_doc_ids(rel_passage['passage'])
    nll_rel = differentiable_hybrid_score(
        test_prefix, doc_rel, bos_rel,
        query_prompt, answer_text,
        model, tokenizer, exp_config, CUTOFF)
    print(f"\nNLL relevant: {nll_rel.item():.4f} (requires_grad: {nll_rel.requires_grad})")

    # Score irrelevant passage
    bos_irr, doc_irr = get_matched_doc_ids(irr_passage['passage'])
    nll_irr = differentiable_hybrid_score(
        test_prefix, doc_irr, bos_irr,
        query_prompt, answer_text,
        model, tokenizer, exp_config, CUTOFF)
    print(f"NLL irrelevant: {nll_irr.item():.4f} (requires_grad: {nll_irr.requires_grad})")

    # Contrastive hinge loss
    hinge_loss = torch.clamp(MARGIN + nll_rel - nll_irr, min=0.0)
    print(f"\nHinge loss: max(0, {MARGIN} + {nll_rel.item():.4f} - {nll_irr.item():.4f}) = {hinge_loss.item():.4f}")
    print(f"Hinge requires_grad: {hinge_loss.requires_grad}")

    # Backward
    hinge_loss.backward()

    print(f"\nBackward pass: SUCCESS")
    print(f"Gradient shape: {test_prefix.grad.shape}")
    print(f"Gradient norm: {test_prefix.grad.norm().item():.6f}")

    if hinge_loss.item() > 0 and test_prefix.grad.norm().item() > 0:
        print("\n>>> CONTRASTIVE GRADIENT FLOW CONFIRMED <<<")
    elif hinge_loss.item() == 0:
        print("\n>>> Hinge loss is 0 (margin satisfied). Gradient expected to be 0. <<<")
        print(">>> This is correct behavior — try with different sample if needed. <<<")
    else:
        print("\n>>> WARNING: Non-zero loss but zero gradient. Debug needed. <<<")

except Exception as e:
    print(f"\n>>> GRADIENT FLOW FAILED: {e} <<<")
    import traceback
    traceback.print_exc()

finally:
    del test_prefix
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# Cell 8: train_contrastive_soft_prefix()

def train_contrastive_soft_prefix(init_mode, train_data, n_epochs, lr, grad_accum,
                                   margin, checkpoint_path, save_path, warmup_steps=30):
    """
    Train soft prefix with contrastive hinge loss for ranking.

    Per step: sample 1 relevant + 1 irrelevant passage per query.
    Loss = max(0, margin + NLL_relevant - NLL_irrelevant).
    Both scored through the SAME soft prefix (sequential forward passes).

    Args:
        init_mode: 'warm' (from exp25_fact) or 'cold' (random)
        train_data: list of query dicts with 'passages' containing relevance labels
        n_epochs: number of passes
        lr: learning rate
        grad_accum: gradient accumulation steps
        margin: hinge loss margin
        checkpoint_path: path to save training checkpoints
        save_path: path to save final embeddings
        warmup_steps: linear warmup steps

    Returns:
        dict with training history and final embeddings
    """
    print(f"\n{'=' * 70}")
    print(f"TRAINING CONTRASTIVE SOFT PREFIX — init={init_mode}")
    print(f"{'=' * 70}")

    # Initialize soft prefix
    if init_mode == 'warm':
        soft_prefix = torch.load(EXP25_SOFT_FACT).to(exp_config.device).float()
        print(f"  Loaded warm-start from {EXP25_SOFT_FACT}")
    elif init_mode == 'cold':
        soft_prefix = torch.randn(1, PREFIX_LEN, HIDDEN_SIZE,
                                   device=exp_config.device, dtype=torch.float32) * 0.02
        print(f"  Cold-start: random N(0, 0.02)")
    else:
        raise ValueError(f"Unknown init_mode: {init_mode}")

    soft_prefix = soft_prefix.detach().requires_grad_(True)
    print(f"  Shape: {soft_prefix.shape}, norm: {soft_prefix.norm().item():.4f}")

    optimizer = torch.optim.AdamW([soft_prefix], lr=lr, weight_decay=0.01)

    total_steps = n_epochs * len(train_data)
    total_optim_steps = total_steps // grad_accum
    print(f"  Total steps: {total_steps}, optim steps: {total_optim_steps}")

    # Checkpoint resume
    history = []
    start_step = 0
    if checkpoint_path.exists():
        ckpt = json.loads(checkpoint_path.read_text())
        if ckpt.get('init_mode') == init_mode and ckpt.get('total_steps') == total_steps:
            history = ckpt['history']
            start_step = ckpt['completed_steps']
            soft_prefix_data = torch.tensor(ckpt['soft_prefix'],
                                            device=exp_config.device, dtype=torch.float32)
            soft_prefix = soft_prefix_data.requires_grad_(True)
            optimizer = torch.optim.AdamW([soft_prefix], lr=lr, weight_decay=0.01)
            print(f"  Resumed from checkpoint: step {start_step}/{total_steps}")

    t_start = time.time()
    step = 0
    optim_step = 0
    running_loss = 0.0
    running_nll_gap = 0.0  # NLL_irr - NLL_rel (positive = good)
    running_satisfied = 0  # fraction where hinge loss = 0
    running_count = 0

    for epoch in range(n_epochs):
        np.random.seed(SEED + epoch)
        epoch_indices = np.random.permutation(len(train_data))

        for data_idx in epoch_indices:
            if step < start_step:
                step += 1
                continue

            qdata = train_data[data_idx]
            query_prompt = QUERY_TEMPLATE.format(query=qdata['query'])
            answer_text = ANSWER_TEMPLATE.format(answer=qdata['answer'])

            # Sample 1 relevant + 1 irrelevant passage
            rel_passages = [p for p in qdata['passages'] if p['is_relevant']]
            irr_passages = [p for p in qdata['passages'] if not p['is_relevant']]
            rel_p = rel_passages[np.random.randint(len(rel_passages))]
            irr_p = irr_passages[np.random.randint(len(irr_passages))]

            try:
                # Score relevant passage
                bos_rel, doc_rel = get_matched_doc_ids(rel_p['passage'])
                nll_rel = differentiable_hybrid_score(
                    soft_prefix, doc_rel, bos_rel,
                    query_prompt, answer_text,
                    model, tokenizer, exp_config, CUTOFF)

                # Score irrelevant passage
                bos_irr, doc_irr = get_matched_doc_ids(irr_p['passage'])
                nll_irr = differentiable_hybrid_score(
                    soft_prefix, doc_irr, bos_irr,
                    query_prompt, answer_text,
                    model, tokenizer, exp_config, CUTOFF)

                # Hinge loss
                hinge_loss = torch.clamp(margin + nll_rel - nll_irr, min=0.0)
                scaled_loss = hinge_loss / grad_accum
                scaled_loss.backward()

                running_loss += hinge_loss.item()
                running_nll_gap += (nll_irr.item() - nll_rel.item())
                running_satisfied += int(hinge_loss.item() == 0)
                running_count += 1

            except RuntimeError as e:
                print(f"  Step {step}: RuntimeError: {e}")
                optimizer.zero_grad()
                gc.collect()
                torch.cuda.empty_cache()
                step += 1
                continue

            # Optimizer step
            if (step + 1) % grad_accum == 0:
                optim_step += 1
                if optim_step <= warmup_steps:
                    for pg in optimizer.param_groups:
                        pg['lr'] = lr * (optim_step / warmup_steps)

                grad_norm = soft_prefix.grad.norm().item() if soft_prefix.grad is not None else 0
                torch.nn.utils.clip_grad_norm_([soft_prefix], max_norm=1.0)

                optimizer.step()
                optimizer.zero_grad()

                avg_loss = running_loss / running_count if running_count > 0 else 0
                avg_gap = running_nll_gap / running_count if running_count > 0 else 0
                sat_frac = running_satisfied / running_count if running_count > 0 else 0

                history.append({
                    'step': step,
                    'optim_step': optim_step,
                    'epoch': epoch,
                    'avg_loss': avg_loss,
                    'avg_nll_gap': avg_gap,
                    'satisfied_frac': sat_frac,
                    'grad_norm': grad_norm,
                    'prefix_norm': soft_prefix.norm().item(),
                    'lr': optimizer.param_groups[0]['lr'],
                })
                running_loss = 0.0
                running_nll_gap = 0.0
                running_satisfied = 0
                running_count = 0

            # Cleanup
            del hinge_loss, scaled_loss, nll_rel, nll_irr
            gc.collect()
            torch.cuda.empty_cache()

            step += 1

            # Checkpoint
            if step % CHECKPOINT_EVERY == 0 or step == total_steps:
                elapsed = time.time() - t_start
                steps_done = step - start_step
                rate = steps_done / elapsed if elapsed > 0 else 0
                remaining = (total_steps - step) / rate if rate > 0 else 0

                last = history[-1] if history else {}
                tqdm.write(
                    f"  [{init_mode}] Step {step}/{total_steps} | "
                    f"loss={last.get('avg_loss', 0):.4f} | "
                    f"gap={last.get('avg_nll_gap', 0):.3f} | "
                    f"sat={last.get('satisfied_frac', 0):.1%} | "
                    f"norm={soft_prefix.norm().item():.3f} | "
                    f"ETA: {remaining/60:.1f}m")

                ckpt_data = {
                    'init_mode': init_mode,
                    'completed_steps': step,
                    'total_steps': total_steps,
                    'history': history,
                    'soft_prefix': soft_prefix.detach().cpu().tolist(),
                    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
                }
                with open(checkpoint_path, 'w') as f:
                    json.dump(ckpt_data, f)

    # Save final
    torch.save(soft_prefix.detach().cpu(), save_path)

    elapsed = time.time() - t_start
    print(f"\n  Training complete: {step} steps in {elapsed/60:.1f} min")
    print(f"  Final prefix norm: {soft_prefix.norm().item():.4f}")
    print(f"  Saved to: {save_path}")

    return {
        'soft_prefix': soft_prefix.detach(),
        'history': history,
        'init_mode': init_mode,
    }


print("train_contrastive_soft_prefix() defined")

In [None]:
# Cell 9: Train warm start (init from exp25 soft_prefix_fact.pt)

result_warm = train_contrastive_soft_prefix(
    init_mode='warm',
    train_data=train_queries,
    n_epochs=N_EPOCHS,
    lr=LR,
    grad_accum=GRAD_ACCUM,
    margin=MARGIN,
    checkpoint_path=CHECKPOINT_TRAIN_WARM_PATH,
    save_path=SOFT_WARM_PATH,
    warmup_steps=WARMUP_STEPS,
)
soft_contrastive_warm = result_warm['soft_prefix']
history_warm = result_warm['history']

In [None]:
# Cell 10: Train cold start (init from random)

result_cold = train_contrastive_soft_prefix(
    init_mode='cold',
    train_data=train_queries,
    n_epochs=N_EPOCHS,
    lr=LR,
    grad_accum=GRAD_ACCUM,
    margin=MARGIN,
    checkpoint_path=CHECKPOINT_TRAIN_COLD_PATH,
    save_path=SOFT_COLD_PATH,
    warmup_steps=WARMUP_STEPS,
)
soft_contrastive_cold = result_cold['soft_prefix']
history_cold = result_cold['history']

In [None]:
# Cell 11: Training curves

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(16, 10))

for hist, label, color in [(history_warm, 'warm (exp25_fact)', '#ff7f0e'),
                            (history_cold, 'cold (random)', '#1f77b4')]:
    if not hist:
        continue
    steps = [h['optim_step'] for h in hist]
    losses = [h['avg_loss'] for h in hist]
    gaps = [h['avg_nll_gap'] for h in hist]
    sats = [h['satisfied_frac'] for h in hist]
    pnorms = [h['prefix_norm'] for h in hist]

    w = min(20, len(losses) // 3 + 1)

    # Panel 1: Contrastive loss
    ax = axes[0, 0]
    ax.plot(steps, losses, alpha=0.2, color=color)
    if len(losses) > w:
        smoothed = np.convolve(losses, np.ones(w)/w, mode='valid')
        ax.plot(steps[w-1:], smoothed, linewidth=2, color=color, label=label)
    else:
        ax.plot(steps, losses, linewidth=2, color=color, label=label)

    # Panel 2: NLL gap (irr - rel)
    ax = axes[0, 1]
    ax.plot(steps, gaps, alpha=0.2, color=color)
    if len(gaps) > w:
        smoothed = np.convolve(gaps, np.ones(w)/w, mode='valid')
        ax.plot(steps[w-1:], smoothed, linewidth=2, color=color, label=label)
    else:
        ax.plot(steps, gaps, linewidth=2, color=color, label=label)

    # Panel 3: Satisfied fraction
    ax = axes[1, 0]
    ax.plot(steps, sats, alpha=0.2, color=color)
    if len(sats) > w:
        smoothed = np.convolve(sats, np.ones(w)/w, mode='valid')
        ax.plot(steps[w-1:], smoothed, linewidth=2, color=color, label=label)
    else:
        ax.plot(steps, sats, linewidth=2, color=color, label=label)

    # Panel 4: Prefix norm
    ax = axes[1, 1]
    ax.plot(steps, pnorms, linewidth=2, color=color, label=label)

axes[0, 0].set_xlabel('Optimizer Step')
axes[0, 0].set_ylabel('Hinge Loss')
axes[0, 0].set_title('Contrastive Loss')
axes[0, 0].legend(fontsize=8)
axes[0, 0].grid(alpha=0.3)

axes[0, 1].set_xlabel('Optimizer Step')
axes[0, 1].set_ylabel('NLL Gap (irr - rel)')
axes[0, 1].set_title('NLL Gap (positive = correct direction)')
axes[0, 1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[0, 1].legend(fontsize=8)
axes[0, 1].grid(alpha=0.3)

axes[1, 0].set_xlabel('Optimizer Step')
axes[1, 0].set_ylabel('Fraction Satisfied')
axes[1, 0].set_title('Hinge Margin Satisfied')
axes[1, 0].set_ylim(-0.05, 1.05)
axes[1, 0].legend(fontsize=8)
axes[1, 0].grid(alpha=0.3)

axes[1, 1].set_xlabel('Optimizer Step')
axes[1, 1].set_ylabel('Norm')
axes[1, 1].set_title('Prefix Embedding Norm')
axes[1, 1].legend(fontsize=8)
axes[1, 1].grid(alpha=0.3)

plt.suptitle('Exp 28: Contrastive Training Curves', fontsize=13)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved to {RESULTS_DIR / 'training_curves.png'}")

In [None]:
# Cell 12: Load 200 validation queries — multi-passage format (like Exp 22)

print("=" * 70)
print("LOADING VALIDATION QUERIES — MULTI-PASSAGE FORMAT")
print("=" * 70)

dataset = load_dataset("microsoft/ms_marco", "v1.1", split="validation")
print(f"Total items: {len(dataset)}")

val_queries = []
np.random.seed(SEED)  # Same seed as Exp 22 for comparability

for item in tqdm(dataset, desc="Filtering val"):
    passages_info = item.get('passages', {})
    passage_texts = passages_info.get('passage_text', [])
    is_selected = passages_info.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])

    if not passage_texts or not query:
        continue
    if not is_selected or sum(is_selected) == 0:
        continue

    word_counts = [count_words(p) for p in passage_texts]
    if any(wc > MAX_PASSAGE_WORDS for wc in word_counts):
        continue

    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    else:
        continue

    passage_list = []
    for i, (ptext, sel) in enumerate(zip(passage_texts, is_selected)):
        passage_list.append({
            'passage': ptext,
            'is_relevant': bool(sel == 1),
            'word_count': word_counts[i],
            'passage_idx': i,
        })

    val_queries.append({
        'query': query,
        'answer': answer,
        'passages': passage_list,
        'n_passages': len(passage_list),
        'n_relevant': sum(1 for p in passage_list if p['is_relevant']),
    })

    if len(val_queries) >= N_EVAL * 3:
        break

np.random.shuffle(val_queries)
val_queries = val_queries[:N_EVAL]
N = len(val_queries)

n_val_passages = [q['n_passages'] for q in val_queries]
total_val_passages = sum(n_val_passages)
total_val_rel = sum(q['n_relevant'] for q in val_queries)

print(f"\nSelected {N} val queries ({total_val_passages} passages)")
print(f"Passages per query: mean={np.mean(n_val_passages):.1f}")
print(f"Relevant: {total_val_rel} ({100*total_val_rel/total_val_passages:.1f}%)")

del dataset
gc.collect()

In [None]:
# Cell 13: Ranking eval — score ALL passages per query under 5 conditions

print("=" * 70)
print(f"RANKING EVALUATION ({N} queries, {total_val_passages} passages, 5 conditions)")
print("=" * 70)

# Load soft prefixes (in case of restart)
if 'soft_contrastive_warm' not in dir():
    soft_contrastive_warm = torch.load(SOFT_WARM_PATH).to(exp_config.device)
    print(f"Loaded contrastive_warm from {SOFT_WARM_PATH}")
if 'soft_contrastive_cold' not in dir():
    soft_contrastive_cold = torch.load(SOFT_COLD_PATH).to(exp_config.device)
    print(f"Loaded contrastive_cold from {SOFT_COLD_PATH}")

soft_exp25_fact = torch.load(EXP25_SOFT_FACT).to(exp_config.device)
print(f"Loaded exp25_fact from {EXP25_SOFT_FACT}")

layer_indices = list(range(CUTOFF))

def score_baseline_nll(query_prompt, answer_text):
    """Score answer with BOS-only cache (no document)."""
    bos_id = torch.tensor([[tokenizer.bos_token_id]], device=exp_config.device)
    with torch.no_grad():
        bos_out = model(input_ids=bos_id,
                        attention_mask=torch.ones_like(bos_id),
                        use_cache=True, return_dict=True)
    bos_cache = _ensure_dynamic_cache(bos_out.past_key_values)
    del bos_out
    nll = score_answer_with_cache(
        bos_cache, 1, query_prompt, answer_text,
        model, tokenizer, exp_config)
    return nll


def score_soft_passage(soft_embs, doc_ids, bos_id, doc_len, context_len,
                        bare_cache, query_prompt, answer_text):
    """Score a passage through soft prefix hybrid cache (no grad)."""
    with torch.no_grad():
        bos_emb = embed_fn(bos_id)
        doc_emb = embed_fn(doc_ids)
        soft_cast = soft_embs.to(device=exp_config.device, dtype=bos_emb.dtype)

        inputs_embeds = torch.cat([bos_emb, soft_cast, doc_emb], dim=1)
        total_len = inputs_embeds.shape[1]
        attn_mask = torch.ones((1, total_len), device=exp_config.device, dtype=torch.long)

        soft_out = model(inputs_embeds=inputs_embeds,
                        attention_mask=attn_mask,
                        use_cache=True, return_dict=True)
        soft_cache = _ensure_dynamic_cache(soft_out.past_key_values)
        del soft_out

        soft_trunc = extract_and_truncate_cache_with_bos(soft_cache, doc_len)
        del soft_cache

        vel_cache = replace_values_at_layers(bare_cache, soft_trunc, layer_indices)
        del soft_trunc

        nll = score_answer_with_cache(
            deepcopy_cache(vel_cache), context_len,
            query_prompt, answer_text, model, tokenizer, exp_config)
        del vel_cache

    return nll


# Checkpoint resume
all_results = []
start_idx = 0

if CHECKPOINT_EVAL_PATH.exists():
    with open(CHECKPOINT_EVAL_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('query_texts', [])
    current_queries = [q['query'] for q in val_queries]
    if ckpt_queries == current_queries:
        all_results = ckpt['results']
        start_idx = len(all_results)
        print(f"Resuming from checkpoint: {start_idx}/{N}")
    else:
        print("Checkpoint query mismatch. Starting fresh.")
else:
    print("No checkpoint found. Starting fresh.")

EVAL_CHECKPOINT_EVERY = 10
t_start = time.time()

for qidx in tqdm(range(start_idx, N), initial=start_idx, total=N, desc="Ranking eval"):
    qdata = val_queries[qidx]
    query_prompt = QUERY_TEMPLATE.format(query=qdata['query'])
    answer_text = ANSWER_TEMPLATE.format(answer=qdata['answer'])

    # Baseline NLL (BOS-only, once per query)
    nll_baseline = score_baseline_nll(query_prompt, answer_text)

    passage_results = []
    for pidx, pinfo in enumerate(qdata['passages']):
        passage_text = pinfo['passage']
        document_text = DOCUMENT_TEMPLATE.format(document=passage_text)

        # Matched tokenization
        full_text = sf_str + document_text
        full_enc = tokenizer(full_text, return_tensors="pt",
                              add_special_tokens=True, padding=False, truncation=False)
        full_ids = full_enc['input_ids'].to(exp_config.device)
        sf_enc = tokenizer(sf_str, return_tensors="pt",
                            add_special_tokens=True, padding=False, truncation=False)
        sf_len = sf_enc['input_ids'].shape[1]
        bos_id = full_ids[:, :1]
        doc_ids = full_ids[:, sf_len:]
        doc_len = doc_ids.shape[1]
        context_len = 1 + doc_len

        del full_enc, full_ids, sf_enc

        # Build bare cache
        bare_input = torch.cat([bos_id, doc_ids], dim=1)
        with torch.no_grad():
            bare_out = model(input_ids=bare_input,
                             attention_mask=torch.ones_like(bare_input),
                             use_cache=True, return_dict=True)
        bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
        del bare_out, bare_input

        # Condition 1: bare
        nll_bare = score_answer_with_cache(
            deepcopy_cache(bare_cache), context_len,
            query_prompt, answer_text, model, tokenizer, exp_config)

        # Condition 2: exp25_fact
        nll_exp25 = score_soft_passage(
            soft_exp25_fact, doc_ids, bos_id, doc_len, context_len,
            bare_cache, query_prompt, answer_text)

        # Condition 3: contrastive_warm
        nll_warm = score_soft_passage(
            soft_contrastive_warm, doc_ids, bos_id, doc_len, context_len,
            bare_cache, query_prompt, answer_text)

        # Condition 4: contrastive_cold
        nll_cold = score_soft_passage(
            soft_contrastive_cold, doc_ids, bos_id, doc_len, context_len,
            bare_cache, query_prompt, answer_text)

        del bare_cache
        gc.collect()
        torch.cuda.empty_cache()

        passage_results.append({
            'passage_idx': pinfo['passage_idx'],
            'is_relevant': pinfo['is_relevant'],
            'word_count': pinfo['word_count'],
            'doc_len': doc_len,
            'nll_bare': float(nll_bare),
            'nll_exp25_fact': float(nll_exp25),
            'nll_contrastive_warm': float(nll_warm),
            'nll_contrastive_cold': float(nll_cold),
            'nll_baseline': float(nll_baseline),
        })

    all_results.append({
        'query_idx': qidx,
        'query': qdata['query'],
        'answer': qdata['answer'],
        'n_passages': len(passage_results),
        'n_relevant': qdata['n_relevant'],
        'nll_baseline': float(nll_baseline),
        'passage_data': passage_results,
    })

    # Checkpoint
    if (qidx + 1) % EVAL_CHECKPOINT_EVERY == 0 or qidx == N - 1:
        ckpt_data = {
            'results': all_results,
            'query_texts': [q['query'] for q in val_queries],
            'completed': len(all_results),
            'total': N,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_EVAL_PATH, 'w') as f:
            json.dump(ckpt_data, f)
        elapsed = time.time() - t_start
        n_done = qidx - start_idx + 1
        rate = n_done / elapsed if elapsed > 0 else 0
        remaining = (N - qidx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint {qidx+1}/{N} | {n_done} done in {elapsed/60:.1f}m | "
                   f"ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEval complete: {len(all_results)} queries in {elapsed_total/60:.1f} min")

In [None]:
# Cell 14: Ranking analysis — AUC-ROC, MRR@10, differential NLL

from sklearn.metrics import roc_auc_score, roc_curve

print("=" * 70)
print("RANKING ANALYSIS")
print("=" * 70)

# Flatten passage-level data
is_relevant_all = []
nll_bare_all = []
nll_exp25_all = []
nll_warm_all = []
nll_cold_all = []
nll_baseline_all = []

for r in all_results:
    for p in r['passage_data']:
        is_relevant_all.append(int(p['is_relevant']))
        nll_bare_all.append(p['nll_bare'])
        nll_exp25_all.append(p['nll_exp25_fact'])
        nll_warm_all.append(p['nll_contrastive_warm'])
        nll_cold_all.append(p['nll_contrastive_cold'])
        nll_baseline_all.append(p['nll_baseline'])

is_relevant = np.array(is_relevant_all)
nll_bare = np.array(nll_bare_all)
nll_exp25 = np.array(nll_exp25_all)
nll_warm = np.array(nll_warm_all)
nll_cold = np.array(nll_cold_all)
nll_baseline = np.array(nll_baseline_all)

# PMI scores
pmi_bare = nll_bare - nll_baseline
pmi_exp25 = nll_exp25 - nll_baseline
pmi_warm = nll_warm - nll_baseline
pmi_cold = nll_cold - nll_baseline

n_total = len(is_relevant)
n_rel = int(is_relevant.sum())
n_irr = n_total - n_rel

print(f"Total passages: {n_total}")
print(f"Relevant: {n_rel} ({100*n_rel/n_total:.1f}%), Irrelevant: {n_irr}")

# === AUC-ROC ===
scoring_methods = {
    'Raw bare': nll_bare,
    'Raw exp25_fact': nll_exp25,
    'Raw contr_warm': nll_warm,
    'Raw contr_cold': nll_cold,
    'PMI bare': pmi_bare,
    'PMI exp25_fact': pmi_exp25,
    'PMI contr_warm': pmi_warm,
    'PMI contr_cold': pmi_cold,
}

auc_results = {}
print(f"\n{'Method':<20} {'AUC':>8}")
print("-" * 30)
for name, scores in scoring_methods.items():
    auc = roc_auc_score(is_relevant, -scores)
    auc_results[name] = float(auc)
    marker = " <<<" if name in ['Raw contr_warm', 'PMI contr_warm'] else ""
    print(f"{name:<20} {auc:>8.3f}{marker}")

# === MRR@10 ===
def compute_mrr_at_k(all_results, score_fn, k=10):
    rr_list = []
    for r in all_results:
        passages = r['passage_data']
        scored = [(score_fn(p), p['is_relevant']) for p in passages]
        scored.sort(key=lambda x: x[0])
        rr = 0.0
        for rank, (score, rel) in enumerate(scored[:k], 1):
            if rel:
                rr = 1.0 / rank
                break
        rr_list.append(rr)
    return np.mean(rr_list), rr_list

mrr_fns = {
    'Raw bare': lambda p: p['nll_bare'],
    'Raw exp25_fact': lambda p: p['nll_exp25_fact'],
    'Raw contr_warm': lambda p: p['nll_contrastive_warm'],
    'Raw contr_cold': lambda p: p['nll_contrastive_cold'],
    'PMI bare': lambda p: p['nll_bare'] - p['nll_baseline'],
    'PMI exp25_fact': lambda p: p['nll_exp25_fact'] - p['nll_baseline'],
    'PMI contr_warm': lambda p: p['nll_contrastive_warm'] - p['nll_baseline'],
    'PMI contr_cold': lambda p: p['nll_contrastive_cold'] - p['nll_baseline'],
}

mrr_results = {}
mrr_per_query = {}
print(f"\n{'Method':<20} {'MRR@10':>8}")
print("-" * 30)
for name, fn in mrr_fns.items():
    mrr, rr_list = compute_mrr_at_k(all_results, fn, k=10)
    mrr_results[name] = float(mrr)
    mrr_per_query[name] = rr_list
    marker = " <<<" if name in ['Raw contr_warm', 'PMI contr_warm'] else ""
    print(f"{name:<20} {mrr:>8.3f}{marker}")

# === Differential NLL ===
print(f"\n{'Method':<20} {'Mean Rel':>10} {'Mean Irr':>10} {'Diff':>10} {'d':>8}")
print("-" * 62)
diff_results = {}
for name, scores in scoring_methods.items():
    rel_vals = scores[is_relevant == 1]
    irr_vals = scores[is_relevant == 0]
    diff = np.mean(irr_vals) - np.mean(rel_vals)
    pooled_std = np.sqrt(
        (np.var(rel_vals) * (len(rel_vals)-1) + np.var(irr_vals) * (len(irr_vals)-1)) /
        (len(rel_vals) + len(irr_vals) - 2)
    )
    d = diff / pooled_std if pooled_std > 0 else 0
    t_stat, p_val = stats.ttest_ind(irr_vals, rel_vals)
    diff_results[name] = {
        'mean_relevant': float(np.mean(rel_vals)),
        'mean_irrelevant': float(np.mean(irr_vals)),
        'diff': float(diff),
        'cohens_d': float(d),
        't_stat': float(t_stat),
        'p_value': float(p_val),
    }
    print(f"{name:<20} {np.mean(rel_vals):>10.4f} {np.mean(irr_vals):>10.4f} "
          f"{diff:>+10.4f} {d:>+8.3f}")

# === Summary ===
print(f"\n{'=' * 70}")
print("SUMMARY TABLE")
print(f"{'=' * 70}")
print(f"{'Method':<20} {'AUC':>8} {'MRR@10':>8} {'Diff NLL':>10} {'d':>8}")
print("-" * 58)
for name in scoring_methods:
    print(f"{name:<20} {auc_results[name]:>8.3f} {mrr_results[name]:>8.3f} "
          f"{diff_results[name]['diff']:>+10.4f} {diff_results[name]['cohens_d']:>+8.3f}")

In [None]:
# Cell 15: NLL improvement check — does contrastive prefix still help average NLL?

print("=" * 70)
print("NLL IMPROVEMENT CHECK (avg NLL vs bare, per-query paired)")
print("=" * 70)

# Compute per-query mean NLL for each condition
per_query_bare = []
per_query_exp25 = []
per_query_warm = []
per_query_cold = []

for r in all_results:
    # Use only the relevant passage(s) for NLL comparison (like Exp 25)
    rel_passages = [p for p in r['passage_data'] if p['is_relevant']]
    if not rel_passages:
        continue
    per_query_bare.append(np.mean([p['nll_bare'] for p in rel_passages]))
    per_query_exp25.append(np.mean([p['nll_exp25_fact'] for p in rel_passages]))
    per_query_warm.append(np.mean([p['nll_contrastive_warm'] for p in rel_passages]))
    per_query_cold.append(np.mean([p['nll_contrastive_cold'] for p in rel_passages]))

pq_bare = np.array(per_query_bare)
pq_exp25 = np.array(per_query_exp25)
pq_warm = np.array(per_query_warm)
pq_cold = np.array(per_query_cold)

nll_conditions = {
    'exp25_fact': pq_exp25,
    'contrastive_warm': pq_warm,
    'contrastive_cold': pq_cold,
}

print(f"\n{'Condition':<20} {'Mean NLL':>10} {'Delta':>10} {'d':>8} {'Win%':>7} {'p':>12}")
print("-" * 72)
print(f"{'bare':<20} {np.mean(pq_bare):>10.4f}")

nll_improvement = {}
for name, arr in nll_conditions.items():
    delta = pq_bare - arr
    d = cohens_d(delta)
    win = np.mean(delta > 0) * 100
    _, p_val = stats.ttest_1samp(delta, 0)
    sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
    print(f"{name:<20} {np.mean(arr):>10.4f} {np.mean(delta):>+10.4f} "
          f"{d:>+8.3f} {win:>6.1f}% {p_val:>12.2e} {sig}")
    nll_improvement[name] = {
        'mean_nll': float(np.mean(arr)),
        'mean_delta': float(np.mean(delta)),
        'cohens_d': float(d),
        'win_pct': float(win),
        'p_value': float(p_val),
    }

In [None]:
# Cell 16: Plots — ROC curves, score distributions, MRR scatter, summary bars

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

colors = {
    'Raw bare': '#1f77b4',
    'Raw exp25_fact': '#2ca02c',
    'Raw contr_warm': '#ff7f0e',
    'Raw contr_cold': '#d62728',
    'PMI bare': '#1f77b4',
    'PMI exp25_fact': '#2ca02c',
    'PMI contr_warm': '#ff7f0e',
    'PMI contr_cold': '#d62728',
}

# --- Panel 1: ROC curves (Raw NLL) ---
ax = axes[0, 0]
for name in ['Raw bare', 'Raw exp25_fact', 'Raw contr_warm', 'Raw contr_cold']:
    scores = scoring_methods[name]
    fpr, tpr, _ = roc_curve(is_relevant, -scores)
    ax.plot(fpr, tpr, color=colors[name], linewidth=2,
            label=f"{name} (AUC={auc_results[name]:.3f})")
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves — Raw NLL')
ax.legend(fontsize=8, loc='lower right')
ax.grid(alpha=0.3)

# --- Panel 2: ROC curves (PMI) ---
ax = axes[0, 1]
for name in ['PMI bare', 'PMI exp25_fact', 'PMI contr_warm', 'PMI contr_cold']:
    scores = scoring_methods[name]
    fpr, tpr, _ = roc_curve(is_relevant, -scores)
    ax.plot(fpr, tpr, color=colors[name], linewidth=2,
            label=f"{name} (AUC={auc_results[name]:.3f})")
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves — PMI')
ax.legend(fontsize=8, loc='lower right')
ax.grid(alpha=0.3)

# --- Panel 3: AUC + MRR summary bars ---
ax = axes[1, 0]
methods = ['Raw bare', 'Raw exp25_fact', 'Raw contr_warm', 'Raw contr_cold']
method_labels = ['bare', 'exp25\nfact', 'contr\nwarm', 'contr\ncold']
aucs = [auc_results[m] for m in methods]
mrrs = [mrr_results[m] for m in methods]

x = np.arange(len(methods))
width = 0.35
ax.bar(x - width/2, aucs, width, label='AUC', color=[colors[m] for m in methods], alpha=0.7)
ax.bar(x + width/2, mrrs, width, label='MRR@10', color=[colors[m] for m in methods], alpha=0.4,
       edgecolor=[colors[m] for m in methods], linewidth=2)
ax.set_xticks(x)
ax.set_xticklabels(method_labels, fontsize=8)
ax.set_ylabel('Score')
ax.set_title('AUC & MRR@10 (Raw NLL)')
ax.legend(fontsize=8)
ax.set_ylim(0.5, 1.0)
ax.grid(axis='y', alpha=0.3)

for i, (a, m) in enumerate(zip(aucs, mrrs)):
    ax.text(i - width/2, a + 0.005, f'{a:.3f}', ha='center', va='bottom', fontsize=7)
    ax.text(i + width/2, m + 0.005, f'{m:.3f}', ha='center', va='bottom', fontsize=7)

# --- Panel 4: NLL improvement (per-query d vs bare) ---
ax = axes[1, 1]
nll_names = ['exp25_fact', 'contrastive_warm', 'contrastive_cold']
nll_labels = ['exp25\nfact', 'contr\nwarm', 'contr\ncold']
nll_ds = [nll_improvement[n]['cohens_d'] for n in nll_names]
nll_colors = ['#2ca02c', '#ff7f0e', '#d62728']

bars = ax.bar(range(len(nll_names)), nll_ds, color=nll_colors,
              edgecolor='black', linewidth=0.5)
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
ax.set_xticks(range(len(nll_names)))
ax.set_xticklabels(nll_labels, fontsize=8)
ax.set_ylabel("Cohen's d vs bare")
ax.set_title('NLL Improvement (relevant passages only)')
ax.grid(axis='y', alpha=0.3)

for i, d_val in enumerate(nll_ds):
    sig = '***' if nll_improvement[nll_names[i]]['p_value'] < 0.001 else \
          '**' if nll_improvement[nll_names[i]]['p_value'] < 0.01 else \
          '*' if nll_improvement[nll_names[i]]['p_value'] < 0.05 else 'ns'
    ax.text(i, d_val + 0.01 if d_val >= 0 else d_val - 0.02,
            f"{d_val:+.3f} {sig}", ha='center',
            va='bottom' if d_val >= 0 else 'top', fontsize=9)

plt.suptitle('Exp 28: Contrastive Ranking Soft Prefix', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'ranking_plots.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved to {RESULTS_DIR / 'ranking_plots.png'}")

In [None]:
# Cell 17: Save results.json + CSV + final verdict

# --- CSV ---
csv_rows = []
for r in all_results:
    for p in r['passage_data']:
        csv_rows.append({
            'query_idx': r['query_idx'],
            'passage_idx': p['passage_idx'],
            'is_relevant': int(p['is_relevant']),
            'nll_bare': p['nll_bare'],
            'nll_exp25_fact': p['nll_exp25_fact'],
            'nll_contrastive_warm': p['nll_contrastive_warm'],
            'nll_contrastive_cold': p['nll_contrastive_cold'],
            'nll_baseline': p['nll_baseline'],
            'pmi_bare': p['nll_bare'] - p['nll_baseline'],
            'pmi_exp25': p['nll_exp25_fact'] - p['nll_baseline'],
            'pmi_warm': p['nll_contrastive_warm'] - p['nll_baseline'],
            'pmi_cold': p['nll_contrastive_cold'] - p['nll_baseline'],
        })

with open(CSV_EVAL_PATH, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=csv_rows[0].keys())
    writer.writeheader()
    writer.writerows(csv_rows)
print(f"CSV saved: {CSV_EVAL_PATH} ({len(csv_rows)} rows)")

# --- Results JSON ---
final = {
    'experiment': 'exp28_contrastive_ranking_soft_prefix',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': MODEL_NAME,
        'model_type': 'gemma3',
        'seed': SEED,
        'cutoff': CUTOFF,
        'prefix_len': PREFIX_LEN,
        'hidden_size': HIDDEN_SIZE,
        'trainable_params': PREFIX_LEN * HIDDEN_SIZE,
        'training': {
            'margin': MARGIN,
            'lr': LR,
            'n_epochs': N_EPOCHS,
            'grad_accum': GRAD_ACCUM,
            'warmup_steps': WARMUP_STEPS,
            'n_train_queries': N_TRAIN,
            'loss': 'hinge: max(0, margin + NLL_rel - NLL_irrel)',
        },
        'eval': {
            'n_queries': N,
            'total_passages': total_val_passages,
            'n_relevant': total_val_rel,
            'conditions': ['bare', 'exp25_fact', 'contrastive_warm', 'contrastive_cold', 'baseline'],
        },
    },
    'training_history': {
        'warm': history_warm if 'history_warm' in dir() else [],
        'cold': history_cold if 'history_cold' in dir() else [],
    },
    'ranking_analysis': {
        'auc': auc_results,
        'mrr_at_10': mrr_results,
        'differential_nll': diff_results,
    },
    'nll_improvement': nll_improvement,
    'reference_exp22': EXP22_REF,
    'per_query_results': all_results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final, f, indent=2)
print(f"Results saved: {FINAL_RESULTS_PATH} ({FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB)")

# --- Final verdict ---
print("\n" + "=" * 70)
print("FINAL VERDICT — Exp 28: Contrastive Ranking Soft Prefix")
print("=" * 70)

print(f"\nModel: Gemma 3 4B | Cutoff: {CUTOFF} | Margin: {MARGIN}")
print(f"Training: {N_TRAIN} queries x {N_EPOCHS} epochs, lr={LR}")
print(f"Eval: {N} queries, {total_val_passages} passages")

print(f"\nRanking Results (Raw NLL):")
print(f"  bare AUC:         {auc_results['Raw bare']:.3f} (ref: {EXP22_REF['raw_bare_auc']:.3f})")
print(f"  exp25_fact AUC:   {auc_results['Raw exp25_fact']:.3f}")
print(f"  contr_warm AUC:   {auc_results['Raw contr_warm']:.3f}")
print(f"  contr_cold AUC:   {auc_results['Raw contr_cold']:.3f}")

print(f"\nRanking Results (PMI):")
print(f"  bare PMI AUC:     {auc_results['PMI bare']:.3f} (ref: {EXP22_REF['pmi_bare_auc']:.3f})")
print(f"  exp25 PMI AUC:    {auc_results['PMI exp25_fact']:.3f}")
print(f"  warm PMI AUC:     {auc_results['PMI contr_warm']:.3f}")
print(f"  cold PMI AUC:     {auc_results['PMI contr_cold']:.3f}")

print(f"\nNLL Improvement (relevant passages, d vs bare):")
for name, data in nll_improvement.items():
    sig = '***' if data['p_value'] < 0.001 else '**' if data['p_value'] < 0.01 else \
          '*' if data['p_value'] < 0.05 else 'ns'
    print(f"  {name:<20} d={data['cohens_d']:+.3f}, win={data['win_pct']:.0f}% {sig}")

# Determine verdict
best_raw_auc = max(auc_results['Raw contr_warm'], auc_results['Raw contr_cold'])
best_pmi_auc = max(auc_results['PMI contr_warm'], auc_results['PMI contr_cold'])

ranking_improved = (best_raw_auc > 0.835) or (best_pmi_auc > 0.845)
nll_still_helps = any(d['cohens_d'] > 0 for d in nll_improvement.values())

if ranking_improved:
    print(f"\nVERDICT: Contrastive training IMPROVES ranking!")
    print(f"  Best raw AUC: {best_raw_auc:.3f} (target: >0.835)")
    print(f"  Best PMI AUC: {best_pmi_auc:.3f} (target: >0.845)")
    print(f"  Value contamination CAN create differential ranking signal.")
else:
    print(f"\nVERDICT: Contrastive training FAILS to improve ranking.")
    print(f"  Best raw AUC: {best_raw_auc:.3f} (target: >0.835, bare: {auc_results['Raw bare']:.3f})")
    print(f"  Best PMI AUC: {best_pmi_auc:.3f} (target: >0.845, bare: {auc_results['PMI bare']:.3f})")
    print(f"  CONFIRMS: Document-independent prefix cannot create query-specific")
    print(f"  relevance discrimination, even with ranking-aware training.")

if nll_still_helps:
    print(f"  Contrastive prefix still helps average NLL (secondary success).")
else:
    print(f"  Contrastive training also HURTS average NLL (full failure).")

print(f"\nDone!")

In [None]:
# Cell 18: GPU cleanup

print("Cleaning up GPU memory...")
mem_before = torch.cuda.memory_allocated() / 1e9

del model
del tokenizer
for var_name in ['soft_contrastive_warm', 'soft_contrastive_cold', 'soft_exp25_fact']:
    if var_name in dir():
        exec(f'del {var_name}')

gc.collect()
torch.cuda.empty_cache()
gc.collect()

mem_after = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {mem_before:.2f} GB -> {mem_after:.2f} GB")
print("Cleanup complete.")