# Experiment 20: Retrieval Ranking & Semantic Steering Survey

**Date:** 2026-02-04

## Purpose

Exp 19 showed priming benefits are MS MARCO-specific for QA. This experiment surveys
alternative framings where priming might show stronger effects:

### Part A: Retrieval Ranking
Instead of scoring P(answer|passage), score P(passage|query) for ranking.
This is closer to ad-serving: given a query, which ad/document is most relevant?

### Part B: Semantic Steering (Generation Diversity)
Measure whether priming reduces output entropy and steers generation toward
a target direction. Relevant for ad copy generation.

### Part C: Multi-Document Focus
Given multiple retrieved passages, does priming help the model focus on the
relevant one? Simulates noisy retrieval in production.

### Part D: Product Search / E-commerce
Test on Amazon product search data - closest analogy to ad-serving.

## Datasets

| Part | Dataset | Why |
|------|---------|-----|
| A | MS MARCO (reframed) | Baseline comparison |
| B | ELI5 | Long generative answers, high ambiguity |
| C | Natural Questions (open) | Multi-passage retrieval |
| D | Amazon ESCI | Product search relevance |

In [1]:
# Cell 1: Setup
import os
os.umask(0o000)

import sys
sys.path.insert(0, '/home/jupyter/research/directed_kvcache')

import json
import random
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass, field
from scipy import stats
from collections import Counter

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DynamicCache
from datasets import load_dataset

from lib.kv_cache import (
    extract_and_truncate_cache_with_bos,
    correct_rope_positions_with_bos,
    score_answer_with_cache,
    deepcopy_cache,
)
from lib.config import ExperimentConfig

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTPUT_DIR = '/home/jupyter/research/directed_kvcache/results/exp20'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

PyTorch: 2.10.0+cu128
CUDA: True


In [2]:
# Cell 2: Load Model
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

config = ExperimentConfig(
    model_name=MODEL_NAME,
    device=model.device,
    seed=SEED
)

