# Directed KV Cache: Comprehensive Experiment

## Hypothesis: RoPE Position Mismatch Explains Truncation Failure

When we build a KV cache from `[surrogate (S tokens)][document (D tokens)]` and then truncate
the surrogate, document token `i` has its key stored as `RoPE(K_i, S+i)`. Queries at position
`j` compute attention with relative position `j - (S+i)` instead of the correct `j - i`.
Every document token appears S positions further away than it should.

**Fix:** Apply `RoPE(-S)` to all cached keys after truncation.

### Experimental Conditions

| Group | ID | Name | Description |
|-------|----|------|-------------|
| A | A1 | baseline | `"Document:\n{document}"` |
| A | A2 | bare_doc | `"{document}"` — no framing |
| A | A3 | baseline_offset | A1 with position IDs starting at offset N |
| B | B1 | full_generated | Generated surrogate kept in context |
| B | B2 | full_perfect | Actual query kept in context |
| B | B3 | full_static | Static intent query kept in context |
| B | B4 | full_suffix | Surrogate AFTER document, kept |
| C | C1 | trunc_generated_broken | Truncated, NO position correction |
| C | C2 | trunc_generated_corrected | Truncated, WITH RoPE correction |
| C | C3 | trunc_perfect_broken | Actual query, truncated, no correction |
| C | C4 | trunc_perfect_corrected | Actual query, truncated, WITH correction |
| C | C5 | trunc_static_corrected | Static query, truncated, WITH correction |
| D | D1 | trunc_suffix_corrected | Surrogate after doc, truncated with correction |
| E | E1 | random_prefix_trunc_corrected | Random tokens, truncated with correction |
| E | E2 | full_random_prefix | Random tokens kept in context |

## 1. Setup

In [None]:
import sys
import os
import json
import time
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from scipy import stats
from tqdm.auto import tqdm
from datetime import datetime

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from sentence_transformers import SentenceTransformer

# Add lib to path
sys.path.insert(0, '.')
from lib import (
    ExperimentConfig,
    build_kv_cache,
    score_answer_with_cache,
    extract_and_truncate_cache,
    build_truncated_kv_cache,
    correct_rope_positions,
    build_truncated_kv_cache_corrected,
    generate_surrogate,
    compute_similarity,
    load_evaluation_samples,
    load_ms_marco,
    STATIC_SURROGATE_QUERIES,
)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
config = ExperimentConfig(num_samples=200, seed=42)
np.random.seed(config.seed)
torch.manual_seed(config.seed)

print(f"Model: {config.model_name}")
print(f"Samples: {config.num_samples}")
print(f"Device: {config.device}")

In [None]:
# Load model and tokenizer
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.eval()

print(f"Model loaded: {config.model_name}")
print(f"RoPE theta: {getattr(model.config, 'rope_theta', 10000.0)}")

In [None]:
# Load embedding model
embed_model = SentenceTransformer(config.embedding_model_name)
print(f"Embedding model loaded: {config.embedding_model_name}")

In [None]:
# Load dataset
dataset = load_ms_marco(config)
samples = load_evaluation_samples(dataset, config, require_answer=True)
print(f"Loaded {len(samples)} evaluation samples")

## 2. Define Evaluation Conditions

Each condition is a function that takes `(document, query, answer, surrogate, sample)` and returns an NLL score.

In [None]:
def eval_with_context(context_text, query, answer):
    """Evaluate NLL by building cache from context_text, then scoring query->answer."""
    ctx_len, cache = build_kv_cache(context_text, model, tokenizer, config)
    query_prompt = config.query_template.format(query=query)
    return score_answer_with_cache(cache, ctx_len, query_prompt, answer, model, tokenizer, config)


def eval_with_offset(context_text, query, answer, offset):
    """Evaluate NLL with position IDs starting at an offset."""
    context_encoding = tokenizer(
        context_text, return_tensors="pt", add_special_tokens=True,
        padding=False, truncation=False
    )
    context_ids = context_encoding['input_ids'].to(config.device)
    seq_len = context_ids.shape[1]

    position_ids = torch.arange(offset, offset + seq_len, device=config.device).unsqueeze(0)

    with torch.no_grad():
        outputs = model(
            input_ids=context_ids,
            attention_mask=torch.ones_like(context_ids),
            position_ids=position_ids,
            use_cache=True,
            return_dict=True
        )

    ctx_len = seq_len
    cache = outputs.past_key_values
    query_prompt = config.query_template.format(query=query)
    return score_answer_with_cache(cache, ctx_len, query_prompt, answer, model, tokenizer, config)


