# Experiment 03d: Cross-Dataset Ablation â€” Condition Examples

This notebook shows the actual text for each experimental condition using real data from the dataset. No GPU needed.

In [None]:
import os, sys, json, re
import numpy as np
from pathlib import Path
from collections import Counter

sys.path.insert(0, ".")

SEED = 42
N_SAMPLES = 500

from datasets import load_dataset
ds = load_dataset("neural-bridge/rag-dataset-12000", split="train")

all_candidates = []
for row in ds:
    q = row.get("question", "")
    doc = row.get("context", "")
    answer = row.get("answer", "")
    if not q or not doc or not answer:
        continue
    q_words = len(q.split())
    a_words = len(answer.split())
    if q_words >= 15 and a_words >= 5:
        all_candidates.append({
            "query": q, "document": doc, "answer": answer,
            "query_words": q_words, "doc_words": len(doc.split()),
            "answer_words": a_words,
        })

np.random.seed(SEED)
indices = np.random.permutation(len(all_candidates))
samples = [all_candidates[i] for i in indices[:N_SAMPLES]]
del ds

# Verify against checkpoint
def verify_checkpoint(exp_name):
    ckpt_path = Path(f"results/{exp_name}/checkpoint.json")
    if ckpt_path.exists():
        ckpt = json.loads(ckpt_path.read_text())
        meta = ckpt.get('sample_meta', ckpt.get('results', []))
        if meta and meta[0].get('query', '')[:50] == samples[0]['query'][:50]:
            print(f"  Checkpoint verification: MATCH ({exp_name})")
            return True
    return None

print(f"Loaded {len(samples)} neural-bridge samples (SEED={SEED})")
print(f"Sample 0 query ({samples[0]['query_words']}w): {samples[0]['query'][:70]}")


def show_sample(s, doc_key='passage', n=0):
    # Show sample info
    doc = s[doc_key]
    print(f"{'='*80}")
    print(f"SAMPLE {n}")
    print(f"{'='*80}")
    print(f"  Query:    {s['query']}")
    print(f"  Answer:   {s['answer']}")
    print(f"  Document: {doc[:100]}...")
    print(f"  Doc words: {len(doc.split())}")
    print()

def show_conditions(conditions, doc_text):
    # conditions: list of (name, description, encoder_prefix_text_or_None)
    # For bare conditions, encoder_prefix_text is None
    print(f"{'Condition':<30} {'Prefix':<14} {'Encoder input (first 70 chars)'}")
    print(f"{'-'*100}")
    for name, desc, prefix_text in conditions:
        if prefix_text is None:
            enc_preview = doc_text[:70]
            print(f"{name:<30} {'(none)':<14} {enc_preview}...")
        else:
            enc_text = prefix_text + "\n" + doc_text
            print(f"{name:<30} {str(len(prefix_text.split()))+'w':<14} {enc_text[:70]}...")
        if desc:
            print(f"  {'':>28} ^ {desc}")
    print()

show_sample(samples[0], doc_key='document')
verify_checkpoint("exp03d")

ex = samples[0]
query_words = ex['query'].split()
n_qw = len(query_words)

other_idx = (0 + N_SAMPLES // 2) % N_SAMPLES
other_doc = samples[other_idx]['document']
other_words = other_doc.split()

rand_matched = " ".join(other_words[:n_qw])

rng = np.random.RandomState(SEED + 1)
shuffled = list(query_words)
rng.shuffle(shuffled)
scrambled = " ".join(shuffled)

STOP_WORDS = {
    'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
    'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
    'should', 'may', 'might', 'can', 'shall', 'to', 'of', 'in', 'for',
    'on', 'with', 'at', 'by', 'from', 'as', 'into', 'through', 'during',
    'before', 'after', 'above', 'below', 'between', 'and', 'but', 'or',
    'not', 'no', 'if', 'then', 'than', 'so', 'up', 'out', 'about',
    'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those',
    'it', 'its', 'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he',
    'him', 'his', 'she', 'her', 'they', 'them', 'their', 'how', 'when',
    'where', 'why', 'much', 'many', 'some', 'any', 'all', 'each',
    'does', 'also', 'just', 'more', 'most', 'very', 'too', 'only',
}
import re
def extract_keywords(text):
    words = re.sub(r'[^\w\s]', '', text.lower()).split()
    return [w for w in words if w not in STOP_WORDS and len(w) > 2]

doc_kws = extract_keywords(ex['document'])
counts = Counter(doc_kws) if doc_kws else Counter()
top_kw = counts.most_common(1)[0][0] if counts else "topic"
surr_template = f"What is {top_kw}?"

conditions = [
    ("bare", "Baseline", None),
    ("oracle_trunc", "Real query", ex['query']),
    ("random_matched_trunc", f"Random words, {n_qw}w matched", rand_matched),
    ("scrambled_oracle_trunc", "Query words shuffled", scrambled),
    ("surr_template_trunc", "'What is [kw]?'", surr_template),
    ("repeat_the_trunc", f"'the' x {n_qw}", " ".join(["the"] * n_qw)),
    ("repeat_kw_trunc", f"Top doc kw x {n_qw}",
     " ".join([top_kw] * n_qw)),
]
show_conditions(conditions, ex['document'])

print("CROSS-DATASET TEST: Same decomposition as Exp 2B but on neural-bridge.")
print(f"  Queries are ~{n_qw} words (vs ~6w in MS MARCO)")
print(f"  Documents are ~{ex['doc_words']} words (vs ~90w in MS MARCO)")
