# Exp 26: Hardness-Gated Soft Prefix

## Motivation

Exp 25 showed soft prefix optimization beats discrete prefix (d=+0.288 vs +0.195) but has a
problem: **it hurts easy queries** (Q1: d=-0.43) while massively helping hard ones (Q4: d=+0.85).
This suggests a simple gating strategy: only apply the soft prefix when the query is hard enough
to benefit.

## Design

| Part | Data | Description |
|------|------|-------------|
| 1 | Exp 25 eval CSV (275 valid samples) | Threshold sweep on existing data |
| 2 | 300 fresh validation queries (seed=44) | Test generalization of optimal threshold |
| 3 | Cross-validation of threshold | Does Part 1 threshold hold on Part 2 data? |

### Gating Rule

```
gated_nll = soft_fact_nll  if bare_nll >= threshold
            bare_nll       otherwise
```

### Success Criteria

- **Primary**: Gated d > ungated d (+0.288) on fresh validation data
- **Secondary**: Q1 harm eliminated (d >= 0 for easiest quintile)
- **Tertiary**: Optimal threshold generalizes from Exp 25 data to fresh data

In [None]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
import json
import time
import csv
import numpy as np
import torch
import gc
from pathlib import Path
from scipy import stats

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

RESULTS_DIR = Path("results/exp26")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

EXP25_CSV = Path("results/exp25/eval_results.csv")
EXP25_SOFT_FACT = Path("results/exp25/soft_prefix_fact.pt")
CHECKPOINT_EVAL_PATH = RESULTS_DIR / "checkpoint_eval.json"
FINAL_RESULTS_PATH = RESULTS_DIR / "results.json"
CSV_EVAL_PATH = RESULTS_DIR / "eval_results.csv"

print(f"SEED: {SEED}")
print(f"Results directory: {RESULTS_DIR}")
print(f"Exp 25 CSV: {EXP25_CSV} (exists: {EXP25_CSV.exists()})")
print(f"Exp 25 soft_prefix_fact: {EXP25_SOFT_FACT} (exists: {EXP25_SOFT_FACT.exists()})")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: Load Exp 25 evaluation data

print("=" * 70)
print("PART 1: THRESHOLD SIMULATION ON EXP 25 DATA")
print("=" * 70)