def eval_truncated_broken(surrogate_prefix, document_text, query, answer):
    """Truncated cache WITHOUT RoPE correction (the broken approach)."""
    # Tokenize doc alone
    doc_enc = tokenizer(document_text, return_tensors="pt", add_special_tokens=False,
                        padding=False, truncation=False)
    doc_len = doc_enc['input_ids'].shape[1]

    full_context = surrogate_prefix + document_text
    full_enc = tokenizer(full_context, return_tensors="pt", add_special_tokens=True,
                         padding=False, truncation=False)
    full_ids = full_enc['input_ids'].to(config.device)

    with torch.no_grad():
        outputs = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )

    truncated_cache = extract_and_truncate_cache(outputs.past_key_values, doc_len)
    query_prompt = config.query_template.format(query=query)
    return score_answer_with_cache(truncated_cache, doc_len, query_prompt, answer, model, tokenizer, config)


def eval_truncated_corrected(surrogate_prefix, document_text, query, answer):
    """Truncated cache WITH RoPE correction."""
    # Tokenize doc alone
    doc_enc = tokenizer(document_text, return_tensors="pt", add_special_tokens=False,
                        padding=False, truncation=False)
    doc_len = doc_enc['input_ids'].shape[1]

    full_context = surrogate_prefix + document_text
    full_enc = tokenizer(full_context, return_tensors="pt", add_special_tokens=True,
                         padding=False, truncation=False)
    full_ids = full_enc['input_ids'].to(config.device)
    surrogate_len = full_ids.shape[1] - doc_len

    with torch.no_grad():
        outputs = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )

    truncated_cache = extract_and_truncate_cache(outputs.past_key_values, doc_len)
    correct_rope_positions(truncated_cache, surrogate_len, model)
    query_prompt = config.query_template.format(query=query)
    return score_answer_with_cache(truncated_cache, doc_len, query_prompt, answer, model, tokenizer, config)


def eval_suffix_truncated_corrected(document_text, surrogate_suffix, query, answer):
    """Document THEN surrogate, truncate suffix (keep doc), correct RoPE."""
    # Tokenize doc alone (with BOS)
    doc_enc = tokenizer(document_text, return_tensors="pt", add_special_tokens=True,
                        padding=False, truncation=False)
    doc_len = doc_enc['input_ids'].shape[1]

    # Full = doc + suffix
    full_context = document_text + surrogate_suffix
    full_enc = tokenizer(full_context, return_tensors="pt", add_special_tokens=True,
                         padding=False, truncation=False)
    full_ids = full_enc['input_ids'].to(config.device)

    with torch.no_grad():
        outputs = model(
            input_ids=full_ids,
            attention_mask=torch.ones_like(full_ids),
            use_cache=True,
            return_dict=True
        )

    # Keep first doc_len entries (no RoPE correction needed — positions are already correct)
    legacy = outputs.past_key_values
    if hasattr(legacy, 'to_legacy_cache'):
        legacy = legacy.to_legacy_cache()
    elif not isinstance(legacy, (tuple, list)):
        legacy = tuple(legacy)

    new_cache = DynamicCache()
    for layer_idx, layer_kv in enumerate(legacy):
        key, value = layer_kv[0], layer_kv[1]
        trunc_key = key[:, :, :doc_len, :].contiguous()
        trunc_value = value[:, :, :doc_len, :].contiguous()
        new_cache.update(trunc_key, trunc_value, layer_idx)

    query_prompt = config.query_template.format(query=query)
    return score_answer_with_cache(new_cache, doc_len, query_prompt, answer, model, tokenizer, config)


def generate_random_prefix(length_tokens=20):
    """Generate a random token sequence as a nonsense prefix."""
    vocab_size = tokenizer.vocab_size
    random_ids = torch.randint(100, vocab_size - 100, (length_tokens,))
    return tokenizer.decode(random_ids, skip_special_tokens=True)


print("Evaluation functions defined.")

## 3. Run Experiment

