# Experiment 08: Diagnosing the Suffix Priming Signal

## Motivation

Experiment 07 showed that suffix priming produces content-independent effects: relevant, irrelevant,
and shuffled suffixes all perform similarly, with ~50% win rate against bare baseline. This mirrors
the prefix result (r=0.924).

Three hypotheses could explain this:

1. **The query makes the suffix redundant.** At scoring time, the model sees
   `[passage + suffix] + query + answer`. The real query already tells the model what to
   focus on — the suffix is redundant information competing for attention.

2. **The model ignores suffix tokens.** Query-time attention to suffix positions may be
   negligible — the model has learned to attend primarily to document content.

3. **MS MARCO is too easy.** Short passages with extractive answers may not require
   better document representations.

## Three Investigations

**Investigation A: Query-free scoring.** Remove the query from the scoring pipeline.
Cache is `[passage + suffix]`, then extend with just `"\n\nAnswer:"` and score the answer.
Now the suffix is the *only* intent signal. If suffixes matter, this is where they'll show it.

**Investigation B: Attention pattern analysis.** Extract attention weights during scoring
to see whether query/answer tokens actually attend to suffix positions, and whether
relevant vs irrelevant suffixes produce different attention patterns.

**Investigation C: Hard-sample analysis.** Filter MS MARCO for samples where the bare
model struggles (high NLL), where better representations should matter most.

## Experimental Notes

- Motivated by Exp 07's failure to find a semantic signal with suffix placement — content-independent effects persisted even when document KV entries were guaranteed identical.
- **Key discovery:** Suffixes STEAL attention from the query. Document tokens attend to suffix tokens instead of query tokens, reducing query attention share from ~20% to 9-10%. This explains why suffixes consistently hurt answer quality.
- **Content-independence r=0.797**, still very high — confirming that the effect is structural, not semantic.
- **Verdict:** Causal attention is the fundamental blocker. In a causal (autoregressive) model, passage tokens cannot "see" suffix tokens placed after them. Suffix tokens can only influence later tokens' attention patterns, where they act as distractors.

## Setup

In [None]:
import sys
import os
import copy
import json
import time
import datetime
import random
from typing import Dict, List, Any, Optional, Tuple

import torch
import numpy as np
from tqdm.auto import tqdm
from scipy import stats
import matplotlib.pyplot as plt

sys.path.insert(0, '.')

from lib import (
    ExperimentConfig,
    build_kv_cache,
    build_suffix_kv_cache,
    score_answer_with_cache,
    build_truncated_kv_cache_corrected,
    generate_all_5_surrogates,
    compute_similarity,
    load_evaluation_samples,
    load_ms_marco,
    TOP_5_SURROGATE_TEMPLATES,
    STATIC_SURROGATE_QUERIES,
)
from lib.analysis import cohens_d

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
config = ExperimentConfig(
    num_samples=800,
    min_passage_words=50,
    max_passage_words=300,
    surrogate_max_tokens=45,
    surrogate_temperature=0.3,
    seed=42,
)
print(f"Model: {config.model_name}")
print(f"Device: {config.device}")

## Model Loading

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {config.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="eager",  # Required for output_attentions=True in Investigation B
)
model.eval()
print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}")

embed_model = SentenceTransformer(config.embedding_model_name)
print(f"Embedding model loaded: {config.embedding_model_name}")

## Dataset

In [None]:
dataset = load_ms_marco(config)
raw_samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Raw samples after basic filtering: {len(raw_samples)}")

filtered_samples = []
excluded_ratio = 0
excluded_short_answer = 0

for s in raw_samples:
    if len(s['answer']) / max(len(s['passage']), 1) > 0.5:
        excluded_ratio += 1
        continue
    answer_ids = tokenizer(s['answer'], return_tensors='pt', add_special_tokens=False)['input_ids']
    if answer_ids.shape[1] < 2:
        excluded_short_answer += 1
        continue
    filtered_samples.append(s)

samples = filtered_samples
print(f"Remaining samples: {len(samples)}")

## Helper Functions

In [None]:
def _get_kv(cache, layer_idx):
    """Get (keys, values) from cache layer — compatible across DynamicCache versions."""
    if hasattr(cache, 'key_cache'):
        return cache.key_cache[layer_idx], cache.value_cache[layer_idx]
    return cache.layers[layer_idx].keys, cache.layers[layer_idx].values


def _num_layers(cache):
    if hasattr(cache, 'key_cache'):
        return len(cache.key_cache)
    return len(cache.layers)


def deep_copy_cache(cache):
    return copy.deepcopy(cache)


def get_irrelevant_query(samples, current_idx, rng):
    other_idx = current_idx
    while other_idx == current_idx:
        other_idx = rng.randint(0, len(samples) - 1)
    return samples[other_idx]['query']


def shuffle_text(text, rng):
    words = text.split()
    rng.shuffle(words)
    return ' '.join(words)


def score_answer_queryless(
    past_key_values, context_len, answer_prompt, answer,
    model, tokenizer, config
):
    """Score answer WITHOUT providing the query.

    Instead of: [cache] + query + answer
    We do:      [cache] + answer_prompt + answer

    The answer_prompt is a short transition like '\n\nAnswer:' — no query.
    The suffix (if present in cache) is the only intent signal.
    """
    # Tokenize answer prompt (transition text)
    prompt_enc = tokenizer(
        answer_prompt, return_tensors="pt", add_special_tokens=False,
        padding=False, truncation=False
    )
    prompt_ids = prompt_enc['input_ids'].to(config.device)
    prompt_len = prompt_ids.shape[1]

    # Tokenize answer
    answer_enc = tokenizer(
        answer, return_tensors="pt", add_special_tokens=False,
        padding=False, truncation=False
    )
    answer_ids = answer_enc['input_ids'].to(config.device)
    answer_len = answer_ids.shape[1]

    # Extend cache with prompt
    attn_mask = torch.ones((1, context_len + prompt_len), device=config.device)

    with torch.no_grad():
        prompt_out = model(
            input_ids=prompt_ids,
            attention_mask=attn_mask,
            past_key_values=past_key_values,
            use_cache=True,
            return_dict=True,
        )
        extended_cache = prompt_out.past_key_values

    # Score answer
    attn_mask_final = torch.ones(
        (1, context_len + prompt_len + answer_len), device=config.device
    )

    with torch.no_grad():
        answer_out = model(
            input_ids=answer_ids,
            attention_mask=attn_mask_final,
            past_key_values=extended_cache,
            use_cache=False,
            return_dict=True,
        )

    logits = answer_out.logits
    shift_logits = logits[:, :-1, :].contiguous().view(-1, logits.size(-1))
    shift_labels = answer_ids[:, 1:].contiguous().view(-1)

    loss = torch.nn.CrossEntropyLoss(reduction='sum')
    nll = loss(shift_logits, shift_labels).item()
    num_scored = answer_len - 1
    return nll / num_scored if num_scored > 0 else 0.0


