# Exp 25: Layer-Selective Soft Surrogates for Gemma 3

## Motivation

Experiments 16, 19, and 24 proved that **value contamination works on Gemma 3 4B**, but
strictly requires stripping out the primed keys and restricting the primed values to early
layers (0-15, yielding d=**+0.227** in Exp 21). Full-cache priming fails due to late-layer
key interference.

The discrete prefix `"What are the key facts I need to know?"` (static_fact) was chosen
by hand. **Can we do better by learning an optimal continuous prefix?**

## Core Question

Can we learn a sequence of continuous embedding vectors (a "Soft Prompt") that maximizes
document value contamination for factoid QA, beating the discrete static_fact prefix when
applied via Gemma's required layer-selective hybrid cache?

## Theoretical Mechanism

We combine **Soft Prompt Tuning** with the **values_early_layers** mechanism. The model
remains completely frozen; only the soft prefix embeddings are updated. The computational
graph flows backward from the generated answer's loss, through the hybrid cache splice,
and into the soft prefix embeddings.

## Design

| Part | Phase | Data | N | Description |
|------|-------|------|---|-------------|
| 1 | Train | MS MARCO train | 2000 | Learn soft_prefix_embeddings via gradient descent |
| 2 | Eval | MS MARCO val | 300 | Compare 4 conditions against Exp 21 baselines |

### Training (Part 1)
- **Trainable params**: `soft_prefix_embeddings` of shape `(prefix_len, hidden_size)`
  where `prefix_len = 7` (matching static_fact token count)
- **Two init conditions**: random (N(0, 0.02)) and static_fact-initialized
- **Loss**: Mean NLL of answer tokens scored through the hybrid cache
- **Optimizer**: AdamW, lr=0.1 (standard for soft prompt tuning)
- **Epochs**: 3 passes over 2000 training samples

### Evaluation Conditions (Part 2)

| Condition | Keys | Values (L0-15) | Values (L16-33) |
|-----------|------|----------------|------------------|
| bare | bare | bare | bare |
| vel_static | bare | static_fact primed | bare |
| vel_soft_random | bare | soft (random init) primed | bare |
| vel_soft_fact | bare | soft (fact init) primed | bare |

## Reference Values

| Source | Condition | d |
|--------|-----------|---|
| Exp 19 | values_only (all 34 layers) | +0.056 |
| Exp 19 | values_early_layers (0-16) | +0.211 |
| Exp 21 | values_early_layers (0-15) | +0.227 |
| Exp 24 | static_fact @ cutoff=16 | ~+0.21 |

## Success Criteria

- **Primary**: Does `vel_soft_fact` achieve Cohen's d > +0.25 (beating discrete +0.227)?
- **Secondary**: Does `vel_soft_random` learn anything useful from scratch (d > +0.10)?
- **Diagnostic**: Track training loss curves for convergence validation

## Technical Watch-Outs

1. **Memory**: Backprop through full LLM attention requires large activation maps.
   Batch size 1 + gradient accumulation mandatory.
2. **RoPE differentiability**: `correct_rope_positions_with_bos` uses in-place ops.
   We must reimplement the hybrid splice with pure functional ops for the training loop.
3. **4-bit model**: BitsAndBytes 4-bit models may not support gradient computation through
   the full forward pass. We use `inputs_embeds` to bypass the embedding lookup and let
   gradients flow through the soft prefix only.
4. **Learning rate**: Soft prompts typically need lr=0.1-0.3 (much higher than fine-tuning).

In [None]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
import json
import time
import numpy as np
import torch
import gc
from pathlib import Path

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

RESULTS_DIR = Path("results/exp25")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_TRAIN_RAND_PATH = RESULTS_DIR / "checkpoint_train_random.json"
CHECKPOINT_TRAIN_FACT_PATH = RESULTS_DIR / "checkpoint_train_fact.json"
CHECKPOINT_EVAL_PATH = RESULTS_DIR / "checkpoint_eval.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"
CSV_EVAL_PATH = RESULTS_DIR / "eval_results.csv"
SOFT_RANDOM_PATH = RESULTS_DIR / "soft_prefix_random.pt"
SOFT_FACT_PATH = RESULTS_DIR / "soft_prefix_fact.pt"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: Load Gemma 3 4B via load_model()
sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.model_utils import load_model

MODEL_NAME = "google/gemma-3-4b-it"

exp_config = ExperimentConfig(
    model_name=MODEL_NAME,
    model_type="gemma3",
    compute_dtype="auto",  # resolves to bfloat16 for gemma3
    use_4bit=True,
    num_samples=2000,
    seed=SEED,
)

print(f"Loading {MODEL_NAME} (4-bit, bfloat16)...")
model, tokenizer = load_model(exp_config)

# Architecture diagnostics
from lib.kv_cache import (
    _get_text_config, _get_head_dim, _get_rope_theta_for_layer,
    _get_cache_keys, _get_cache_values, _set_cache_keys, _set_cache_values,
    _ensure_dynamic_cache, extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos, score_answer_with_cache,
    deepcopy_cache, replace_values_at_layers,
)

text_config = _get_text_config(model.config)
NUM_LAYERS = text_config.num_hidden_layers
HIDDEN_SIZE = text_config.hidden_size
HEAD_DIM = _get_head_dim(model.config)

print(f"\nModel loaded successfully.")
print(f"  Num layers: {NUM_LAYERS}")
print(f"  Hidden size: {HIDDEN_SIZE}")
print(f"  Head dim: {HEAD_DIM}")
print(f"  Num KV heads: {text_config.num_key_value_heads}")
print(f"  BOS token ID: {tokenizer.bos_token_id}")

# Verify dtype
sample_ids = tokenizer("test", return_tensors="pt")['input_ids'].to(exp_config.device)
with torch.no_grad():
    out = model(sample_ids, use_cache=True)
    cache_check = _ensure_dynamic_cache(out.past_key_values)
    k0 = _get_cache_keys(cache_check, 0)
    print(f"  Cache key dtype: {k0.dtype}")
    print(f"  Cache key shape: {k0.shape}")
del out, sample_ids, cache_check
torch.cuda.empty_cache()

In [None]:
# Cell 3: Constants, templates, imports
from lib.analysis import cohens_d
from lib.data import count_words
from lib.surrogate import STATIC_SURROGATE_QUERIES
from scipy import stats
from tqdm.auto import tqdm

# Templates
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"

STATIC_FACT = STATIC_SURROGATE_QUERIES['static_factual']['query']

# Experiment parameters
N_TRAIN = 2000
N_EVAL = 300
MAX_PASSAGE_WORDS = 300
CUTOFF = 16          # layers 0-15
PREFIX_LEN = 7       # match static_fact token count
CHECKPOINT_EVERY = 50

# Training hyperparameters
LR = 0.1
N_EPOCHS = 3
GRAD_ACCUM_STEPS = 4  # effective batch size = 4
WARMUP_STEPS = 50

