# Decoder-Only Exp 01: Surrogate KV Caching — Condition Examples

This notebook shows the actual inputs for each experimental condition using real data.
No GPU needed — just data loading and text display.

The experiment tests whether conditioning a **decoder-only** model's KV cache with
surrogate prompts improves answer quality, using actual KV cache slicing and position
ID alignment.

In [1]:
import os, sys, json, re
import numpy as np
from pathlib import Path
from collections import Counter

sys.path.insert(0, "../../..")
from lib.data import count_words

SEED = 42
N_SAMPLES = 400

# ---- Load MS MARCO (same reconstruction as scoring notebook) ----
from datasets import load_dataset
ds = load_dataset("microsoft/ms_marco", "v1.1", split="validation")

all_candidates = []
for item in ds:
    if len(all_candidates) >= 3 * N_SAMPLES:
        break
    passages = item.get('passages', {})
    ptexts = passages.get('passage_text', [])
    is_sel = passages.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])
    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] not in ('[]', ''):
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    if not answer:
        continue
    for pt, sel in zip(ptexts, is_sel):
        wc = count_words(pt)
        if sel == 1 and 30 <= wc <= 300:
            all_candidates.append({
                'passage': pt, 'query': query, 'answer': answer,
                'word_count': wc,
            })
            break

np.random.seed(SEED)
indices = np.random.permutation(len(all_candidates))
samples = [all_candidates[i] for i in indices[:N_SAMPLES]]
del ds, all_candidates

STOP_WORDS = {
    'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
    'should', 'may', 'might', 'can', 'shall', 'to', 'of', 'in', 'for',
    'on', 'with', 'at', 'by', 'from', 'as', 'into', 'through', 'during',
    'before', 'after', 'above', 'below', 'between', 'and', 'but', 'or',
    'not', 'no', 'if', 'then', 'than', 'so', 'up', 'out', 'about',
    'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those',
    'it', 'its', 'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he',
    'him', 'his', 'she', 'her', 'they', 'them', 'their', 'how', 'when',
    'where', 'why', 'much', 'many', 'some', 'any', 'all', 'each',
    'does', 'also', 'just', 'more', 'most', 'very', 'too', 'only',
}

def extract_keywords(text):
    words = re.sub(r'[^\w\s]', '', text.lower()).split()
    return [w for w in words if w not in STOP_WORDS and len(w) > 2]

def make_doc_keywords(passage):
    content_words = extract_keywords(passage)
    if not content_words:
        return "information"
    counts = Counter(content_words)
    return " ".join(w for w, _ in counts.most_common(5))

SURROGATES = {
    'universal': "Analyze the following text for all key entities, factual claims, and logical relationships.",
    'extractor': "Examine this document specifically for data points, dates, numerical values, and specific named attributes.",
    'reasonant': "Evaluate the underlying arguments, sentiment, and intent of the following passage.",
    'analytic': "Provide a technical breakdown of the systems and processes described in this text.",
}

ADVERSARIAL_PREFIX = "The recipe calls for two cups of flour, one cup of sugar, and a pinch of salt."

# Verify against checkpoint
ckpt_path = Path("../../../results/decoder_only/exp01/checkpoint.json")
if ckpt_path.exists():
    ckpt = json.loads(ckpt_path.read_text())
    results = ckpt.get('results', [])
    if results and results[0].get('query', '')[:50] == samples[0]['query'][:50]:
        print(f"Checkpoint verification: MATCH")
    elif results:
        print(f"Checkpoint verification: MISMATCH")
        print(f"  Checkpoint: {results[0].get('query', '')[:50]}")
        print(f"  Samples:    {samples[0]['query'][:50]}")
else:
    print("No checkpoint found")

print(f"Loaded {len(samples)} MS MARCO samples (SEED={SEED})")
print(f"Sample 0 query: {samples[0]['query'][:70]}")

Checkpoint verification: MATCH
Loaded 400 MS MARCO samples (SEED=42)
Sample 0 query: average annual temperature of Uruguay