def extract_attention_to_suffix(
    past_key_values, context_len, query_prompt, answer,
    model, tokenizer, config,
    passage_len, suffix_start,
):
    """Run the scoring forward pass with output_attentions=True.

    Returns a dict with attention statistics:
    - attn_to_passage: mean attention weight from answer tokens to passage region
    - attn_to_suffix: mean attention weight from answer tokens to suffix region
    - attn_to_query: mean attention weight from answer tokens to query region
    - per_layer_suffix_attn: list of per-layer mean attention to suffix
    """
    cache_copy = deep_copy_cache(past_key_values)

    # Tokenize query
    query_enc = tokenizer(
        query_prompt, return_tensors="pt", add_special_tokens=False
    )
    query_ids = query_enc['input_ids'].to(config.device)
    query_len = query_ids.shape[1]

    # Tokenize answer
    answer_enc = tokenizer(
        answer, return_tensors="pt", add_special_tokens=False
    )
    answer_ids = answer_enc['input_ids'].to(config.device)
    answer_len = answer_ids.shape[1]

    # Extend cache with query (need use_cache=True to build extended cache)
    attn_mask = torch.ones((1, context_len + query_len), device=config.device)
    with torch.no_grad():
        query_out = model(
            input_ids=query_ids,
            attention_mask=attn_mask,
            past_key_values=cache_copy,
            use_cache=True,
            return_dict=True,
        )
        extended_cache = query_out.past_key_values

    # Score answer WITH attention output
    total_len = context_len + query_len + answer_len
    attn_mask_final = torch.ones((1, total_len), device=config.device)

    with torch.no_grad():
        answer_out = model(
            input_ids=answer_ids,
            attention_mask=attn_mask_final,
            past_key_values=extended_cache,
            use_cache=False,
            return_dict=True,
            output_attentions=True,
        )

    # answer_out.attentions is a tuple of (n_layers,) tensors
    # Each tensor shape: (batch, n_heads, answer_len, total_len)
    # The total_len includes all cached positions + answer positions

    # Define regions in the full sequence:
    # [0, passage_len) = passage/BOS tokens
    # [suffix_start, context_len) = suffix tokens (if suffix_start < context_len)
    # [context_len, context_len + query_len) = query tokens
    # [context_len + query_len, total_len) = answer tokens (but these only attend to prior)

    # For answer tokens, attention is over [0, context_len + query_len + answer_pos]
    # We look at the mean attention across all answer token positions

    per_layer_suffix_attn = []
    per_layer_passage_attn = []
    per_layer_query_attn = []

    suffix_len = context_len - suffix_start if suffix_start < context_len else 0
    query_start = context_len

    for layer_attn in answer_out.attentions:
        # shape: (1, n_heads, answer_len, total_ctx_for_this_token)
        # For Mistral with GQA, n_heads = 32 (query heads)
        attn = layer_attn[0]  # (n_heads, answer_len, total_len)

        # Mean over heads and answer positions
        mean_attn = attn.mean(dim=(0, 1))  # (total_len,)

        # Passage region: [0, suffix_start) — the pure passage tokens
        passage_attn = mean_attn[:suffix_start].sum().item()
        per_layer_passage_attn.append(passage_attn)

        # Suffix region: [suffix_start, context_len)
        if suffix_len > 0:
            sfx_attn = mean_attn[suffix_start:context_len].sum().item()
        else:
            sfx_attn = 0.0
        per_layer_suffix_attn.append(sfx_attn)

        # Query region: [context_len, context_len + query_len)
        q_attn = mean_attn[query_start:query_start + query_len].sum().item()
        per_layer_query_attn.append(q_attn)

    return {
        'attn_to_passage': np.mean(per_layer_passage_attn),
        'attn_to_suffix': np.mean(per_layer_suffix_attn),
        'attn_to_query': np.mean(per_layer_query_attn),
        'per_layer_suffix_attn': per_layer_suffix_attn,
        'per_layer_passage_attn': per_layer_passage_attn,
        'per_layer_query_attn': per_layer_query_attn,
        'suffix_len': suffix_len,
        'query_len': query_len,
        'answer_len': answer_len,
    }


print("Helper functions defined.")

## Investigation A: Query-Free Scoring

**Key idea**: Remove the query from scoring. If the suffix is the only intent signal,
relevant suffixes should dramatically outperform irrelevant ones.

### Conditions (6):
1. `bare` — passage only, no query, no suffix
2. `bare_with_query` — passage only, with query (standard scoring, for calibration)
3. `sfx_relevant_no_query` — passage + relevant suffix, NO query
4. `sfx_irrel_no_query` — passage + irrelevant suffix, NO query
5. `sfx_perfect_no_query` — passage + actual query as suffix, NO query
6. `sfx_shuffled_no_query` — passage + shuffled suffix, NO query

If the suffix carries semantic signal, conditions 3 and 5 should beat 4 and 6.
The gap should be MUCH larger than in exp 07 (where query made suffix redundant).

In [None]:
# Pipeline verification on one sample
test_sample = samples[0]
passage = test_sample['passage']
query = test_sample['query']
answer = test_sample['answer']
query_prompt = config.query_template.format(query=query)
answer_prompt = "\n\nAnswer:"

print(f"Passage: {passage[:80]}...")
print(f"Query:   {query}")
print(f"Answer:  {answer[:60]}")
print()

# Generate surrogates
test_surrogates = generate_all_5_surrogates(passage, model, tokenizer, config)
test_sims = {k: compute_similarity(v, query, embed_model) for k, v in test_surrogates.items()}
routed_key = max(test_sims, key=test_sims.get)
routed_surr = test_surrogates[routed_key]
print(f"Routed surrogate ({routed_key}): {routed_surr}")
print()

# --- With query (standard) ---
bare_len, bare_cache = build_kv_cache(passage, model, tokenizer, config)
nll_bare_q = score_answer_with_cache(
    deep_copy_cache(bare_cache), bare_len, query_prompt, answer, model, tokenizer, config
)
print(f"1. bare (with query):            NLL = {nll_bare_q:.4f}")

# --- Without query ---
nll_bare_noq = score_answer_queryless(
    deep_copy_cache(bare_cache), bare_len, answer_prompt, answer, model, tokenizer, config
)
print(f"2. bare (NO query):              NLL = {nll_bare_noq:.4f}")

# --- Suffix relevant, no query ---
sfx_len, sfx_cache = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config)
nll_sfx_rel_noq = score_answer_queryless(
    deep_copy_cache(sfx_cache), sfx_len, answer_prompt, answer, model, tokenizer, config
)
print(f"3. sfx_relevant (NO query):      NLL = {nll_sfx_rel_noq:.4f}")

