# Experiment 23: Position-Based Priming Strategies

**Date:** 2025-02-06

## Motivation

Experiment 22 revealed that **position matters more than early context contamination**:
- Condition A (query at scoring time) beat Condition B (query in cache at start)
- Document tokens only attend ~8% to query tokens in early positions
- But cache content DOES matter (Condition C with wrong query was completely fooled)

## Hypothesis

**Recency bias**: Tokens closer to the generation point have more influence on the output.
Instead of PREFIX priming (query before document), try SUFFIX priming (query after document).

## Experimental Conditions

| Condition | Cache Structure | Scoring Prompt | Rationale |
|-----------|----------------|----------------|------------|
| A | `[doc]` | `Query: X\nAnswer:` | Baseline (query at scoring) |
| B | `[query][doc]` | `Answer:` | Prefix priming (Exp 21 style) |
| C | `[doc][query]` | `Answer:` | **Suffix priming** (query near generation) |
| D | `[query][doc][query]` | `Answer:` | Bookend (query at both ends) |
| E | `[doc]` | `Query: X\nAnswer:` | Same as A (sanity check) |

## Expected Outcome

If recency bias is the key factor:
- C (suffix) > B (prefix) — query closer to generation helps more
- C ≈ A or C > A — suffix priming might match or beat query-at-scoring
- D ≥ C — bookend provides both recency and early contamination

In [1]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
sys.path.insert(0, '/home/jupyter/research/directed_kvcache')

import json
import random
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Dict, List, Tuple, Optional
from scipy import stats

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from datasets import load_dataset