# Reference values
EXP19_REF = {'values_only_d': 0.056, 'values_early_layers_d': 0.211}
EXP21_REF = {'values_early_layers_d': 0.227}

print("Config ready")
print(f"  Model: {MODEL_NAME}")
print(f"  Num layers: {NUM_LAYERS}, hidden_size: {HIDDEN_SIZE}")
print(f"  Cutoff: {CUTOFF} (layers 0-{CUTOFF-1})")
print(f"  Soft prefix length: {PREFIX_LEN} tokens")
print(f"  Trainable params: {PREFIX_LEN * HIDDEN_SIZE:,} ({PREFIX_LEN} x {HIDDEN_SIZE})")
print(f"  Training: {N_TRAIN} samples x {N_EPOCHS} epochs, lr={LR}, grad_accum={GRAD_ACCUM_STEPS}")
print(f"  Eval: {N_EVAL} samples, 4 conditions")
print(f"  Static fact prefix: '{STATIC_FACT}'")
print(f"\nReference values:")
print(f"  Exp 19 values_early_layers: d={EXP19_REF['values_early_layers_d']:+.3f}")
print(f"  Exp 21 values_early_layers: d={EXP21_REF['values_early_layers_d']:+.3f}")

In [None]:
# Cell 4: Load MS MARCO training + validation splits
from datasets import load_dataset

def load_marco_split(split_name, n_samples, seed):
    """Load MS MARCO samples with positive passages."""
    dataset = load_dataset("microsoft/ms_marco", "v1.1", split=split_name,
                           trust_remote_code=True)
    print(f"Total items in {split_name}: {len(dataset)}")

    samples = []
    np.random.seed(seed)

    for item in tqdm(dataset, desc=f"Filtering {split_name}"):
        passages_info = item.get('passages', {})
        passage_texts = passages_info.get('passage_text', [])
        is_selected = passages_info.get('is_selected', [])
        query = item.get('query', '')
        answers = item.get('answers', [])
        well_formed = item.get('wellFormedAnswers', [])

        if not passage_texts or not query:
            continue

        answer = None
        if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
            answer = well_formed[0]
        elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
            answer = answers[0]
        else:
            continue

        for ptext, sel in zip(passage_texts, is_selected):
            if sel == 1 and count_words(ptext) <= MAX_PASSAGE_WORDS:
                samples.append({
                    'query': query,
                    'answer': answer,
                    'passage': ptext,
                    'word_count': count_words(ptext),
                })
                break

        if len(samples) >= n_samples * 3:
            break

    np.random.shuffle(samples)
    samples = samples[:n_samples]
    del dataset
    gc.collect()
    return samples

print("=" * 70)
print("LOADING MS MARCO v1.1 — TRAINING SPLIT")
print("=" * 70)
train_samples = load_marco_split("train", N_TRAIN, SEED)
print(f"Selected {len(train_samples)} training samples")
print(f"Word counts: mean={np.mean([q['word_count'] for q in train_samples]):.0f}, "
      f"min={min(q['word_count'] for q in train_samples)}, "
      f"max={max(q['word_count'] for q in train_samples)}")

print("\n" + "=" * 70)
print("LOADING MS MARCO v1.1 — VALIDATION SPLIT")
print("=" * 70)
eval_samples = load_marco_split("validation", N_EVAL, SEED + 1)
print(f"Selected {len(eval_samples)} eval samples")
print(f"Word counts: mean={np.mean([q['word_count'] for q in eval_samples]):.0f}, "
      f"min={min(q['word_count'] for q in eval_samples)}, "
      f"max={max(q['word_count'] for q in eval_samples)}")

In [None]:
# Cell 5: Tokenize static_fact prefix + BPE boundary check

print("=" * 70)
print("PREFIX TOKENIZATION")
print("=" * 70)

sf_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=STATIC_FACT)
sf_ids = tokenizer(sf_str, return_tensors="pt",
                    add_special_tokens=False)['input_ids'].to(exp_config.device)
SF_TOKEN_LEN = sf_ids.shape[1]

print(f"Static fact prefix: '{STATIC_FACT}'")
print(f"  Formatted: '{sf_str.strip()}'")
print(f"  Token length: {SF_TOKEN_LEN}")
print(f"  Soft prefix length: {PREFIX_LEN} (matches: {PREFIX_LEN == SF_TOKEN_LEN})")

# If mismatch, update PREFIX_LEN to match
if PREFIX_LEN != SF_TOKEN_LEN:
    print(f"  WARNING: Updating PREFIX_LEN from {PREFIX_LEN} to {SF_TOKEN_LEN}")
    PREFIX_LEN = SF_TOKEN_LEN

# BPE boundary check
example_doc = train_samples[0]['passage']
concat = sf_str + DOCUMENT_TEMPLATE.format(document=example_doc)
concat_enc = tokenizer(concat, add_special_tokens=True)['input_ids']
prefix_enc = tokenizer(sf_str, add_special_tokens=True)['input_ids']
doc_ids_from_concat = concat_enc[len(prefix_enc):]
bare_doc_enc = tokenizer(DOCUMENT_TEMPLATE.format(document=example_doc),
                          add_special_tokens=False)['input_ids']
match = sum(1 for a, b in zip(doc_ids_from_concat, bare_doc_enc) if a == b)
total = max(len(bare_doc_enc), 1)
print(f"\nBPE boundary check: {match}/{total} tokens match ({100*match/total:.1f}%)")

In [None]:
# Cell 6: Explain experimental conditions

print("=" * 70)
print("EXPERIMENTAL CONDITIONS")
print("=" * 70)

print("\n### Part 1: Training the Soft Prefix ###")
print(f"  Data: {N_TRAIN} MS MARCO training queries")
print(f"  Trainable: soft_prefix_embeddings ({PREFIX_LEN} x {HIDDEN_SIZE} = {PREFIX_LEN * HIDDEN_SIZE:,} params)")
print(f"  Two runs: random init (N(0, 0.02)) and static_fact init")
print(f"  Epochs: {N_EPOCHS}, lr={LR}, grad_accum={GRAD_ACCUM_STEPS}")
print()
print("  Per training step:")
print("    1. Build inputs_embeds = [BOS_emb] + [soft_prefix (grad)] + [doc_embs (detached)]")
print("    2. Forward pass -> get primed cache (gradients flow through soft prefix)")
print("    3. Extract primed values at layers 0-15 (functional ops, no in-place mutation)")
print("    4. Build bare cache (no grad)")
print("    5. Splice: bare keys + primed values (L0-15) + bare values (L16-33)")
print("    6. Score answer NLL through hybrid cache")
print("    7. loss.backward() -> updates only soft_prefix_embeddings")