print(f"Model loaded on {model.device}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded on cuda:0


In [3]:
# Cell 3: Core Cache Functions (from Exp 19, with additions)

def build_bare_cache(text: str) -> Tuple[DynamicCache, int]:
    """Build baseline cache from text only."""
    ids = tokenizer.encode(text, return_tensors='pt', add_special_tokens=True).to(model.device)
    with torch.no_grad():
        out = model(ids, use_cache=True)
    return out.past_key_values, ids.shape[1]

def build_primed_cache_truncated(prefix: str, text: str) -> Tuple[DynamicCache, int]:
    """Build truncated cache: prefix removed after forward pass."""
    prefix_with_sep = prefix + " "
    prefix_ids = tokenizer.encode(prefix_with_sep, return_tensors='pt', add_special_tokens=True)
    prefix_len = prefix_ids.shape[1]
    
    full_text = prefix_with_sep + text
    full_ids = tokenizer.encode(full_text, return_tensors='pt', add_special_tokens=True).to(model.device)
    full_len = full_ids.shape[1]
    doc_len = full_len - prefix_len
    
    with torch.no_grad():
        out = model(full_ids, use_cache=True)
    
    truncated_cache = extract_and_truncate_cache_with_bos(out.past_key_values, doc_len)
    surrogate_offset = prefix_len - 1
    correct_rope_positions_with_bos(truncated_cache, surrogate_offset, model)
    
    return truncated_cache, 1 + doc_len

def score_continuation_nll(cache: DynamicCache, cache_len: int, continuation: str) -> float:
    """
    Score negative log-likelihood of continuation text given cache.
    
    This scores P(continuation | cache) directly without any prompt.
    """
    # Tokenize continuation (no special tokens - we're continuing from cache)
    cont_ids = tokenizer.encode(continuation, return_tensors='pt', add_special_tokens=False).to(model.device)
    
    if cont_ids.shape[1] == 0:
        return 0.0
    
    # Get the full sequence length for attention mask
    total_len = cache_len + cont_ids.shape[1]
    attention_mask = torch.ones((1, total_len), device=model.device)
    
    # Forward pass with cache
    cache_copy = deepcopy_cache(cache)
    with torch.no_grad():
        outputs = model(
            input_ids=cont_ids,
            attention_mask=attention_mask,
            past_key_values=cache_copy,
            use_cache=True,
            return_dict=True
        )
    
    # Compute NLL over continuation tokens
    logits = outputs.logits  # [1, cont_len, vocab]
    
    # Shift for next-token prediction: logits[:-1] predicts tokens[1:]
    # But we also need the last token of cache to predict first continuation token
    # The model output already accounts for this with past_key_values
    
    # For tokens 0..n-1 in continuation, logits[i] predicts continuation[i+1]
    # We need to score continuation[0] using last cache position, which isn't in logits
    # So we score continuation[1:] using logits[:-1]
    
    if cont_ids.shape[1] == 1:
        # Only one token - can't compute NLL this way, use full forward
        return 0.0
    
    shift_logits = logits[:, :-1, :].contiguous()  # [1, cont_len-1, vocab]
    shift_labels = cont_ids[:, 1:].contiguous()     # [1, cont_len-1]
    
    # Compute cross-entropy loss
    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    return loss.item()

def score_with_prompt(cache: DynamicCache, cache_len: int, prompt: str, completion: str) -> float:
    """Score P(completion | cache, prompt)."""
    return score_answer_with_cache(
        cache, cache_len,
        prompt,
        completion,
        model, tokenizer, config
    )

print("Cache functions defined.")

Cache functions defined.


---
# PART A: Retrieval Ranking

**Task:** Given a query and N candidate documents, rank by relevance.

**Scoring:** Instead of P(answer|doc, query), we score P(query|doc) — how well does the document "predict" or align with the query?

**Hypothesis:** Priming documents with relevant queries should increase P(query|doc) for matching pairs, improving ranking.

**Conditions:**
- `bare`: Score P(query|doc) with unprimed document cache
- `oracle_primed`: Prime doc with the target query, score P(query|doc)
- `random_primed`: Prime doc with random query, score P(query|doc)
- `topic_primed`: Prime doc with extracted topic/keywords

In [4]:
# Cell 5: Load MS MARCO for Retrieval Ranking

print("Loading MS MARCO for retrieval ranking...")
msmarco = load_dataset("ms_marco", "v1.1", split="train")

# Build retrieval samples: 1 relevant doc + 9 distractors per query
retrieval_samples = []
all_passages = []

# Collect passages
for item in msmarco:
    if item['passages']['passage_text']:
        for p in item['passages']['passage_text']:
            if p and len(p.split()) > 20:
                all_passages.append(p)
    if len(all_passages) > 10000:
        break

random.shuffle(all_passages)
distractor_pool = all_passages[:5000]

# Build samples
N_RETRIEVAL_SAMPLES = 200
N_CANDIDATES = 10  # 1 relevant + 9 distractors

for item in msmarco:
    if len(retrieval_samples) >= N_RETRIEVAL_SAMPLES:
        break
    if not item['passages']['passage_text'] or not item['query']:
        continue
    
    relevant_passage = item['passages']['passage_text'][0]
    if not relevant_passage or len(relevant_passage.split()) < 20:
        continue
    
    # Sample distractors
    distractors = random.sample(distractor_pool, N_CANDIDATES - 1)
    
    # Build candidate list with relevance labels
    candidates = [(relevant_passage, 1)]  # (passage, is_relevant)
    for d in distractors:
        candidates.append((d, 0))
    
    random.shuffle(candidates)
    
    retrieval_samples.append({
        'query': item['query'],
        'candidates': candidates,
        'relevant_idx': [i for i, (_, rel) in enumerate(candidates) if rel == 1][0]
    })

print(f"Built {len(retrieval_samples)} retrieval samples with {N_CANDIDATES} candidates each")

Loading MS MARCO for retrieval ranking...
Built 200 retrieval samples with 10 candidates each


In [5]:
# Cell 6: Retrieval Ranking Evaluation

def compute_mrr(rankings: List[int]) -> float:
    """Compute Mean Reciprocal Rank."""
    rrs = [1.0 / (r + 1) for r in rankings]
    return np.mean(rrs)

def compute_hit_at_k(rankings: List[int], k: int) -> float:
    """Compute Hit@k."""
    return np.mean([1 if r < k else 0 for r in rankings])

def evaluate_retrieval_ranking(samples: List[dict], condition: str) -> dict:
    """
    Evaluate retrieval ranking under a condition.
    
    For each sample:
    - Build cache for each candidate document (with/without priming)
    - Score P(query | doc_cache) for each candidate
    - Rank by ascending NLL (lower = more likely = more relevant)
    - Record rank of the relevant document
    """
    rankings = []
    
    for sample in tqdm(samples, desc=f"Retrieval ({condition})"):
        query = sample['query']
        candidates = sample['candidates']
        relevant_idx = sample['relevant_idx']
        
        scores = []  # (idx, nll)
        
        for idx, (passage, _) in enumerate(candidates):
            passage_truncated = passage[:2000]  # Limit length
            
            if condition == 'bare':
                cache, cache_len = build_bare_cache(passage_truncated)
            elif condition == 'oracle_primed':
                # Prime with the query we're about to score
                prefix = " ".join([query] * 3)
                cache, cache_len = build_primed_cache_truncated(prefix, passage_truncated)
            elif condition == 'random_primed':
                # Prime with random query from pool
                random_query = random.choice([s['query'] for s in samples if s['query'] != query])
                prefix = " ".join([random_query] * 3)
                cache, cache_len = build_primed_cache_truncated(prefix, passage_truncated)
            else:
                raise ValueError(f"Unknown condition: {condition}")
            
            # Score P(query | doc) using a prompt format
            # We use score_with_prompt with a transition prompt
            nll = score_with_prompt(
                deepcopy_cache(cache), cache_len,
                "\n\nThis document is relevant to the query:",
                " " + query
            )
            scores.append((idx, nll))
        
        # Rank by ascending NLL
        scores.sort(key=lambda x: x[1])
        ranked_indices = [idx for idx, _ in scores]
        
        # Find rank of relevant document (0-indexed)
        rank_of_relevant = ranked_indices.index(relevant_idx)
        rankings.append(rank_of_relevant)
    
    return {
        'condition': condition,
        'mrr': compute_mrr(rankings),
        'hit_at_1': compute_hit_at_k(rankings, 1),
        'hit_at_3': compute_hit_at_k(rankings, 3),
        'mean_rank': np.mean(rankings),
        'rankings': rankings
    }

print("Retrieval evaluation function defined.")

Retrieval evaluation function defined.


In [6]:
# Cell 7: Run Retrieval Ranking Evaluation

print("="*70)
print("PART A: RETRIEVAL RANKING")
print("="*70)
print("\nScoring P(query | document) to rank candidates.")
print("Hypothesis: Priming documents with query should improve ranking.\n")

retrieval_results = {}

for condition in ['bare', 'oracle_primed', 'random_primed']:
    result = evaluate_retrieval_ranking(retrieval_samples[:200], condition)
    retrieval_results[condition] = result
    print(f"\n{condition}:")
    print(f"  MRR: {result['mrr']:.3f}")
    print(f"  Hit@1: {result['hit_at_1']*100:.1f}%")
    print(f"  Hit@3: {result['hit_at_3']*100:.1f}%")
    print(f"  Mean Rank: {result['mean_rank']:.2f}")

PART A: RETRIEVAL RANKING

Scoring P(query | document) to rank candidates.
Hypothesis: Priming documents with query should improve ranking.



Retrieval (bare):   0%|          | 0/200 [00:00<?, ?it/s]


bare:
  MRR: 0.985
  Hit@1: 97.5%
  Hit@3: 99.5%
  Mean Rank: 0.04


Retrieval (oracle_primed):   0%|          | 0/200 [00:00<?, ?it/s]


oracle_primed:
  MRR: 0.962
  Hit@1: 93.5%
  Hit@3: 99.0%
  Mean Rank: 0.11


Retrieval (random_primed):   0%|          | 0/200 [00:00<?, ?it/s]


random_primed:
  MRR: 0.974
  Hit@1: 95.5%
  Hit@3: 99.5%
  Mean Rank: 0.10


In [7]:
# Cell 8: Retrieval Ranking Analysis

print("\n" + "="*70)
print("RETRIEVAL RANKING ANALYSIS")
print("="*70)

bare_rankings = retrieval_results['bare']['rankings']
oracle_rankings = retrieval_results['oracle_primed']['rankings']
random_rankings = retrieval_results['random_primed']['rankings']

# Paired comparisons
oracle_better = sum(1 for b, o in zip(bare_rankings, oracle_rankings) if o < b)
oracle_worse = sum(1 for b, o in zip(bare_rankings, oracle_rankings) if o > b)
oracle_same = sum(1 for b, o in zip(bare_rankings, oracle_rankings) if o == b)

print(f"\nOracle vs Bare:")
print(f"  Oracle better: {oracle_better} ({oracle_better/len(bare_rankings)*100:.1f}%)")
print(f"  Oracle worse: {oracle_worse} ({oracle_worse/len(bare_rankings)*100:.1f}%)")
print(f"  Same: {oracle_same}")

# Statistical test (Wilcoxon signed-rank)
stat, p_value = stats.wilcoxon(bare_rankings, oracle_rankings, alternative='greater')
print(f"  Wilcoxon p-value (bare > oracle): {p_value:.4f}")

random_better = sum(1 for b, r in zip(bare_rankings, random_rankings) if r < b)
random_worse = sum(1 for b, r in zip(bare_rankings, random_rankings) if r > b)

print(f"\nRandom vs Bare:")
print(f"  Random better: {random_better} ({random_better/len(bare_rankings)*100:.1f}%)")
print(f"  Random worse: {random_worse} ({random_worse/len(bare_rankings)*100:.1f}%)")

# Oracle vs Random
oracle_beats_random = sum(1 for o, r in zip(oracle_rankings, random_rankings) if o < r)
print(f"\nOracle vs Random:")
print(f"  Oracle better: {oracle_beats_random} ({oracle_beats_random/len(bare_rankings)*100:.1f}%)")


RETRIEVAL RANKING ANALYSIS

Oracle vs Bare:
  Oracle better: 2 (1.0%)
  Oracle worse: 13 (6.5%)
  Same: 185
  Wilcoxon p-value (bare > oracle): 0.9899

Random vs Bare:
  Random better: 1 (0.5%)
  Random worse: 5 (2.5%)

Oracle vs Random:
  Oracle better: 5 (2.5%)


---
# PART B: Semantic Steering (Generation Diversity)

**Task:** Generate explanations for questions. Measure if priming reduces diversity and steers toward target.

**Dataset:** ELI5 (Explain Like I'm 5) - long generative answers with high ambiguity.

**Hypothesis:** Priming with topic/intent keywords should:
1. Reduce generation entropy (more focused)
2. Increase overlap with reference answer
3. Produce more consistent outputs across samples

**Conditions:**
- `bare`: Generate from question alone
- `topic_primed`: Prime with extracted keywords from reference answer
- `random_primed`: Prime with random keywords

In [8]:
# Cell 10: Load ELI5 (or fallback to similar Q&A dataset)

print("Loading Q&A dataset for semantic steering...")

# Try multiple options for long-form Q&A
eli5_samples = []

# Option 1: Try Yahoo Answers (similar long-form Q&A)
try:
    print("Trying Yahoo Answers Topics...")
    yahoo = load_dataset("yahoo_answers_topics", split="train")
    for item in yahoo:
        if len(eli5_samples) >= 250:
            break
        question = item.get('question_title', '') + " " + item.get('question_content', '')
        answer = item.get('best_answer', '')
        if question.strip() and answer and len(answer.split()) > 20:
            eli5_samples.append({
                'question': question.strip()[:500],
                'reference_answer': answer[:1000],
            })
    print(f"Loaded {len(eli5_samples)} samples from Yahoo Answers")
except Exception as e:
    print(f"Yahoo Answers failed: {e}")

# Option 2: Try Natural Questions (open-domain)
if len(eli5_samples) < 250:
    try:
        print("Trying Natural Questions...")
        nq = load_dataset("google-research-datasets/natural_questions", "default", split="train", streaming=True)
        for item in nq:
            if len(eli5_samples) >= 250:
                break
            question = item.get('question', {}).get('text', '')
            # NQ has short answers - use the document text as context
            annotations = item.get('annotations', [])
            if annotations and annotations[0].get('long_answer', {}).get('candidate_index', -1) >= 0:
                doc = item.get('document', {}).get('tokens', {}).get('token', [])
                if doc:
                    answer = ' '.join(doc[:200])  # First 200 tokens as "answer"
                    if question and len(answer.split()) > 20:
                        eli5_samples.append({
                            'question': question,
                            'reference_answer': answer[:1000],
                        })
        print(f"Loaded {len(eli5_samples)} samples")
    except Exception as e:
        print(f"Natural Questions failed: {e}")

# Option 3: Use MS MARCO QA pairs (already loaded) as fallback
if len(eli5_samples) < 250:
    print("Using MS MARCO for semantic steering (already loaded)...")
    for item in msmarco:
        if len(eli5_samples) >= 250:
            break
        query = item.get('query', '')
        answers = item.get('answers', [])
        passages = item.get('passages', {}).get('passage_text', [])
        if query and passages and passages[0] and len(passages[0].split()) > 20:
            # Use passage as "reference answer" for steering test
            eli5_samples.append({
                'question': query,
                'reference_answer': passages[0][:1000],
            })
    print(f"Using {len(eli5_samples)} MS MARCO samples for steering evaluation")

print(f"\nTotal samples: {len(eli5_samples)}")
if eli5_samples:
    print(f"\nExample:")
    print(f"Q: {eli5_samples[0]['question'][:100]}...")
    print(f"A: {eli5_samples[0]['reference_answer'][:200]}...")

Loading Q&A dataset for semantic steering...
Trying Yahoo Answers Topics...




Loaded 250 samples from Yahoo Answers

Total samples: 250

Example:
Q: why doesn't an optical mouse work on a glass table? or even on some surfaces?...
A: Optical mice use an LED and a camera to rapidly capture images of the surface beneath the mouse.  The infomation from the camera is analyzed by a DSP (Digital Signal Processor) and used to detect impe...


In [9]:
# Cell 11: Extract Keywords for Topic Priming

import re
from collections import Counter

# Simple keyword extraction (top TF words, excluding stopwords)
STOPWORDS = set(['the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
                 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
                 'should', 'may', 'might', 'must', 'shall', 'can', 'need', 'dare',
                 'ought', 'used', 'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by',
                 'from', 'as', 'into', 'through', 'during', 'before', 'after',
                 'above', 'below', 'between', 'under', 'again', 'further', 'then',
                 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all',
                 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no',
                 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very',
                 's', 't', 'just', 'don', 'now', 'and', 'but', 'or', 'because',
                 'if', 'while', 'although', 'this', 'that', 'these', 'those',
                 'it', 'its', 'they', 'them', 'their', 'what', 'which', 'who',
                 'whom', 'i', 'you', 'he', 'she', 'we', 'your', 'his', 'her', 'my'])

def extract_keywords(text: str, n: int = 5) -> List[str]:
    """Extract top-n keywords from text."""
    words = re.findall(r'\b[a-z]+\b', text.lower())
    words = [w for w in words if w not in STOPWORDS and len(w) > 2]
    counts = Counter(words)
    return [w for w, _ in counts.most_common(n)]

# Add keywords to samples
for sample in eli5_samples:
    sample['topic_keywords'] = extract_keywords(sample['reference_answer'], n=5)

print("Example keywords:")
for i in range(3):
    print(f"  Q: {eli5_samples[i]['question'][:50]}...")
    print(f"  Keywords: {eli5_samples[i]['topic_keywords']}")

Example keywords:
  Q: why doesn't an optical mouse work on a glass table...
  Keywords: ['surface', 'mouse', 'dsp', 'motion', 'camera']
  Q: What is Trans Fat? How to reduce that? I heard tha...
  Keywords: ['trans', 'fat', 'fats', 'foods', 'oil']
  Q: How many planes Fedex has? I heard that it is the ...
  Keywords: ['boeing', 'airbus', 'atr', 'cessna', 'according']


In [10]:
# Cell 12: Semantic Steering Evaluation (NLL-based, no generation)

def compute_keyword_overlap(text: str, keywords: List[str]) -> float:
    """Compute fraction of keywords present in text."""
    text_lower = text.lower()
    return sum(1 for k in keywords if k in text_lower) / len(keywords) if keywords else 0

def evaluate_steering(samples: List[dict], condition: str, n_samples: int = 200) -> dict:
    """
    Evaluate semantic steering via NLL scoring.
    
    For each sample:
    - Build context cache (with/without priming)
    - Score P(reference_answer | context) via NLL
    - Lower NLL = priming steers model toward reference content
    
    Note: Generation removed due to transformers compatibility issues with custom caches.
    NLL is the primary metric for measuring steering effectiveness.
    """
    results = []
    
    for sample in tqdm(samples[:n_samples], desc=f"Steering ({condition})"):
        question = sample['question']
        keywords = sample['topic_keywords']
        reference = sample['reference_answer']
        
        # Build context
        context = f"Question: {question}\n\nExplain simply:"
        
        if condition == 'bare':
            cache, cache_len = build_bare_cache(context)
        elif condition == 'topic_primed':
            # Prime with topic keywords (oracle - knows the answer keywords)
            prefix = "Key topics: " + ", ".join(keywords)
            cache, cache_len = build_primed_cache_truncated(prefix, context)
        elif condition == 'random_primed':
            # Prime with random keywords from other samples
            other_keywords = random.choice([s['topic_keywords'] for s in samples if s != sample])
            prefix = "Key topics: " + ", ".join(other_keywords)
            cache, cache_len = build_primed_cache_truncated(prefix, context)
        else:
            raise ValueError(f"Unknown condition: {condition}")
        
        # Score reference answer NLL (primary metric)
        # Lower NLL = model assigns higher probability to reference content
        ref_nll = score_continuation_nll(deepcopy_cache(cache), cache_len, " " + reference[:200])
        
        # Also compute keyword overlap in reference (sanity check)
        ref_keyword_overlap = compute_keyword_overlap(reference, keywords)
        
        results.append({
            'ref_nll': ref_nll,
            'ref_keyword_overlap': ref_keyword_overlap,
        })
    
    return {
        'condition': condition,
        'mean_ref_nll': np.mean([r['ref_nll'] for r in results]),
        'std_ref_nll': np.std([r['ref_nll'] for r in results]),
        'mean_keyword_overlap': np.mean([r['ref_keyword_overlap'] for r in results]),
        'results': results
    }

print("Steering evaluation function defined (NLL-based).")

Steering evaluation function defined (NLL-based).


In [11]:
# Cell 13: Run Semantic Steering Evaluation

print("\n" + "="*70)
print("PART B: SEMANTIC STEERING")
print("="*70)
print("\nMeasuring if topic priming steers model toward reference content.")
print("Metric: NLL of reference answer (lower = better steering)")
print("Oracle condition: primed with keywords extracted from the reference answer.\n")

steering_results = {}

for condition in ['bare', 'topic_primed', 'random_primed']:
    result = evaluate_steering(eli5_samples, condition, n_samples=200)
    steering_results[condition] = result
    print(f"\n{condition}:")
    print(f"  Mean Reference NLL: {result['mean_ref_nll']:.3f} (+/- {result['std_ref_nll']:.3f})")


PART B: SEMANTIC STEERING

Measuring if topic priming steers model toward reference content.
Metric: NLL of reference answer (lower = better steering)
Oracle condition: primed with keywords extracted from the reference answer.



Steering (bare):   0%|          | 0/200 [00:00<?, ?it/s]


bare:
  Mean Reference NLL: 3.077 (+/- 0.792)


Steering (topic_primed):   0%|          | 0/200 [00:00<?, ?it/s]


topic_primed:
  Mean Reference NLL: 3.484 (+/- 0.831)


Steering (random_primed):   0%|          | 0/200 [00:00<?, ?it/s]


random_primed:
  Mean Reference NLL: 3.457 (+/- 0.844)


In [12]:
# Cell 14: Steering Analysis

print("\n" + "="*70)
print("SEMANTIC STEERING ANALYSIS")
print("="*70)

bare_nlls = [r['ref_nll'] for r in steering_results['bare']['results']]
topic_nlls = [r['ref_nll'] for r in steering_results['topic_primed']['results']]
random_nlls = [r['ref_nll'] for r in steering_results['random_primed']['results']]

# Statistical tests
t_nll, p_nll = stats.ttest_rel(bare_nlls, topic_nlls)

print(f"\nTopic Primed vs Bare:")
print(f"  Reference NLL: {np.mean(topic_nlls):.3f} vs {np.mean(bare_nlls):.3f}")
delta_nll = np.mean(bare_nlls) - np.mean(topic_nlls)
print(f"    Delta: {delta_nll:+.3f} (positive = priming helps)")
print(f"    p-value: {p_nll:.4f}")

# Win rate
win_rate = np.mean(np.array(bare_nlls) > np.array(topic_nlls))
print(f"    Win Rate: {win_rate*100:.1f}%")

# Effect size (Cohen's d)
d_nll = delta_nll / np.std(np.array(bare_nlls) - np.array(topic_nlls)) if np.std(np.array(bare_nlls) - np.array(topic_nlls)) > 0 else 0
print(f"    Cohen's d: {d_nll:+.3f}")

# Random vs Bare comparison
t_random, p_random = stats.ttest_rel(bare_nlls, random_nlls)
delta_random = np.mean(bare_nlls) - np.mean(random_nlls)
print(f"\nRandom Primed vs Bare:")
print(f"  Reference NLL: {np.mean(random_nlls):.3f} vs {np.mean(bare_nlls):.3f}")
print(f"    Delta: {delta_random:+.3f}")
print(f"    p-value: {p_random:.4f}")

# Topic vs Random (semantic signal test)
t_topic_random, p_topic_random = stats.ttest_rel(topic_nlls, random_nlls)
delta_topic_random = np.mean(random_nlls) - np.mean(topic_nlls)
print(f"\nTopic vs Random (semantic signal):")
print(f"  Topic NLL: {np.mean(topic_nlls):.3f} vs Random NLL: {np.mean(random_nlls):.3f}")
print(f"    Delta: {delta_topic_random:+.3f} (positive = topic better)")
print(f"    p-value: {p_topic_random:.4f}")


SEMANTIC STEERING ANALYSIS

Topic Primed vs Bare:
  Reference NLL: 3.484 vs 3.077
    Delta: -0.407 (positive = priming helps)
    p-value: 0.0000
    Win Rate: 6.0%
    Cohen's d: -1.376

Random Primed vs Bare:
  Reference NLL: 3.457 vs 3.077
    Delta: -0.380
    p-value: 0.0000

Topic vs Random (semantic signal):
  Topic NLL: 3.484 vs Random NLL: 3.457
    Delta: -0.027 (positive = topic better)
    p-value: 0.0276


---
# PART C: Multi-Document Focus

**Task:** Given a question and multiple retrieved passages (some relevant, some not),
can priming help the model focus on the relevant passage?

**Setup:** Concatenate 3 passages (1 relevant + 2 distractors), ask model to answer.
Score P(answer | multi-passage context).

**Hypothesis:** Priming the relevant passage should help the model attend to it
despite distractor noise.

**Conditions:**
- `bare`: No priming on any passage
- `relevant_primed`: Prime only the relevant passage with the query
- `all_primed`: Prime all passages with the query

In [13]:
# Cell 16: Build Multi-Document Samples

print("Building multi-document samples from MS MARCO...")

multidoc_samples = []

for item in msmarco:
    if len(multidoc_samples) >= 250:  # Extra buffer
        break
    
    if not item['answers'] or not item['answers'][0]:
        continue
    if not item['passages']['passage_text'] or not item['passages']['passage_text'][0]:
        continue
    
    relevant_passage = item['passages']['passage_text'][0][:500]
    query = item['query']
    answer = item['answers'][0]
    
    if len(relevant_passage.split()) < 20 or len(answer.split()) < 3:
        continue
    
    # Sample 2 distractor passages
    distractors = random.sample(distractor_pool, 2)
    distractors = [d[:500] for d in distractors]
    
    # Random position for relevant passage
    position = random.randint(0, 2)
    passages = distractors[:position] + [relevant_passage] + distractors[position:]
    passages = passages[:3]  # Ensure exactly 3
    
    multidoc_samples.append({
        'query': query,
        'answer': answer,
        'passages': passages,
        'relevant_position': position,
    })

print(f"Built {len(multidoc_samples)} multi-document samples")

Building multi-document samples from MS MARCO...
Built 250 multi-document samples


In [14]:
# Cell 17: Multi-Document Evaluation

def evaluate_multidoc(samples: List[dict], condition: str, n_samples: int = 200) -> dict:
    """
    Evaluate multi-document focus.
    
    Build a combined context from 3 passages, score P(answer | context, query).
    """
    results = []
    
    for sample in tqdm(samples[:n_samples], desc=f"MultiDoc ({condition})"):
        query = sample['query']
        answer = sample['answer']
        passages = sample['passages']
        relevant_pos = sample['relevant_position']
        
        # Build combined context
        if condition == 'bare':
            # No priming
            context_parts = [f"Passage {i+1}: {p}" for i, p in enumerate(passages)]
        elif condition == 'relevant_primed':
            # Prime only the relevant passage
            context_parts = []
            for i, p in enumerate(passages):
                if i == relevant_pos:
                    # This passage gets primed (we simulate by adding query context)
                    context_parts.append(f"Passage {i+1} [RELEVANT]: {p}")
                else:
                    context_parts.append(f"Passage {i+1}: {p}")
        elif condition == 'query_hint':
            # Add query hint before relevant passage
            context_parts = []
            for i, p in enumerate(passages):
                if i == relevant_pos:
                    context_parts.append(f"Passage {i+1} (answers: {query[:50]}): {p}")
                else:
                    context_parts.append(f"Passage {i+1}: {p}")
        else:
            raise ValueError(f"Unknown condition: {condition}")
        
        context = "\n\n".join(context_parts)
        
        # Build cache and score
        cache, cache_len = build_bare_cache(context)
        prompt = f"\n\nQuery: {query}\nAnswer:"
        nll = score_with_prompt(deepcopy_cache(cache), cache_len, prompt, " " + answer)
        
        results.append({
            'nll': nll,
            'relevant_position': relevant_pos,
        })
    
    return {
        'condition': condition,
        'mean_nll': np.mean([r['nll'] for r in results]),
        'std_nll': np.std([r['nll'] for r in results]),
        'results': results
    }

print("Multi-doc evaluation function defined.")

Multi-doc evaluation function defined.


In [15]:
# Cell 18: Run Multi-Document Evaluation

print("\n" + "="*70)
print("PART C: MULTI-DOCUMENT FOCUS")
print("="*70)
print("\n3 passages per sample (1 relevant + 2 distractors).")
print("Testing if hints/priming helps model focus on relevant passage.\n")

multidoc_results = {}

for condition in ['bare', 'relevant_primed', 'query_hint']:
    result = evaluate_multidoc(multidoc_samples, condition, n_samples=200)
    multidoc_results[condition] = result
    print(f"\n{condition}:")
    print(f"  Mean NLL: {result['mean_nll']:.3f} (+/- {result['std_nll']:.3f})")


PART C: MULTI-DOCUMENT FOCUS

3 passages per sample (1 relevant + 2 distractors).
Testing if hints/priming helps model focus on relevant passage.



MultiDoc (bare):   0%|          | 0/200 [00:00<?, ?it/s]


bare:
  Mean NLL: 2.801 (+/- 1.734)


MultiDoc (relevant_primed):   0%|          | 0/200 [00:00<?, ?it/s]


relevant_primed:
  Mean NLL: 2.739 (+/- 1.688)


MultiDoc (query_hint):   0%|          | 0/200 [00:00<?, ?it/s]


query_hint:
  Mean NLL: 2.726 (+/- 1.738)


In [16]:
# Cell 19: Multi-Doc Analysis

print("\n" + "="*70)
print("MULTI-DOCUMENT ANALYSIS")
print("="*70)

bare_nlls = [r['nll'] for r in multidoc_results['bare']['results']]
hint_nlls = [r['nll'] for r in multidoc_results['query_hint']['results']]

delta = np.array(bare_nlls) - np.array(hint_nlls)
win_rate = np.mean(delta > 0)
d = np.mean(delta) / np.std(delta) if np.std(delta) > 0 else 0

t_stat, p_value = stats.ttest_rel(bare_nlls, hint_nlls)

print(f"\nQuery Hint vs Bare:")
print(f"  Mean NLL: {np.mean(hint_nlls):.3f} vs {np.mean(bare_nlls):.3f}")
print(f"  Delta: {np.mean(delta):+.3f} (positive = hint helps)")
print(f"  Win Rate: {win_rate*100:.1f}%")
print(f"  Cohen's d: {d:+.3f}")
print(f"  p-value: {p_value:.4f}")


MULTI-DOCUMENT ANALYSIS

Query Hint vs Bare:
  Mean NLL: 2.726 vs 2.801
  Delta: +0.074 (positive = hint helps)
  Win Rate: 66.0%
  Cohen's d: +0.195
  p-value: 0.0064


---
# PART D: Product Search (Amazon ESCI)

**Task:** Given a search query and product descriptions, rank products by relevance.

**Dataset:** Amazon ESCI (Shopping Queries Dataset) - product search relevance.

**Hypothesis:** Priming product descriptions with query intent should improve ranking.
This is the closest analogy to ad-serving.

In [17]:
# Cell 21: Load Amazon ESCI

print("Loading Amazon ESCI dataset...")

try:
    esci = load_dataset("tasksource/esci", split="train")
    
    # Build product search samples
    esci_samples = []
    query_products = {}  # Group products by query
    
    for item in esci:
        query = item['query']
        product = item['product_title']
        label = item['esci_label']  # E (exact), S (substitute), C (complement), I (irrelevant)
        
        if query not in query_products:
            query_products[query] = []
        query_products[query].append((product, label))
        
        if len(query_products) > 1000:  # More queries for 200 samples
            break
    
    # Build ranking samples (queries with both relevant and irrelevant products)
    for query, products in query_products.items():
        relevant = [p for p, l in products if l in ['E', 'S']]
        irrelevant = [p for p, l in products if l == 'I']
        
        if relevant and irrelevant and len(irrelevant) >= 4:
            # 1 relevant + 4 irrelevant
            candidates = [(relevant[0], 1)] + [(p, 0) for p in irrelevant[:4]]
            random.shuffle(candidates)
            
            esci_samples.append({
                'query': query,
                'candidates': candidates,
                'relevant_idx': [i for i, (_, rel) in enumerate(candidates) if rel == 1][0]
            })
        
        if len(esci_samples) >= 250:  # Buffer for 200
            break
    
    print(f"Built {len(esci_samples)} product search samples")
    print(f"\nExample:")
    print(f"  Query: {esci_samples[0]['query']}")
    print(f"  Relevant: {esci_samples[0]['candidates'][esci_samples[0]['relevant_idx']][0][:80]}...")
    
except Exception as e:
    print(f"Could not load ESCI: {e}")
    print("Falling back to MS MARCO product-like queries...")
    esci_samples = []

Loading Amazon ESCI dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00011-2d36455632bef8(…):   0%|          | 0.00/115M [00:00<?, ?B/s]

data/train-00001-of-00011-18b81793a48399(…):   0%|          | 0.00/120M [00:00<?, ?B/s]

data/train-00002-of-00011-71f741fdff9a6f(…):   0%|          | 0.00/144M [00:00<?, ?B/s]

data/train-00003-of-00011-986bc53b83688d(…):   0%|          | 0.00/155M [00:00<?, ?B/s]

data/train-00004-of-00011-207d8e840a42bc(…):   0%|          | 0.00/166M [00:00<?, ?B/s]

data/train-00005-of-00011-14047762cd2d57(…):   0%|          | 0.00/177M [00:00<?, ?B/s]

data/train-00006-of-00011-8832797e39def5(…):   0%|          | 0.00/184M [00:00<?, ?B/s]

data/train-00007-of-00011-75a55aecb7275f(…):   0%|          | 0.00/189M [00:00<?, ?B/s]

data/train-00008-of-00011-75a25564d1f0fd(…):   0%|          | 0.00/206M [00:00<?, ?B/s]

data/train-00009-of-00011-5cd393dda922ee(…):   0%|          | 0.00/182M [00:00<?, ?B/s]

data/train-00010-of-00011-232f0dd1a755c7(…):   0%|          | 0.00/164M [00:00<?, ?B/s]

data/test-00000-of-00004-d48474212b95f33(…):   0%|          | 0.00/161M [00:00<?, ?B/s]

data/test-00001-of-00004-b7602f1b5c13695(…):   0%|          | 0.00/187M [00:00<?, ?B/s]

data/test-00002-of-00004-a81cff173329b48(…):   0%|          | 0.00/193M [00:00<?, ?B/s]

data/test-00003-of-00004-22af4ca7fa1313b(…):   0%|          | 0.00/175M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2027874 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/652490 [00:00<?, ? examples/s]

Built 0 product search samples

Example:
Could not load ESCI: list index out of range
Falling back to MS MARCO product-like queries...


In [18]:
# Cell 22: Product Search Evaluation

if esci_samples:
    print("\n" + "="*70)
    print("PART D: PRODUCT SEARCH (AMAZON ESCI)")
    print("="*70)
    print("\nRanking products by P(query | product_description).\n")
    
    product_results = {}
    
    for condition in ['bare', 'oracle_primed']:
        rankings = []
        
        for sample in tqdm(esci_samples[:200], desc=f"Products ({condition})"):
            query = sample['query']
            candidates = sample['candidates']
            relevant_idx = sample['relevant_idx']
            
            scores = []
            
            for idx, (product, _) in enumerate(candidates):
                product_text = f"Product: {product}"
                
                if condition == 'bare':
                    cache, cache_len = build_bare_cache(product_text)
                elif condition == 'oracle_primed':
                    prefix = f"Search query: {query} {query} {query}"
                    cache, cache_len = build_primed_cache_truncated(prefix, product_text)
                
                # Score P(query | product) with a transition prompt
                nll = score_with_prompt(
                    deepcopy_cache(cache), cache_len,
                    "\n\nRelevant search query:",
                    " " + query
                )
                scores.append((idx, nll))
            
            scores.sort(key=lambda x: x[1])
            ranked_indices = [idx for idx, _ in scores]
            rank_of_relevant = ranked_indices.index(relevant_idx)
            rankings.append(rank_of_relevant)
        
        product_results[condition] = {
            'mrr': compute_mrr(rankings),
            'hit_at_1': compute_hit_at_k(rankings, 1),
            'rankings': rankings
        }
        
        print(f"\n{condition}:")
        print(f"  MRR: {product_results[condition]['mrr']:.3f}")
        print(f"  Hit@1: {product_results[condition]['hit_at_1']*100:.1f}%")
else:
    print("Skipping Part D (ESCI dataset not available)")
    product_results = {}

Skipping Part D (ESCI dataset not available)


In [19]:
# Cell 23: Overall Summary

print("\n" + "="*70)
print("EXPERIMENT 20: OVERALL SUMMARY")
print("="*70)

print("\n### PART A: Retrieval Ranking (MS MARCO) ###")
print(f"Task: Rank documents by P(query|doc)")
if retrieval_results:
    print(f"Bare MRR: {retrieval_results['bare']['mrr']:.3f}")
    print(f"Oracle Primed MRR: {retrieval_results['oracle_primed']['mrr']:.3f}")
    delta_mrr = retrieval_results['oracle_primed']['mrr'] - retrieval_results['bare']['mrr']
    print(f"Delta: {delta_mrr:+.3f}")
    verdict_a = "HELPS" if delta_mrr > 0.02 else "HURTS" if delta_mrr < -0.02 else "NEUTRAL"
    print(f"Verdict: {verdict_a}")

print("\n### PART B: Semantic Steering (Yahoo Answers) ###")
print(f"Task: Score reference answer NLL with topic priming")
if steering_results:
    bare_nll = steering_results['bare']['mean_ref_nll']
    topic_nll = steering_results['topic_primed']['mean_ref_nll']
    random_nll = steering_results['random_primed']['mean_ref_nll']
    print(f"Bare Reference NLL: {bare_nll:.3f}")
    print(f"Topic Primed Reference NLL: {topic_nll:.3f}")
    print(f"Random Primed Reference NLL: {random_nll:.3f}")
    print(f"Delta (bare - topic): {bare_nll - topic_nll:+.3f}")
    print(f"Delta (random - topic): {random_nll - topic_nll:+.3f} (semantic signal)")
    verdict_b = "HELPS" if topic_nll < bare_nll - 0.1 else "HURTS" if topic_nll > bare_nll + 0.1 else "NEUTRAL"
    print(f"Verdict: {verdict_b}")

print("\n### PART C: Multi-Document Focus (MS MARCO) ###")
print(f"Task: Answer from 3 passages (1 relevant + 2 distractors)")
if multidoc_results:
    bare_nll = multidoc_results['bare']['mean_nll']
    hint_nll = multidoc_results['query_hint']['mean_nll']
    print(f"Bare NLL: {bare_nll:.3f}")
    print(f"Query Hint NLL: {hint_nll:.3f}")
    print(f"Delta: {bare_nll - hint_nll:+.3f}")
    verdict_c = "HELPS" if hint_nll < bare_nll - 0.1 else "HURTS" if hint_nll > bare_nll + 0.1 else "NEUTRAL"
    print(f"Verdict: {verdict_c}")

print("\n### PART D: Product Search (Amazon ESCI) ###")
if product_results:
    print(f"Task: Rank products by P(query|product)")
    print(f"Bare MRR: {product_results['bare']['mrr']:.3f}")
    print(f"Oracle Primed MRR: {product_results['oracle_primed']['mrr']:.3f}")
    delta_mrr = product_results['oracle_primed']['mrr'] - product_results['bare']['mrr']
    print(f"Delta: {delta_mrr:+.3f}")
    verdict_d = "HELPS" if delta_mrr > 0.02 else "HURTS" if delta_mrr < -0.02 else "NEUTRAL"
    print(f"Verdict: {verdict_d}")
else:
    print("(Skipped - dataset not available)")

print("\n" + "="*70)
print("KEY FINDINGS")
print("="*70)


EXPERIMENT 20: OVERALL SUMMARY

### PART A: Retrieval Ranking (MS MARCO) ###
Task: Rank documents by P(query|doc)
Bare MRR: 0.985
Oracle Primed MRR: 0.962
Delta: -0.023
Verdict: HURTS

### PART B: Semantic Steering (Yahoo Answers) ###
Task: Score reference answer NLL with topic priming
Bare Reference NLL: 3.077
Topic Primed Reference NLL: 3.484
Random Primed Reference NLL: 3.457
Delta (bare - topic): -0.407
Delta (random - topic): -0.027 (semantic signal)
Verdict: HURTS

### PART C: Multi-Document Focus (MS MARCO) ###
Task: Answer from 3 passages (1 relevant + 2 distractors)
Bare NLL: 2.801
Query Hint NLL: 2.726
Delta: +0.074
Verdict: NEUTRAL

### PART D: Product Search (Amazon ESCI) ###
(Skipped - dataset not available)

KEY FINDINGS


In [20]:
# Cell 24: Save Results

all_results = {
    'retrieval_ranking': {k: {kk: vv for kk, vv in v.items() if kk != 'rankings'} 
                         for k, v in retrieval_results.items()} if retrieval_results else {},
    'semantic_steering': {k: {kk: vv for kk, vv in v.items() if kk != 'results'} 
                         for k, v in steering_results.items()} if steering_results else {},
    'multi_document': {k: {kk: vv for kk, vv in v.items() if kk != 'results'} 
                      for k, v in multidoc_results.items()} if multidoc_results else {},
    'product_search': {k: {kk: vv for kk, vv in v.items() if kk != 'rankings'} 
                      for k, v in product_results.items()} if product_results else {},
}

with open(f'{OUTPUT_DIR}/results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"Results saved to {OUTPUT_DIR}/results.json")
print("\nSteering results summary:")
for cond, res in steering_results.items():
    print(f"  {cond}: NLL={res['mean_ref_nll']:.3f}")

Results saved to /home/jupyter/research/directed_kvcache/results/exp20/results.json

Steering results summary:
  bare: NLL=3.077
  topic_primed: NLL=3.484
  random_primed: NLL=3.457