# --- Suffix irrelevant, no query ---
rng = random.Random(config.seed)
irrel_q = get_irrelevant_query(samples, 0, rng)
sfx_len_i, sfx_cache_i = build_suffix_kv_cache(passage, irrel_q, model, tokenizer, config)
nll_sfx_irrel_noq = score_answer_queryless(
    deep_copy_cache(sfx_cache_i), sfx_len_i, answer_prompt, answer, model, tokenizer, config
)
print(f"4. sfx_irrelevant (NO query):    NLL = {nll_sfx_irrel_noq:.4f}")

# --- Suffix perfect (actual query), no query ---
sfx_len_p, sfx_cache_p = build_suffix_kv_cache(passage, query, model, tokenizer, config)
nll_sfx_perf_noq = score_answer_queryless(
    deep_copy_cache(sfx_cache_p), sfx_len_p, answer_prompt, answer, model, tokenizer, config
)
print(f"5. sfx_perfect (NO query):       NLL = {nll_sfx_perf_noq:.4f}")

# --- Suffix shuffled, no query ---
shuffled_surr = shuffle_text(routed_surr, rng)
sfx_len_s, sfx_cache_s = build_suffix_kv_cache(passage, shuffled_surr, model, tokenizer, config)
nll_sfx_shuf_noq = score_answer_queryless(
    deep_copy_cache(sfx_cache_s), sfx_len_s, answer_prompt, answer, model, tokenizer, config
)
print(f"6. sfx_shuffled (NO query):      NLL = {nll_sfx_shuf_noq:.4f}")

print(f"\n--- Deltas (lower = better) ---")
print(f"Query-free: relevant - bare    = {nll_sfx_rel_noq - nll_bare_noq:+.4f}")
print(f"Query-free: perfect - bare     = {nll_sfx_perf_noq - nll_bare_noq:+.4f}")
print(f"Query-free: irrelevant - bare  = {nll_sfx_irrel_noq - nll_bare_noq:+.4f}")
print(f"Query-free: shuffled - bare    = {nll_sfx_shuf_noq - nll_bare_noq:+.4f}")
print(f"Query-free: relevant - irrel   = {nll_sfx_rel_noq - nll_sfx_irrel_noq:+.4f}")
print(f"With-query (baseline ref):       {nll_bare_q:.4f}")

### Investigation A: Main Loop

In [None]:
def evaluate_queryless(sample, idx, all_samples, model, tokenizer, embed_model, config):
    """Evaluate query-free scoring for one sample."""
    passage = sample['passage']
    query = sample['query']
    answer = sample['answer']
    query_prompt = config.query_template.format(query=query)
    answer_prompt = "\n\nAnswer:"
    rng = random.Random(config.seed + idx)

    # Generate surrogates
    surrogates = generate_all_5_surrogates(passage, model, tokenizer, config)
    sims = {k: compute_similarity(v, query, embed_model) for k, v in surrogates.items()}
    routed_key = max(sims, key=sims.get)
    routed_surr = surrogates[routed_key]

    # 1. Bare (with query — calibration)
    bare_len, bare_cache = build_kv_cache(passage, model, tokenizer, config)
    nll_bare_q = score_answer_with_cache(
        deep_copy_cache(bare_cache), bare_len, query_prompt, answer, model, tokenizer, config
    )

    # 2. Bare (no query)
    nll_bare_noq = score_answer_queryless(
        deep_copy_cache(bare_cache), bare_len, answer_prompt, answer, model, tokenizer, config
    )

    # 3. Suffix relevant (no query)
    sfx_len, sfx_cache = build_suffix_kv_cache(passage, routed_surr, model, tokenizer, config)
    nll_sfx_rel_noq = score_answer_queryless(
        deep_copy_cache(sfx_cache), sfx_len, answer_prompt, answer, model, tokenizer, config
    )

    # 4. Suffix irrelevant (no query)
    irrel_q = get_irrelevant_query(all_samples, idx, rng)
    sfx_len_i, sfx_cache_i = build_suffix_kv_cache(passage, irrel_q, model, tokenizer, config)
    nll_sfx_irrel_noq = score_answer_queryless(
        deep_copy_cache(sfx_cache_i), sfx_len_i, answer_prompt, answer, model, tokenizer, config
    )

    # 5. Suffix perfect (no query)
    sfx_len_p, sfx_cache_p = build_suffix_kv_cache(passage, query, model, tokenizer, config)
    nll_sfx_perf_noq = score_answer_queryless(
        deep_copy_cache(sfx_cache_p), sfx_len_p, answer_prompt, answer, model, tokenizer, config
    )

    # 6. Suffix shuffled (no query)
    shuffled = shuffle_text(routed_surr, rng)
    sfx_len_s, sfx_cache_s = build_suffix_kv_cache(passage, shuffled, model, tokenizer, config)
    nll_sfx_shuf_noq = score_answer_queryless(
        deep_copy_cache(sfx_cache_s), sfx_len_s, answer_prompt, answer, model, tokenizer, config
    )

    # Also score WITH query for the suffix conditions (exp 07 replication)
    nll_sfx_rel_wq = score_answer_with_cache(
        deep_copy_cache(sfx_cache), sfx_len, query_prompt, answer, model, tokenizer, config
    )
    nll_sfx_irrel_wq = score_answer_with_cache(
        deep_copy_cache(sfx_cache_i), sfx_len_i, query_prompt, answer, model, tokenizer, config
    )

    return {
        'idx': idx,
        'query': query,
        'passage_len': len(passage),
        'answer_len': len(answer),
        'routed_key': routed_key,
        'routed_sim': sims[routed_key],

        # With query
        'bare_wq': nll_bare_q,
        'sfx_rel_wq': nll_sfx_rel_wq,
        'sfx_irrel_wq': nll_sfx_irrel_wq,

        # Without query
        'bare_noq': nll_bare_noq,
        'sfx_rel_noq': nll_sfx_rel_noq,
        'sfx_irrel_noq': nll_sfx_irrel_noq,
        'sfx_perfect_noq': nll_sfx_perf_noq,
        'sfx_shuffled_noq': nll_sfx_shuf_noq,
    }


print("evaluate_queryless() defined — 8 conditions.")

In [None]:
results_a = []
errors_a = 0
start_time_a = time.time()
checkpoint_path_a = 'results/exp08/08a_checkpoint.json'

# Resume
start_from_a = 0
if os.path.exists(checkpoint_path_a):
    with open(checkpoint_path_a) as f:
        ckpt = json.load(f)
    results_a = ckpt['results']
    start_from_a = ckpt['n_done']
    print(f"Resuming Investigation A from sample {start_from_a}")
else:
    print("Starting Investigation A from scratch.")

N_SAMPLES_A = min(200, len(samples))  # 200 samples for Investigation A

print("=" * 80)
print("INVESTIGATION A: QUERY-FREE SCORING")
print(f"Samples: {N_SAMPLES_A}")
print("=" * 80)