from lib.kv_cache import (
    deepcopy_cache,
    _get_cache_keys,
    _get_cache_values,
    _ensure_dynamic_cache,
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTPUT_DIR = '/home/jupyter/research/directed_kvcache/results/exp23'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

PyTorch: 2.10.0+cu128
CUDA: True


In [2]:
# Cell 2: Load Model
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

print(f"Model loaded on {model.device}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded on cuda:0


In [3]:
# Cell 3: Core Functions

def build_cache(text: str) -> DynamicCache:
    """Build cache from text."""
    ids = tokenizer.encode(text, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return _ensure_dynamic_cache(out.past_key_values)


def score_answer_nll(cache: DynamicCache, prompt: str, answer: str) -> float:
    """Score P(answer | cache, prompt) as NLL."""
    cache = _ensure_dynamic_cache(cache)
    cache_len = _get_cache_keys(cache, 0).shape[2]
    
    prompt_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False).to(model.device)
    answer_ids = tokenizer.encode(answer, return_tensors='pt', add_special_tokens=False).to(model.device)
    
    input_ids = torch.cat([prompt_ids, answer_ids], dim=1)
    total_len = cache_len + input_ids.shape[1]
    attention_mask = torch.ones((1, total_len), device=model.device)
    
    cache_copy = deepcopy_cache(cache)
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=cache_copy,
            use_cache=True,
            return_dict=True
        )
    
    logits = outputs.logits
    prompt_len = prompt_ids.shape[1]
    answer_len = answer_ids.shape[1]
    
    if answer_len == 0:
        return 0.0
    
    answer_logits = logits[:, prompt_len-1:prompt_len+answer_len-1, :]
    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
    loss = loss_fct(answer_logits.view(-1, answer_logits.size(-1)), answer_ids.view(-1))
    
    return loss.item()


print("Core functions defined.")

Core functions defined.


In [4]:
# Cell 4: Define Position-Based Conditions

def evaluate_position_conditions(doc: str, query: str, answer: str) -> Dict[str, float]:
    """
    Evaluate all position-based priming conditions.
    
    Returns dict of condition_name -> NLL
    """
    results = {}
    answer_with_space = " " + answer
    
    # Condition A: Bare doc, query at scoring (baseline)
    cache_A = build_cache(doc)
    prompt_A = f"\n\nQuestion: {query}\nAnswer:"
    results['A_bare_query_at_scoring'] = score_answer_nll(cache_A, prompt_A, answer_with_space)
    
    # Condition B: Prefix priming [query][doc], minimal scoring prompt
    text_B = f"Question: {query}\n\n{doc}"
    cache_B = build_cache(text_B)
    prompt_B = "\n\nAnswer:"
    results['B_prefix_priming'] = score_answer_nll(cache_B, prompt_B, answer_with_space)
    
    # Condition C: Suffix priming [doc][query], minimal scoring prompt
    text_C = f"{doc}\n\nQuestion: {query}"
    cache_C = build_cache(text_C)
    prompt_C = "\nAnswer:"
    results['C_suffix_priming'] = score_answer_nll(cache_C, prompt_C, answer_with_space)
    
    # Condition D: Bookend [query][doc][query], minimal scoring prompt
    text_D = f"Question: {query}\n\n{doc}\n\nQuestion: {query}"
    cache_D = build_cache(text_D)
    prompt_D = "\nAnswer:"
    results['D_bookend_priming'] = score_answer_nll(cache_D, prompt_D, answer_with_space)
    
    # Condition E: Suffix with explicit answer prompt structure
    text_E = f"{doc}\n\nBased on the above, answer the following question.\nQuestion: {query}"
    cache_E = build_cache(text_E)
    prompt_E = "\nAnswer:"
    results['E_suffix_explicit'] = score_answer_nll(cache_E, prompt_E, answer_with_space)
    
    return results


print("Position conditions defined.")
print()
print("Conditions:")
print("  A: [doc] + 'Question: X\\nAnswer:' at scoring")
print("  B: [Question: X][doc] + 'Answer:' at scoring (PREFIX)")
print("  C: [doc][Question: X] + 'Answer:' at scoring (SUFFIX)")
print("  D: [Question: X][doc][Question: X] + 'Answer:' (BOOKEND)")
print("  E: [doc][instruction][Question: X] + 'Answer:' (SUFFIX + instruction)")

Position conditions defined.

Conditions:
  A: [doc] + 'Question: X\nAnswer:' at scoring
  B: [Question: X][doc] + 'Answer:' at scoring (PREFIX)
  C: [doc][Question: X] + 'Answer:' at scoring (SUFFIX)
  D: [Question: X][doc][Question: X] + 'Answer:' (BOOKEND)
  E: [doc][instruction][Question: X] + 'Answer:' (SUFFIX + instruction)


In [5]:
# Cell 5: Sanity Check on Synthetic Examples

print("="*70)
print("SANITY CHECK: Synthetic Examples")
print("="*70)

synthetic_examples = [
    {
        'doc': "The capital of France is Paris. It is known for the Eiffel Tower.",
        'query': "What is the capital of France?",
        'answer': "Paris",
    },
    {
        'doc': "Apple Inc. was founded in 1976. Microsoft was founded in 1975. Google was founded in 1998.",
        'query': "When was Apple founded?",
        'answer': "1976",
    },
    {
        'doc': "The Nile is 6,650 km long. The Amazon is 6,400 km long. The Yangtze is 6,300 km long.",
        'query': "How long is the Amazon river?",
        'answer': "6,400 km",
    },
]

for i, ex in enumerate(synthetic_examples):
    print(f"\nExample {i+1}: {ex['query']}")
    results = evaluate_position_conditions(ex['doc'], ex['query'], ex['answer'])
    
    # Sort by NLL (lower is better)
    sorted_results = sorted(results.items(), key=lambda x: x[1])
    
    print(f"  {'Condition':<30} {'NLL':>8}  {'Rank':>6}")
    print(f"  {'-'*50}")
    for rank, (cond, nll) in enumerate(sorted_results, 1):
        marker = "<-- BEST" if rank == 1 else ""
        print(f"  {cond:<30} {nll:>8.3f}  {rank:>6}  {marker}")

SANITY CHECK: Synthetic Examples

Example 1: What is the capital of France?
  Condition                           NLL    Rank
  --------------------------------------------------
  A_bare_query_at_scoring           0.049       1  <-- BEST
  E_suffix_explicit                 0.089       2  
  D_bookend_priming                 0.181       3  
  C_suffix_priming                  0.201       4  
  B_prefix_priming                  0.898       5  

Example 2: When was Apple founded?
  Condition                           NLL    Rank
  --------------------------------------------------
  A_bare_query_at_scoring           0.447       1  <-- BEST
  B_prefix_priming                  0.508       2  
  C_suffix_priming                  0.707       3  
  D_bookend_priming                 1.016       4  
  E_suffix_explicit                 2.156       5  

Example 3: How long is the Amazon river?
  Condition                           NLL    Rank
  --------------------------------------------------
 

In [6]:
# Cell 6: Load MS MARCO for Full Evaluation

print("\nLoading MS MARCO...")
msmarco = load_dataset("ms_marco", "v1.1", split="train")

samples = []
for item in msmarco:
    if len(samples) >= 250:
        break
    
    query = item.get('query', '')
    passages = item.get('passages', {}).get('passage_text', [])
    answers = item.get('answers', [])
    
    if not query or not passages or not passages[0] or not answers or not answers[0]:
        continue
    
    passage = passages[0]
    answer = answers[0]
    
    # Filter reasonable lengths
    if len(passage.split()) < 20 or len(passage.split()) > 150:
        continue
    if len(answer.split()) < 1 or len(answer.split()) > 30:
        continue
    
    samples.append({
        'query': query,
        'passage': passage,
        'answer': answer,
    })

print(f"Loaded {len(samples)} samples")


Loading MS MARCO...
Loaded 250 samples


In [7]:
# Cell 7: Full Evaluation

print("\n" + "="*70)
print("MS MARCO EVALUATION: Position-Based Priming")
print("="*70)

N_SAMPLES = 200

all_results = {cond: [] for cond in [
    'A_bare_query_at_scoring',
    'B_prefix_priming',
    'C_suffix_priming',
    'D_bookend_priming',
    'E_suffix_explicit',
]}

for sample in tqdm(samples[:N_SAMPLES], desc="Evaluating"):
    results = evaluate_position_conditions(
        sample['passage'],
        sample['query'],
        sample['answer']
    )
    
    for cond, nll in results.items():
        all_results[cond].append(nll)

print("\nEvaluation complete.")


MS MARCO EVALUATION: Position-Based Priming


Evaluating:   0%|          | 0/200 [00:00<?, ?it/s]


Evaluation complete.


In [8]:
# Cell 8: Results Summary

print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)