In [2]:
# Show 3 representative samples
for idx in [0, 1, 2]:
    ex = samples[idx]
    doc_kw = make_doc_keywords(ex['passage'])
    print("=" * 80)
    print(f"SAMPLE {idx}")
    print("=" * 80)
    print(f"  Query:        {ex['query']}")
    print(f"  Answer:       {ex['answer']}")
    print(f"  Document:     {ex['passage'][:100]}...")
    print(f"  Doc words:    {ex['word_count']}")
    print(f"  Doc keywords: {doc_kw}")
    print()

SAMPLE 0
  Query:        average annual temperature of Uruguay
  Answer:       Very mild at 15.8 degrees Celsius (60.4 degrees Fahrenheit).
  Document:     Average Temperatures in Montevideo, Uruguay. 1  The average annual temperature in Montevideo, Urugua...
  Doc words:    76
  Doc keywords: average degrees temperatures montevideo uruguay

SAMPLE 1
  Query:        average cost for an acre of land in arizona
  Answer:       $4,300 per acre.
  Document:     Arizona. With more than 72 million acres, Arizona was one of the largest states in the country. Howe...
  Doc words:    102
  Doc keywords: acre land average less arizona

SAMPLE 2
  Query:        where can i buy nematodes
  Answer:       Here to buy THE GOOD BUGS Supplier beneficial insects, mites and nematodes for commercial growers Buying insects from a reliable source is a very important step. yes
  Document:     Where to buy THE GOOD BUGS Supplier beneficial insects, mites and nematodes for commercial growers B...
  Doc words: 

In [3]:
ex = samples[0]
doc = ex['passage']
query = ex['query']
answer = ex['answer']
doc_kw = make_doc_keywords(doc)
doc_short = doc[:80]

print("=" * 80)
print("HOW THIS EXPERIMENT WORKS")
print("=" * 80)
print()
print("  Gemma 2 2B is a DECODER-ONLY model with causal attention.")
print("  In standard left-to-right processing [doc, query, answer]:")
print("    - Doc tokens can only see other doc tokens (causal mask)")
print("    - Query tokens can see doc + query")
print("    - Answer tokens can see everything")
print()
print("  KEY LIMITATION: Doc tokens NEVER see the query. Their KV")
print("  representations are query-blind.")
print()
print("  THE TRICK: Put a surrogate prompt BEFORE the document:")
print("    [surrogate, doc, query, answer]")
print("  Now doc tokens can see the surrogate via causal attention,")
print("  and their KV representations encode surrogate-aware features.")
print()
print("  PRODUCTION DEPLOYMENT:")
print("    Phase A (offline): Forward [surrogate + doc] -> KV cache")
print("                       Slice off surrogate entries -> store doc KV")
print("    Phase B (online):  Load cached doc KV")
print("                       Forward [query + answer] with cached KV")
print("                       Set position_ids to continue after doc")
print()
print("  POSITION ALIGNMENT:")
print("    If surrogate = S tokens, doc = D tokens:")
print("      Conditioning: [BOS, s1, ..., sS, sep, d1, ..., dD]")
print("      Positions:    [0,   1,  ..., S,  S+1, S+2, ..., S+1+D]")
print("      After slice:  keep [d1, ..., dD] at positions [S+2, ..., S+1+D]")
print("      Query starts: position S+2+D")
print()
print("    RoPE relative positions are IDENTICAL across conditions:")
print("      last_doc - first_query distance = 2 (always)")
print("      doc internal distances = unchanged")
print("    So the position offset does NOT confound the comparison.")

HOW THIS EXPERIMENT WORKS

  Gemma 2 2B is a DECODER-ONLY model with causal attention.
  In standard left-to-right processing [doc, query, answer]:
    - Doc tokens can only see other doc tokens (causal mask)
    - Query tokens can see doc + query
    - Answer tokens can see everything

  KEY LIMITATION: Doc tokens NEVER see the query. Their KV
  representations are query-blind.

  THE TRICK: Put a surrogate prompt BEFORE the document:
    [surrogate, doc, query, answer]
  Now doc tokens can see the surrogate via causal attention,
  and their KV representations encode surrogate-aware features.

  PRODUCTION DEPLOYMENT:
    Phase A (offline): Forward [surrogate + doc] -> KV cache
                       Slice off surrogate entries -> store doc KV
    Phase B (online):  Load cached doc KV
                       Forward [query + answer] with cached KV
                       Set position_ids to continue after doc

  POSITION ALIGNMENT:
    If surrogate = S tokens, doc = D tokens:
      Cond