for idx in tqdm(range(start_from_a, N_SAMPLES_A), desc="Inv A"):
    sample = samples[idx]
    try:
        result = evaluate_queryless(
            sample, idx, samples, model, tokenizer, embed_model, config
        )
        if result is not None:
            results_a.append(result)
    except Exception as e:
        errors_a += 1
        if errors_a <= 5:
            print(f"\n  Error on sample {idx}: {type(e).__name__}: {e}")
        continue

    if len(results_a) > 0 and len(results_a) % 10 == 0:
        # Checkpoint
        with open(checkpoint_path_a, 'w') as f:
            json.dump({'n_done': len(results_a), 'results': results_a}, f, default=str)

    if len(results_a) > 0 and len(results_a) % 50 == 0:
        recent = results_a[-50:]
        bare_noq_m = np.mean([r['bare_noq'] for r in recent])
        rel_noq_m = np.mean([r['sfx_rel_noq'] for r in recent])
        irrel_noq_m = np.mean([r['sfx_irrel_noq'] for r in recent])
        perf_noq_m = np.mean([r['sfx_perfect_noq'] for r in recent])
        wr_rel = np.mean([r['bare_noq'] > r['sfx_rel_noq'] for r in recent]) * 100
        wr_perf = np.mean([r['bare_noq'] > r['sfx_perfect_noq'] for r in recent]) * 100
        print(
            f"\n  [{len(results_a)} done] bare_noq={bare_noq_m:.3f} "
            f"rel_noq={rel_noq_m:.3f}({wr_rel:.0f}% win) "
            f"perf_noq={perf_noq_m:.3f}({wr_perf:.0f}% win) "
            f"irrel_noq={irrel_noq_m:.3f}"
        )

elapsed_a = time.time() - start_time_a
print(f"\nDone. {len(results_a)} evaluated, {errors_a} errors, {elapsed_a/60:.1f}m")

### Investigation A: Results

In [None]:
n_a = len(results_a)
print("=" * 120)
print(f"INVESTIGATION A: QUERY-FREE SCORING RESULTS (N = {n_a})")
print("=" * 120)

conditions_a = [
    ('bare (with query)',         'bare_wq'),
    ('bare (NO query)',           'bare_noq'),
    ('sfx_relevant (NO query)',   'sfx_rel_noq'),
    ('sfx_perfect (NO query)',    'sfx_perfect_noq'),
    ('sfx_irrelevant (NO query)', 'sfx_irrel_noq'),
    ('sfx_shuffled (NO query)',   'sfx_shuffled_noq'),
    ('sfx_relevant (with query)', 'sfx_rel_wq'),
    ('sfx_irrel (with query)',    'sfx_irrel_wq'),
]

bare_noq = np.array([r['bare_noq'] for r in results_a])

print(f"\n{'Condition':<32} {'Mean NLL':>10} {'Std':>8} {'Delta vs bare_noq':>18} {'Win%':>8} {'t':>8} {'p':>12} {'d':>8}")
print("-" * 120)

for label, key in conditions_a:
    arr = np.array([r[key] for r in results_a])
    delta = bare_noq - arr
    mn = np.mean(arr)
    sd = np.std(arr)
    if key == 'bare_noq':
        print(f"{label:<32} {mn:>10.4f} {sd:>8.4f} {'BASELINE':>18} {'--':>8} {'--':>8} {'--':>12} {'--':>8}")
    else:
        t, p = stats.ttest_rel(bare_noq, arr)
        d = cohens_d(delta)
        wr = np.mean(delta > 0) * 100
        print(f"{label:<32} {mn:>10.4f} {sd:>8.4f} {np.mean(delta):>+18.4f} {wr:>7.1f}% {t:>8.3f} {p:>12.6f} {d:>8.4f}")

# Key pairwise comparisons
print(f"\n{'=' * 120}")
print("KEY PAIRWISE: Does removing the query expose semantic signal?")
print("=" * 120)

pairs_a = [
    ("relevant vs irrelevant (NO query)",  'sfx_rel_noq',     'sfx_irrel_noq'),
    ("perfect vs irrelevant (NO query)",   'sfx_perfect_noq', 'sfx_irrel_noq'),
    ("relevant vs shuffled (NO query)",    'sfx_rel_noq',     'sfx_shuffled_noq'),
    ("relevant vs irrelevant (WITH query)",'sfx_rel_wq',      'sfx_irrel_wq'),
]

for label, ka, kb in pairs_a:
    a = np.array([r[ka] for r in results_a])
    b = np.array([r[kb] for r in results_a])
    diff = b - a  # positive = a better (lower NLL)
    t, p = stats.ttest_rel(a, b)
    d = cohens_d(diff)
    print(f"  {label}")
    print(f"    mean A={np.mean(a):.4f}, mean B={np.mean(b):.4f}, delta={np.mean(diff):+.4f}, t={t:.3f}, p={p:.6f}, d={d:.4f}")

# Correlation: query-free deltas
print(f"\n{'=' * 120}")
print("CORRELATION: query-free deltas (relevant vs shuffled)")
print("=" * 120)
rel_deltas = np.array([r['bare_noq'] - r['sfx_rel_noq'] for r in results_a])
shuf_deltas = np.array([r['bare_noq'] - r['sfx_shuffled_noq'] for r in results_a])
irrel_deltas = np.array([r['bare_noq'] - r['sfx_irrel_noq'] for r in results_a])

r_shuf, p_shuf = stats.pearsonr(rel_deltas, shuf_deltas)
r_irrel, p_irrel = stats.pearsonr(rel_deltas, irrel_deltas)
print(f"  relevant vs shuffled deltas:   r = {r_shuf:.4f} (p = {p_shuf:.6f})")
print(f"  relevant vs irrelevant deltas: r = {r_irrel:.4f} (p = {p_irrel:.6f})")
print(f"  (Exp 07 with-query prefix ref: r ~ 0.924)")

## Investigation B: Attention Pattern Analysis

Extract attention weights to see:
1. How much do answer tokens attend to suffix positions vs passage vs query?
2. Does this differ between relevant and irrelevant suffixes?
3. Which layers attend most to the suffix?

Run on a smaller subset (30 samples) since `output_attentions=True` is memory-intensive.

In [None]:
N_SAMPLES_B = 30
results_b = []
errors_b = 0

# Force eager attention for output_attentions=True (SDPA returns None for attention weights)
model.config._attn_implementation = "eager"
model.config._attn_implementation_internal = "eager"
# Also patch each layer's attention module to use eager forward
for layer in model.model.layers:
    layer.self_attn._attn_implementation = "eager"

print("=" * 80)
print(f"INVESTIGATION B: ATTENTION ANALYSIS (N = {N_SAMPLES_B})")
print("=" * 80)