# Convert to numpy
results_np = {k: np.array(v) for k, v in all_results.items()}

# Compute statistics
print("\nMean NLL by Condition (lower = better):")
print("-" * 60)

# Sort by mean NLL
sorted_conds = sorted(results_np.items(), key=lambda x: np.mean(x[1]))

baseline = results_np['A_bare_query_at_scoring']
baseline_mean = np.mean(baseline)

print(f"{'Condition':<30} {'Mean NLL':>10} {'Std':>8} {'vs A':>10} {'Win%':>8}")
print("-" * 70)

for cond, nlls in sorted_conds:
    mean_nll = np.mean(nlls)
    std_nll = np.std(nlls)
    delta = baseline_mean - mean_nll  # Positive = better than A
    win_rate = np.mean(baseline > nlls) * 100  # % where this condition beats A
    
    marker = "(BASELINE)" if cond == 'A_bare_query_at_scoring' else ""
    print(f"{cond:<30} {mean_nll:>10.4f} {std_nll:>8.4f} {delta:>+10.4f} {win_rate:>7.1f}% {marker}")


RESULTS SUMMARY

Mean NLL by Condition (lower = better):
------------------------------------------------------------
Condition                        Mean NLL      Std       vs A     Win%
----------------------------------------------------------------------
A_bare_query_at_scoring            3.3912   2.7358    +0.0000     0.0% (BASELINE)
C_suffix_priming                   3.4711   2.7604    -0.0798    37.0% 
D_bookend_priming                  3.7335   2.7164    -0.3423    32.5% 
B_prefix_priming                   3.7393   2.8655    -0.3481    36.0% 
E_suffix_explicit                  4.1673   3.1877    -0.7761    14.0% 


In [9]:
# Cell 9: Key Comparisons

print("\n" + "="*70)
print("KEY COMPARISONS")
print("="*70)

A = results_np['A_bare_query_at_scoring']
B = results_np['B_prefix_priming']
C = results_np['C_suffix_priming']
D = results_np['D_bookend_priming']
E = results_np['E_suffix_explicit']

def compare(name, x, y):
    """Compare x vs y. Positive delta means x is better."""
    delta = np.mean(y) - np.mean(x)  # Positive = x better (lower NLL)
    win_rate = np.mean(x < y)
    t_stat, p_val = stats.ttest_rel(x, y)
    d = delta / np.std(y - x) if np.std(y - x) > 0 else 0
    
    print(f"\n{name}")
    print(f"  Mean: {np.mean(x):.4f} vs {np.mean(y):.4f}")
    print(f"  Delta: {delta:+.4f} (positive = first is better)")
    print(f"  Win rate: {win_rate*100:.1f}%")
    print(f"  Cohen's d: {d:+.3f}")
    print(f"  p-value: {p_val:.4f}")
    
    if p_val < 0.05:
        winner = name.split(' vs ')[0] if delta > 0 else name.split(' vs ')[1]
        print(f"  --> {winner} is significantly better")
    else:
        print(f"  --> No significant difference")

