# HQF-DE: Document Expansion Pipeline

This notebook expands MS MARCO documents using the HQF-DE pipeline:

1. **LLM** (Llama-3-8B) → Identifies semantic gaps and generates expansions
2. **NLI** (DeBERTa) → Validates expansions are factually consistent
3. **Doc2Query** (T5) → Generates synthetic queries
4. **Combiner** → Deduplicates and selects best expansions

**Input:** `collection_subset.tsv` (1M docs, 331MB)  
**Output:** `expanded_hqfde.tsv`, `expanded_d2q.tsv`

## Step 1: Setup

In [None]:
# Check GPU
!nvidia-smi

# Install dependencies
!pip install -q transformers sentence-transformers accelerate bitsandbytes pydantic-settings sentencepiece protobuf tqdm

In [None]:
# Login to HuggingFace (required for Llama-3)
# Add your token to Colab secrets: Key = HF_TOKEN, Value = your_token
from google.colab import userdata
from huggingface_hub import login

try:
    token = userdata.get('HF_TOKEN')
    login(token=token)
    print("Logged in with Colab secret!")
except:
    print("HF_TOKEN not found in Colab secrets. Add it via the key icon in the left sidebar.")
    login()  # Fallback to manual prompt

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs('/content/output', exist_ok=True)

# Set paths - UPDATE THESE to match your Drive folder
DRIVE_FOLDER = "/content/drive/MyDrive/hqf_de"  # Change if needed
INPUT_FILE = f"{DRIVE_FOLDER}/collection_subset.tsv"
OUTPUT_HQFDE = f"{DRIVE_FOLDER}/expanded_hqfde.tsv"
OUTPUT_D2Q = f"{DRIVE_FOLDER}/expanded_d2q.tsv"

# Verify input file exists
if os.path.exists(INPUT_FILE):
    print(f"Found: {INPUT_FILE}")
    !wc -l {INPUT_FILE}
else:
    print(f"ERROR: File not found: {INPUT_FILE}")
    print(f"Please upload collection_subset.tsv to Google Drive at: {DRIVE_FOLDER}/")

## Step 2: Load Models

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from transformers import T5ForConditionalGeneration, T5Tokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

DEVICE = "cuda"
CACHE = "/content/cache"

# Load Llama-3-8B (4-bit quantized)
print("Loading Llama-3-8B...")
llm_name = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, cache_dir=CACHE)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_model = AutoModelForCausalLM.from_pretrained(
    llm_name, cache_dir=CACHE,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
    device_map="auto"
)
llm_pipe = pipeline("text-generation", model=llm_model, tokenizer=llm_tokenizer, max_new_tokens=256)

# Load NLI (DeBERTa)
print("Loading DeBERTa NLI...")
nli_pipe = pipeline("text-classification", model="microsoft/deberta-v3-large-mnli", device=0, top_k=None)

# Load Doc2Query (T5)
print("Loading Doc2Query...")
d2q_tokenizer = T5Tokenizer.from_pretrained("castorini/doc2query-t5-base-msmarco", cache_dir=CACHE)
d2q_model = T5ForConditionalGeneration.from_pretrained("castorini/doc2query-t5-base-msmarco", cache_dir=CACHE).to(DEVICE)

# Load Sentence-BERT
print("Loading Sentence-BERT...")
sbert = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)

print("All models loaded!")

## Step 3: Define Expansion Functions

In [None]:
def llm_expand(doc):
    """Use LLM to identify gaps and generate expansions."""
    prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Analyze this document and generate 3-5 brief factual expansions that add missing context:
{doc}

Expansions:<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    result = llm_pipe(prompt, return_full_text=False, pad_token_id=llm_tokenizer.pad_token_id)
    text = result[0]["generated_text"]
    
    # Parse numbered/bulleted list
    expansions = []
    for line in text.split("\n"):
        line = line.strip()
        if line and len(line) > 10:
            if line[0].isdigit():
                line = line.split(".", 1)[-1].strip()
            if line.startswith("-"):
                line = line[1:].strip()
            if line:
                expansions.append(line)
    return expansions[:5]


def nli_validate(doc, expansions, threshold=0.7):
    """Filter expansions using NLI entailment."""
    valid = []
    for exp in expansions:
        try:
            result = nli_pipe(f"{doc} [SEP] {exp}", truncation=True, max_length=512)
            scores = {r["label"].lower(): r["score"] for r in result}
            if scores.get("entailment", 0) >= threshold:
                valid.append(exp)
        except:
            pass
    return valid