# Load CSV
exp25_data = []
with open(EXP25_CSV, 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        exp25_data.append({
            'query_idx': int(row['query_idx']),
            'doc_len': int(row['doc_len']),
            'bare_nll': float(row['bare_nll']),
            'vel_soft_fact_nll': float(row['vel_soft_fact_nll']),
        })

print(f"Loaded {len(exp25_data)} rows from Exp 25 CSV")

bare_arr = np.array([r['bare_nll'] for r in exp25_data])
soft_arr = np.array([r['vel_soft_fact_nll'] for r in exp25_data])

# Filter valid samples (same as Exp 25)
valid = (
    (bare_arr != 0) & np.isfinite(bare_arr) &
    (soft_arr != 0) & np.isfinite(soft_arr)
)
bare = bare_arr[valid]
soft = soft_arr[valid]
n_valid = int(np.sum(valid))

print(f"Valid samples: {n_valid}/{len(exp25_data)}")
print(f"Bare NLL: mean={np.mean(bare):.4f}, std={np.std(bare):.4f}")
print(f"Soft fact NLL: mean={np.mean(soft):.4f}, std={np.std(soft):.4f}")

# Ungated reference
from lib.analysis import cohens_d

ungated_delta = bare - soft
ungated_d = cohens_d(ungated_delta)
ungated_win = np.mean(ungated_delta > 0) * 100
_, ungated_p = stats.ttest_1samp(ungated_delta, 0)

print(f"\nUngated soft_fact: d={ungated_d:+.3f}, win={ungated_win:.1f}%, p={ungated_p:.2e}")

In [None]:
# Cell 3: Threshold simulation sweep

print("=" * 70)
print("THRESHOLD SWEEP")
print("=" * 70)

def compute_gated_metrics(bare, soft, threshold):
    """Apply hardness gating and compute metrics.
    
    Gated NLL = soft_nll if bare_nll >= threshold, else bare_nll.
    Returns dict with d, win%, p, n_gated, fraction_gated.
    """
    gated = np.where(bare >= threshold, soft, bare)
    delta = bare - gated
    n_gated = int(np.sum(bare >= threshold))
    
    d = cohens_d(delta)
    win_pct = np.mean(delta > 0) * 100
    _, p_val = stats.ttest_1samp(delta, 0)
    
    return {
        'threshold': float(threshold),
        'cohens_d': float(d),
        'win_pct': float(win_pct),
        'p_value': float(p_val),
        'n_gated': n_gated,
        'frac_gated': n_gated / len(bare),
        'mean_delta': float(np.mean(delta)),
    }

# Sweep: percentile-based + fine grid
percentiles = [0, 10, 20, 30, 40, 50, 60, 70, 75, 80, 85, 90, 95]
percentile_thresholds = np.percentile(bare, percentiles)

# Also sweep a fine grid of absolute thresholds
fine_thresholds = np.linspace(0, np.percentile(bare, 95), 100)

# Combine and deduplicate
all_thresholds = np.unique(np.concatenate([percentile_thresholds, fine_thresholds]))
all_thresholds.sort()

sweep_results = []
for t in all_thresholds:
    metrics = compute_gated_metrics(bare, soft, t)
    sweep_results.append(metrics)

# Find optimal threshold
best_idx = np.argmax([r['cohens_d'] for r in sweep_results])
best = sweep_results[best_idx]

print(f"\nSweep: {len(all_thresholds)} thresholds tested")
print(f"\nOptimal threshold: {best['threshold']:.4f}")
print(f"  Cohen's d: {best['cohens_d']:+.3f} (ungated: {ungated_d:+.3f})")
print(f"  Win%: {best['win_pct']:.1f}% (ungated: {ungated_win:.1f}%)")
print(f"  p-value: {best['p_value']:.2e}")
print(f"  N gated: {best['n_gated']}/{n_valid} ({best['frac_gated']*100:.1f}%)")

# Show percentile-based results
print(f"\n{'Percentile':<12} {'Threshold':>10} {'d':>8} {'Win%':>7} {'N gated':>8} {'% gated':>8}")
print("-" * 58)
for pct in percentiles:
    t = np.percentile(bare, pct)
    m = compute_gated_metrics(bare, soft, t)
    marker = " <<<" if abs(t - best['threshold']) < 1e-6 else ""
    print(f"P{pct:<11} {t:>10.4f} {m['cohens_d']:>+8.3f} {m['win_pct']:>6.1f}% "
          f"{m['n_gated']:>8} {m['frac_gated']*100:>7.1f}%{marker}")

# Select top-5 candidate thresholds for fresh validation
# Percentile-based: P50, P60, P70, P75, P80 + absolute optimal
candidate_percentiles = [50, 60, 70, 75, 80]
candidate_thresholds = {f'P{p}': float(np.percentile(bare, p)) for p in candidate_percentiles}
candidate_thresholds['optimal'] = best['threshold']

print(f"\nCandidate thresholds for fresh validation:")
for name, t in candidate_thresholds.items():
    m = compute_gated_metrics(bare, soft, t)
    print(f"  {name}: threshold={t:.4f}, d={m['cohens_d']:+.3f}, "
          f"win={m['win_pct']:.1f}%, gated={m['frac_gated']*100:.0f}%")

In [None]:
# Cell 4: Bootstrap CI + oracle gating upper bound

print("=" * 70)
print("BOOTSTRAP CONFIDENCE INTERVALS & ORACLE GATING")
print("=" * 70)

N_BOOTSTRAP = 1000

def bootstrap_d(bare, soft, threshold, n_boot=N_BOOTSTRAP, seed=42):
    """Bootstrap 95% CI for gated Cohen's d."""
    rng = np.random.RandomState(seed)
    ds = []
    n = len(bare)
    for _ in range(n_boot):
        idx = rng.choice(n, size=n, replace=True)
        b, s = bare[idx], soft[idx]
        gated = np.where(b >= threshold, s, b)
        delta = b - gated
        d = cohens_d(delta)
        ds.append(d)
    ds = np.array(ds)
    return {
        'mean': float(np.mean(ds)),
        'ci_low': float(np.percentile(ds, 2.5)),
        'ci_high': float(np.percentile(ds, 97.5)),
        'std': float(np.std(ds)),
    }

# Bootstrap for ungated
boot_ungated = bootstrap_d(bare, soft, 0.0)  # threshold=0 means all gated
print(f"\nUngated d: {ungated_d:+.3f} (95% CI: [{boot_ungated['ci_low']:+.3f}, {boot_ungated['ci_high']:+.3f}])")

# Bootstrap for optimal threshold
boot_optimal = bootstrap_d(bare, soft, best['threshold'])
print(f"Optimal gated d: {best['cohens_d']:+.3f} (95% CI: [{boot_optimal['ci_low']:+.3f}, {boot_optimal['ci_high']:+.3f}])")

# Bootstrap for each candidate
print(f"\n{'Candidate':<12} {'d':>8} {'CI low':>8} {'CI high':>8}")
print("-" * 40)
boot_candidates = {}
for name, t in candidate_thresholds.items():
    m = compute_gated_metrics(bare, soft, t)
    boot = bootstrap_d(bare, soft, t)
    boot_candidates[name] = boot
    print(f"{name:<12} {m['cohens_d']:>+8.3f} {boot['ci_low']:>+8.3f} {boot['ci_high']:>+8.3f}")

# Oracle gating: apply soft only when it actually helps
oracle_gated = np.where(soft < bare, soft, bare)
oracle_delta = bare - oracle_gated
oracle_d = cohens_d(oracle_delta)
oracle_win = np.mean(oracle_delta > 0) * 100

print(f"\nOracle gating (perfect knowledge):")
print(f"  d={oracle_d:+.3f}, win={oracle_win:.1f}%")
print(f"  This is the UPPER BOUND — no threshold can beat this.")
print(f"  Fraction where soft helps: {np.mean(soft < bare)*100:.1f}%")
print(f"  Fraction where soft hurts: {np.mean(soft > bare)*100:.1f}%")

In [None]:
# Cell 5: Part 1 plots

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# --- Panel 1: d vs threshold curve ---
ax = axes[0]
thresholds = [r['threshold'] for r in sweep_results]
ds = [r['cohens_d'] for r in sweep_results]

ax.plot(thresholds, ds, linewidth=2, color='#1f77b4')
ax.axhline(y=ungated_d, color='red', linestyle='--', linewidth=1.5,
           label=f'Ungated d={ungated_d:+.3f}')
ax.axhline(y=oracle_d, color='green', linestyle=':', linewidth=1.5,
           label=f'Oracle d={oracle_d:+.3f}')
ax.axvline(x=best['threshold'], color='orange', linestyle='--', alpha=0.7,
           label=f'Optimal t={best["threshold"]:.3f}')
ax.set_xlabel('Bare NLL Threshold')
ax.set_ylabel("Cohen's d (gated)")
ax.set_title('Gated d vs Threshold')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# --- Panel 2: Per-quintile heatmap (gated vs ungated) ---
ax = axes[1]

quintile_bounds = np.percentile(bare, [20, 40, 60, 80])
qlabels = ['Q1\neasy', 'Q2', 'Q3', 'Q4', 'Q5\nhard']
quintiles = np.digitize(bare, quintile_bounds)

# Compare ungated vs optimal gated per quintile
gated_optimal = np.where(bare >= best['threshold'], soft, bare)
conditions = {
    'ungated': soft,
    'gated_optimal': gated_optimal,
}

heatmap = np.zeros((len(conditions), 5))
for i, (cname, carr) in enumerate(conditions.items()):
    delta = bare - carr
    for q in range(5):
        mask = quintiles == q
        if np.sum(mask) >= 5:
            heatmap[i, q] = cohens_d(delta[mask])

im = ax.imshow(heatmap, cmap='RdBu', aspect='auto', vmin=-0.5, vmax=1.0)
ax.set_xticks(range(5))
ax.set_xticklabels(qlabels, fontsize=8)
ax.set_yticks(range(len(conditions)))
ax.set_yticklabels(['ungated', 'gated'])
ax.set_title('Quintile d: Ungated vs Gated')

for i in range(len(conditions)):
    for j in range(5):
        val = heatmap[i, j]
        ax.text(j, i, f"{val:+.2f}", ha='center', va='center',
                fontsize=9, color='white' if abs(val) > 0.3 else 'black')
fig.colorbar(im, ax=ax, shrink=0.8)

# --- Panel 3: Scatter (bare NLL vs delta) with threshold ---
ax = axes[2]
delta_ungated = bare - soft
ax.scatter(bare, delta_ungated, alpha=0.4, s=20, c='#1f77b4', edgecolors='none',
           label='ungated')
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
ax.axvline(x=best['threshold'], color='orange', linestyle='--', linewidth=2,
           label=f'threshold={best["threshold"]:.3f}')
ax.set_xlabel('Bare NLL (difficulty)')
ax.set_ylabel('Delta (bare - soft, positive = helps)')
ax.set_title('Per-Sample: Difficulty vs Benefit')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# Annotate regions
ax.text(best['threshold'] * 0.4, ax.get_ylim()[0] * 0.7,
        'SKIP\n(easy)', ha='center', fontsize=10, color='red', alpha=0.7)
ax.text(best['threshold'] * 1.5 if best['threshold'] > 0 else 0.5,
        ax.get_ylim()[1] * 0.7,
        'APPLY\n(hard)', ha='center', fontsize=10, color='green', alpha=0.7)

plt.suptitle('Exp 26 Part 1: Threshold Simulation on Exp 25 Data', fontsize=13)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'part1_threshold_sweep.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved to {RESULTS_DIR / 'part1_threshold_sweep.png'}")

In [None]:
# Cell 6: Load Gemma 3 4B + Exp 25 soft_prefix_fact.pt

print("=" * 70)
print("PART 2: FRESH VALIDATION")
print("=" * 70)

sys.path.insert(0, ".")

from lib.config import ExperimentConfig
from lib.model_utils import load_model

MODEL_NAME = "google/gemma-3-4b-it"

exp_config = ExperimentConfig(
    model_name=MODEL_NAME,
    model_type="gemma3",
    compute_dtype="auto",
    use_4bit=True,
    num_samples=300,
    seed=SEED,
)

print(f"Loading {MODEL_NAME} (4-bit, bfloat16)...")
model, tokenizer = load_model(exp_config)

from lib.kv_cache import (
    _get_text_config, _get_head_dim,
    _get_cache_keys, _get_cache_values,
    _set_cache_keys, _set_cache_values,
    _ensure_dynamic_cache,
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    score_answer_with_cache,
    deepcopy_cache,
    replace_values_at_layers,
)

text_config = _get_text_config(model.config)
NUM_LAYERS = text_config.num_hidden_layers
HIDDEN_SIZE = text_config.hidden_size

print(f"Model loaded. Layers={NUM_LAYERS}, hidden={HIDDEN_SIZE}")

# Load soft prefix
soft_prefix_fact = torch.load(EXP25_SOFT_FACT).to(exp_config.device)
print(f"Loaded soft_prefix_fact: shape={soft_prefix_fact.shape}")

# Get embedding function
embed_fn = model.get_input_embeddings()
print(f"Embedding layer: {type(embed_fn).__name__}, shape={embed_fn.weight.shape}")

In [None]:
# Cell 7: Load 300 fresh validation queries (seed=44, non-overlapping with Exp 25's seed=43)

from lib.data import count_words
from lib.surrogate import STATIC_SURROGATE_QUERIES
from datasets import load_dataset
from tqdm.auto import tqdm

CUTOFF = 16
MAX_PASSAGE_WORDS = 300
N_FRESH = 300
FRESH_SEED = 44  # Different from Exp 25 train (42) and eval (43)
CHECKPOINT_EVERY = 50

STATIC_FACT = STATIC_SURROGATE_QUERIES['static_factual']['query']
SURROGATE_PREFIX_TEMPLATE = "{surrogate}\n"
DOCUMENT_TEMPLATE = "{document}"
QUERY_TEMPLATE = "\nQuery: {query}\nAnswer:"
ANSWER_TEMPLATE = " {answer}"

sf_str = SURROGATE_PREFIX_TEMPLATE.format(surrogate=STATIC_FACT)
sf_ids = tokenizer(sf_str, return_tensors="pt",
                    add_special_tokens=False)['input_ids'].to(exp_config.device)
PREFIX_LEN = sf_ids.shape[1]

print(f"Loading MS MARCO validation with seed={FRESH_SEED}...")

dataset = load_dataset("microsoft/ms_marco", "v1.1", split="validation")
print(f"Total items: {len(dataset)}")

fresh_samples = []
np.random.seed(FRESH_SEED)

for item in tqdm(dataset, desc="Filtering"):
    passages_info = item.get('passages', {})
    passage_texts = passages_info.get('passage_text', [])
    is_selected = passages_info.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])

    if not passage_texts or not query:
        continue

    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] != '[]':
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    else:
        continue

    for ptext, sel in zip(passage_texts, is_selected):
        if sel == 1 and count_words(ptext) <= MAX_PASSAGE_WORDS:
            fresh_samples.append({
                'query': query,
                'answer': answer,
                'passage': ptext,
                'word_count': count_words(ptext),
            })
            break

    if len(fresh_samples) >= N_FRESH * 3:
        break