print("\n### Part 2: Evaluation (4 conditions) ###")
print(f"  Data: {N_EVAL} MS MARCO validation queries")
print()
print("  bare:")
print("    Cache: [BOS][doc] -> score as-is")
print("    Baseline, no modifications.")
print()
print("  vel_static:")
print("    Cache: [BOS][static_fact][doc] -> truncate -> RoPE correct")
print(f"    Replace values at layers 0-{CUTOFF-1} into bare cache")
print(f"    This is the Exp 21 condition (d=+0.227). Ceiling to beat.")
print()
print("  vel_soft_random:")
print("    Cache: [BOS][soft_random_embs][doc_embs] -> forward -> extract values L0-15")
print(f"    Splice into bare cache. Tests: can random-init soft prefix learn useful values?")
print()
print("  vel_soft_fact:")
print("    Cache: [BOS][soft_fact_embs][doc_embs] -> forward -> extract values L0-15")
print(f"    Splice into bare cache. Tests: can we refine static_fact in continuous space?")

In [None]:
# Cell 7: Differentiable hybrid cache scoring function
#
# The existing lib functions use in-place ops and torch.no_grad().
# For training we need a fully differentiable path from soft embeddings to loss.

from lib.kv_cache import _get_rope_theta_for_layer, _build_rope_correction, _rotate_half

# Get the embedding layer using standard HuggingFace API
embed_fn = model.get_input_embeddings()
print(f"Embedding layer found: {type(embed_fn).__name__}")
print(f"Embedding dim: {embed_fn.weight.shape}")


def differentiable_hybrid_score(
    soft_prefix: torch.Tensor,       # (1, prefix_len, hidden_size), requires_grad
    doc_ids: torch.Tensor,            # (1, doc_len)
    bos_id: torch.Tensor,             # (1, 1)
    query_prompt: str,
    answer_text: str,
    model,
    tokenizer,
    config,
    cutoff: int,
):
    """
    Compute answer NLL through a hybrid cache built from soft prefix embeddings.

    Returns a scalar loss tensor with gradients flowing back to soft_prefix.
    """
    device = config.device
    doc_len = doc_ids.shape[1]
    prefix_len = soft_prefix.shape[1]
    context_len = 1 + doc_len  # BOS + doc

    # --- Step 1: Bare cache (no gradients needed) ---
    bare_input = torch.cat([bos_id, doc_ids], dim=1)
    with torch.no_grad():
        bare_out = model(input_ids=bare_input,
                         attention_mask=torch.ones_like(bare_input),
                         use_cache=True, return_dict=True)
    bare_cache = bare_out.past_key_values
    del bare_out

    # --- Step 2: Primed cache via inputs_embeds (gradients enabled) ---
    # Get embeddings for BOS and doc tokens (detached from graph)
    with torch.no_grad():
        bos_emb = embed_fn(bos_id)          # (1, 1, hidden)
        doc_emb = embed_fn(doc_ids)          # (1, doc_len, hidden)

    # Cast soft prefix to model dtype for forward pass
    soft_cast = soft_prefix.to(dtype=bos_emb.dtype)

    # Concatenate: [BOS_emb, soft_prefix, doc_emb]
    inputs_embeds = torch.cat([bos_emb.detach(), soft_cast, doc_emb.detach()], dim=1)
    total_len = inputs_embeds.shape[1]  # 1 + prefix_len + doc_len
    attn_mask = torch.ones((1, total_len), device=device, dtype=torch.long)

    # Forward pass with gradients flowing through soft_prefix
    primed_out = model(inputs_embeds=inputs_embeds,
                       attention_mask=attn_mask,
                       use_cache=True, return_dict=True)
    primed_cache = primed_out.past_key_values
    del primed_out

    # --- Step 3+4: Build hybrid cache ---
    # Keys: from bare cache (already at correct positions)
    # Values L0-{cutoff-1}: from primed cache (BOS + last doc_len positions)
    # Values L{cutoff}-{N-1}: from bare cache
    # No RoPE correction needed — we use bare keys, and values have no RoPE.

    primed_cache_dc = _ensure_dynamic_cache(primed_cache)
    bare_cache_dc = _ensure_dynamic_cache(bare_cache)

    from transformers import DynamicCache
    from transformers.cache_utils import DynamicSlidingWindowLayer, DynamicLayer

    hybrid_cache = DynamicCache()
    for layer_idx in range(NUM_LAYERS):
        k = _get_cache_keys(bare_cache_dc, layer_idx)

        if layer_idx < cutoff:
            primed_v = _get_cache_values(primed_cache_dc, layer_idx)
            bos_v = primed_v[:, :, :1, :]
            doc_v = primed_v[:, :, -doc_len:, :]
            v = torch.cat([bos_v, doc_v], dim=2)
        else:
            v = _get_cache_values(bare_cache_dc, layer_idx)

        src_layer = bare_cache_dc.layers[layer_idx]
        if isinstance(src_layer, DynamicSlidingWindowLayer):
            new_layer = DynamicSlidingWindowLayer(sliding_window=src_layer.sliding_window)
            new_layer.dtype = k.dtype
            new_layer.device = k.device
            new_layer.keys = k
            new_layer.values = v
            new_layer.is_initialized = True
            new_layer.cumulative_length = src_layer.cumulative_length
            new_layer._sliding_window_tensor = new_layer._sliding_window_tensor.to(k.device)
        else:
            new_layer = DynamicLayer()
            new_layer.dtype = k.dtype
            new_layer.device = k.device
            new_layer.keys = k
            new_layer.values = v
            new_layer.is_initialized = True
        hybrid_cache.layers.append(new_layer)

    # --- Step 5: Score answer through hybrid cache ---
    query_ids = tokenizer(query_prompt, return_tensors="pt",
                          add_special_tokens=False)['input_ids'].to(device)
    answer_ids = tokenizer(answer_text, return_tensors="pt",
                           add_special_tokens=False)['input_ids'].to(device)
    query_len = query_ids.shape[1]
    answer_len = answer_ids.shape[1]

    # Single forward pass: query + answer together
    qa_ids = torch.cat([query_ids, answer_ids], dim=1)
    qa_len = qa_ids.shape[1]
    qa_attn_full = torch.ones((1, context_len + qa_len), device=device)

    qa_out = model(input_ids=qa_ids,
                   attention_mask=qa_attn_full,
                   past_key_values=hybrid_cache,
                   use_cache=False, return_dict=True)

    logits = qa_out.logits
    # logit at position [query_len-1] predicts answer_ids[0]
    answer_logits = logits[:, query_len - 1 : query_len + answer_len - 1, :]

    loss = torch.nn.functional.cross_entropy(
        answer_logits.reshape(-1, answer_logits.shape[-1]),
        answer_ids.reshape(-1),
        reduction='mean'
    )

    del qa_out, logits, bare_cache, bare_cache_dc, primed_cache, primed_cache_dc

    return loss


