# Experiment 01: Decoder-Only Surrogate Prefix Conditioning

## Motivation

In a causal (decoder-only) model, document tokens D cannot attend to a query Q
that comes after them. We test whether prepending a **surrogate query** before the
document allows D to encode query-relevant features via causal attention, improving
downstream answer NLL.

## Method — Two-Phase KV Cache with Natural Encoding + RoPE Repositioning

We use Gemma 3 12B-IT with a two-phase scoring approach that preserves semantic
signal during encoding while eliminating position confounds during inference.

**Phase A (conditioning):** Encode `[BOS] + prefix + \n + doc` with **natural**
(sequential) positions. Prefix at `[1, ..., P]`, doc at `[1+P+NL, ..., P+NL+D]`.
Doc tokens attend to prefix via normal causal attention at in-distribution RoPE
relative distances — the semantic content of the prefix can genuinely influence
doc representations.

**Slice:** Remove the first `1 + len(prefix) + len(\n)` KV entries, keeping only
the doc KV cache.

**Reposition:** Rotate all doc **keys** back by `-(P+NL)` RoPE positions, so they
appear at positions `[1, ..., D]` — identical to bare. Values are not affected
by RoPE, so the semantic enrichment from Phase A is preserved in values while
positional geometry matches bare exactly.

**Phase B (inference):** Score `[\n + query + \n + answer]` using the repositioned
doc-only cache, with position_ids starting at `D + 1` (same as bare condition).

**Why this works:** The semantic enrichment from prefix attention is baked into
both keys (content projection) and values during Phase A. Repositioning only changes
the RoPE component of keys, not their content. Phase B then sees doc cache at
bare-identical positions, so any NLL difference comes purely from how the prefix
**content** altered doc representations — not from position offset.

## Conditions (8 total)

| # | Condition | Prefix | Description |
|---|-----------|--------|-------------|
| 1 | bare | (none) | Standard causal — lower bound |
| 2 | oracle | real query | Real query conditions doc — upper bound |
| 3 | surr_universal | generic analysis | "Analyze for entities, facts, relationships" |
| 4 | surr_extractor | data extraction | "Examine for data points, dates, attributes" |
| 5 | surr_reasonant | reasoning | "Evaluate arguments, sentiment, intent" |
| 6 | surr_analytic | technical | "Technical breakdown of systems/processes" |
| 7 | surr_doc_kw | doc keywords | Top-5 document keywords (v3's best) |
| 8 | adversarial | off-topic | Off-topic text — negative control |

## Key metrics
- **Recovery rate**: (surrogate − bare) / (oracle − bare) × 100%
- Cohen's d, win%, paired t-test
- Hardness gradient analysis

In [None]:
# Cell 2: Setup and model loading
import os
os.umask(0o000)

import sys, json, time, gc, re
import random as pyrandom
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from collections import Counter
from scipy import stats
from tqdm.auto import tqdm

sys.path.insert(0, "../../..")
from lib.analysis import cohens_d

SEED = 42
N_SAMPLES = 400
MODEL_NAME = "google/gemma-3-12b-it"

RESULTS_DIR = Path("../../../results/decoder_only/exp01")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_PATH = RESULTS_DIR / "checkpoint.json"

np.random.seed(SEED)
torch.manual_seed(SEED)
pyrandom.seed(SEED)

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
HF_TOKEN = os.environ.get("HF_TOKEN")

from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN,
)
model.eval()

DEVICE = next(model.parameters()).device
BOS_ID = tokenizer.bos_token_id
NEWLINE_IDS = tokenizer("\n", add_special_tokens=False).input_ids