np.random.shuffle(fresh_samples)
fresh_samples = fresh_samples[:N_FRESH]

del dataset
gc.collect()

print(f"\nSelected {len(fresh_samples)} fresh validation samples (seed={FRESH_SEED})")
print(f"Word counts: mean={np.mean([q['word_count'] for q in fresh_samples]):.0f}, "
      f"min={min(q['word_count'] for q in fresh_samples)}, "
      f"max={max(q['word_count'] for q in fresh_samples)}")

In [None]:
# Cell 8: Score bare + soft_fact for all 300 fresh queries

print("=" * 70)
print(f"SCORING {N_FRESH} FRESH QUERIES (bare + soft_fact)")
print("=" * 70)

layer_indices = list(range(CUTOFF))

# Checkpoint resume
fresh_results = []
fresh_start_idx = 0

if CHECKPOINT_EVAL_PATH.exists():
    with open(CHECKPOINT_EVAL_PATH, 'r') as f:
        ckpt = json.load(f)
    ckpt_queries = ckpt.get('query_texts', [])
    current_queries = [q['query'] for q in fresh_samples[:N_FRESH]]
    if ckpt_queries == current_queries:
        fresh_results = ckpt['results']
        fresh_start_idx = len(fresh_results)
        print(f"Resuming from checkpoint: {fresh_start_idx}/{N_FRESH}")
    else:
        print("Checkpoint query mismatch. Starting fresh.")