print("differentiable_hybrid_score() defined")
print("  Input: soft_prefix (requires_grad), doc_ids, query, answer")
print("  Output: scalar NLL loss with gradients to soft_prefix")

In [None]:
# Cell 8: Gradient flow sanity check
# Verify that gradients actually flow back to the soft prefix before committing
# to the full training loop.

print("=" * 70)
print("GRADIENT FLOW SANITY CHECK")
print("=" * 70)

# Create a tiny soft prefix
test_prefix = torch.randn(1, PREFIX_LEN, HIDDEN_SIZE,
                           device=exp_config.device,
                           dtype=torch.float32,
                           requires_grad=True)

# Use first training sample
sample = train_samples[0]
sf_str_test = SURROGATE_PREFIX_TEMPLATE.format(surrogate=STATIC_FACT)
doc_text = DOCUMENT_TEMPLATE.format(document=sample['passage'])

# Matched tokenization
full_text = sf_str_test + doc_text
full_enc = tokenizer(full_text, return_tensors="pt",
                      add_special_tokens=True, padding=False, truncation=False)
full_ids = full_enc['input_ids'].to(exp_config.device)
sf_prefix_enc = tokenizer(sf_str_test, return_tensors="pt",
                           add_special_tokens=True, padding=False, truncation=False)
sf_prefix_len_matched = sf_prefix_enc['input_ids'].shape[1]
bos_id = full_ids[:, :1]
doc_ids = full_ids[:, sf_prefix_len_matched:]

query_prompt = QUERY_TEMPLATE.format(query=sample['query'])
answer_text = ANSWER_TEMPLATE.format(answer=sample['answer'])

print(f"Test sample: doc_len={doc_ids.shape[1]}, query='{sample['query'][:50]}...'")
print(f"Test prefix shape: {test_prefix.shape}")

test_loss = None
try:
    test_loss = differentiable_hybrid_score(
        test_prefix, doc_ids, bos_id,
        query_prompt, answer_text,
        model, tokenizer, exp_config, CUTOFF)

    print(f"\nForward pass: loss = {test_loss.item():.4f}")
    print(f"Loss requires_grad: {test_loss.requires_grad}")

    test_loss.backward()

    print(f"Backward pass: SUCCESS")
    print(f"Gradient shape: {test_prefix.grad.shape}")
    print(f"Gradient norm: {test_prefix.grad.norm().item():.6f}")
    print(f"Gradient mean: {test_prefix.grad.mean().item():.6f}")
    print(f"Gradient max: {test_prefix.grad.abs().max().item():.6f}")

    if test_prefix.grad.norm().item() > 0:
        print("\n>>> GRADIENT FLOW CONFIRMED. Training loop should work. <<<")
    else:
        print("\n>>> WARNING: Zero gradients. Check computational graph. <<<")

except Exception as e:
    print(f"\n>>> GRADIENT FLOW FAILED: {e} <<<")
    print("Training will not work. Need to debug the differentiable path.")
    import traceback
    traceback.print_exc()

finally:
    del test_prefix
    if test_loss is not None:
        del test_loss
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# Cell 9: Training loop

def train_soft_prefix(init_mode, train_data, n_epochs, lr, grad_accum,
                       checkpoint_path, save_path, warmup_steps=50):
    """
    Train soft prefix embeddings via gradient descent on answer NLL.

    Args:
        init_mode: 'random' or 'fact' (initialize from static_fact embeddings)
        train_data: list of query dicts
        n_epochs: number of passes over data
        lr: learning rate
        grad_accum: gradient accumulation steps
        checkpoint_path: path to save training checkpoints
        save_path: path to save final trained embeddings
        warmup_steps: linear warmup steps

    Returns:
        dict with training history and final embeddings
    """
    print(f"\n{'=' * 70}")
    print(f"TRAINING SOFT PREFIX — init={init_mode}")
    print(f"{'=' * 70}")

    # Initialize soft prefix
    if init_mode == 'random':
        soft_prefix = torch.randn(1, PREFIX_LEN, HIDDEN_SIZE,
                                   device=exp_config.device, dtype=torch.float32) * 0.02
    elif init_mode == 'fact':
        # Get embeddings for static_fact tokens
        with torch.no_grad():
            fact_emb = embed_fn(sf_ids)  # (1, prefix_len, hidden)
        soft_prefix = fact_emb.float().clone()
    else:
        raise ValueError(f"Unknown init_mode: {init_mode}")

    soft_prefix = soft_prefix.detach().requires_grad_(True)

    print(f"  Soft prefix shape: {soft_prefix.shape}")
    print(f"  Soft prefix dtype: {soft_prefix.dtype}")
    print(f"  Soft prefix norm: {soft_prefix.norm().item():.4f}")

    optimizer = torch.optim.AdamW([soft_prefix], lr=lr, weight_decay=0.01)

    total_steps = n_epochs * len(train_data)
    total_optim_steps = total_steps // grad_accum
    print(f"  Total forward passes: {total_steps}")
    print(f"  Total optimizer steps: {total_optim_steps}")
    print(f"  Warmup steps: {warmup_steps}")

    # Checkpoint resume
    history = []
    start_step = 0
    if checkpoint_path.exists():
        ckpt = json.loads(checkpoint_path.read_text())
        if ckpt.get('init_mode') == init_mode and ckpt.get('total_steps') == total_steps:
            history = ckpt['history']
            start_step = ckpt['completed_steps']
            # Restore soft_prefix
            soft_prefix_data = torch.tensor(ckpt['soft_prefix'],
                                            device=exp_config.device, dtype=torch.float32)
            soft_prefix = soft_prefix_data.requires_grad_(True)
            optimizer = torch.optim.AdamW([soft_prefix], lr=lr, weight_decay=0.01)
            print(f"  Resumed from checkpoint: step {start_step}/{total_steps}")

    t_start = time.time()
    step = 0
    optim_step = 0
    running_loss = 0.0
    running_count = 0

    for epoch in range(n_epochs):
        # Shuffle training data each epoch
        np.random.seed(SEED + epoch)
        epoch_indices = np.random.permutation(len(train_data))

        for idx_in_epoch, data_idx in enumerate(epoch_indices):
            # Skip already-completed steps
            if step < start_step:
                step += 1
                continue

            sample = train_data[data_idx]
            doc_text = DOCUMENT_TEMPLATE.format(document=sample['passage'])
            query_prompt = QUERY_TEMPLATE.format(query=sample['query'])
            answer_text = ANSWER_TEMPLATE.format(answer=sample['answer'])

            # Matched tokenization using sf_str as reference
            full_text = sf_str + doc_text
            full_enc = tokenizer(full_text, return_tensors="pt",
                                  add_special_tokens=True, padding=False, truncation=False)
            full_ids_t = full_enc['input_ids'].to(exp_config.device)
            sf_enc = tokenizer(sf_str, return_tensors="pt",
                                add_special_tokens=True, padding=False, truncation=False)
            sf_len = sf_enc['input_ids'].shape[1]
            bos_t = full_ids_t[:, :1]
            doc_ids_t = full_ids_t[:, sf_len:]

            try:
                loss = differentiable_hybrid_score(
                    soft_prefix, doc_ids_t, bos_t,
                    query_prompt, answer_text,
                    model, tokenizer, exp_config, CUTOFF)

                # Scale loss for gradient accumulation
                scaled_loss = loss / grad_accum
                scaled_loss.backward()

                running_loss += loss.item()
                running_count += 1

            except RuntimeError as e:
                print(f"  Step {step}: RuntimeError: {e}")
                optimizer.zero_grad()
                gc.collect()
                torch.cuda.empty_cache()
                step += 1
                continue

            # Optimizer step
            if (step + 1) % grad_accum == 0:
                # Linear warmup
                optim_step += 1
                if optim_step <= warmup_steps:
                    warmup_factor = optim_step / warmup_steps
                    for pg in optimizer.param_groups:
                        pg['lr'] = lr * warmup_factor

                # Gradient clipping
                grad_norm = soft_prefix.grad.norm().item() if soft_prefix.grad is not None else 0
                torch.nn.utils.clip_grad_norm_([soft_prefix], max_norm=1.0)

                optimizer.step()
                optimizer.zero_grad()

                avg_loss = running_loss / running_count if running_count > 0 else 0
                history.append({
                    'step': step,
                    'optim_step': optim_step,
                    'epoch': epoch,
                    'avg_loss': avg_loss,
                    'grad_norm': grad_norm,
                    'prefix_norm': soft_prefix.norm().item(),
                    'lr': optimizer.param_groups[0]['lr'],
                })
                running_loss = 0.0
                running_count = 0

            # Cleanup
            del loss, scaled_loss
            gc.collect()
            torch.cuda.empty_cache()

            step += 1

            # Checkpoint
            if step % CHECKPOINT_EVERY == 0 or step == total_steps:
                elapsed = time.time() - t_start
                steps_done = step - start_step
                rate = steps_done / elapsed if elapsed > 0 else 0
                remaining = (total_steps - step) / rate if rate > 0 else 0

                last_loss = history[-1]['avg_loss'] if history else 0
                tqdm.write(f"  [{init_mode}] Step {step}/{total_steps} | "
                           f"loss={last_loss:.4f} | "
                           f"prefix_norm={soft_prefix.norm().item():.3f} | "
                           f"ETA: {remaining/60:.1f}m")

                ckpt_data = {
                    'init_mode': init_mode,
                    'completed_steps': step,
                    'total_steps': total_steps,
                    'history': history,
                    'soft_prefix': soft_prefix.detach().cpu().tolist(),
                    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
                }
                with open(checkpoint_path, 'w') as f:
                    json.dump(ckpt_data, f)

    # Save final embeddings
    torch.save(soft_prefix.detach().cpu(), save_path)

    elapsed = time.time() - t_start
    print(f"\n  Training complete: {step} steps in {elapsed/60:.1f} min")
    print(f"  Final prefix norm: {soft_prefix.norm().item():.4f}")
    print(f"  Saved to: {save_path}")

    return {
        'soft_prefix': soft_prefix.detach(),
        'history': history,
        'init_mode': init_mode,
    }