In [None]:
results = []
static_query = STATIC_SURROGATE_QUERIES['static_definitional']['query']

start_time = time.time()

for idx, sample in enumerate(tqdm(samples, desc="Evaluating")):
    document = sample['passage']
    query = sample['query']
    answer = sample['answer']

    doc_text = f"Document:\n{document}"
    query_prompt = config.query_template.format(query=query)

    # Generate surrogate for this document
    surrogate = generate_surrogate(document, model, tokenizer, config)

    # Generate random prefix (fixed seed per sample for reproducibility)
    np.random.seed(config.seed + idx)
    random_prefix_text = generate_random_prefix(20)

    result = {
        'idx': idx,
        'query': query,
        'surrogate': surrogate,
        'similarity': compute_similarity(surrogate, query, embed_model),
    }

    try:
        # ===== Group A: Baselines =====
        # A1: baseline
        result['A1_baseline'] = eval_with_context(doc_text, query, answer)

        # A2: bare_doc (no framing)
        result['A2_bare_doc'] = eval_with_context(document, query, answer)

        # A3: baseline_offset (position offset = 20, same as typical surrogate length)
        result['A3_baseline_offset'] = eval_with_offset(doc_text, query, answer, offset=20)

        # ===== Group B: Full context (surrogate kept) =====
        surr_prefix = f"This document may be relevant to queries like: {surrogate}\n\n"
        perfect_prefix = f"This document may be relevant to queries like: {query}\n\n"
        static_prefix = f"This document may be relevant to queries like: {static_query}\n\n"

        # B1: full_generated
        result['B1_full_generated'] = eval_with_context(surr_prefix + doc_text, query, answer)

        # B2: full_perfect
        result['B2_full_perfect'] = eval_with_context(perfect_prefix + doc_text, query, answer)

        # B3: full_static
        result['B3_full_static'] = eval_with_context(static_prefix + doc_text, query, answer)

        # B4: full_suffix
        surr_suffix = f"\n\nThis document may be relevant to queries like: {surrogate}"
        result['B4_full_suffix'] = eval_with_context(doc_text + surr_suffix, query, answer)

        # ===== Group C: Truncated =====
        # C1: trunc_generated_broken
        result['C1_trunc_generated_broken'] = eval_truncated_broken(
            surr_prefix, doc_text, query, answer)

        # C2: trunc_generated_corrected
        result['C2_trunc_generated_corrected'] = eval_truncated_corrected(
            surr_prefix, doc_text, query, answer)

        # C3: trunc_perfect_broken
        result['C3_trunc_perfect_broken'] = eval_truncated_broken(
            perfect_prefix, doc_text, query, answer)

        # C4: trunc_perfect_corrected
        result['C4_trunc_perfect_corrected'] = eval_truncated_corrected(
            perfect_prefix, doc_text, query, answer)

        # C5: trunc_static_corrected
        result['C5_trunc_static_corrected'] = eval_truncated_corrected(
            static_prefix, doc_text, query, answer)

        # ===== Group D: Ordering =====
        # D1: trunc_suffix_corrected
        result['D1_trunc_suffix_corrected'] = eval_suffix_truncated_corrected(
            doc_text, surr_suffix, query, answer)

        # ===== Group E: Controls =====
        random_prefix = f"{random_prefix_text}\n\n"

        # E1: random_prefix_trunc_corrected
        result['E1_random_prefix_trunc_corrected'] = eval_truncated_corrected(
            random_prefix, doc_text, query, answer)

        # E2: full_random_prefix
        result['E2_full_random_prefix'] = eval_with_context(
            random_prefix + doc_text, query, answer)

    except Exception as e:
        print(f"Error on sample {idx}: {e}")
        result['error'] = str(e)

    results.append(result)

    # Progress update every 20 samples
    if (idx + 1) % 20 == 0:
        valid = [r for r in results if 'error' not in r]
        if valid:
            a1 = np.mean([r['A1_baseline'] for r in valid])
            c1 = np.mean([r['C1_trunc_generated_broken'] for r in valid])
            c2 = np.mean([r['C2_trunc_generated_corrected'] for r in valid])
            b1 = np.mean([r['B1_full_generated'] for r in valid])
            print(f"  [{idx+1}/{len(samples)}] A1={a1:.4f}, C1(broken)={c1:.4f}, C2(corrected)={c2:.4f}, B1(full)={b1:.4f}")