else:
    print("No checkpoint found. Starting fresh.")

t_start = time.time()

for qidx in tqdm(range(fresh_start_idx, N_FRESH), initial=fresh_start_idx,
                  total=N_FRESH, desc="Scoring"):
    qdata = fresh_samples[qidx]
    query_prompt = QUERY_TEMPLATE.format(query=qdata['query'])
    answer_text = ANSWER_TEMPLATE.format(answer=qdata['answer'])
    document_text = DOCUMENT_TEMPLATE.format(document=qdata['passage'])

    # Matched tokenization
    full_text = sf_str + document_text
    full_enc = tokenizer(full_text, return_tensors="pt",
                          add_special_tokens=True, padding=False, truncation=False)
    full_ids = full_enc['input_ids'].to(exp_config.device)
    sf_prefix_enc = tokenizer(sf_str, return_tensors="pt",
                               add_special_tokens=True, padding=False, truncation=False)
    sf_prefix_len_matched = sf_prefix_enc['input_ids'].shape[1]
    bos_id = full_ids[:, :1]
    doc_ids = full_ids[:, sf_prefix_len_matched:]
    doc_len = doc_ids.shape[1]
    context_len = 1 + doc_len

    del full_enc, full_ids, sf_prefix_enc

    # --- Bare NLL ---
    bare_input = torch.cat([bos_id, doc_ids], dim=1)
    with torch.no_grad():
        bare_out = model(input_ids=bare_input,
                         attention_mask=torch.ones_like(bare_input),
                         use_cache=True, return_dict=True)
    bare_cache = _ensure_dynamic_cache(bare_out.past_key_values)
    del bare_out

    bare_nll = score_answer_with_cache(
        deepcopy_cache(bare_cache), context_len,
        query_prompt, answer_text, model, tokenizer, exp_config)

    # --- Soft fact NLL ---
    with torch.no_grad():
        bos_emb = embed_fn(bos_id)
        doc_emb = embed_fn(doc_ids)
        soft_cast = soft_prefix_fact.to(device=exp_config.device, dtype=bos_emb.dtype)

        inputs_embeds = torch.cat([bos_emb, soft_cast, doc_emb], dim=1)
        total_len = inputs_embeds.shape[1]
        attn_mask = torch.ones((1, total_len), device=exp_config.device, dtype=torch.long)

        soft_out = model(inputs_embeds=inputs_embeds,
                        attention_mask=attn_mask,
                        use_cache=True, return_dict=True)
        soft_cache = _ensure_dynamic_cache(soft_out.past_key_values)
        del soft_out

        soft_trunc = extract_and_truncate_cache_with_bos(soft_cache, doc_len)
        del soft_cache

        vel_soft_cache = replace_values_at_layers(bare_cache, soft_trunc, layer_indices)
        del soft_trunc

        soft_fact_nll = score_answer_with_cache(
            deepcopy_cache(vel_soft_cache), context_len,
            query_prompt, answer_text, model, tokenizer, exp_config)
        del vel_soft_cache

    del bare_cache, bare_input
    gc.collect()
    torch.cuda.empty_cache()

    fresh_results.append({
        'query_idx': qidx,
        'doc_len': doc_len,
        'bare_nll': float(bare_nll),
        'soft_fact_nll': float(soft_fact_nll),
    })

    # Checkpoint
    if (qidx + 1) % CHECKPOINT_EVERY == 0 or qidx == N_FRESH - 1:
        ckpt_data = {
            'results': fresh_results,
            'query_texts': [q['query'] for q in fresh_samples[:N_FRESH]],
            'completed': len(fresh_results),
            'total': N_FRESH,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        }
        with open(CHECKPOINT_EVAL_PATH, 'w') as f:
            json.dump(ckpt_data, f)
        elapsed = time.time() - t_start
        n_done = qidx - fresh_start_idx + 1
        rate = n_done / elapsed if elapsed > 0 else 0
        remaining = (N_FRESH - qidx - 1) / rate if rate > 0 else 0
        tqdm.write(f"  Checkpoint {qidx+1}/{N_FRESH} | {n_done} done in {elapsed/60:.1f}m | "
                   f"ETA: {remaining/60:.1f} min")

