# Experiment 01: Truncation Test â€” Condition Examples

This notebook shows the actual text for each experimental condition using real data from the dataset. No GPU needed.

In [None]:
import os, sys, json, re
import numpy as np
from pathlib import Path
from collections import Counter

sys.path.insert(0, "../..")
from lib.data import count_words

SEED = 42

# ---- Load MS MARCO (same reconstruction as Exp 01/02/etc.) ----
from datasets import load_dataset
ds = load_dataset("microsoft/ms_marco", "v1.1", split="validation")

samples = []
for item in ds:
    if len(samples) >= 1500:
        break
    passages = item.get('passages', {})
    ptexts = passages.get('passage_text', [])
    is_sel = passages.get('is_selected', [])
    query = item.get('query', '')
    answers = item.get('answers', [])
    well_formed = item.get('wellFormedAnswers', [])
    answer = None
    if well_formed and len(well_formed) > 0 and well_formed[0] not in ('[]', ''):
        answer = well_formed[0]
    elif answers and len(answers) > 0 and answers[0] != 'No Answer Present.':
        answer = answers[0]
    if not answer:
        continue
    for pt, sel in zip(ptexts, is_sel):
        wc = count_words(pt)
        if sel == 1 and 30 <= wc <= 300:
            samples.append({
                'passage': pt, 'query': query, 'answer': answer,
                'word_count': wc,
            })
            break

np.random.seed(SEED)
np.random.shuffle(samples)
samples = samples[:500]
del ds

STOP_WORDS = {
    'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
    'should', 'may', 'might', 'can', 'shall', 'to', 'of', 'in', 'for',
    'on', 'with', 'at', 'by', 'from', 'as', 'into', 'through', 'during',
    'before', 'after', 'above', 'below', 'between', 'and', 'but', 'or',
    'not', 'no', 'if', 'then', 'than', 'so', 'up', 'out', 'about',
    'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those',
    'it', 'its', 'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he',
    'him', 'his', 'she', 'her', 'they', 'them', 'their', 'how', 'when',
    'where', 'why', 'much', 'many', 'some', 'any', 'all', 'each',
    'does', 'also', 'just', 'more', 'most', 'very', 'too', 'only',
}

def extract_keywords(text):
    words = re.sub(r'[^\w\s]', '', text.lower()).split()
    return [w for w in words if w not in STOP_WORDS and len(w) > 2]

def make_surrogate_paraphrase(query):
    keywords = extract_keywords(query)
    return " ".join(keywords[::-1]) if keywords else query

def make_surrogate_from_doc(passage):
    content_words = extract_keywords(passage)
    if not content_words:
        return "information"
    counts = Counter(content_words)
    return " ".join(w for w, _ in counts.most_common(5))

def make_surrogate_template(passage):
    content_words = extract_keywords(passage)
    if not content_words:
        return "What is this about?"
    counts = Counter(content_words)
    top_word = counts.most_common(1)[0][0]
    return f"What is {top_word}?"

# Verify against checkpoint
def verify_checkpoint(exp_name):
    ckpt_path = Path(f"../../results/{exp_name}/checkpoint.json")
    if ckpt_path.exists():
        ckpt = json.loads(ckpt_path.read_text())
        results = ckpt.get('results', [])
        if results and results[0].get('query', '')[:50] == samples[0]['query'][:50]:
            print(f"  Checkpoint verification: MATCH ({exp_name})")
            return True
        elif results:
            print(f"  Checkpoint verification: MISMATCH ({exp_name})")
            print(f"    Checkpoint: {results[0].get('query', '')[:50]}")
            print(f"    Samples:    {samples[0]['query'][:50]}")
            return False
    else:
        print(f"  No checkpoint found for {exp_name}")
    return None

print(f"Loaded {len(samples)} MS MARCO samples (SEED={SEED})")
print(f"Sample 0 query: {samples[0]['query'][:70]}")


def show_sample(s, doc_key='passage', n=0):
    # Show sample info
    doc = s[doc_key]
    print(f"{'='*80}")
    print(f"SAMPLE {n}")
    print(f"{'='*80}")
    print(f"  Query:    {s['query']}")
    print(f"  Answer:   {s['answer']}")
    print(f"  Document: {doc[:100]}...")
    print(f"  Doc words: {len(doc.split())}")
    print()

def show_conditions(conditions, doc_text):
    # conditions: list of (name, description, encoder_prefix_text_or_None)
    # For bare conditions, encoder_prefix_text is None
    print(f"{'Condition':<30} {'Prefix':<14} {'Encoder input (first 70 chars)'}")
    print(f"{'-'*100}")
    for name, desc, prefix_text in conditions:
        if prefix_text is None:
            enc_preview = doc_text[:70]
            print(f"{name:<30} {'(none)':<14} {enc_preview}...")
        else:
            enc_text = prefix_text + "\n" + doc_text
            print(f"{name:<30} {str(len(prefix_text.split()))+'w':<14} {enc_text[:70]}...")
        if desc:
            print(f"  {'':>28} ^ {desc}")
    print()