print("train_soft_prefix() defined")

In [None]:
# Cell 10: Train random-init soft prefix

result_random = train_soft_prefix(
    init_mode='random',
    train_data=train_samples,
    n_epochs=N_EPOCHS,
    lr=LR,
    grad_accum=GRAD_ACCUM_STEPS,
    checkpoint_path=CHECKPOINT_TRAIN_RAND_PATH,
    save_path=SOFT_RANDOM_PATH,
    warmup_steps=WARMUP_STEPS,
)
soft_prefix_random = result_random['soft_prefix']
history_random = result_random['history']

In [None]:
# Cell 11: Train fact-init soft prefix

result_fact = train_soft_prefix(
    init_mode='fact',
    train_data=train_samples,
    n_epochs=N_EPOCHS,
    lr=LR,
    grad_accum=GRAD_ACCUM_STEPS,
    checkpoint_path=CHECKPOINT_TRAIN_FACT_PATH,
    save_path=SOFT_FACT_PATH,
    warmup_steps=WARMUP_STEPS,
)
soft_prefix_fact = result_fact['soft_prefix']
history_fact = result_fact['history']

In [None]:
# Cell 12: Training curves visualization

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for hist, label, color in [(history_random, 'random init', '#1f77b4'),
                            (history_fact, 'fact init', '#ff7f0e')]:
    if not hist:
        continue
    steps = [h['optim_step'] for h in hist]
    losses = [h['avg_loss'] for h in hist]
    gnorms = [h['grad_norm'] for h in hist]
    pnorms = [h['prefix_norm'] for h in hist]

    # Smoothed loss (rolling average window=20)
    w = min(20, len(losses) // 3 + 1)
    smoothed = np.convolve(losses, np.ones(w)/w, mode='valid') if len(losses) > w else losses
    smooth_steps = steps[w-1:] if len(losses) > w else steps

    axes[0].plot(steps, losses, alpha=0.2, color=color)
    axes[0].plot(smooth_steps, smoothed, linewidth=2, color=color, label=label)
    axes[1].plot(steps, gnorms, alpha=0.5, linewidth=1, color=color, label=label)
    axes[2].plot(steps, pnorms, linewidth=2, color=color, label=label)

axes[0].set_xlabel('Optimizer Step')
axes[0].set_ylabel('Loss (NLL)')
axes[0].set_title('Training Loss')
axes[0].legend()

axes[1].set_xlabel('Optimizer Step')
axes[1].set_ylabel('Gradient Norm')
axes[1].set_title('Gradient Norm')
axes[1].legend()

axes[2].set_xlabel('Optimizer Step')
axes[2].set_ylabel('Prefix Embedding Norm')
axes[2].set_title('Prefix Norm')
axes[2].legend()

plt.suptitle('Exp 25: Soft Prefix Training Curves', fontsize=13)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Training curves saved to {RESULTS_DIR / 'training_curves.png'}")

In [None]:
# Cell 13: Part 2 — Evaluation (300 validation queries, 4 conditions)

print("=" * 70)
print(f"PART 2: EVALUATION ({N_EVAL} queries, 4 conditions)")
print("=" * 70)

# Load trained soft prefixes (in case of restart)
if 'soft_prefix_random' not in dir():
    soft_prefix_random = torch.load(SOFT_RANDOM_PATH).to(exp_config.device)
    print(f"Loaded soft_prefix_random from {SOFT_RANDOM_PATH}")
if 'soft_prefix_fact' not in dir():
    soft_prefix_fact = torch.load(SOFT_FACT_PATH).to(exp_config.device)
    print(f"Loaded soft_prefix_fact from {SOFT_FACT_PATH}")

# Checkpoint resume
eval_results = []
eval_start_idx = 0

if CHECKPOINT_EVAL_PATH.exists():
    with open(CHECKPOINT_EVAL_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('query_texts', [])
    current_queries = [q['query'] for q in eval_samples[:N_EVAL]]
    if ckpt_queries == current_queries:
        eval_results = ckpt['results']
        eval_start_idx = len(eval_results)
        print(f"Resuming from checkpoint: {eval_start_idx}/{N_EVAL}")
    else:
        print("Checkpoint query mismatch. Starting fresh.")
else:
    print("No checkpoint found. Starting fresh.")

layer_indices = list(range(CUTOFF))
t_start = time.time()

for qidx in tqdm(range(eval_start_idx, N_EVAL), initial=eval_start_idx, total=N_EVAL,
                  desc="Eval"):
    qdata = eval_samples[qidx]
    query_prompt = QUERY_TEMPLATE.format(query=qdata['query'])
    answer_text = ANSWER_TEMPLATE.format(answer=qdata['answer'])
    passage = qdata['passage']
    document_text = DOCUMENT_TEMPLATE.format(document=passage)

    # Matched tokenization (using sf_str as reference for BPE boundaries)
    full_text = sf_str + document_text
    full_enc = tokenizer(full_text, return_tensors="pt",
                          add_special_tokens=True, padding=False, truncation=False)
    full_ids = full_enc['input_ids'].to(exp_config.device)
    sf_prefix_enc = tokenizer(sf_str, return_tensors="pt",
                               add_special_tokens=True, padding=False, truncation=False)
    sf_prefix_len_matched = sf_prefix_enc['input_ids'].shape[1]
    bos_id = full_ids[:, :1]
    doc_ids = full_ids[:, sf_prefix_len_matched:]
    doc_len = doc_ids.shape[1]
    context_len = 1 + doc_len

    del full_enc, full_ids, sf_prefix_enc

    # --- Condition 1: bare ---
    bare_input = torch.cat([bos_id, doc_ids], dim=1)
    with torch.no_grad():
        bare_out = model(input_ids=bare_input,
                         attention_mask=torch.ones_like(bare_input),
                         use_cache=True, return_dict=True)
    bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
    del bare_out

    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_cache), context_len,
        query_prompt, answer_text, model, tokenizer, exp_config)

    # --- Condition 2: vel_static (discrete static_fact prefix) ---
    primed_input = torch.cat([bos_id, sf_ids, doc_ids], dim=1)
    with torch.no_grad():
        primed_out = model(input_ids=primed_input,
                           attention_mask=torch.ones_like(primed_input),
                           use_cache=True, return_dict=True)
    primed_full = _ensure_dynamic_cache(primed_out.past_key_values)
    del primed_out

    trunc_raw = extract_and_truncate_cache_with_bos(primed_full, doc_len)
    sf_trunc_cache = deepcopy_cache(trunc_raw)
    correct_rope_positions_with_bos(sf_trunc_cache, sf_ids.shape[1], model)
    del primed_full, trunc_raw

    vel_static_cache = replace_values_at_layers(bare_cache, sf_trunc_cache, layer_indices)
    vel_static_nll = score_answer_with_cache(
        deepcopy_cache(vel_static_cache), context_len,
        query_prompt, answer_text, model, tokenizer, exp_config)
    del sf_trunc_cache, vel_static_cache

    # --- Helper: score soft prefix condition ---
    def score_soft_condition(soft_embs):
        """Build hybrid cache from soft embeddings and score."""
        with torch.no_grad():
            bos_emb = embed_fn(bos_id)
            doc_emb = embed_fn(doc_ids)
            soft_cast = soft_embs.to(device=exp_config.device, dtype=bos_emb.dtype)

            inputs_embeds = torch.cat([bos_emb, soft_cast, doc_emb], dim=1)
            total_len = inputs_embeds.shape[1]
            attn_mask = torch.ones((1, total_len), device=exp_config.device, dtype=torch.long)

            soft_out = model(inputs_embeds=inputs_embeds,
                            attention_mask=attn_mask,
                            use_cache=True, return_dict=True)
            soft_cache = _ensure_dynamic_cache(soft_out.past_key_values)
            del soft_out

            # Extract BOS + doc values from soft cache, splice into bare
            soft_trunc = extract_and_truncate_cache_with_bos(soft_cache, doc_len)
            # No RoPE correction needed because we use bare keys
            # (values don't have positional encoding)
            del soft_cache

            vel_soft_cache = replace_values_at_layers(bare_cache, soft_trunc, layer_indices)
            del soft_trunc

            nll = score_answer_with_cache(
                deepcopy_cache(vel_soft_cache), context_len,
                query_prompt, answer_text, model, tokenizer, exp_config)
            del vel_soft_cache

        return nll

    # --- Condition 3: vel_soft_random ---
    vel_soft_random_nll = score_soft_condition(soft_prefix_random)

    # --- Condition 4: vel_soft_fact ---
    vel_soft_fact_nll = score_soft_condition(soft_prefix_fact)

    del bare_cache, bare_input
    gc.collect()
    torch.cuda.empty_cache()

    eval_results.append({
        'query_idx': qidx,
        'query': qdata['query'],
        'doc_len': doc_len,
        'bare_nll': bare_nll,
        'vel_static_nll': vel_static_nll,
        'vel_soft_random_nll': vel_soft_random_nll,
        'vel_soft_fact_nll': vel_soft_fact_nll,
    })

    # Checkpoint
    if (qidx + 1) % CHECKPOINT_EVERY == 0 or qidx == N_EVAL - 1:
        ckpt_data = {
            'results': eval_results,
            'query_texts': [q['query'] for q in eval_samples[:N_EVAL]],
            'completed': len(eval_results),
            'total': N_EVAL,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_EVAL_PATH, 'w') as f:
            json.dump(ckpt_data, f)
        elapsed = time.time() - t_start
        n_done = qidx - eval_start_idx + 1
        rate = n_done / elapsed if elapsed > 0 else 0
        remaining = (N_EVAL - qidx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint {qidx+1}/{N_EVAL} | {n_done} done in {elapsed/60:.1f}m | "
                   f"ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nEval complete: {len(eval_results)} queries in {elapsed_total/60:.1f} min")