elapsed_total = time.time() - t_start
print(f"\nScoring complete: {len(fresh_results)} queries in {elapsed_total/60:.1f} min")

In [None]:
# Cell 9: Apply threshold candidates to fresh data

print("=" * 70)
print("GATED METRICS ON FRESH DATA")
print("=" * 70)

fresh_bare = np.array([r['bare_nll'] for r in fresh_results])
fresh_soft = np.array([r['soft_fact_nll'] for r in fresh_results])

# Filter valid
fresh_valid = (
    (fresh_bare != 0) & np.isfinite(fresh_bare) &
    (fresh_soft != 0) & np.isfinite(fresh_soft)
)
fb = fresh_bare[fresh_valid]
fs = fresh_soft[fresh_valid]
n_fresh_valid = int(np.sum(fresh_valid))

print(f"Fresh valid samples: {n_fresh_valid}/{len(fresh_results)}")

# Ungated on fresh data
fresh_ungated_delta = fb - fs
fresh_ungated_d = cohens_d(fresh_ungated_delta)
fresh_ungated_win = np.mean(fresh_ungated_delta > 0) * 100
_, fresh_ungated_p = stats.ttest_1samp(fresh_ungated_delta, 0)

print(f"\nFresh ungated: d={fresh_ungated_d:+.3f}, win={fresh_ungated_win:.1f}%, p={fresh_ungated_p:.2e}")

# Apply each candidate threshold
print(f"\n{'Candidate':<12} {'Threshold':>10} {'d':>8} {'Win%':>7} {'p':>12} {'N gated':>8} {'Improve?':>10}")
print("-" * 75)

fresh_gated_results = {}
for name, t in candidate_thresholds.items():
    m = compute_gated_metrics(fb, fs, t)
    improve = "YES" if m['cohens_d'] > fresh_ungated_d else "no"
    sig = '***' if m['p_value'] < 0.001 else '**' if m['p_value'] < 0.01 else '*' if m['p_value'] < 0.05 else 'ns'
    print(f"{name:<12} {t:>10.4f} {m['cohens_d']:>+8.3f} {m['win_pct']:>6.1f}% "
          f"{m['p_value']:>12.2e} {m['n_gated']:>8} {improve:>10}")
    fresh_gated_results[name] = m