In [4]:
print("=" * 80)
print(f"CONDITIONS (8 total) — Sample 0")
print("=" * 80)
print()
print(f"  Query:   {query}")
print(f"  Answer:  {answer}")
print(f"  Doc:     {doc_short}...")

print()
print("--- CONDITION 1: bare (LOWER BOUND) ---")
print()
print("  Phase A cache: [doc]")
print("  Slicing:       none")
print("  Phase B input: [query + answer]")
print()
print("  Standard causal LM. Doc tokens are query-blind.")
print("  Equivalent to a single forward pass of [doc, query, answer].")
print()
print(f"  Phase A sees: \"{doc_short}...\"")
print(f"  Phase B sees: \"\\n{query}\\n{answer}\"")

print()
print("--- CONDITION 2: oracle (UPPER BOUND) ---")
print()
print(f"  Prefix:        \"{query}\"")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer] (same query again)")
print()
print("  The real query conditions the doc KV via causal attention.")
print("  Doc tokens now \"know\" what the query is about.")
print("  This is the ceiling -- but uses future knowledge (cheating).")
print()
print(f"  Phase A sees: \"{query}\\n{doc_short}...\"")
print(f"  After slice:  doc KV only (at positions {len(query.split())+2}+)")
print(f"  Phase B sees: \"\\n{query}\\n{answer}\"")

print()
print("--- CONDITION 3: surr_universal ---")
print()
surr = SURROGATES['universal']
print(f"  Prefix:        \"{surr}\"")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer]")
print()
print("  Generic analysis prompt. Forces the model to activate entity,")
print("  factual, and relational features in the document KV.")
print()
print(f"  Phase A sees: \"{surr}\\n{doc_short}...\"")

print()
print("--- CONDITION 4: surr_extractor ---")
print()
surr = SURROGATES['extractor']
print(f"  Prefix:        \"{surr}\"")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer]")
print()
print("  Targets numerical/factual extraction. Best for data-point queries.")
print()
print(f"  Phase A sees: \"{surr}\\n{doc_short}...\"")

print()
print("--- CONDITION 5: surr_reasonant ---")
print()
surr = SURROGATES['reasonant']
print(f"  Prefix:        \"{surr}\"")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer]")
print()
print("  Targets reasoning/sentiment. Best for \"how\" or \"why\" queries.")
print()
print(f"  Phase A sees: \"{surr}\\n{doc_short}...\"")

print()
print("--- CONDITION 6: surr_analytic ---")
print()
surr = SURROGATES['analytic']
print(f"  Prefix:        \"{surr}\"")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer]")
print()
print("  Targets technical/process analysis. Best for engineering docs.")
print()
print(f"  Phase A sees: \"{surr}\\n{doc_short}...\"")

print()
print("--- CONDITION 7: surr_doc_kw ---")
print()
print(f"  Prefix:        \"{doc_kw}\"  (top-5 TF keywords from the document)")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer]")
print()
print("  Document-derived surrogate. In v3 (encoder-decoder), this was the")
print("  best non-oracle condition, capturing 89% of oracle benefit.")
print()
print(f"  Phase A sees: \"{doc_kw}\\n{doc_short}...\"")

print()
print("--- CONDITION 8: adversarial (NEGATIVE CONTROL) ---")
print()
print(f"  Prefix:        \"{ADVERSARIAL_PREFIX}\"")
print(f"  Phase A cache: [prefix + \\n + doc]")
print(f"  Slicing:       remove prefix entries")
print(f"  Phase B input: [query + answer]")
print()
print("  Completely off-topic prefix. Tests semantic sensitivity:")
print("  - If adversarial HURTS vs bare: conditioning is semantic")
print("  - If adversarial HELPS like others: effect is structural")
print("  - If adversarial = bare: prefix content is ignored")
print()
print(f"  Phase A sees: \"{ADVERSARIAL_PREFIX}\\n{doc_short}...\"")