elapsed = time.time() - start_time
print(f"\nDone. {len(results)} samples in {elapsed/60:.1f} minutes.")
print(f"Errors: {sum(1 for r in results if 'error' in r)}")

## 4. Analysis

In [None]:
# Filter out errors
valid_results = [r for r in results if 'error' not in r]
print(f"Valid results: {len(valid_results)} / {len(results)}")

# Extract all condition keys
condition_keys = [k for k in valid_results[0].keys()
                  if k not in ('idx', 'query', 'surrogate', 'similarity', 'error')]
print(f"Conditions: {condition_keys}")

In [None]:
# Build summary table
df = pd.DataFrame(valid_results)

summary_rows = []
baseline = df['A1_baseline'].values

for cond in condition_keys:
    vals = df[cond].values
    delta = baseline - vals  # positive = condition better than baseline
    t_stat, p_val = stats.ttest_rel(baseline, vals)
    d = np.mean(delta) / np.std(delta, ddof=1) if np.std(delta) > 0 else 0

    summary_rows.append({
        'Condition': cond,
        'Mean NLL': np.mean(vals),
        'Std NLL': np.std(vals),
        'Mean Delta': np.mean(delta),
        'Win Rate': np.mean(delta > 0),
        't-stat': t_stat,
        'p-value': p_val,
        'Cohen\'s d': d,
    })

summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values('Mean NLL')

print("\n" + "=" * 100)
print("ALL CONDITIONS vs BASELINE (A1) — sorted by Mean NLL (lower = better)")
print("=" * 100)
print("Positive Delta = condition outperforms baseline")
print()
print(summary_df.to_string(index=False, float_format='%.4f'))

In [None]:
# ===== Critical Comparison: C1 (broken) vs C2 (corrected) =====
print("=" * 80)
print("CRITICAL TEST: Does RoPE correction fix truncation?")
print("=" * 80)

c1 = df['C1_trunc_generated_broken'].values
c2 = df['C2_trunc_generated_corrected'].values
a1 = df['A1_baseline'].values

t_c1_c2, p_c1_c2 = stats.ttest_rel(c1, c2)
t_c2_a1, p_c2_a1 = stats.ttest_rel(a1, c2)

print(f"\nC1 (broken truncation):    Mean NLL = {c1.mean():.4f} ± {c1.std():.4f}")
print(f"C2 (corrected truncation): Mean NLL = {c2.mean():.4f} ± {c2.std():.4f}")
print(f"A1 (baseline):             Mean NLL = {a1.mean():.4f} ± {a1.std():.4f}")

print(f"\nC1 vs C2: t={t_c1_c2:.3f}, p={p_c1_c2:.2e}")
print(f"C2 vs A1: t={t_c2_a1:.3f}, p={p_c2_a1:.2e}")
print(f"\nC2 improvement over C1: {(c1.mean() - c2.mean()):.4f} NLL")
print(f"C2 vs A1 delta: {(a1.mean() - c2.mean()):.4f} NLL (positive = C2 better)")

In [None]:
# ===== Surrogate quality comparison =====
print("=" * 80)
print("SURROGATE QUALITY COMPARISON (all with RoPE correction)")
print("=" * 80)

for cond, label in [
    ('C2_trunc_generated_corrected', 'Generated surrogate'),
    ('C4_trunc_perfect_corrected', 'Perfect (actual query)'),
    ('C5_trunc_static_corrected', 'Static query'),
    ('E1_random_prefix_trunc_corrected', 'Random prefix'),
]:
    vals = df[cond].values
    delta = a1 - vals
    print(f"  {label:30s}: NLL={vals.mean():.4f}, delta vs baseline={delta.mean():+.4f}, win={np.mean(delta>0)*100:.1f}%")

In [None]:
# ===== Full context vs Truncated comparison =====
print("=" * 80)
print("FULL CONTEXT vs TRUNCATED (is benefit baked into cache or from inference-time attention?)")
print("=" * 80)