# Oracle gating on fresh data
fresh_oracle_gated = np.where(fs < fb, fs, fb)
fresh_oracle_delta = fb - fresh_oracle_gated
fresh_oracle_d = cohens_d(fresh_oracle_delta)
print(f"\nOracle gating on fresh data: d={fresh_oracle_d:+.3f}")

# Also sweep fresh data for its own optimal
fresh_sweep = []
for t in np.linspace(0, np.percentile(fb, 95), 100):
    m = compute_gated_metrics(fb, fs, t)
    fresh_sweep.append(m)

fresh_best_idx = np.argmax([r['cohens_d'] for r in fresh_sweep])
fresh_best = fresh_sweep[fresh_best_idx]
print(f"Fresh-optimal threshold: {fresh_best['threshold']:.4f}, d={fresh_best['cohens_d']:+.3f}")
print(f"Exp25-optimal threshold: {best['threshold']:.4f}")

In [None]:
# Cell 10: Generalization analysis

print("=" * 70)
print("PART 3: GENERALIZATION ANALYSIS")
print("=" * 70)

# Does the Exp 25-optimal threshold generalize to fresh data?
exp25_optimal_on_fresh = compute_gated_metrics(fb, fs, best['threshold'])
fresh_optimal_on_fresh = fresh_best

print(f"\nThreshold transfer analysis:")
print(f"  Exp 25 optimal threshold: {best['threshold']:.4f}")
print(f"  Fresh optimal threshold:  {fresh_best['threshold']:.4f}")
print(f"  Difference: {abs(best['threshold'] - fresh_best['threshold']):.4f}")
print(f"")
print(f"  Exp25 threshold on Exp25 data: d={best['cohens_d']:+.3f}")
print(f"  Exp25 threshold on fresh data: d={exp25_optimal_on_fresh['cohens_d']:+.3f}")
print(f"  Fresh threshold on fresh data: d={fresh_optimal_on_fresh['cohens_d']:+.3f}")
print(f"  Ungated on fresh data:         d={fresh_ungated_d:+.3f}")

# Per-quintile analysis on fresh data (gated vs ungated)
fresh_quintile_bounds = np.percentile(fb, [20, 40, 60, 80])
fresh_quintiles = np.digitize(fb, fresh_quintile_bounds)
fresh_qlabels = ['Q1 easy', 'Q2', 'Q3', 'Q4', 'Q5 hard']

# Use the best performing threshold from fresh_gated_results
best_fresh_name = max(fresh_gated_results, key=lambda k: fresh_gated_results[k]['cohens_d'])
best_fresh_t = candidate_thresholds[best_fresh_name]
fresh_gated_best = np.where(fb >= best_fresh_t, fs, fb)

print(f"\nBest candidate on fresh data: {best_fresh_name} (threshold={best_fresh_t:.4f})")

print(f"\n{'Quintile':<12} {'Ungated d':>10} {'Gated d':>10} {'Improve?':>10}")
print("-" * 46)
for q in range(5):
    mask = fresh_quintiles == q
    n_q = int(np.sum(mask))
    if n_q < 5:
        print(f"{fresh_qlabels[q]:<12} {'n/a':>10} {'n/a':>10}")
        continue
    
    d_ungated = cohens_d(fb[mask] - fs[mask])
    d_gated = cohens_d(fb[mask] - fresh_gated_best[mask])
    improve = "YES" if d_gated > d_ungated else "no"
    print(f"{fresh_qlabels[q]:<12} {d_ungated:>+10.3f} {d_gated:>+10.3f} {improve:>10}")

# Bootstrap CI on fresh data
boot_fresh_ungated = bootstrap_d(fb, fs, 0.0, seed=44)
boot_fresh_gated = bootstrap_d(fb, fs, best_fresh_t, seed=44)
print(f"\nFresh ungated d: {fresh_ungated_d:+.3f} (95% CI: [{boot_fresh_ungated['ci_low']:+.3f}, {boot_fresh_ungated['ci_high']:+.3f}])")
print(f"Fresh gated d:   {fresh_gated_results[best_fresh_name]['cohens_d']:+.3f} (95% CI: [{boot_fresh_gated['ci_low']:+.3f}, {boot_fresh_gated['ci_high']:+.3f}])")

In [None]:
# Cell 11: Comprehensive 6-panel visualization

fig, axes = plt.subplots(2, 3, figsize=(20, 12))

# --- Panel 1: d vs threshold (Exp 25 data) ---
ax = axes[0, 0]
thresholds_p1 = [r['threshold'] for r in sweep_results]
ds_p1 = [r['cohens_d'] for r in sweep_results]
ax.plot(thresholds_p1, ds_p1, linewidth=2, color='#1f77b4', label='Exp 25 data')
ax.axhline(y=ungated_d, color='red', linestyle='--', linewidth=1.5,
           label=f'Ungated d={ungated_d:+.3f}')
ax.axvline(x=best['threshold'], color='orange', linestyle='--', alpha=0.7,
           label=f'Exp25 optimal')