print("\n### Test 1: Suffix vs Prefix (C vs B) ###")
print("Does putting query AFTER document help more than BEFORE?")
compare("C (suffix) vs B (prefix)", C, B)

print("\n### Test 2: Suffix vs Baseline (C vs A) ###")
print("Does suffix priming beat query-at-scoring?")
compare("C (suffix) vs A (baseline)", C, A)

print("\n### Test 3: Bookend vs Suffix (D vs C) ###")
print("Does adding query at BOTH ends help?")
compare("D (bookend) vs C (suffix)", D, C)

print("\n### Test 4: Prefix vs Baseline (B vs A) ###")
print("Sanity check: prefix priming vs baseline (should match Exp 21)")
compare("B (prefix) vs A (baseline)", B, A)

print("\n### Test 5: Explicit Suffix vs Simple Suffix (E vs C) ###")
print("Does adding instruction help suffix priming?")
compare("E (explicit) vs C (simple)", E, C)


KEY COMPARISONS

### Test 1: Suffix vs Prefix (C vs B) ###
Does putting query AFTER document help more than BEFORE?

C (suffix) vs B (prefix)
  Mean: 3.4711 vs 3.7393
  Delta: +0.2682 (positive = first is better)
  Win rate: 57.0%
  Cohen's d: +0.155
  p-value: 0.0296
  --> C (suffix) is significantly better

### Test 2: Suffix vs Baseline (C vs A) ###
Does suffix priming beat query-at-scoring?

C (suffix) vs A (baseline)
  Mean: 3.4711 vs 3.3912
  Delta: -0.0798 (positive = first is better)
  Win rate: 37.0%
  Cohen's d: -0.233
  p-value: 0.0012
  --> A (baseline) is significantly better

### Test 3: Bookend vs Suffix (D vs C) ###
Does adding query at BOTH ends help?

D (bookend) vs C (suffix)
  Mean: 3.7335 vs 3.4711
  Delta: -0.2624 (positive = first is better)
  Win rate: 36.5%
  Cohen's d: -0.206
  p-value: 0.0040
  --> C (suffix) is significantly better

### Test 4: Prefix vs Baseline (B vs A) ###
Sanity check: prefix priming vs baseline (should match Exp 21)