for full_cond, trunc_cond, label in [
    ('B1_full_generated', 'C2_trunc_generated_corrected', 'Generated'),
    ('B2_full_perfect', 'C4_trunc_perfect_corrected', 'Perfect (actual query)'),
    ('B3_full_static', 'C5_trunc_static_corrected', 'Static'),
]:
    full_v = df[full_cond].values
    trunc_v = df[trunc_cond].values
    print(f"  {label:25s}: Full={full_v.mean():.4f}, Trunc(corrected)={trunc_v.mean():.4f}, gap={full_v.mean()-trunc_v.mean():.4f}")

In [None]:
# ===== Ordering effects =====
print("=" * 80)
print("ORDERING EFFECTS (prefix vs suffix)")
print("=" * 80)

c2_v = df['C2_trunc_generated_corrected'].values
d1_v = df['D1_trunc_suffix_corrected'].values
t_ord, p_ord = stats.ttest_rel(c2_v, d1_v)

print(f"  C2 (prefix, truncated+corrected): NLL={c2_v.mean():.4f}")
print(f"  D1 (suffix, truncated+corrected): NLL={d1_v.mean():.4f}")
print(f"  t={t_ord:.3f}, p={p_ord:.4f}")

b1_v = df['B1_full_generated'].values
b4_v = df['B4_full_suffix'].values
t_ord2, p_ord2 = stats.ttest_rel(b1_v, b4_v)

print(f"\n  B1 (prefix, full context): NLL={b1_v.mean():.4f}")
print(f"  B4 (suffix, full context): NLL={b4_v.mean():.4f}")
print(f"  t={t_ord2:.3f}, p={p_ord2:.4f}")

In [None]:
# ===== Similarity correlation =====
print("=" * 80)
print("SURROGATE SIMILARITY CORRELATION")
print("=" * 80)

sims = df['similarity'].values

for cond, label in [
    ('C2_trunc_generated_corrected', 'C2 trunc corrected'),
    ('B1_full_generated', 'B1 full generated'),
]:
    deltas = a1 - df[cond].values
    r, p = stats.pearsonr(sims, deltas)
    print(f"  {label:30s}: r={r:.3f}, p={p:.4f}")

In [None]:
# ===== Win rate by baseline quality bins =====
print("=" * 80)
print("WIN RATE BY BASELINE QUALITY (C2 vs A1)")
print("=" * 80)

baseline_q = pd.qcut(df['A1_baseline'], q=4, labels=['Easy (low NLL)', 'Medium-Easy', 'Medium-Hard', 'Hard (high NLL)'])

for label in ['Easy (low NLL)', 'Medium-Easy', 'Medium-Hard', 'Hard (high NLL)']:
    mask = baseline_q == label
    n = mask.sum()
    delta = df.loc[mask, 'A1_baseline'].values - df.loc[mask, 'C2_trunc_generated_corrected'].values
    wr = np.mean(delta > 0) * 100
    print(f"  {label:20s} (n={n:3d}): win rate={wr:.1f}%, mean delta={delta.mean():+.4f}")

## 5. Visualizations

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# --- Plot 1: Bar chart of all conditions ---
ax = axes[0, 0]
plot_conds = [
    'A1_baseline', 'A2_bare_doc', 'A3_baseline_offset',
    'B1_full_generated', 'B2_full_perfect', 'B3_full_static', 'B4_full_suffix',
    'C1_trunc_generated_broken', 'C2_trunc_generated_corrected',
    'C3_trunc_perfect_broken', 'C4_trunc_perfect_corrected', 'C5_trunc_static_corrected',
    'D1_trunc_suffix_corrected',
    'E1_random_prefix_trunc_corrected', 'E2_full_random_prefix',
]
means = [df[c].mean() for c in plot_conds]
stds = [df[c].std() / np.sqrt(len(df)) for c in plot_conds]  # SEM
colors = (['#2196F3'] * 3 + ['#4CAF50'] * 4 + ['#F44336', '#FF9800'] +
          ['#F44336', '#FF9800', '#FF9800'] + ['#9C27B0'] + ['#795548'] * 2)

bars = ax.barh(range(len(plot_conds)), means, xerr=stds, color=colors, alpha=0.8)
ax.set_yticks(range(len(plot_conds)))
ax.set_yticklabels([c.split('_', 1)[1] for c in plot_conds], fontsize=8)
ax.set_xlabel('Mean NLL (lower = better)')
ax.set_title('All Conditions: Mean NLL')
ax.axvline(x=df['A1_baseline'].mean(), color='blue', linestyle='--', alpha=0.5, label='Baseline')
ax.invert_yaxis()
ax.legend(fontsize=8)