ax.set_xlabel('Bare NLL Threshold')
ax.set_ylabel("Cohen's d")
ax.set_title('Part 1: Threshold Sweep (Exp 25)')
ax.legend(fontsize=7)
ax.grid(alpha=0.3)

# --- Panel 2: d vs threshold (fresh data) ---
ax = axes[0, 1]
fresh_ts = [r['threshold'] for r in fresh_sweep]
fresh_ds = [r['cohens_d'] for r in fresh_sweep]
ax.plot(fresh_ts, fresh_ds, linewidth=2, color='#ff7f0e', label='Fresh data')
ax.axhline(y=fresh_ungated_d, color='red', linestyle='--', linewidth=1.5,
           label=f'Ungated d={fresh_ungated_d:+.3f}')
ax.axvline(x=best['threshold'], color='blue', linestyle='--', alpha=0.7,
           label=f'Exp25 optimal t')
ax.axvline(x=fresh_best['threshold'], color='orange', linestyle=':', alpha=0.7,
           label=f'Fresh optimal t')
ax.set_xlabel('Bare NLL Threshold')
ax.set_ylabel("Cohen's d")
ax.set_title('Part 2: Threshold Sweep (Fresh)')
ax.legend(fontsize=7)
ax.grid(alpha=0.3)

# --- Panel 3: Bar chart comparing conditions ---
ax = axes[0, 2]
bar_names = ['Ungated\n(Exp25)', f'Gated\n(Exp25)', 'Ungated\n(Fresh)', f'Gated\n(Fresh)']
bar_ds = [ungated_d, best['cohens_d'], fresh_ungated_d,
          fresh_gated_results[best_fresh_name]['cohens_d']]
bar_colors = ['#1f77b4', '#2ca02c', '#ff7f0e', '#d62728']

bars = ax.bar(range(4), bar_ds, color=bar_colors, edgecolor='black', linewidth=0.5)
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)
ax.set_xticks(range(4))
ax.set_xticklabels(bar_names, fontsize=8)
ax.set_ylabel("Cohen's d")
ax.set_title('Summary: Ungated vs Gated')
for i, d_val in enumerate(bar_ds):
    ax.text(i, d_val + 0.01, f"{d_val:+.3f}", ha='center', va='bottom', fontsize=9)

# --- Panel 4: Per-quintile comparison (fresh data) ---
ax = axes[1, 0]

q_ungated_ds = []
q_gated_ds = []
for q in range(5):
    mask = fresh_quintiles == q
    if np.sum(mask) >= 5:
        q_ungated_ds.append(cohens_d(fb[mask] - fs[mask]))
        q_gated_ds.append(cohens_d(fb[mask] - fresh_gated_best[mask]))
    else:
        q_ungated_ds.append(0)
        q_gated_ds.append(0)

x = np.arange(5)
width = 0.35
ax.bar(x - width/2, q_ungated_ds, width, label='Ungated', color='#1f77b4', alpha=0.7)
ax.bar(x + width/2, q_gated_ds, width, label='Gated', color='#2ca02c', alpha=0.7)
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
ax.set_xticks(x)
ax.set_xticklabels(fresh_qlabels, fontsize=8)
ax.set_ylabel("Cohen's d")
ax.set_title('Per-Quintile: Ungated vs Gated (Fresh)')
ax.legend(fontsize=8)
ax.grid(axis='y', alpha=0.3)

# --- Panel 5: Scatter bare NLL vs delta (fresh data) ---
ax = axes[1, 1]
delta_fresh = fb - fs
ax.scatter(fb, delta_fresh, alpha=0.4, s=20, c='#ff7f0e', edgecolors='none')
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
ax.axvline(x=best_fresh_t, color='green', linestyle='--', linewidth=2,
           label=f'threshold={best_fresh_t:.3f}')
ax.set_xlabel('Bare NLL (difficulty)')
ax.set_ylabel('Delta (bare - soft, positive = helps)')
ax.set_title('Fresh Data: Difficulty vs Benefit')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# --- Panel 6: NLL distribution (fresh) ---
ax = axes[1, 2]
ax.hist(fb, bins=40, alpha=0.5, color='#1f77b4', label='Bare', density=True)
ax.hist(fs, bins=40, alpha=0.5, color='#ff7f0e', label='Soft fact', density=True)
ax.axvline(x=best_fresh_t, color='green', linestyle='--', linewidth=2,
           label=f'Gate threshold')
ax.set_xlabel('NLL')
ax.set_ylabel('Density')
ax.set_title('Fresh: NLL Distributions')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