CONDITIONS (8 total) — Sample 0

  Query:   average annual temperature of Uruguay
  Answer:  Very mild at 15.8 degrees Celsius (60.4 degrees Fahrenheit).
  Doc:     Average Temperatures in Montevideo, Uruguay. 1  The average annual temperature i...

--- CONDITION 1: bare (LOWER BOUND) ---

  Phase A cache: [doc]
  Slicing:       none
  Phase B input: [query + answer]

  Standard causal LM. Doc tokens are query-blind.
  Equivalent to a single forward pass of [doc, query, answer].

  Phase A sees: "Average Temperatures in Montevideo, Uruguay. 1  The average annual temperature i..."
  Phase B sees: "\naverage annual temperature of Uruguay\nVery mild at 15.8 degrees Celsius (60.4 degrees Fahrenheit)."

--- CONDITION 2: oracle (UPPER BOUND) ---

  Prefix:        "average annual temperature of Uruguay"
  Phase A cache: [prefix + \n + doc]
  Slicing:       remove prefix entries
  Phase B input: [query + answer] (same query again)

  The real query conditions the doc KV via causal attention.
 

In [5]:
print("=" * 80)
print("SUMMARY TABLE — All conditions for Sample 0")
print("=" * 80)
print()

conditions = [
    ('bare',           None),
    ('oracle',         query),
    ('surr_universal',  SURROGATES['universal']),
    ('surr_extractor',  SURROGATES['extractor']),
    ('surr_reasonant',  SURROGATES['reasonant']),
    ('surr_analytic',   SURROGATES['analytic']),
    ('surr_doc_kw',     doc_kw),
    ('adversarial',     ADVERSARIAL_PREFIX),
]

print(f"  {'#':<3} {'Condition':<20} {'Prefix words':>13} {'Phase A: conditioning input (first 70 chars)'}")
print(f"  {'-'*110}")
for i, (name, prefix) in enumerate(conditions, 1):
    if prefix is None:
        pw = '(none)'
        phase_a = doc[:70]
    else:
        pw = f"{len(prefix.split())}w"
        phase_a = (prefix + '\\n' + doc)[:70]
    print(f"  {i:<3} {name:<20} {pw:>13}   {phase_a}...")

print()
print(f"  Phase B (same for ALL conditions): \"\\n{query}\\n{answer}\"")
print(f"  NLL measured on: answer tokens only ({len(answer.split())} words)")
print()
print("  For conditions 2-8, Phase A output is sliced: prefix KV entries")
print("  are removed, keeping only the conditioned document KV.")
print("  Phase B position_ids are set to continue after the original")
print("  conditioning sequence (preserving correct RoPE distances).")

SUMMARY TABLE — All conditions for Sample 0

  #   Condition             Prefix words Phase A: conditioning input (first 70 chars)
  --------------------------------------------------------------------------------------------------------------
  1   bare                        (none)   Average Temperatures in Montevideo, Uruguay. 1  The average annual tem...
  2   oracle                          5w   average annual temperature of Uruguay\nAverage Temperatures in Montevi...
  3   surr_universal                 13w   Analyze the following text for all key entities, factual claims, and l...
  4   surr_extractor                 14w   Examine this document specifically for data points, dates, numerical v...
  5   surr_reasonant                 11w   Evaluate the underlying arguments, sentiment, and intent of the follow...
  6   surr_analytic                  13w   Provide a technical breakdown of the systems and processes described i...
  7   surr_doc_kw                     5w   average deg