for idx in tqdm(range(N_SAMPLES_B), desc="Inv B"):
    sample = samples[idx]
    try:
        passage = sample['passage']
        query = sample['query']
        answer = sample['answer']
        query_prompt = config.query_template.format(query=query)
        rng = random.Random(config.seed + idx)

        # Generate surrogate
        surrogates = generate_all_5_surrogates(passage, model, tokenizer, config)
        sims = {k: compute_similarity(v, query, embed_model) for k, v in surrogates.items()}
        routed_key = max(sims, key=sims.get)
        routed_surr = surrogates[routed_key]

        # Get bare passage length (for computing suffix start position)
        bare_len, bare_cache = build_kv_cache(passage, model, tokenizer, config)

        # Build suffix caches
        sfx_len_rel, sfx_cache_rel = build_suffix_kv_cache(
            passage, routed_surr, model, tokenizer, config
        )
        irrel_q = get_irrelevant_query(samples, idx, rng)
        sfx_len_irrel, sfx_cache_irrel = build_suffix_kv_cache(
            passage, irrel_q, model, tokenizer, config
        )

        # Extract attention for bare (no suffix, suffix_start = bare_len)
        attn_bare = extract_attention_to_suffix(
            deep_copy_cache(bare_cache), bare_len, query_prompt, answer,
            model, tokenizer, config,
            passage_len=bare_len, suffix_start=bare_len,
        )

        # Extract attention for relevant suffix
        attn_rel = extract_attention_to_suffix(
            deep_copy_cache(sfx_cache_rel), sfx_len_rel, query_prompt, answer,
            model, tokenizer, config,
            passage_len=bare_len, suffix_start=bare_len,
        )

        # Extract attention for irrelevant suffix
        attn_irrel = extract_attention_to_suffix(
            deep_copy_cache(sfx_cache_irrel), sfx_len_irrel, query_prompt, answer,
            model, tokenizer, config,
            passage_len=bare_len, suffix_start=bare_len,
        )

        results_b.append({
            'idx': idx,
            'bare_len': bare_len,
            'sfx_rel_len': sfx_len_rel,
            'sfx_irrel_len': sfx_len_irrel,
            'attn_bare': attn_bare,
            'attn_rel': attn_rel,
            'attn_irrel': attn_irrel,
        })

        torch.cuda.empty_cache()

    except Exception as e:
        errors_b += 1
        if errors_b <= 5:
            print(f"\n  Error on sample {idx}: {type(e).__name__}: {e}")
        torch.cuda.empty_cache()
        continue

print(f"\nDone. {len(results_b)} evaluated, {errors_b} errors.")

### Investigation B: Results

In [None]:
n_b = len(results_b)
print("=" * 100)
print(f"ATTENTION ANALYSIS RESULTS (N = {n_b})")
print("=" * 100)

# Overall attention distribution
print(f"\n{'Condition':<25} {'Attn to Passage':>18} {'Attn to Suffix':>18} {'Attn to Query':>18}")
print("-" * 80)

for label, key in [('Bare (no suffix)', 'attn_bare'), ('Relevant suffix', 'attn_rel'), ('Irrelevant suffix', 'attn_irrel')]:
    passage_attn = np.mean([r[key]['attn_to_passage'] for r in results_b])
    suffix_attn = np.mean([r[key]['attn_to_suffix'] for r in results_b])
    query_attn = np.mean([r[key]['attn_to_query'] for r in results_b])
    print(f"{label:<25} {passage_attn:>18.4f} {suffix_attn:>18.4f} {query_attn:>18.4f}")

# Per-layer suffix attention
print(f"\n{'=' * 100}")
print("PER-LAYER MEAN ATTENTION TO SUFFIX (relevant vs irrelevant)")
print("=" * 100)

n_layers_model = len(results_b[0]['attn_rel']['per_layer_suffix_attn'])
print(f"\n{'Layer':<8} {'Relevant':>12} {'Irrelevant':>12} {'Delta':>12}")
print("-" * 48)

for layer_idx in range(n_layers_model):
    rel_attn = np.mean([r['attn_rel']['per_layer_suffix_attn'][layer_idx] for r in results_b])
    irrel_attn = np.mean([r['attn_irrel']['per_layer_suffix_attn'][layer_idx] for r in results_b])
    delta = rel_attn - irrel_attn
    print(f"{layer_idx:<8} {rel_attn:>12.6f} {irrel_attn:>12.6f} {delta:>+12.6f}")

# Statistical test: does relevant suffix get more attention than irrelevant?
rel_sfx_attn = np.array([r['attn_rel']['attn_to_suffix'] for r in results_b])
irrel_sfx_attn = np.array([r['attn_irrel']['attn_to_suffix'] for r in results_b])
t_attn, p_attn = stats.ttest_rel(rel_sfx_attn, irrel_sfx_attn)
print(f"\nRelevant vs irrelevant suffix attention: t={t_attn:.3f}, p={p_attn:.6f}")
print(f"  Mean relevant: {np.mean(rel_sfx_attn):.6f}")
print(f"  Mean irrel:    {np.mean(irrel_sfx_attn):.6f}")
print(f"  Delta:         {np.mean(rel_sfx_attn - irrel_sfx_attn):+.6f}")

# What fraction of total attention goes to suffix?
rel_total = rel_sfx_attn / np.array([
    r['attn_rel']['attn_to_passage'] + r['attn_rel']['attn_to_suffix'] + r['attn_rel']['attn_to_query']
    for r in results_b
])
print(f"\nFraction of attention to suffix (relevant): {np.mean(rel_total)*100:.2f}%")

### Investigation B: Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Investigation B: Attention Pattern Analysis', fontsize=14, fontweight='bold')

# (0,0) Attention distribution: bar chart
ax = axes[0, 0]
conditions_attn = ['Bare', 'Rel Suffix', 'Irrel Suffix']
regions = ['Passage', 'Suffix', 'Query']
data_attn = np.zeros((3, 3))  # conditions x regions

for i, key in enumerate(['attn_bare', 'attn_rel', 'attn_irrel']):
    data_attn[i, 0] = np.mean([r[key]['attn_to_passage'] for r in results_b])
    data_attn[i, 1] = np.mean([r[key]['attn_to_suffix'] for r in results_b])
    data_attn[i, 2] = np.mean([r[key]['attn_to_query'] for r in results_b])

x = np.arange(len(conditions_attn))
w = 0.25
for j, (region, color) in enumerate(zip(regions, ['#4c72b0', '#55a868', '#dd8452'])):
    ax.bar(x + j*w, data_attn[:, j], w, label=region, color=color)
ax.set_xticks(x + w)
ax.set_xticklabels(conditions_attn)
ax.set_ylabel('Mean Attention Weight')
ax.set_title('Attention Distribution by Region')
ax.legend()