verify_checkpoint("exp01")

ex = samples[0]
surr_para = make_surrogate_paraphrase(ex['query'])
surr_doc_kw = make_surrogate_from_doc(ex['passage'])
doc_short = ex['passage'][:80]

print("=" * 80)
print("SAMPLE")
print("=" * 80)
print(f"  Query:      {ex['query']}")
print(f"  Answer:     {ex['answer']}")
print(f"  Document:   {doc_short}...")
print(f"  Doc words:  {len(ex['passage'].split())}")
print()
print("=" * 80)
print("HOW THIS EXPERIMENT WORKS")
print("=" * 80)
print()
print("  The T5Gemma encoder-decoder has two stages:")
print()
print("    1. ENCODER: reads text with bidirectional attention (sees everything)")
print("    2. DECODER: generates the answer, cross-attending to encoder output")
print()
print("  We vary two things:")
print("    - What PREFIX (if any) is prepended to the document in the encoder")
print("    - Whether the decoder can see the prefix tokens (full) or only the")
print("      document tokens (trunc = truncated cross-attention)")
print()
print("  The decoder always scores the same answer text via NLL.")
print()

print("=" * 80)
print("CONDITIONS (7 total)")
print("=" * 80)

print()
print("--- CONDITION 1: bare (BASELINE) ---")
print()
print("  Encoder input:      [document]")
print("  Decoder attends to: [document]")
print()
print("  No prefix. This is the control -- how well does the model predict")
print("  the answer from the document alone?")
print()
print(f"  Encoder sees: \"{doc_short}...\"")

print()
print("--- CONDITION 2: oracle_full ---")
print()
print(f"  Prefix:             \"{ex['query']}\"")
print(f"  Encoder input:      [prefix] + [document]")
print(f"  Decoder attends to: [prefix + document]  <-- decoder CAN read the query")
print()
print("  Upper bound. But is the benefit because the decoder reads the query")
print("  directly, or because co-encoding improved the document representations?")

print()
print("--- CONDITION 3: oracle_trunc  *** THE KEY CONDITION ***")
print()
print(f"  Prefix:             \"{ex['query']}\"")
print(f"  Encoder input:      [prefix] + [document]  (same as oracle_full)")
print(f"  Decoder attends to: [document ONLY]  <-- prefix tokens MASKED")
print()
print("  Same encoder input as oracle_full, but the decoder CANNOT see the")
print("  query tokens. If this still beats bare, the document representations")
print("  themselves are improved by co-encoding with the query.")

print()
print("--- CONDITION 4: surr_para_full ---")
print()
print(f"  Prefix:             \"{surr_para}\"  (query keywords reversed)")
print(f"  Encoder input:      [prefix] + [document]")
print(f"  Decoder attends to: [prefix + document]")

print()
print("--- CONDITION 5: surr_para_trunc ---")
print()
print(f"  Prefix:             \"{surr_para}\"  (query keywords reversed)")
print(f"  Encoder input:      [prefix] + [document]")
print(f"  Decoder attends to: [document ONLY]  <-- prefix MASKED")

print()
print("--- CONDITION 6: surr_doc_full ---")
print()
print(f"  Prefix:             \"{surr_doc_kw}\"  (top-5 TF keywords from document)")
print(f"  Encoder input:      [prefix] + [document]")
print(f"  Decoder attends to: [prefix + document]")

print()
print("--- CONDITION 7: surr_doc_trunc ---")
print()
print(f"  Prefix:             \"{surr_doc_kw}\"  (top-5 TF keywords from document)")
print(f"  Encoder input:      [prefix] + [document]")
print(f"  Decoder attends to: [document ONLY]  <-- prefix MASKED")

print()
print("=" * 80)
print("WHAT TO LOOK FOR IN RESULTS")
print("=" * 80)
print()
print("  1. oracle_full >> bare?")
print("     Expected yes (replicates v2 Exp 33b, d ~ +0.35).")
print()
print("  2. oracle_trunc > bare?")
print("     If YES: document reps are genuinely improved by co-encoding.")
print("     This is the key finding -- the benefit isn't just the decoder")
print("     reading the query from the encoder output.")
print()
print("  3. oracle_trunc / oracle_full retention %?")
print("     High (>50%): most benefit is from improved doc representations.")
print("     Low  (<20%): decoder was mostly just reading the query directly.")
print()
print("  4. surr_*_trunc > bare?")
print("     Do surrogate prefixes also improve doc representations,")
print("     even when the decoder can't see them?")