In [None]:
# Cell 14: Evaluation analysis

print("=" * 70)
print("EVALUATION ANALYSIS")
print("=" * 70)

bare_arr = np.array([r['bare_nll'] for r in eval_results])
static_arr = np.array([r['vel_static_nll'] for r in eval_results])
soft_rand_arr = np.array([r['vel_soft_random_nll'] for r in eval_results])
soft_fact_arr = np.array([r['vel_soft_fact_nll'] for r in eval_results])

# Filter valid results
valid = (
    (bare_arr != 0) & np.isfinite(bare_arr) &
    (static_arr != 0) & np.isfinite(static_arr) &
    (soft_rand_arr != 0) & np.isfinite(soft_rand_arr) &
    (soft_fact_arr != 0) & np.isfinite(soft_fact_arr)
)

b = bare_arr[valid]
conditions = {
    'vel_static': static_arr[valid],
    'vel_soft_random': soft_rand_arr[valid],
    'vel_soft_fact': soft_fact_arr[valid],
}

print(f"\nValid samples: {np.sum(valid)}/{len(eval_results)}")
print(f"\n{'Condition':<20} {'Mean NLL':>10} {'Mean D':>10} {'d':>8} {'Win%':>7} {'p':>12} {'sig':>5}")
print("-" * 78)
print(f"{'bare':<20} {np.mean(b):>10.4f} {'—':>10} {'—':>8} {'—':>7} {'—':>12} {'—':>5}")