plt.suptitle('Exp 26: Hardness-Gated Soft Prefix', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'comprehensive_plots.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved to {RESULTS_DIR / 'comprehensive_plots.png'}")

In [None]:
# Cell 12: Save results.json + CSV + final verdict

# --- CSV ---
with open(CSV_EVAL_PATH, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=[
        'query_idx', 'doc_len', 'bare_nll', 'soft_fact_nll'])
    writer.writeheader()
    for r in fresh_results:
        writer.writerow(r)
print(f"CSV saved: {CSV_EVAL_PATH}")

# --- Results JSON ---
final = {
    'experiment': 'exp26_hardness_gated_soft_prefix',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'config': {
        'model_name': MODEL_NAME,
        'cutoff': CUTOFF,
        'fresh_seed': FRESH_SEED,
        'n_fresh': N_FRESH,
        'n_fresh_valid': n_fresh_valid,
        'soft_prefix_source': str(EXP25_SOFT_FACT),
    },
    'part1_exp25_data': {
        'n_valid': n_valid,
        'ungated_d': ungated_d,
        'ungated_win_pct': ungated_win,
        'ungated_p': float(ungated_p),
        'optimal_threshold': best['threshold'],
        'optimal_d': best['cohens_d'],
        'oracle_d': oracle_d,
        'candidate_thresholds': candidate_thresholds,
        'bootstrap_ungated': boot_ungated,
        'bootstrap_optimal': boot_optimal,
    },
    'part2_fresh_data': {
        'n_valid': n_fresh_valid,
        'ungated_d': fresh_ungated_d,
        'ungated_win_pct': fresh_ungated_win,
        'ungated_p': float(fresh_ungated_p),
        'fresh_optimal_threshold': fresh_best['threshold'],
        'fresh_optimal_d': fresh_best['cohens_d'],
        'oracle_d': float(fresh_oracle_d),
        'gated_results': fresh_gated_results,
        'bootstrap_ungated': boot_fresh_ungated,
        'bootstrap_gated': boot_fresh_gated,
    },
    'part3_generalization': {
        'exp25_threshold': best['threshold'],
        'fresh_threshold': fresh_best['threshold'],
        'exp25_threshold_on_exp25': best['cohens_d'],
        'exp25_threshold_on_fresh': exp25_optimal_on_fresh['cohens_d'],
        'fresh_threshold_on_fresh': fresh_best['cohens_d'],
        'best_candidate_name': best_fresh_name,
        'best_candidate_threshold': candidate_thresholds[best_fresh_name],
        'best_candidate_d': fresh_gated_results[best_fresh_name]['cohens_d'],
    },
    'fresh_per_query': fresh_results,
}

with open(FINAL_RESULTS_PATH, 'w') as f:
    json.dump(final, f, indent=2)
print(f"Results saved: {FINAL_RESULTS_PATH} ({FINAL_RESULTS_PATH.stat().st_size / 1024:.1f} KB)")

# --- Final verdict ---
print("\n" + "=" * 70)
print("FINAL VERDICT — Exp 26: Hardness-Gated Soft Prefix")
print("=" * 70)
print(f"\nPart 1 (Exp 25 data, {n_valid} samples):")
print(f"  Ungated d: {ungated_d:+.3f}")
print(f"  Best gated d: {best['cohens_d']:+.3f} (threshold={best['threshold']:.4f})")
print(f"  Oracle d: {oracle_d:+.3f}")
print(f"\nPart 2 (Fresh data, {n_fresh_valid} samples):")
print(f"  Ungated d: {fresh_ungated_d:+.3f}")
print(f"  Best gated d: {fresh_gated_results[best_fresh_name]['cohens_d']:+.3f} "
      f"(threshold={candidate_thresholds[best_fresh_name]:.4f}, {best_fresh_name})")
print(f"  Oracle d: {fresh_oracle_d:+.3f}")

gated_d_fresh = fresh_gated_results[best_fresh_name]['cohens_d']
if gated_d_fresh > fresh_ungated_d + 0.02:
    print(f"\nVERDICT: Hardness gating IMPROVES over ungated soft prefix ")
    print(f"  ({gated_d_fresh:+.3f} vs {fresh_ungated_d:+.3f} on fresh data).")
    print(f"  Gating avoids easy-query harm while preserving hard-query benefit.")
elif gated_d_fresh > fresh_ungated_d - 0.02:
    print(f"\nVERDICT: Hardness gating MATCHES ungated soft prefix ")
    print(f"  ({gated_d_fresh:+.3f} vs {fresh_ungated_d:+.3f} on fresh data).")
    print(f"  Gating does not significantly change overall effect.")
else:
    print(f"\nVERDICT: Hardness gating HURTS vs ungated soft prefix ")
    print(f"  ({gated_d_fresh:+.3f} vs {fresh_ungated_d:+.3f} on fresh data).")
    print(f"  Excluding easy queries reduces total positive signal.")

exp25_transfers = abs(exp25_optimal_on_fresh['cohens_d'] - fresh_best['cohens_d']) < 0.05
print(f"\nThreshold generalizes: {'YES' if exp25_transfers else 'NO'}")
print(f"  (Exp25 optimal on fresh: d={exp25_optimal_on_fresh['cohens_d']:+.3f}, "
      f"Fresh optimal: d={fresh_best['cohens_d']:+.3f})")

print(f"\nDone!")

In [None]:
# Cell 13: GPU cleanup

print("Cleaning up GPU memory...")
mem_before = torch.cuda.memory_allocated() / 1e9

del model
del tokenizer
if 'soft_prefix_fact' in dir():
    del soft_prefix_fact

gc.collect()
torch.cuda.empty_cache()
gc.collect()

mem_after = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {mem_before:.2f} GB -> {mem_after:.2f} GB")
print("Cleanup complete.")