print(f"Exp 01: Decoder-Only Surrogate Prefix Conditioning")
print(f"Scoring: Two-phase KV cache + natural encoding + RoPE repositioning")
print(f"N: {N_SAMPLES}, Model: {MODEL_NAME}")
print(f"DEVICE: {DEVICE}, dtype: {next(model.parameters()).dtype}")
print(f"GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
text_cfg = getattr(model.config, 'text_config', model.config)
print(f"Vocab size: {getattr(text_cfg, 'vocab_size', 'N/A')}")
print(f"Num layers: {getattr(text_cfg, 'num_hidden_layers', 'N/A')}")
print(f"Num KV heads: {getattr(text_cfg, 'num_key_value_heads', 'N/A')}")
rope_params = getattr(text_cfg, 'rope_parameters', {})
layer_types_list = getattr(text_cfg, 'layer_types', [])
print(f"Layer types: {set(layer_types_list)} ({len(layer_types_list)} layers)")
for ltype, params in rope_params.items():
    print(f"  {ltype}: theta={params.get('rope_theta')}, "
          f"type={params.get('rope_type')}, factor={params.get('factor', 'N/A')}")
n_global = sum(1 for t in layer_types_list if t == 'full_attention')
print(f"  Global layers: {n_global}/{len(layer_types_list)} "
      f"(indices: {[i for i, t in enumerate(layer_types_list) if t == 'full_attention']})")

In [None]:
# Cell 3: Two-phase scoring with natural encoding + RoPE repositioning

def slice_kv_cache(cache, start_idx):
    sliced = DynamicCache()
    for i in range(len(cache.layers)):
        k = cache.layers[i].keys[:, :, start_idx:, :]
        v = cache.layers[i].values[:, :, start_idx:, :]
        sliced.update(k, v, i)
    return sliced


def rotate_half(x):
    """Rotates half the hidden dims (HuggingFace RoPE convention)."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def build_layer_inv_freqs(config, head_dim, device):
    """Build per-layer inv_freq tensors matching Gemma 3's hybrid RoPE.

    Gemma 3 uses different RoPE parameters for local (sliding) vs global
    (full) attention layers:
      - sliding_attention: theta=10,000, standard RoPE
      - full_attention: theta=1,000,000, linear scaling (÷8.0)

    Returns a list of inv_freq tensors, one per layer.
    """
    text_cfg = getattr(config, 'text_config', config)
    layer_types = text_cfg.layer_types
    rope_params = text_cfg.rope_parameters

    # Pre-compute inv_freq for each layer type
    type_inv_freqs = {}
    for ltype, params in rope_params.items():
        theta = float(params['rope_theta'])
        inv_freq = 1.0 / (theta ** (
            torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
            / head_dim
        ))
        if params.get('rope_type') == 'linear':
            inv_freq = inv_freq / float(params['factor'])
        type_inv_freqs[ltype] = inv_freq

    # Map each layer to its inv_freq
    return [type_inv_freqs[layer_types[i]] for i in range(len(layer_types))]


def reposition_kv_cache(cache, delta, layer_inv_freqs):
    """Rotate all keys in cache by `delta` RoPE positions.

    Uses per-layer RoPE parameters to handle Gemma 3's hybrid attention:
    local (sliding) and global (full) layers use different theta values.

    Computation done in float32, cast back to original dtype.
    """
    # Pre-compute rotation embeddings per unique inv_freq
    rotation_cache = {}
    repositioned = DynamicCache()

    for i in range(len(cache.layers)):
        k = cache.layers[i].keys
        v = cache.layers[i].values

        # Use id() to cache rotation matrices for layers sharing the same inv_freq
        freq_id = id(layer_inv_freqs[i])
        if freq_id not in rotation_cache:
            inv_freq = layer_inv_freqs[i]
            angles = delta * inv_freq
            emb = torch.cat([angles, angles])
            rotation_cache[freq_id] = (
                emb.cos().view(1, 1, 1, -1),
                emb.sin().view(1, 1, 1, -1),
            )

        cos_d, sin_d = rotation_cache[freq_id]
        k_f = k.float()
        k_new = (k_f * cos_d + rotate_half(k_f) * sin_d).to(k.dtype)
        repositioned.update(k_new, v, i)

    return repositioned


# Pre-compute layer inv_freqs (used by reposition_kv_cache and validation)
head_dim_check = model.config.text_config.head_dim
LAYER_INV_FREQS = build_layer_inv_freqs(model.config, head_dim_check, DEVICE)
print(f"Built per-layer inv_freqs for {len(LAYER_INV_FREQS)} layers")
# Verify the two types have different frequencies
local_freq = LAYER_INV_FREQS[0]  # layer 0 = sliding
global_idx = [i for i, t in enumerate(text_cfg.layer_types) if t == 'full_attention'][0]
global_freq = LAYER_INV_FREQS[global_idx]
print(f"  Local  inv_freq[0]: {local_freq[0].item():.6e}")
print(f"  Global inv_freq[0]: {global_freq[0].item():.6e}")
print(f"  Ratio (local/global): {local_freq[0].item() / global_freq[0].item():.1f}x")


def score(doc_text, query_text, answer_text, prefix_text=None):
    # Two-phase KV cache scoring with natural encoding + RoPE repositioning.
    #
    # Phase A: encode with NATURAL positions so prefix-doc attention uses
    # in-distribution RoPE relative distances (semantic signal preserved).
    #
    # Conditioned: [BOS + prefix + \n + doc] → natural positions [0..P+NL+D]
    #   - Doc tokens attend to prefix at normal positive relative distances
    #   - Slice removes BOS + prefix + \n
    #   - Reposition: rotate doc keys back by -(P+NL) to positions [1..D]
    #
    # Bare: [BOS + doc] → natural positions [0..D]
    #   - Slice removes BOS only
    #   - No repositioning needed (doc already at [1..D])
    #
    # Phase B: score [\n + query + \n + answer] at positions [D+1, ...]
    # Position geometry is IDENTICAL for bare and conditioned.
    # Any NLL difference comes from semantic enrichment of doc representations.

    doc_ids = tokenizer(doc_text, add_special_tokens=False,
                        truncation=True, max_length=1536).input_ids
    D = len(doc_ids)

    if prefix_text:
        prefix_ids = tokenizer(prefix_text, add_special_tokens=False,
                               truncation=True, max_length=512).input_ids
        P = len(prefix_ids)
        NL = len(NEWLINE_IDS)

        # Build Phase A token sequence: [BOS, prefix, \n, doc]
        cond_ids = [BOS_ID] + prefix_ids + NEWLINE_IDS + doc_ids
        slice_start = 1 + P + NL  # remove BOS + prefix + \n
        reposition_delta = -(P + NL)  # rotate doc keys back to [1..D]
    else:
        # Bare: [BOS + doc], natural positions, slice BOS
        cond_ids = [BOS_ID] + doc_ids
        slice_start = 1  # remove BOS
        reposition_delta = 0

    # Phase B always starts at D+1 (after repositioning, doc is at [1..D])
    phase_b_start = D + 1

    # --- Phase A: build KV cache with natural positions ---
    with torch.no_grad():
        pa = model(
            input_ids=torch.tensor([cond_ids], device=DEVICE),
            use_cache=True,
        )
    cache = pa.past_key_values
    del pa

    # Slice prefix entries
    cache = slice_kv_cache(cache, slice_start)

    # Reposition doc keys to [1..D] (no-op for bare)
    if reposition_delta != 0:
        cache = reposition_kv_cache(cache, reposition_delta, LAYER_INV_FREQS)

    # --- Phase B: score query + answer ---
    query_ids = tokenizer("\n" + query_text + "\n", add_special_tokens=False).input_ids
    answer_ids = tokenizer(answer_text, add_special_tokens=False,
                           truncation=True, max_length=256).input_ids
    if not answer_ids:
        del cache
        return 0.0

    pb_ids = query_ids + answer_ids
    pos = torch.arange(phase_b_start, phase_b_start + len(pb_ids), device=DEVICE)

    with torch.no_grad():
        pb = model(
            input_ids=torch.tensor([pb_ids], device=DEVICE),
            past_key_values=cache,
            position_ids=pos.unsqueeze(0),
            cache_position=pos,
            use_cache=False,
        )

    # Score answer tokens only
    n_q = len(query_ids)
    logits = pb.logits[0, n_q - 1:n_q - 1 + len(answer_ids), :].float()
    targets = torch.tensor(answer_ids, device=DEVICE)
    nll = -F.log_softmax(logits, dim=-1).gather(1, targets.unsqueeze(1)).squeeze(1).mean().item()
    del cache, pb
    return nll


# === Surrogate definitions ===
SURROGATES = {
    'universal': "Analyze the following text for all key entities, factual claims, and logical relationships.",
    'extractor': "Examine this document specifically for data points, dates, numerical values, and specific named attributes.",
    'reasonant': "Evaluate the underlying arguments, sentiment, and intent of the following passage.",
    'analytic': "Provide a technical breakdown of the systems and processes described in this text.",
}

ADVERSARIAL_PREFIX = "The recipe calls for two cups of flour, one cup of sugar, and a pinch of salt."

STOP_WORDS = {
    'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
    'should', 'may', 'might', 'can', 'shall', 'to', 'of', 'in', 'for',
    'on', 'with', 'at', 'by', 'from', 'as', 'into', 'through', 'during',
    'before', 'after', 'above', 'below', 'between', 'and', 'but', 'or',
    'not', 'no', 'if', 'then', 'than', 'so', 'up', 'out', 'about',
    'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those',
    'it', 'its', 'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he',
    'him', 'his', 'she', 'her', 'they', 'them', 'their', 'how', 'when',
    'where', 'why', 'much', 'many', 'some', 'any', 'all', 'each',
    'does', 'also', 'just', 'more', 'most', 'very', 'too', 'only',
}

def extract_keywords(text):
    words = re.sub(r'[^\w\s]', '', text.lower()).split()
    return [w for w in words if w not in STOP_WORDS and len(w) > 2]

def make_doc_keywords(passage):
    content_words = extract_keywords(passage)
    if not content_words:
        return "information"
    counts = Counter(content_words)
    return " ".join(w for w, _ in counts.most_common(5))


print("\nScoring function defined (natural encoding + per-layer RoPE repositioning).")
print(f"\nSurrogate prompts:")
for name, prompt in SURROGATES.items():
    n_tok = len(tokenizer(prompt, add_special_tokens=False).input_ids)
    print(f"  {name:<12} ({n_tok:>2} tok): {prompt[:60]}...")
adv_tok = len(tokenizer(ADVERSARIAL_PREFIX, add_special_tokens=False).input_ids)
print(f"  {'adversarial':<12} ({adv_tok:>2} tok): {ADVERSARIAL_PREFIX[:60]}...")

In [None]:
# Cell 4: Load MS MARCO data and generate surrogates
from lib.data import count_words
from datasets import load_dataset

print("Loading MS MARCO v1.1 validation...")
ds = load_dataset("microsoft/ms_marco", "v1.1", split="validation")

all_candidates = []
for item in ds:
    if len(all_candidates) >= 3 * N_SAMPLES:
        break
    passages = item.get('passages', {})
    ptexts = passages.get('passage_text', [])
    is_sel = passages.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])
    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] not in ('[]', ''):
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    if not answer:
        continue
    for pt, sel in zip(ptexts, is_sel):
        wc = count_words(pt)
        if sel == 1 and 30 <= wc <= 300:
            all_candidates.append({
                'passage': pt, 'query': query, 'answer': answer,
                'word_count': wc,
            })
            break

print(f"Total candidates: {len(all_candidates)}")
np.random.seed(SEED)
indices = np.random.permutation(len(all_candidates))
samples = [all_candidates[i] for i in indices[:N_SAMPLES]]
del ds, all_candidates
gc.collect()

# Generate surrogates
for s in samples:
    s['surr_doc_kw'] = make_doc_keywords(s['passage'])

print(f"Loaded {len(samples)} samples")
print(f"Mean passage words: {np.mean([s['word_count'] for s in samples]):.0f}")
print(f"Mean answer words: {np.mean([count_words(s['answer']) for s in samples]):.0f}")
print(f"Mean query words: {np.mean([count_words(s['query']) for s in samples]):.0f}")
print(f"\nFirst sample:")
print(f"  Query:  {samples[0]['query'][:70]}...")
print(f"  Answer: {samples[0]['answer'][:70]}...")
print(f"  Passage ({samples[0]['word_count']}w): {samples[0]['passage'][:70]}...")
print(f"  Doc keywords: {samples[0]['surr_doc_kw']}")


In [None]:
# Cell 5: DEEP VALIDATION — test repositioning against native Gemma 3 RoPE
print("=" * 70)
print("DEEP VALIDATION: RoPE repositioning vs native Gemma 3 implementation")
print("=" * 70)

s = samples[0]
doc_ids = tokenizer(s['passage'], add_special_tokens=False,
                    truncation=True, max_length=1536).input_ids
D_test = len(doc_ids)

# ================================================================
# TEST 1: inv_freq comparison against model's internal buffers
# ================================================================
print("\n--- Test 1: inv_freq matches model's rotary_emb buffers ---")
rotary = model.model.rotary_emb
for ltype in ['sliding_attention', 'full_attention']:
    model_inv = getattr(rotary, f'{ltype}_inv_freq').float().to(DEVICE)
    layer_idx_for_type = [i for i, t in enumerate(text_cfg.layer_types)
                          if t == ltype][0]
    our_inv = LAYER_INV_FREQS[layer_idx_for_type]
    diff = (model_inv - our_inv).abs().max().item()
    print(f"  {ltype}: max diff = {diff:.2e} "
          f"(model shape={model_inv.shape}, ours={our_inv.shape})")
    assert diff < 1e-6, f"inv_freq mismatch for {ltype}: diff={diff}"
print("  PASSED")

# ================================================================
# TEST 2: attention_scaling is 1.0 for both layer types
# ================================================================
print("\n--- Test 2: attention_scaling values ---")
for ltype in ['sliding_attention', 'full_attention']:
    scaling = getattr(rotary, f'{ltype}_attention_scaling')
    print(f"  {ltype}: attention_scaling = {scaling}")
    assert scaling == 1.0, f"attention_scaling != 1.0 for {ltype}: {scaling}"
print("  PASSED")

# ================================================================
# TEST 3: Layer-0 repositioning matches native model computation
# ================================================================
# At layer 0, keys = RoPE(k_norm(k_proj(embed(token))), position).
# The pre-RoPE content depends ONLY on token identity, not position.
# So: reposition(key_at_pos_p, delta) should EXACTLY match key_at_pos_{p+delta}
# (up to bf16 quantization from storing the original key).
print("\n--- Test 3: Layer-0 key repositioning vs native computation ---")
test_tokens = tokenizer("The quick brown fox jumps over the lazy dog",
                        add_special_tokens=False).input_ids
delta_test = 23

# Encode at natural positions [0, 1, ..., N]
ids_test = [BOS_ID] + test_tokens
with torch.no_grad():
    out_nat = model(input_ids=torch.tensor([ids_test], device=DEVICE),
                    use_cache=True)
cache_nat = out_nat.past_key_values
del out_nat

# Encode at shifted positions [delta, delta+1, ..., delta+N]
pos_shifted = torch.arange(delta_test, delta_test + len(ids_test), device=DEVICE)
with torch.no_grad():
    out_shifted = model(
        input_ids=torch.tensor([ids_test], device=DEVICE),
        position_ids=pos_shifted.unsqueeze(0),
        cache_position=pos_shifted,
        use_cache=True,
    )
cache_shifted = out_shifted.past_key_values
del out_shifted

# Test layer 0 (sliding) and first global layer
for test_layer in [0, global_idx]:
    ltype = text_cfg.layer_types[test_layer]
    k_nat = cache_nat.layers[test_layer].keys  # at positions 0..N
    k_shifted = cache_shifted.layers[test_layer].keys  # at positions delta..delta+N

    # Reposition k_nat by delta_test using our function
    inv_freq = LAYER_INV_FREQS[test_layer]
    angles = delta_test * inv_freq
    emb = torch.cat([angles, angles])
    cos_d = emb.cos().view(1, 1, 1, -1)
    sin_d = emb.sin().view(1, 1, 1, -1)
    k_nat_f = k_nat.float()
    k_repositioned = (k_nat_f * cos_d + rotate_half(k_nat_f) * sin_d)

    # Compare repositioned vs natively shifted
    k_shifted_f = k_shifted.float()
    abs_diff = (k_repositioned - k_shifted_f).abs().max().item()
    rel_diff = abs_diff / k_shifted_f.abs().max().item()

    # Also compare values (no RoPE on values, so should be identical)
    v_nat = cache_nat.layers[test_layer].values
    v_shifted = cache_shifted.layers[test_layer].values
    val_diff = (v_nat.float() - v_shifted.float()).abs().max().item()

    print(f"  Layer {test_layer:>2} ({ltype[:7]}): "
          f"key rel diff = {rel_diff:.2e}, "
          f"val abs diff = {val_diff:.2e}")
    assert rel_diff < 1e-4, (
        f"Layer-0 repositioning mismatch for {ltype}: rel_diff={rel_diff}")

# Show higher layers diverge (expected — attention-dependent hidden states)
print("  Higher layers (expected to diverge — different attention contexts):")
for test_layer in [1, 2, 3]:
    k_nat = cache_nat.layers[test_layer].keys
    k_shifted = cache_shifted.layers[test_layer].keys
    inv_freq = LAYER_INV_FREQS[test_layer]
    angles = delta_test * inv_freq
    emb = torch.cat([angles, angles])
    cos_d = emb.cos().view(1, 1, 1, -1)
    sin_d = emb.sin().view(1, 1, 1, -1)
    k_repositioned = (k_nat.float() * cos_d + rotate_half(k_nat.float()) * sin_d)
    rel_diff = ((k_repositioned - k_shifted.float()).abs().max().item()
                / k_shifted.float().abs().max().item())
    print(f"    Layer {test_layer}: key rel diff = {rel_diff:.2e}")

del cache_nat, cache_shifted
print("  PASSED — layer-0 repositioning matches native computation exactly")

# ================================================================
# TEST 4: Full pipeline test — conditioned + slice + reposition vs bare
# ================================================================
# Encode [BOS + prefix + \n + doc], slice prefix, reposition doc keys.
# At layer 0, the repositioned doc keys should MATCH bare [BOS + doc] keys
# because pre-RoPE keys depend only on token identity.
print("\n--- Test 4: Full pipeline (cond+slice+reposition) vs bare at layer 0 ---")
prefix_ids = tokenizer(s['query'], add_special_tokens=False,
                       truncation=True, max_length=512).input_ids
P = len(prefix_ids)
NL = len(NEWLINE_IDS)
D = len(doc_ids)
print(f"  Prefix tokens: {P}, Newline tokens: {NL}, Doc tokens: {D}")
print(f"  Reposition delta: {-(P + NL)}")

# Bare: [BOS + doc]
with torch.no_grad():
    out_bare = model(
        input_ids=torch.tensor([[BOS_ID] + doc_ids], device=DEVICE),
        use_cache=True)
cache_bare = out_bare.past_key_values
del out_bare

# Conditioned: [BOS + prefix + \n + doc]
cond_ids = [BOS_ID] + prefix_ids + NEWLINE_IDS + doc_ids
with torch.no_grad():
    out_cond = model(
        input_ids=torch.tensor([cond_ids], device=DEVICE),
        use_cache=True)
cache_cond = out_cond.past_key_values
del out_cond

# Slice prefix from conditioned cache
slice_start = 1 + P + NL
cache_cond_sliced = slice_kv_cache(cache_cond, slice_start)

# Reposition doc keys
cache_cond_repo = reposition_kv_cache(cache_cond_sliced, -(P + NL),
                                      LAYER_INV_FREQS)

# Compare at layer 0 (keys and values)
# Bare keys: skip BOS (index 0), take doc entries (indices 1..D)
bare_k0 = cache_bare.layers[0].keys[:, :, 1:, :].float()
repo_k0 = cache_cond_repo.layers[0].keys.float()
bare_v0 = cache_bare.layers[0].values[:, :, 1:, :].float()
repo_v0 = cache_cond_repo.layers[0].values.float()

key_rel_diff = (bare_k0 - repo_k0).abs().max().item() / bare_k0.abs().max().item()
val_abs_diff = (bare_v0 - repo_v0).abs().max().item()
print(f"  Layer 0 keys:   rel diff = {key_rel_diff:.2e}")
print(f"  Layer 0 values: abs diff = {val_abs_diff:.2e}")
assert key_rel_diff < 0.02, f"Layer 0 key mismatch after full pipeline: {key_rel_diff}"
assert val_abs_diff < 1e-6, f"Layer 0 value mismatch: {val_abs_diff}"

# Do the same for the first global layer
bare_kg = cache_bare.layers[global_idx].keys[:, :, 1:, :].float()
repo_kg = cache_cond_repo.layers[global_idx].keys.float()
bare_vg = cache_bare.layers[global_idx].values[:, :, 1:, :].float()
repo_vg = cache_cond_repo.layers[global_idx].values.float()
key_rel_diff_g = (bare_kg - repo_kg).abs().max().item() / bare_kg.abs().max().item()
val_abs_diff_g = (bare_vg - repo_vg).abs().max().item()
print(f"  Layer {global_idx} keys:   rel diff = {key_rel_diff_g:.2e}")
print(f"  Layer {global_idx} values: abs diff = {val_abs_diff_g:.2e}")

# Show per-layer comparison (layer 0 should match, higher layers diverge)
print(f"\n  Per-layer key comparison (first 15 layers):")
print(f"  {'Layer':>5} {'Type':>4} {'Key RelDiff':>12} {'Val AbsDiff':>12} {'Match':>6}")
for L in range(min(15, len(cache_bare.layers))):
    bare_k = cache_bare.layers[L].keys[:, :, 1:, :].float()
    repo_k = cache_cond_repo.layers[L].keys.float()
    bare_v = cache_bare.layers[L].values[:, :, 1:, :].float()
    repo_v = cache_cond_repo.layers[L].values.float()
    krd = (bare_k - repo_k).abs().max().item() / bare_k.abs().max().item()
    vad = (bare_v - repo_v).abs().max().item() / bare_v.abs().max().item()
    ltype = 'G' if text_cfg.layer_types[L] == 'full_attention' else 'L'
    match = 'Y' if krd < 0.02 and vad < 0.02 else 'N'
    print(f"  {L:>5} {ltype:>4} {krd:>12.4e} {vad:>12.4e} {match:>6}")

del cache_bare, cache_cond, cache_cond_sliced, cache_cond_repo

print("  PASSED — layer-0 keys and values match between bare and conditioned")

# ================================================================
# TEST 5: Phase B NLL comparison (the key end-to-end test)
# ================================================================
# If repositioning is correct, the ONLY difference between bare NLL
# and conditioned NLL should come from the semantic content change
# in doc representations (values and key content at higher layers).
#
# To isolate repositioning artifacts, test with a "neutral" prefix that
# the model should essentially ignore (just padding tokens).
# If repositioning introduces artifacts, even a neutral prefix will
# change NLL significantly.
print("\n--- Test 5: End-to-end NLL comparison ---")
nll_bare = score(s['passage'], s['query'], s['answer'])
nll_oracle = score(s['passage'], s['query'], s['answer'], prefix_text=s['query'])
nll_adv = score(s['passage'], s['query'], s['answer'],
                prefix_text=ADVERSARIAL_PREFIX)
print(f"  Bare:        {nll_bare:.6f}")
print(f"  Oracle:      {nll_oracle:.6f} (delta: {nll_bare - nll_oracle:+.6f})")
print(f"  Adversarial: {nll_adv:.6f} (delta: {nll_bare - nll_adv:+.6f})")

# ================================================================
# TEST 6: Verify the model's cos/sin at specific positions
# ================================================================
# Generate cos/sin through the model's rotary embedding, and compare
# what we'd compute for the rotation delta.
print("\n--- Test 6: cos/sin verification against model's rotary_emb ---")
dummy_hidden = torch.zeros(1, 1, head_dim_check, device=DEVICE,
                           dtype=torch.bfloat16)
pos_ids = torch.tensor([[5, 22]], device=DEVICE)  # positions 5 and 22
for ltype in ['sliding_attention', 'full_attention']:
    cos_model, sin_model = rotary(dummy_hidden, pos_ids, layer_type=ltype)
    # cos_model shape: [1, 2, head_dim]
    # Extract cos/sin at position 5 and position 22
    cos_5 = cos_model[0, 0, :].float()  # [head_dim]
    cos_22 = cos_model[0, 1, :].float()
    sin_5 = sin_model[0, 0, :].float()
    sin_22 = sin_model[0, 1, :].float()

    # Our rotation delta = 17 (from pos 5 to pos 22)
    inv_freq_type = LAYER_INV_FREQS[
        [i for i, t in enumerate(text_cfg.layer_types) if t == ltype][0]]
    angles_17 = 17.0 * inv_freq_type
    emb_17 = torch.cat([angles_17, angles_17])
    our_cos_17 = emb_17.cos()
    our_sin_17 = emb_17.sin()

    # Verify: cos(22*theta) = cos(5*theta)*cos(17*theta) - sin(5*theta)*sin(17*theta)
    cos_22_from_composition = cos_5 * our_cos_17 - sin_5 * our_sin_17
    sin_22_from_composition = sin_5 * our_cos_17 + cos_5 * our_sin_17

    cos_err = (cos_22 - cos_22_from_composition).abs().max().item()
    sin_err = (sin_22 - sin_22_from_composition).abs().max().item()
    print(f"  {ltype}: cos composition err = {cos_err:.2e}, "
          f"sin composition err = {sin_err:.2e}")
    assert cos_err < 1e-5, f"cos composition failed for {ltype}: {cos_err}"
    assert sin_err < 1e-5, f"sin composition failed for {ltype}: {sin_err}"
print("  PASSED — rotation composition identity verified against model")

# ================================================================
# TEST 7: Bare NLL consistency
# ================================================================
print("\n--- Test 7: Bare NLL consistency ---")
nll_1 = score(s['passage'], s['query'], s['answer'])
nll_2 = score(s['passage'], s['query'], s['answer'])
print(f"  Call 1: {nll_1:.6f}, Call 2: {nll_2:.6f}, diff: {abs(nll_1 - nll_2):.2e}")
assert abs(nll_1 - nll_2) < 1e-5, "Bare NLL not consistent"
print("  PASSED")

# ================================================================
# TEST 8: Multi-sample oracle check
# ================================================================
print("\n--- Test 8: 5-sample bare vs oracle ---")
oracle_wins = 0
for i in range(5):
    s_test = samples[i]
    nll_b = score(s_test['passage'], s_test['query'], s_test['answer'])
    nll_o = score(s_test['passage'], s_test['query'], s_test['answer'],
                  prefix_text=s_test['query'])
    delta = nll_b - nll_o
    win = delta > 0
    oracle_wins += win
    print(f"  Sample {i}: bare={nll_b:.6f}, oracle={nll_o:.6f}, "
          f"delta={delta:+.6f} {'(oracle wins)' if win else '(bare wins)'}")
print(f"  Oracle wins: {oracle_wins}/5")

gc.collect()
torch.cuda.empty_cache()
print("\n" + "=" * 70)
print("ALL DEEP VALIDATION TESTS PASSED")
print("=" * 70)