eval_analysis = {}
for cname, carr in conditions.items():
    delta = b - carr
    d = cohens_d(delta)
    win = np.mean(delta > 0) * 100
    t_stat, p_val = stats.ttest_1samp(delta, 0)
    sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
    print(f"{cname:<20} {np.mean(carr):>10.4f} {np.mean(delta):>+10.4f} "
          f"{d:>+8.3f} {win:>6.1f}% {p_val:>12.2e} {sig:>5}")
    eval_analysis[cname] = {
        'n_valid': int(np.sum(valid)),
        'mean_nll': float(np.mean(carr)),
        'mean_delta': float(np.mean(delta)),
        'cohens_d': float(d),
        'win_pct': float(win),
        't_stat': float(t_stat),
        'p_value': float(p_val),
    }

# Pairwise: soft vs static
print("\nPairwise comparisons:")
for n1, a1, n2, a2 in [
    ('vel_soft_fact', conditions['vel_soft_fact'], 'vel_static', conditions['vel_static']),
    ('vel_soft_random', conditions['vel_soft_random'], 'vel_static', conditions['vel_static']),
    ('vel_soft_fact', conditions['vel_soft_fact'], 'vel_soft_random', conditions['vel_soft_random']),
]:
    delta = a2 - a1  # positive = n1 better (lower NLL)
    d = cohens_d(delta)
    t_stat, p_val = stats.ttest_1samp(delta, 0)
    sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
    print(f"  {n1} vs {n2}: d={d:+.3f}, p={p_val:.2e} {sig}")

# Reference comparison
print(f"\nReference comparison:")
print(f"  Exp 21 vel_static d: {EXP21_REF['values_early_layers_d']:+.3f}")
print(f"  This exp vel_static d: {eval_analysis['vel_static']['cohens_d']:+.3f}")
print(f"  This exp vel_soft_fact d: {eval_analysis['vel_soft_fact']['cohens_d']:+.3f}")
print(f"  This exp vel_soft_random d: {eval_analysis['vel_soft_random']['cohens_d']:+.3f}")

# Hardness gradient (quintiles)
print("\nHardness gradient (bare NLL quintiles):")
quintile_bounds = np.percentile(b, [20, 40, 60, 80])
qlabels = ['Q1 easy', 'Q2', 'Q3', 'Q4', 'Q5 hard']
quintiles = np.digitize(b, quintile_bounds)

print(f"{'Condition':<20}", end='')
for ql in qlabels:
    print(f"{ql:>12}", end='')
print()
print("-" * (20 + 12 * 5))

hardness_data = {}
for cname, carr in conditions.items():
    delta = b - carr
    row = []
    print(f"{cname:<20}", end='')
    for q in range(5):
        mask = quintiles == q
        if np.sum(mask) < 5:
            print(f"{'n/a':>12}", end='')
            row.append(None)
        else:
            d_q = cohens_d(delta[mask])
            print(f"{d_q:>+12.3f}", end='')
            row.append(float(d_q))
    print()
    hardness_data[cname] = row

In [None]:
# Cell 15: Evaluation plots (4-panel)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# --- Panel 1: Bar chart of Cohen's d ---
ax = axes[0, 0]
cond_names = ['vel_static', 'vel_soft_random', 'vel_soft_fact']
cond_labels = ['static_fact\n(discrete)', 'soft_random\n(random init)', 'soft_fact\n(fact init)']
cond_ds = [eval_analysis[cn]['cohens_d'] for cn in cond_names]
cond_colors = ['#2ca02c', '#1f77b4', '#ff7f0e']

bars = ax.bar(range(len(cond_names)), cond_ds, color=cond_colors,
              edgecolor='black', linewidth=0.5)
ax.axhline(y=EXP21_REF['values_early_layers_d'], color='#9467bd', linestyle='--',
           linewidth=1.5, label=f"Exp 21 d={EXP21_REF['values_early_layers_d']:+.3f}")
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)

ax.set_xticks(range(len(cond_names)))
ax.set_xticklabels(cond_labels)
ax.set_ylabel("Cohen's d vs Bare")
ax.set_title("Effect Size by Condition")
ax.legend(fontsize=8)

for i, d_val in enumerate(cond_ds):
    p_val = eval_analysis[cond_names[i]]['p_value']
    sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else 'ns'
    ax.text(i, d_val + 0.005 if d_val >= 0 else d_val - 0.015,
            f"{d_val:+.3f} {sig}", ha='center',
            va='bottom' if d_val >= 0 else 'top', fontsize=9)

# --- Panel 2: Per-sample scatter (soft_fact vs static) ---
ax = axes[0, 1]
ax.scatter(conditions['vel_static'], conditions['vel_soft_fact'],
           alpha=0.4, s=20, color='#ff7f0e', edgecolors='none')
lims = [min(conditions['vel_static'].min(), conditions['vel_soft_fact'].min()),
        max(conditions['vel_static'].max(), conditions['vel_soft_fact'].max())]
ax.plot(lims, lims, 'k--', alpha=0.5, linewidth=1, label='y=x')
ax.set_xlabel('vel_static NLL (discrete)')
ax.set_ylabel('vel_soft_fact NLL (learned)')
ax.set_title('Per-Sample: Soft Fact vs Static')
ax.legend(fontsize=8)