B (prefix) vs A (b

In [10]:
# Cell 10: Ranking Analysis

print("\n" + "="*70)
print("CONDITION RANKING PER SAMPLE")
print("="*70)

# For each sample, rank the conditions
condition_names = list(all_results.keys())
n_conditions = len(condition_names)

rank_counts = {cond: {r: 0 for r in range(1, n_conditions+1)} for cond in condition_names}

for i in range(N_SAMPLES):
    sample_nlls = [(cond, all_results[cond][i]) for cond in condition_names]
    sorted_sample = sorted(sample_nlls, key=lambda x: x[1])
    
    for rank, (cond, _) in enumerate(sorted_sample, 1):
        rank_counts[cond][rank] += 1

print("\nHow often each condition ranks 1st, 2nd, etc.:")
print(f"{'Condition':<30} {'#1':>6} {'#2':>6} {'#3':>6} {'#4':>6} {'#5':>6}")
print("-" * 66)

for cond in condition_names:
    counts = [rank_counts[cond][r] for r in range(1, n_conditions+1)]
    pcts = [f"{c/N_SAMPLES*100:.0f}%" for c in counts]
    print(f"{cond:<30} {pcts[0]:>6} {pcts[1]:>6} {pcts[2]:>6} {pcts[3]:>6} {pcts[4]:>6}")


CONDITION RANKING PER SAMPLE

How often each condition ranks 1st, 2nd, etc.:
Condition                          #1     #2     #3     #4     #5
------------------------------------------------------------------
A_bare_query_at_scoring           30%    35%    21%    12%     2%
B_prefix_priming                  26%    13%    22%    19%    20%
C_suffix_priming                  20%    32%    23%    20%     6%
D_bookend_priming                 17%    16%    16%    34%    18%
E_suffix_explicit                  7%     4%    18%    16%    55%


In [11]:
# Cell 11: Diagnosis and Conclusions

print("\n" + "="*70)
print("DIAGNOSIS")
print("="*70)

# Determine the winner
means = {cond: np.mean(nlls) for cond, nlls in results_np.items()}
best_cond = min(means, key=means.get)
worst_cond = max(means, key=means.get)

print(f"\nBest condition: {best_cond} (mean NLL = {means[best_cond]:.4f})")
print(f"Worst condition: {worst_cond} (mean NLL = {means[worst_cond]:.4f})")
print(f"Gap: {means[worst_cond] - means[best_cond]:.4f}")

# Check hypotheses
print("\n" + "-"*50)
print("Hypothesis Testing:")
print("-"*50)

# H1: Suffix > Prefix
suffix_beats_prefix = np.mean(C) < np.mean(B)
print(f"\n1. Suffix > Prefix (recency helps)?")
print(f"   C mean: {np.mean(C):.4f}, B mean: {np.mean(B):.4f}")
print(f"   Result: {'YES' if suffix_beats_prefix else 'NO'} - Suffix is {'better' if suffix_beats_prefix else 'worse'}")

# H2: Suffix ≈ Baseline or better
suffix_beats_baseline = np.mean(C) < np.mean(A)
print(f"\n2. Suffix >= Baseline?")
print(f"   C mean: {np.mean(C):.4f}, A mean: {np.mean(A):.4f}")
print(f"   Result: {'YES' if suffix_beats_baseline else 'NO'} - Suffix is {'better' if suffix_beats_baseline else 'worse'} than baseline")

# H3: Bookend >= Suffix
bookend_beats_suffix = np.mean(D) <= np.mean(C)
print(f"\n3. Bookend >= Suffix?")
print(f"   D mean: {np.mean(D):.4f}, C mean: {np.mean(C):.4f}")
print(f"   Result: {'YES' if bookend_beats_suffix else 'NO'}")


DIAGNOSIS

Best condition: A_bare_query_at_scoring (mean NLL = 3.3912)
Worst condition: E_suffix_explicit (mean NLL = 4.1673)
Gap: 0.7761

--------------------------------------------------
Hypothesis Testing:
--------------------------------------------------

1. Suffix > Prefix (recency helps)?
   C mean: 3.4711, B mean: 3.7393
   Result: YES - Suffix is better

2. Suffix >= Baseline?
   C mean: 3.4711, A mean: 3.3912
   Result: NO - Suffix is worse than baseline

3. Bookend >= Suffix?
   D mean: 3.7335, C mean: 3.4711
   Result: NO


In [12]:
# Cell 12: Save Results

output = {
    'n_samples': N_SAMPLES,
    'conditions': {
        cond: {
            'mean': float(np.mean(nlls)),
            'std': float(np.std(nlls)),
        }
        for cond, nlls in results_np.items()
    },
    'comparisons': {
        'C_vs_B': {
            'delta': float(np.mean(B) - np.mean(C)),
            'p_value': float(stats.ttest_rel(C, B)[1]),
        },
        'C_vs_A': {
            'delta': float(np.mean(A) - np.mean(C)),
            'p_value': float(stats.ttest_rel(C, A)[1]),
        },
        'D_vs_C': {
            'delta': float(np.mean(C) - np.mean(D)),
            'p_value': float(stats.ttest_rel(D, C)[1]),
        },
    },
    'best_condition': best_cond,
    'raw_results': {k: [float(x) for x in v] for k, v in all_results.items()},
}

with open(f'{OUTPUT_DIR}/results.json', 'w') as f:
    json.dump(output, f, indent=2)

print(f"Results saved to {OUTPUT_DIR}/results.json")

Results saved to /home/jupyter/research/directed_kvcache/results/exp23/results.json


---
## Summary

### Key Question
Does **suffix priming** (query after document) work better than **prefix priming** (query before document)?

### Conditions Tested
- **A**: Bare doc, query at scoring (baseline)
- **B**: Prefix priming `[query][doc]`
- **C**: Suffix priming `[doc][query]`
- **D**: Bookend `[query][doc][query]`
- **E**: Suffix with explicit instruction

### Results
[To be filled after running]