In [6]:
# Show full condition details for 2 more samples
for idx in [1, 2]:
    ex = samples[idx]
    dkw = make_doc_keywords(ex['passage'])
    print("=" * 80)
    print(f"SAMPLE {idx}")
    print("=" * 80)
    print(f"  Query:        {ex['query']}")
    print(f"  Answer:       {ex['answer']}")
    print(f"  Document:     {ex['passage'][:100]}...")
    print(f"  Doc words:    {ex['word_count']}")
    print(f"  Doc keywords: {dkw}")
    print()

    conds = [
        ('bare',           None),
        ('oracle',         ex['query']),
        ('surr_universal',  SURROGATES['universal']),
        ('surr_extractor',  SURROGATES['extractor']),
        ('surr_reasonant',  SURROGATES['reasonant']),
        ('surr_analytic',   SURROGATES['analytic']),
        ('surr_doc_kw',     dkw),
        ('adversarial',     ADVERSARIAL_PREFIX),
    ]

    print(f"  {'#':<3} {'Condition':<20} {'Prefix words':>13} {'Phase A input (first 70 chars)'}")
    print(f"  {'-'*110}")
    for i, (name, prefix) in enumerate(conds, 1):
        if prefix is None:
            pw = '(none)'
            phase_a = ex['passage'][:70]
        else:
            pw = f"{len(prefix.split())}w"
            phase_a = (prefix + '\\n' + ex['passage'])[:70]
        print(f"  {i:<3} {name:<20} {pw:>13}   {phase_a}...")
    print()
    print(f"  Phase B: \"\\n{ex['query']}\\n{ex['answer']}\"")
    print()

SAMPLE 1
  Query:        average cost for an acre of land in arizona
  Answer:       $4,300 per acre.
  Document:     Arizona. With more than 72 million acres, Arizona was one of the largest states in the country. Howe...
  Doc words:    102
  Doc keywords: acre land average less arizona

  #   Condition             Prefix words Phase A input (first 70 chars)
  --------------------------------------------------------------------------------------------------------------
  1   bare                        (none)   Arizona. With more than 72 million acres, Arizona was one of the large...
  2   oracle                          9w   average cost for an acre of land in arizona\nArizona. With more than 7...
  3   surr_universal                 13w   Analyze the following text for all key entities, factual claims, and l...
  4   surr_extractor                 14w   Examine this document specifically for data points, dates, numerical v...
  5   surr_reasonant                 11w   Evaluate the u

In [7]:
print("=" * 80)
print("WHAT TO LOOK FOR IN RESULTS")
print("=" * 80)
print()
print("  1. oracle >> bare?")
print("     Does conditioning with the real query improve answer NLL?")
print("     Expected: yes (d~+0.4 based on v3 findings).")
print()
print("  2. surrogates > bare?")
print("     Do generic prompt surrogates help without knowing the query?")
print("     This is the practical question -- surrogates are free at offline time.")
print()
print("  3. adversarial vs bare?")
print("     The semantic sensitivity test:")
print("     - adversarial < bare: conditioning is content-sensitive (good)")
print("     - adversarial ~ surrogates: effect is purely structural (interesting)")
print("     - adversarial = bare: prefix is ignored entirely")
print()
print("  4. Recovery rate: (surr - bare) / (oracle - bare) x 100%")
print("     How much of the oracle ceiling does each surrogate capture?")
print("     v3 found surrogates often EXCEED oracle (>100% recovery)")
print("     because the real query can create semantic interference.")
print()
print("  5. Surrogate type ranking:")
print("     Do task-specific prompts (extractor, reasonant, analytic)")
print("     outperform generic ones (universal)? Or is the effect")
print("     mostly structural (v3 found 85% structural)?")
print()
print("  6. Hardness gradient:")
print("     Does conditioning help more for harder queries (higher bare NLL)?")
print("     v3 found huge gains for hard queries, slight degradation for easy.")

WHAT TO LOOK FOR IN RESULTS

  1. oracle >> bare?
     Does conditioning with the real query improve answer NLL?
     Expected: yes (d~+0.4 based on v3 findings).

  2. surrogates > bare?
     Do generic prompt surrogates help without knowing the query?
     This is the practical question -- surrogates are free at offline time.

  3. adversarial vs bare?
     The semantic sensitivity test:
     - adversarial < bare: conditioning is content-sensitive (good)
     - adversarial ~ surrogates: effect is purely structural (interesting)
     - adversarial = bare: prefix is ignored entirely

  4. Recovery rate: (surr - bare) / (oracle - bare) x 100%
     How much of the oracle ceiling does each surrogate capture?
     v3 found surrogates often EXCEED oracle (>100% recovery)
     because the real query can create semantic interference.

  5. Surrogate type ranking:
     Do task-specific prompts (extractor, reasonant, analytic)
     outperform generic ones (universal)? Or is the effect
     most