# Count wins
soft_wins = np.sum(conditions['vel_soft_fact'] < conditions['vel_static'])
static_wins = np.sum(conditions['vel_soft_fact'] > conditions['vel_static'])
ties = np.sum(conditions['vel_soft_fact'] == conditions['vel_static'])
ax.text(0.05, 0.95, f"Soft wins: {soft_wins}\nStatic wins: {static_wins}\nTies: {ties}",
        transform=ax.transAxes, fontsize=9, va='top',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# --- Panel 3: Hardness gradient heatmap ---
ax = axes[1, 0]
heatmap = np.zeros((len(cond_names), 5))
for i, cn in enumerate(cond_names):
    for q in range(5):
        val = hardness_data[cn][q]
        heatmap[i, q] = val if val is not None else np.nan

im = ax.imshow(heatmap, cmap='RdBu', aspect='auto', vmin=-0.5, vmax=0.5)
ax.set_xticks(range(5))
ax.set_xticklabels(qlabels, fontsize=8)
ax.set_yticks(range(len(cond_names)))
ax.set_yticklabels(['static', 'soft_random', 'soft_fact'])
ax.set_xlabel('Difficulty Quintile')
ax.set_title("Hardness x Condition (Cohen's d)")

for i in range(len(cond_names)):
    for j in range(5):
        val = heatmap[i, j]
        if not np.isnan(val):
            ax.text(j, i, f"{val:+.2f}", ha='center', va='center',
                    fontsize=8, color='white' if abs(val) > 0.25 else 'black')
fig.colorbar(im, ax=ax, shrink=0.8)

# --- Panel 4: NLL distribution comparison ---
ax = axes[1, 1]
delta_static = b - conditions['vel_static']
delta_soft_fact = b - conditions['vel_soft_fact']
delta_soft_random = b - conditions['vel_soft_random']

ax.hist(delta_static, bins=40, alpha=0.5, color='#2ca02c', label='static', density=True)
ax.hist(delta_soft_fact, bins=40, alpha=0.5, color='#ff7f0e', label='soft_fact', density=True)
ax.hist(delta_soft_random, bins=40, alpha=0.3, color='#1f77b4', label='soft_random', density=True)
ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
ax.set_xlabel('NLL Delta (bare - condition, positive = helps)')
ax.set_ylabel('Density')
ax.set_title('Delta Distribution')
ax.legend(fontsize=8)

plt.suptitle('Exp 25: Soft Prefix Evaluation', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'eval_plots.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Eval plots saved to {RESULTS_DIR / 'eval_plots.png'}")

In [None]:
# Cell 16: Save results.json + CSV
import csv

# --- Eval CSV ---
with open(CSV_EVAL_PATH, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=[
        'query_idx', 'doc_len', 'bare_nll',
        'vel_static_nll', 'vel_soft_random_nll', 'vel_soft_fact_nll'])
    writer.writeheader()
    for r in eval_results:
        writer.writerow({
            'query_idx': r['query_idx'],
            'doc_len': r['doc_len'],
            'bare_nll': r['bare_nll'],
            'vel_static_nll': r['vel_static_nll'],
            'vel_soft_random_nll': r['vel_soft_random_nll'],
            'vel_soft_fact_nll': r['vel_soft_fact_nll'],
        })
print(f"Eval CSV saved: {CSV_EVAL_PATH}")

# --- Combined results.json ---
final = {
    'experiment': 'exp25_soft_prefix_optimization',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': MODEL_NAME,
        'model_type': 'gemma3',
        'seed': SEED,
        'cutoff': CUTOFF,
        'prefix_len': PREFIX_LEN,
        'hidden_size': HIDDEN_SIZE,
        'trainable_params': PREFIX_LEN * HIDDEN_SIZE,
        'training': {
            'n_samples': N_TRAIN,
            'n_epochs': N_EPOCHS,
            'lr': LR,
            'grad_accum': GRAD_ACCUM_STEPS,
            'warmup_steps': WARMUP_STEPS,
            'dataset': 'MS MARCO v1.1 train',
        },
        'eval': {
            'n_samples': N_EVAL,
            'dataset': 'MS MARCO v1.1 validation',
            'conditions': ['bare', 'vel_static', 'vel_soft_random', 'vel_soft_fact'],
        },
    },
    'training_history': {
        'random': history_random if 'history_random' in dir() else [],
        'fact': history_fact if 'history_fact' in dir() else [],
    },
    'eval_analysis': eval_analysis,
    'eval_hardness': hardness_data,
    'reference_values': {
        'exp19_gemma': EXP19_REF,
        'exp21_gemma': EXP21_REF,
    },
    'eval_per_query': eval_results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final, f, indent=2)

print(f"\nResults saved to {FINAL_RESULTS_PATH}")
print(f"File size: {FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB")

# Final summary
print("\n" + "=" * 70)
print("SUMMARY — Exp 25: Soft Prefix Optimization")
print("=" * 70)
print(f"Model: Gemma 3 4B ({NUM_LAYERS} layers, hidden={HIDDEN_SIZE}, bfloat16)")
print(f"Soft prefix: {PREFIX_LEN} vectors x {HIDDEN_SIZE} dims = {PREFIX_LEN * HIDDEN_SIZE:,} params")
print(f"Training: {N_TRAIN} samples x {N_EPOCHS} epochs, lr={LR}")

print(f"\nEvaluation ({N_EVAL} queries):")
for cn in ['vel_static', 'vel_soft_random', 'vel_soft_fact']:
    a = eval_analysis[cn]
    sig = '***' if a['p_value'] < 0.001 else '**' if a['p_value'] < 0.01 else '*' if a['p_value'] < 0.05 else 'ns'
    print(f"  {cn:<20} d={a['cohens_d']:>+.3f}  win={a['win_pct']:.0f}%  {sig}")

d_static = eval_analysis['vel_static']['cohens_d']
d_soft_fact = eval_analysis['vel_soft_fact']['cohens_d']
d_soft_random = eval_analysis['vel_soft_random']['cohens_d']

if d_soft_fact > d_static + 0.02:
    print(f"\nVERDICT: Soft fact-init BEATS discrete static_fact "
          f"({d_soft_fact:+.3f} vs {d_static:+.3f}). "
          f"Continuous optimization improves value contamination.")
elif d_soft_fact > d_static - 0.02:
    print(f"\nVERDICT: Soft fact-init MATCHES discrete static_fact "
          f"({d_soft_fact:+.3f} vs {d_static:+.3f}). "
          f"Continuous space adds no benefit beyond the discrete prefix.")
else:
    print(f"\nVERDICT: Soft fact-init WORSE than discrete static_fact "
          f"({d_soft_fact:+.3f} vs {d_static:+.3f}). "
          f"Gradient optimization may be disrupting the prefix signal.")

if d_soft_random > 0.10:
    print(f"  Random-init learned useful signal from scratch (d={d_soft_random:+.3f}).")
else:
    print(f"  Random-init did NOT learn useful signal (d={d_soft_random:+.3f}).")

print(f"\nDone!")

In [None]:
# Cell 17: GPU cleanup
import gc

print("Cleaning up GPU memory...")
mem_before = torch.cuda.memory_allocated() / 1e9

del model
del tokenizer

gc.collect()
torch.cuda.empty_cache()
gc.collect()

mem_after = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {mem_before:.2f} GB -> {mem_after:.2f} GB")
print("Cleanup complete.")