# --- Plot 2: Critical comparison C1 vs C2 ---
ax = axes[0, 1]
conds = ['A1_baseline', 'C1_trunc_generated_broken', 'C2_trunc_generated_corrected', 'B1_full_generated']
labels = ['A1: Baseline', 'C1: Broken\n(no RoPE fix)', 'C2: Corrected\n(RoPE fix)', 'B1: Full\n(surrogate kept)']
vals = [df[c].mean() for c in conds]
errs = [df[c].std() / np.sqrt(len(df)) for c in conds]
colors2 = ['#2196F3', '#F44336', '#FF9800', '#4CAF50']
ax.bar(range(4), vals, yerr=errs, color=colors2, alpha=0.8)
ax.set_xticks(range(4))
ax.set_xticklabels(labels, fontsize=9)
ax.set_ylabel('Mean NLL')
ax.set_title('Critical Test: RoPE Correction Effect')
ax.axhline(y=df['A1_baseline'].mean(), color='blue', linestyle='--', alpha=0.5)

# --- Plot 3: Distribution of C2 deltas ---
ax = axes[1, 0]
delta_c2 = df['A1_baseline'].values - df['C2_trunc_generated_corrected'].values
ax.hist(delta_c2, bins=30, color='#FF9800', alpha=0.7, edgecolor='black')
ax.axvline(x=0, color='red', linestyle='--', linewidth=2)
ax.axvline(x=delta_c2.mean(), color='blue', linestyle='-', linewidth=2, label=f'Mean={delta_c2.mean():.4f}')
ax.set_xlabel('Delta NLL (positive = C2 better than baseline)')
ax.set_ylabel('Count')
ax.set_title('C2 (Corrected Truncation) vs Baseline: Delta Distribution')
ax.legend()

# --- Plot 4: Similarity vs Delta scatter ---
ax = axes[1, 1]
ax.scatter(df['similarity'], delta_c2, alpha=0.4, s=20, color='#FF9800')
# Trend line
z = np.polyfit(df['similarity'].values, delta_c2, 1)
p_line = np.poly1d(z)
x_line = np.linspace(df['similarity'].min(), df['similarity'].max(), 100)
ax.plot(x_line, p_line(x_line), 'r-', linewidth=2)
r, p = stats.pearsonr(df['similarity'].values, delta_c2)
ax.set_xlabel('Surrogate-Query Similarity')
ax.set_ylabel('Delta NLL (positive = C2 better)')
ax.set_title(f'Similarity vs C2 Benefit (r={r:.3f}, p={p:.4f})')
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig('results/exp01/directed_kvcache_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: directed_kvcache_results.png")

## 6. Save Results

In [None]:
# Save raw results
output_path = 'results/exp01/directed_kvcache_experiment_results.json'
with open(output_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)
print(f"Saved raw results to {output_path}")

# Save summary table
summary_path = 'results/exp01/directed_kvcache_experiment_summary.csv'
summary_df.to_csv(summary_path, index=False)
print(f"Saved summary to {summary_path}")

# Print final summary
print("\n" + "=" * 80)
print("EXPERIMENT COMPLETE")
print("=" * 80)
print(f"Samples: {len(valid_results)}")
print(f"\nKey findings:")
print(f"  Baseline (A1):           {df['A1_baseline'].mean():.4f}")
print(f"  Broken truncation (C1):  {df['C1_trunc_generated_broken'].mean():.4f}")
print(f"  Corrected truncation (C2): {df['C2_trunc_generated_corrected'].mean():.4f}")
print(f"  Full context (B1):       {df['B1_full_generated'].mean():.4f}")
print(f"  Perfect+corrected (C4):  {df['C4_trunc_perfect_corrected'].mean():.4f}")
print(f"\n  RoPE fix improvement (C1→C2): {(df['C1_trunc_generated_broken'].mean() - df['C2_trunc_generated_corrected'].mean()):.4f}")
print(f"  C2 vs baseline:          {(df['A1_baseline'].mean() - df['C2_trunc_generated_corrected'].mean()):+.4f}")