# (0,1) Per-layer suffix attention
ax = axes[0, 1]
layers = list(range(n_layers_model))
rel_per_layer = [np.mean([r['attn_rel']['per_layer_suffix_attn'][l] for r in results_b]) for l in layers]
irrel_per_layer = [np.mean([r['attn_irrel']['per_layer_suffix_attn'][l] for r in results_b]) for l in layers]
ax.plot(layers, rel_per_layer, 'b-', label='Relevant suffix', alpha=0.8)
ax.plot(layers, irrel_per_layer, 'r-', label='Irrelevant suffix', alpha=0.8)
ax.set_xlabel('Layer')
ax.set_ylabel('Mean Attention to Suffix')
ax.set_title('Per-Layer Suffix Attention')
ax.legend()

# (1,0) Suffix attention: relevant vs irrelevant scatter
ax = axes[1, 0]
ax.scatter(rel_sfx_attn, irrel_sfx_attn, alpha=0.5, s=20, c='#4c72b0')
lims = [0, max(rel_sfx_attn.max(), irrel_sfx_attn.max()) * 1.1]
ax.plot(lims, lims, 'r--', linewidth=1, alpha=0.5)
ax.set_xlabel('Attention to Relevant Suffix')
ax.set_ylabel('Attention to Irrelevant Suffix')
ax.set_title('Suffix Attention: Relevant vs Irrelevant')

# (1,1) Per-layer delta (relevant - irrelevant)
ax = axes[1, 1]
delta_per_layer = [r - i for r, i in zip(rel_per_layer, irrel_per_layer)]
colors = ['#55a868' if d > 0 else '#c44e52' for d in delta_per_layer]
ax.bar(layers, delta_per_layer, color=colors)
ax.axhline(0, color='black', linestyle='-', linewidth=0.5)
ax.set_xlabel('Layer')
ax.set_ylabel('Attention Delta (rel - irrel)')
ax.set_title('Per-Layer: Does Relevant Suffix Get More Attention?')