def doc2query(doc, n=5):
    """Generate synthetic queries."""
    inputs = d2q_tokenizer(doc, max_length=512, truncation=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = d2q_model.generate(**inputs, max_length=64, do_sample=True, top_k=10, num_return_sequences=n)
    queries = []
    for out in outputs:
        q = d2q_tokenizer.decode(out, skip_special_tokens=True).strip()
        if q and q not in queries:
            queries.append(q)
    return queries


def deduplicate(doc, expansions, threshold=0.85):
    """Remove duplicate and redundant expansions."""
    if not expansions:
        return []
    
    embs = sbert.encode(expansions, normalize_embeddings=True)
    sim = cosine_similarity(embs)
    kept = []
    removed = set()
    for i in range(len(expansions)):
        if i not in removed:
            kept.append(expansions[i])
            for j in range(i+1, len(expansions)):
                if sim[i,j] >= threshold:
                    removed.add(j)
    
    if kept:
        doc_emb = sbert.encode([doc], normalize_embeddings=True)
        exp_embs = sbert.encode(kept, normalize_embeddings=True)
        sims = cosine_similarity(exp_embs, doc_emb).flatten()
        kept = [e for e, s in zip(kept, sims) if s < threshold]
    
    return kept[:10]


def expand_hqfde(doc_id, doc):
    """Full HQF-DE pipeline."""
    expansions = llm_expand(doc)
    valid = nli_validate(doc, expansions)
    queries = doc2query(doc)
    final = deduplicate(doc, valid + queries)
    return doc_id, f"{doc} {' '.join(final)}"


def expand_d2q_only(doc_id, doc):
    """Doc2Query baseline only."""
    queries = doc2query(doc)
    return doc_id, f"{doc} {' '.join(queries)}"


print("Functions defined!")

## Step 4: Test on Sample Document

In [None]:
test_doc = "The Eiffel Tower is a famous landmark in Paris, France. It was built in 1889."

print("Original:", test_doc)
print()

expansions = llm_expand(test_doc)
print("LLM Expansions:")
for e in expansions:
    print(f"  - {e}")

valid = nli_validate(test_doc, expansions)
print(f"\nNLI Valid ({len(valid)}/{len(expansions)}):")
for e in valid:
    print(f"  + {e}")

queries = doc2query(test_doc)
print("\nDoc2Query:")
for q in queries:
    print(f"  ? {q}")

_, expanded = expand_hqfde("test", test_doc)
print("\nFinal Expanded:")
print(expanded)

## Step 5: Process Documents

In [None]:
import csv
from tqdm import tqdm

LIMIT = 100000  # Process 100K docs (set to None for all 1M)

# Load documents
docs = []
with open(INPUT_FILE, 'r') as f:
    for row in csv.reader(f, delimiter='\t'):
        if len(row) >= 2:
            docs.append((row[0], row[1]))
            if LIMIT and len(docs) >= LIMIT:
                break

print(f"Loaded {len(docs):,} documents")

In [None]:
# Run HQF-DE expansion
print(f"Running HQF-DE expansion on {len(docs):,} docs...")
print(f"Output: {OUTPUT_HQFDE}")

with open(OUTPUT_HQFDE, 'w', newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    for doc_id, text in tqdm(docs):
        try:
            did, expanded = expand_hqfde(doc_id, text)
            writer.writerow([did, expanded])
        except Exception as e:
            writer.writerow([doc_id, text])

print("Done!")

In [None]:
# Run Doc2Query baseline (for comparison)
print(f"Running Doc2Query baseline on {len(docs):,} docs...")
print(f"Output: {OUTPUT_D2Q}")

with open(OUTPUT_D2Q, 'w', newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    for doc_id, text in tqdm(docs):
        try:
            did, expanded = expand_d2q_only(doc_id, text)
            writer.writerow([did, expanded])
        except Exception as e:
            writer.writerow([doc_id, text])

print("Done!")

## Step 6: Verify Output

In [None]:
import os

print("Output files saved to Google Drive:")
print(f"  HQF-DE: {OUTPUT_HQFDE} ({os.path.getsize(OUTPUT_HQFDE)/1e6:.1f} MB)")
print(f"  D2Q:    {OUTPUT_D2Q} ({os.path.getsize(OUTPUT_D2Q)/1e6:.1f} MB)")
print()
print("Files are automatically synced to your Google Drive.")
print("Next: Use these with your C++ indexer for evaluation.")