plt.tight_layout()
plt.savefig('results/exp08/08_attention_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: 08_attention_analysis.png")

## Investigation C: Hard-Sample Analysis

Filter for samples where bare NLL is high (model struggles). On these samples,
better document representations should matter most. Use the exp 07 results
if available, otherwise compute from Investigation A data.

In [None]:
# Try to load exp 07 results for richer data
exp07_path = 'results/exp07/07_suffix_priming_results.json'
exp07_ckpt = 'results/exp07/07_checkpoint.json'

if os.path.exists(exp07_path):
    with open(exp07_path) as f:
        exp07_data = json.load(f)
    results_07 = exp07_data['results']
    source = "exp 07 results"
elif os.path.exists(exp07_ckpt):
    with open(exp07_ckpt) as f:
        exp07_data = json.load(f)
    results_07 = exp07_data['results']
    source = "exp 07 checkpoint"
else:
    results_07 = None
    source = "Investigation A data"

if results_07:
    print(f"Loaded {len(results_07)} samples from {source}")
else:
    print(f"No exp 07 data found. Using Investigation A data ({len(results_a)} samples).")

In [None]:
# Use whichever data source is available
if results_07:
    hard_data = results_07
    bare_key = 'bare_nll'
    sfx_rel_key = 'sfx_gen_routed_nll'
    sfx_irrel_key = 'sfx_irrel_nll'
    sfx_shuf_key = 'sfx_shuffled_nll'
    sfx_perf_key = 'sfx_perfect_nll'
    pfx_key = 'pfx_trunc_routed_nll'
    has_prefix = True
else:
    # Map Investigation A keys
    hard_data = results_a
    bare_key = 'bare_wq'
    sfx_rel_key = 'sfx_rel_wq'
    sfx_irrel_key = 'sfx_irrel_wq'
    sfx_shuf_key = None
    sfx_perf_key = None
    pfx_key = None
    has_prefix = False

bare_nlls = np.array([r[bare_key] for r in hard_data])

# Compute difficulty quartiles
q25, q50, q75 = np.percentile(bare_nlls, [25, 50, 75])
print(f"Bare NLL quartiles: Q25={q25:.3f}, Q50={q50:.3f}, Q75={q75:.3f}")
print(f"Range: [{bare_nlls.min():.3f}, {bare_nlls.max():.3f}]")

# Split into quartiles
quartile_masks = {
    'Q1 (easiest)': bare_nlls <= q25,
    'Q2': (bare_nlls > q25) & (bare_nlls <= q50),
    'Q3': (bare_nlls > q50) & (bare_nlls <= q75),
    'Q4 (hardest)': bare_nlls > q75,
}

# Also define a "hard" subset: top 25%
hard_mask = bare_nlls > q75
# And "very hard": top 10%
q90 = np.percentile(bare_nlls, 90)
very_hard_mask = bare_nlls > q90

print(f"\nHard samples (Q4, top 25%): {hard_mask.sum()}")
print(f"Very hard samples (top 10%): {very_hard_mask.sum()}")

print(f"\n{'=' * 120}")
print("WIN RATES BY DIFFICULTY QUARTILE")
print("=" * 120)

conditions_c = [('sfx_relevant', sfx_rel_key), ('sfx_irrelevant', sfx_irrel_key)]
if sfx_shuf_key:
    conditions_c.append(('sfx_shuffled', sfx_shuf_key))
if sfx_perf_key:
    conditions_c.append(('sfx_perfect', sfx_perf_key))
if pfx_key:
    conditions_c.append(('pfx_trunc', pfx_key))

header = f"{'Quartile':<18} {'N':>5}"
for name, _ in conditions_c:
    header += f" {name:>16}"
print(header)
print("-" * 120)

for q_label, mask in quartile_masks.items():
    if mask.sum() == 0:
        continue
    row = f"{q_label:<18} {mask.sum():>5}"
    for name, key in conditions_c:
        if key is None:
            row += f" {'N/A':>16}"
            continue
        cond_nlls = np.array([r[key] for r in hard_data])[mask]
        bare_q = bare_nlls[mask]
        wr = np.mean(bare_q > cond_nlls) * 100
        delta = np.mean(bare_q - cond_nlls)
        row += f" {wr:>6.1f}% ({delta:+.3f})"
    print(row)

# Detailed analysis on hard samples
print(f"\n{'=' * 120}")
print("HARD SAMPLES ONLY (top 25% bare NLL)")
print("=" * 120)

hard_indices = np.where(hard_mask)[0]
print(f"N = {len(hard_indices)}")

for name, key in conditions_c:
    if key is None:
        continue
    cond_nlls = np.array([r[key] for r in hard_data])[hard_mask]
    bare_q = bare_nlls[hard_mask]
    delta = bare_q - cond_nlls
    t, p = stats.ttest_rel(bare_q, cond_nlls)
    d = cohens_d(delta)
    wr = np.mean(delta > 0) * 100
    print(f"  {name:<20} delta={np.mean(delta):+.4f}  win={wr:.1f}%  t={t:.3f}  p={p:.6f}  d={d:.4f}")

# Semantic separation on hard samples only
if sfx_rel_key and sfx_irrel_key:
    hard_rel = np.array([r[sfx_rel_key] for r in hard_data])[hard_mask]
    hard_irrel = np.array([r[sfx_irrel_key] for r in hard_data])[hard_mask]
    t_sep, p_sep = stats.ttest_rel(hard_rel, hard_irrel)
    d_sep = cohens_d(hard_irrel - hard_rel)
    print(f"\n  Semantic separation (rel vs irrel) on hard samples:")
    print(f"    delta={np.mean(hard_irrel - hard_rel):+.4f}  t={t_sep:.3f}  p={p_sep:.6f}  d={d_sep:.4f}")

# Correlation test on hard samples
if sfx_shuf_key:
    hard_rel_delta = bare_nlls[hard_mask] - np.array([r[sfx_rel_key] for r in hard_data])[hard_mask]
    hard_shuf_delta = bare_nlls[hard_mask] - np.array([r[sfx_shuf_key] for r in hard_data])[hard_mask]
    r_hard, p_hard = stats.pearsonr(hard_rel_delta, hard_shuf_delta)
    print(f"\n  Delta correlation on hard samples (rel vs shuffled): r={r_hard:.4f} (p={p_hard:.6f})")

### Investigation C: Extractive vs Non-Extractive

In [None]:
# Measure answer extractiveness: token overlap between answer and passage
print("=" * 100)
print("EXTRACTIVE vs NON-EXTRACTIVE ANALYSIS")
print("=" * 100)

def token_overlap_ratio(answer, passage):
    """Fraction of answer words found in passage."""
    answer_words = set(answer.lower().split())
    passage_words = set(passage.lower().split())
    if not answer_words:
        return 1.0
    return len(answer_words & passage_words) / len(answer_words)

overlaps = np.array([token_overlap_ratio(r.get('answer', r.get('query', '')),
                                          samples[r['idx']]['passage'] if r['idx'] < len(samples) else '')
                      for r in hard_data])

# If overlaps couldn't be computed properly, try direct
if np.all(overlaps == 0):
    print("Note: overlap computation fell back — using passage from samples directly")
    overlaps = np.array([token_overlap_ratio(samples[i]['answer'], samples[i]['passage'])
                          for i in range(len(hard_data))])

overlap_median = np.median(overlaps)
extractive_mask = overlaps >= overlap_median
nonextractive_mask = overlaps < overlap_median

print(f"Overlap median: {overlap_median:.3f}")
print(f"Extractive (>= median): {extractive_mask.sum()}")
print(f"Non-extractive (< median): {nonextractive_mask.sum()}")

for subset_name, mask in [("Extractive", extractive_mask), ("Non-extractive", nonextractive_mask)]:
    print(f"\n  --- {subset_name} (N={mask.sum()}) ---")
    for name, key in conditions_c:
        if key is None:
            continue
        cond_nlls = np.array([r[key] for r in hard_data])[mask]
        bare_q = bare_nlls[mask]
        delta = bare_q - cond_nlls
        wr = np.mean(delta > 0) * 100
        if mask.sum() >= 5:
            t, p = stats.ttest_rel(bare_q, cond_nlls)
            print(f"    {name:<20} delta={np.mean(delta):+.4f}  win={wr:.1f}%  t={t:.3f}  p={p:.6f}")
        else:
            print(f"    {name:<20} delta={np.mean(delta):+.4f}  win={wr:.1f}%  (too few for t-test)")

# Does semantic separation appear in non-extractive subset?
if sfx_rel_key and sfx_irrel_key and nonextractive_mask.sum() >= 10:
    ne_rel = np.array([r[sfx_rel_key] for r in hard_data])[nonextractive_mask]
    ne_irrel = np.array([r[sfx_irrel_key] for r in hard_data])[nonextractive_mask]
    t_ne, p_ne = stats.ttest_rel(ne_rel, ne_irrel)
    d_ne = cohens_d(ne_irrel - ne_rel)
    print(f"\n  Semantic separation on non-extractive samples:")
    print(f"    rel vs irrel: delta={np.mean(ne_irrel - ne_rel):+.4f}  t={t_ne:.3f}  p={p_ne:.6f}  d={d_ne:.4f}")

### Investigation C: Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Investigation C: Hard-Sample and Extractiveness Analysis', fontsize=14, fontweight='bold')

# (0,0) Win rate by difficulty
ax = axes[0, 0]
q_labels = list(quartile_masks.keys())
if sfx_rel_key:
    rel_wrs = []
    irrel_wrs = []
    for q_label, mask in quartile_masks.items():
        if mask.sum() == 0:
            rel_wrs.append(0)
            irrel_wrs.append(0)
            continue
        rel_wrs.append(np.mean(bare_nlls[mask] > np.array([r[sfx_rel_key] for r in hard_data])[mask]) * 100)
        irrel_wrs.append(np.mean(bare_nlls[mask] > np.array([r[sfx_irrel_key] for r in hard_data])[mask]) * 100)
    x = np.arange(len(q_labels))
    w = 0.35
    ax.bar(x - w/2, rel_wrs, w, label='Relevant suffix', color='#4c72b0')
    ax.bar(x + w/2, irrel_wrs, w, label='Irrelevant suffix', color='#c44e52')
    ax.axhline(50, color='black', linestyle='--', linewidth=0.8, alpha=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels([l.replace(' ', '\n') for l in q_labels], fontsize=8)
    ax.set_ylabel('Win Rate vs Bare (%)')
    ax.set_title('Win Rate by Difficulty Quartile')
    ax.legend(fontsize=8)
    ax.set_ylim(0, 100)

# (0,1) Delta by difficulty
ax = axes[0, 1]
if sfx_rel_key:
    rel_deltas_q = []
    irrel_deltas_q = []
    for q_label, mask in quartile_masks.items():
        if mask.sum() == 0:
            rel_deltas_q.append(0)
            irrel_deltas_q.append(0)
            continue
        rel_deltas_q.append(np.mean(bare_nlls[mask] - np.array([r[sfx_rel_key] for r in hard_data])[mask]))
        irrel_deltas_q.append(np.mean(bare_nlls[mask] - np.array([r[sfx_irrel_key] for r in hard_data])[mask]))
    ax.bar(x - w/2, rel_deltas_q, w, label='Relevant', color='#4c72b0')
    ax.bar(x + w/2, irrel_deltas_q, w, label='Irrelevant', color='#c44e52')
    ax.axhline(0, color='black', linestyle='-', linewidth=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels([l.replace(' ', '\n') for l in q_labels], fontsize=8)
    ax.set_ylabel('Mean Delta NLL')
    ax.set_title('Mean Improvement by Quartile')
    ax.legend(fontsize=8)

# (1,0) Overlap histogram
ax = axes[1, 0]
ax.hist(overlaps, bins=30, color='#4c72b0', alpha=0.7, edgecolor='black', linewidth=0.5)
ax.axvline(overlap_median, color='red', linestyle='--', label=f'Median={overlap_median:.2f}')
ax.set_xlabel('Answer-Passage Token Overlap')
ax.set_ylabel('Count')
ax.set_title('Answer Extractiveness Distribution')
ax.legend()

# (1,1) Extractive vs non-extractive semantic gap
ax = axes[1, 1]
if sfx_rel_key and sfx_irrel_key:
    categories = ['Extractive', 'Non-extractive']
    gaps = []
    for mask in [extractive_mask, nonextractive_mask]:
        if mask.sum() > 0:
            rel = np.array([r[sfx_rel_key] for r in hard_data])[mask]
            irrel = np.array([r[sfx_irrel_key] for r in hard_data])[mask]
            gaps.append(np.mean(irrel - rel))
        else:
            gaps.append(0)
    colors = ['#55a868' if g > 0 else '#c44e52' for g in gaps]
    ax.bar(categories, gaps, color=colors)
    ax.axhline(0, color='black', linestyle='-', linewidth=0.5)
    ax.set_ylabel('Mean NLL Gap (irrel - rel)')
    ax.set_title('Semantic Gap: Extractive vs Non-Extractive')
    ax.annotate('Positive = relevant suffix\nis better', xy=(0.02, 0.95),
                xycoords='axes fraction', fontsize=8, va='top')

plt.tight_layout()
plt.savefig('results/exp08/08_hard_sample_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: 08_hard_sample_analysis.png")

## Summary and Conclusions

In [None]:
print("=" * 80)
print("EXPERIMENT 08 SUMMARY")
print("=" * 80)

print("\n--- Investigation A: Query-Free Scoring ---")
rel_noq = np.array([r['sfx_rel_noq'] for r in results_a])
irrel_noq = np.array([r['sfx_irrel_noq'] for r in results_a])
perf_noq = np.array([r['sfx_perfect_noq'] for r in results_a])
bare_noq_arr = np.array([r['bare_noq'] for r in results_a])

t_ri, p_ri = stats.ttest_rel(rel_noq, irrel_noq)
d_ri = cohens_d(irrel_noq - rel_noq)
print(f"  Relevant vs Irrelevant (no query): t={t_ri:.3f}, p={p_ri:.6f}, d={d_ri:.4f}")

t_pb, p_pb = stats.ttest_rel(perf_noq, bare_noq_arr)
d_pb = cohens_d(bare_noq_arr - perf_noq)
print(f"  Perfect vs Bare (no query): t={t_pb:.3f}, p={p_pb:.6f}, d={d_pb:.4f}")

if p_ri < 0.01 and np.mean(rel_noq) < np.mean(irrel_noq):
    print("  VERDICT: Semantic signal EXISTS when query is removed!")
    print("  => The query was masking the suffix effect in exp 07.")
elif p_pb < 0.01 and np.mean(perf_noq) < np.mean(bare_noq_arr):
    print("  VERDICT: Perfect suffix helps but generated surrogates don't separate.")
    print("  => Surrogate quality is the bottleneck, not the approach.")
else:
    print("  VERDICT: No semantic signal even without query.")
    print("  => The suffix mechanism itself is insufficient.")

print("\n--- Investigation B: Attention Analysis ---")
mean_sfx_attn_rel = np.mean(rel_sfx_attn)
mean_sfx_attn_irrel = np.mean(irrel_sfx_attn)
sfx_fraction = np.mean(rel_total) * 100
print(f"  Mean attention to suffix: relevant={mean_sfx_attn_rel:.4f}, irrelevant={mean_sfx_attn_irrel:.4f}")
print(f"  Suffix attention as % of total: {sfx_fraction:.2f}%")
if sfx_fraction < 1.0:
    print("  VERDICT: Suffix receives negligible attention (<1%). Model ignores it.")
elif abs(mean_sfx_attn_rel - mean_sfx_attn_irrel) < 0.001:
    print("  VERDICT: Suffix gets attention but same amount regardless of content.")
else:
    print("  VERDICT: Differential attention to relevant vs irrelevant suffix detected.")

print("\n--- Investigation C: Hard Samples ---")
if sfx_rel_key and sfx_irrel_key:
    hard_rel_arr = np.array([r[sfx_rel_key] for r in hard_data])[hard_mask]
    hard_irrel_arr = np.array([r[sfx_irrel_key] for r in hard_data])[hard_mask]
    t_h, p_h = stats.ttest_rel(hard_rel_arr, hard_irrel_arr)
    hard_wr = np.mean(hard_rel_arr < hard_irrel_arr) * 100
    print(f"  Hard samples — relevant beats irrelevant: {hard_wr:.1f}% (p={p_h:.6f})")

    if p_h < 0.05 and hard_wr > 55:
        print("  VERDICT: Semantic signal appears on hard samples.")
        print("  => MS MARCO's easy samples are washing out the effect.")
    else:
        print("  VERDICT: No semantic signal even on hard samples.")

print("\n--- Overall Assessment ---")
print("If Investigation A shows a query-free effect but B shows low attention:")
print("  => Architectural issue — model can't leverage suffix via attention.")
print("If A shows effect AND B shows differential attention:")
print("  => Signal exists but too weak for the with-query setting.")
print("If nothing works:")
print("  => Move to long-document QA (Experiment 09) where priming should matter more.")

## Save Results

In [None]:
os.makedirs('results', exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

output = {
    'metadata': {
        'experiment': '08_diagnostic_suffix_signal',
        'description': (
            'Three investigations diagnosing why suffix priming shows no semantic signal: '
            'A) query-free scoring (200 samples), B) attention pattern analysis (30 samples), '
            'C) hard-sample and extractiveness stratification.'
        ),
        'timestamp': datetime.datetime.now().isoformat(),
        'model_name': config.model_name,
        'seed': config.seed,
        'n_samples_a': len(results_a),
        'n_samples_b': len(results_b),
    },
    'results_a': results_a,
    'results_b': results_b,
}

output_path = f'results/08_diagnostic_results_{timestamp}.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Saved: {output_path}")

output_path_canonical = 'results/exp08/08_diagnostic_results.json'
with open(output_path_canonical, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Saved: {output_path_canonical}")

## Next Steps

Based on the results above:

- **If semantic signal appears in query-free scoring (Investigation A)**: The approach works
  but the with-query setting makes suffix redundant. Consider alternative scoring approaches
  or tasks where the query isn't available at scoring time.

- **If hard/non-extractive samples show separation (Investigation C)**: Move to long-document
  QA where ALL samples are "hard" and extractive answers are rare. Create Experiment 09
  with NarrativeQA or QuALITY.

- **If nothing works**: The suffix KV mechanism may be fundamentally limited. Consider
  alternative approaches: encoder-decoder models, cross-attention injection, or learned
  prefix tuning.