# HQF-DE: Document Expansion Pipeline

This notebook expands MS MARCO documents using the HQF-DE pipeline:

1. **LLM** (Llama-3-8B) → Identifies semantic gaps and generates expansions
2. **NLI** (BART-large-mnli) → Validates expansions are factually consistent
3. **Doc2Query** (T5) → Generates synthetic queries
4. **Combiner** → Deduplicates and selects best expansions

**Input:** `collection_100k.tsv` (100K docs, 33MB)  
**Output:** `expanded_hqfde.tsv`, `expanded_d2q.tsv`

## Step 1: Setup

In [25]:
# Check GPU
!nvidia-smi

# Install dependencies
!pip install -q transformers sentence-transformers accelerate bitsandbytes pydantic-settings sentencepiece protobuf tqdm

Sat Dec 20 22:52:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   36C    P0             55W /  400W |   20747MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# Login to HuggingFace (required for Llama-3)
# Option 1: Use Colab secrets (recommended)
# Option 2: Paste token directly (uncomment and add token)

from huggingface_hub import login

try:
    from google.colab import userdata
    token = userdata.get('HF_TOKEN')
    login(token=token)
    print("Logged in using Colab secret!")
except:
    # For VS Code: uncomment and add your token
    # login(token="your_token_here")
    print("Add your HF token: login(token='your_token_here')")

In [None]:
# Setup data paths
import os
os.makedirs('/content/data', exist_ok=True)
os.makedirs('/content/output', exist_ok=True)

INPUT_FILE = "/content/data/collection_100k.tsv"
OUTPUT_HQFDE = "/content/output/expanded_hqfde.tsv"
OUTPUT_D2Q = "/content/output/expanded_d2q.tsv"

# Option 1: If file already uploaded to runtime
if os.path.exists(INPUT_FILE):
    print(f"Found: {INPUT_FILE}")
    !wc -l {INPUT_FILE}
else:
    # Download from Google Drive
    # Replace FILE_ID with your file ID from the shareable link
    FILE_ID = "YOUR_FILE_ID_HERE"  # Update this after uploading collection_100k.tsv
    
    if FILE_ID != "YOUR_FILE_ID_HERE":
        !pip install -q gdown
        import gdown
        gdown.download(f"https://drive.google.com/uc?id={FILE_ID}", INPUT_FILE, quiet=False)
    else:
        print("File not found. To download from Drive:")
        print("1. Upload collection_100k.tsv to Google Drive")
        print("2. Share it (Anyone with link)")
        print("3. Copy the FILE_ID from the share link")
        print("4. Replace YOUR_FILE_ID_HERE above and re-run")

## Step 2: Load Models

In [28]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from transformers import T5ForConditionalGeneration, T5Tokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

DEVICE = "cuda"
CACHE = "/content/cache"

# Load Llama-3-8B (4-bit quantized)
print("Loading Llama-3-8B...")
llm_name = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, cache_dir=CACHE)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_model = AutoModelForCausalLM.from_pretrained(
    llm_name, cache_dir=CACHE,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
    device_map="auto"
)
llm_pipe = pipeline("text-generation", model=llm_model, tokenizer=llm_tokenizer, max_new_tokens=256)

# Load NLI (DeBERTa) - using correct model name
print("Loading DeBERTa NLI...")
nli_pipe = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0)

# Load Doc2Query (T5)
print("Loading Doc2Query...")
d2q_tokenizer = T5Tokenizer.from_pretrained("castorini/doc2query-t5-base-msmarco", cache_dir=CACHE)
d2q_model = T5ForConditionalGeneration.from_pretrained("castorini/doc2query-t5-base-msmarco", cache_dir=CACHE).to(DEVICE)

# Load Sentence-BERT
print("Loading Sentence-BERT...")
sbert = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)

print("All models loaded!")

Loading Llama-3-8B...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


Loading DeBERTa NLI...


Device set to use cuda:0


Loading Doc2Query...
Loading Sentence-BERT...
All models loaded!


## Step 3: Define Expansion Functions

In [29]:
import torch
from tqdm import tqdm

# Fix padding warning
llm_tokenizer.padding_side = 'left'

def llm_expand_batch(docs, batch_size=8):
    """Batch LLM expansion for efficiency."""
    all_expansions = []
    
    for i in tqdm(range(0, len(docs), batch_size), desc="  LLM", leave=False):
        batch = docs[i:i+batch_size]
        prompts = []
        for doc in batch:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Analyze this document and generate 3-5 brief factual expansions that add missing context:
{doc}

Expansions:<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
            prompts.append(prompt)
        
        results = llm_pipe(prompts, return_full_text=False, pad_token_id=llm_tokenizer.pad_token_id, batch_size=len(prompts))
        
        for result in results:
            text = result[0]["generated_text"]
            expansions = []
            for line in text.split("\n"):
                line = line.strip()
                if line and len(line) > 10:
                    if line[0].isdigit():
                        line = line.split(".", 1)[-1].strip()
                    if line.startswith("-"):
                        line = line[1:].strip()
                    if line:
                        expansions.append(line)
            all_expansions.append(expansions[:5])
    
    return all_expansions


def nli_validate_batch(docs, expansions_list, threshold=0.7):
    """Batch NLI validation."""
    # Flatten all (doc, expansion) pairs
    pairs = []
    pair_indices = []
    for i, (doc, expansions) in enumerate(zip(docs, expansions_list)):
        for exp in expansions:
            pairs.append((doc, exp))
            pair_indices.append(i)
    
    if not pairs:
        return [[] for _ in docs]
    
    # Process in batches
    valid_flags = []
    batch_size = 32
    for i in tqdm(range(0, len(pairs), batch_size), desc="  NLI", leave=False):
        batch = pairs[i:i+batch_size]
        try:
            results = nli_pipe(
                [exp for _, exp in batch],
                candidate_labels=["entailed", "not entailed"],
                batch_size=len(batch)
            )
            if not isinstance(results, list):
                results = [results]
            for r in results:
                is_valid = r['labels'][0] == 'entailed' and r['scores'][0] >= threshold
                valid_flags.append(is_valid)
        except:
            valid_flags.extend([False] * len(batch))
    
    # Reconstruct per-document valid lists
    all_valid = [[] for _ in docs]
    for idx, ((doc, exp), is_valid) in enumerate(zip(pairs, valid_flags)):
        if is_valid:
            all_valid[pair_indices[idx]].append(exp)
    
    return all_valid


def doc2query_batch(docs, n=5, batch_size=16):
    """Batch Doc2Query generation."""
    all_queries = []
    
    for i in tqdm(range(0, len(docs), batch_size), desc="  D2Q", leave=False):
        batch = docs[i:i+batch_size]
        inputs = d2q_tokenizer(batch, max_length=512, truncation=True, padding=True, return_tensors="pt").to(DEVICE)
        
        with torch.no_grad():
            outputs = d2q_model.generate(
                **inputs, 
                max_length=64, 
                do_sample=True, 
                top_k=10, 
                num_return_sequences=n
            )
        
        # Reshape: (batch_size * n) -> (batch_size, n)
        for j in range(len(batch)):
            queries = []
            for k in range(n):
                idx = j * n + k
                if idx < len(outputs):
                    q = d2q_tokenizer.decode(outputs[idx], skip_special_tokens=True).strip()
                    if q and q not in queries:
                        queries.append(q)
            all_queries.append(queries)
    
    return all_queries


def deduplicate_batch(docs, expansions_list, threshold=0.85):
    """Batch deduplication using embeddings."""
    all_final = []
    
    for doc, expansions in zip(docs, expansions_list):
        if not expansions:
            all_final.append([])
            continue
        
        # Remove duplicates among expansions
        embs = sbert.encode(expansions, normalize_embeddings=True)
        sim = cosine_similarity(embs)
        kept = []
        removed = set()
        for i in range(len(expansions)):
            if i not in removed:
                kept.append(expansions[i])
                for j in range(i+1, len(expansions)):
                    if sim[i,j] >= threshold:
                        removed.add(j)
        
        # Remove expansions too similar to doc
        if kept:
            doc_emb = sbert.encode([doc], normalize_embeddings=True)
            exp_embs = sbert.encode(kept, normalize_embeddings=True)
            sims = cosine_similarity(exp_embs, doc_emb).flatten()
            kept = [e for e, s in zip(kept, sims) if s < threshold]
        
        all_final.append(kept[:10])
    
    return all_final


def expand_batch(doc_ids, docs, batch_size=8):
    """Full HQF-DE pipeline with batching."""
    print(f"    Processing {len(docs)} docs...")
    
    # Phase 1: LLM expansion
    llm_expansions = llm_expand_batch(docs, batch_size=batch_size)
    
    # Phase 2: NLI validation
    valid_expansions = nli_validate_batch(docs, llm_expansions)
    
    # Phase 3: Doc2Query
    queries = doc2query_batch(docs, n=5, batch_size=batch_size*2)
    
    # Phase 4: Combine and deduplicate
    combined = [v + q for v, q in zip(valid_expansions, queries)]
    final = deduplicate_batch(docs, combined)
    
    # Build results
    results = []
    for doc_id, doc, f in zip(doc_ids, docs, final):
        expanded = f"{doc} {' '.join(f)}"
        results.append((doc_id, expanded))
    
    return results


def expand_d2q_batch(doc_ids, docs, batch_size=16):
    """Doc2Query baseline with batching."""
    queries = doc2query_batch(docs, n=5, batch_size=batch_size)
    
    results = []
    for doc_id, doc, q in zip(doc_ids, docs, queries):
        expanded = f"{doc} {' '.join(q)}"
        results.append((doc_id, expanded))
    
    return results


print("Batch functions defined!")

Batch functions defined!


## Step 4: Test on Sample Document

In [30]:
# Test on a sample document
test_doc = "The Eiffel Tower is a famous landmark in Paris, France. It was built in 1889."
test_id = "test"

print("Original:", test_doc)
print()

# Test batch functions with single doc
results = expand_batch([test_id], [test_doc], batch_size=1)
doc_id, expanded = results[0]

print("Expanded:")
print(expanded)
print()
print("Expansion added:", expanded.replace(test_doc, "").strip())

Original: The Eiffel Tower is a famous landmark in Paris, France. It was built in 1889.

    Processing 1 docs...


                                                    

Expanded:
The Eiffel Tower is a famous landmark in Paris, France. It was built in 1889. Based on the provided document, I've generated the following brief factual expansions that add missing context: **The Eiffel Tower was built for the World's Fair**: The Eiffel Tower was constructed for the 1889 World's Fair, held in Paris, France. The fair was a celebration of French culture, technology, and innovation, and the tower was the main centerpiece of the event. **Designed by Gustave Eiffel**: The Eiffel Tower was designed and built by Gustave Eiffel, a French engineer and entrepreneur. Eiffel's company, Compagnie des Établissements Eiffel, was responsible for the tower's construction. eiffel tower famous buildings

Expansion added: Based on the provided document, I've generated the following brief factual expansions that add missing context: **The Eiffel Tower was built for the World's Fair**: The Eiffel Tower was constructed for the 1889 World's Fair, held in Paris, France. The fair was 



## Step 5: Process Documents

In [None]:
import csv
from tqdm import tqdm

LIMIT = None  # Process all 100K docs (set to 1000 for testing)
BATCH_SIZE = 32  # Adjust based on GPU memory
CHECKPOINT_EVERY = 5  # Save progress every N batches (smaller = more updates)

# Load documents
docs = []
doc_ids = []
with open(INPUT_FILE, 'r') as f:
    for row in csv.reader(f, delimiter='\t'):
        if len(row) >= 2:
            doc_ids.append(row[0])
            docs.append(row[1])
            if LIMIT and len(docs) >= LIMIT:
                break

print(f"Loaded {len(docs):,} documents")
print(f"Will process in mega-batches of {BATCH_SIZE * CHECKPOINT_EVERY:,} docs")
print(f"Estimated time at 0.6 docs/sec: {len(docs)/0.6/3600:.1f} hours")

In [45]:
# Run HQF-DE expansion with batching
import time
print(f"Running HQF-DE expansion on {len(docs):,} docs with batch_size={BATCH_SIZE}...")
print(f"Output: {OUTPUT_HQFDE}")

start_time = time.time()
processed = 0
mega_batch = BATCH_SIZE * CHECKPOINT_EVERY  # Process this many before checkpoint

with open(OUTPUT_HQFDE, 'w', newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    
    for i in tqdm(range(0, len(docs), mega_batch), desc="HQF-DE"):
        batch_doc_ids = doc_ids[i:i+mega_batch]
        batch_docs = docs[i:i+mega_batch]
        
        try:
            results = expand_batch(batch_doc_ids, batch_docs, batch_size=BATCH_SIZE)
            for doc_id, expanded in results:
                writer.writerow([doc_id, expanded])
            f.flush()  # Ensure data is written
        except Exception as e:
            print(f"Error at batch {i}: {e}")
            # Write originals on error
            for doc_id, doc in zip(batch_doc_ids, batch_docs):
                writer.writerow([doc_id, doc])
        
        processed += len(batch_docs)
        elapsed = time.time() - start_time
        rate = processed / elapsed
        remaining = (len(docs) - processed) / rate if rate > 0 else 0
        
        if (i // mega_batch) % 10 == 0:
            print(f"  Processed: {processed:,} | Rate: {rate:.1f} docs/sec | ETA: {remaining/3600:.1f} hrs")

total_time = time.time() - start_time
print(f"\nDone! Processed {processed:,} docs in {total_time/3600:.2f} hours")
print(f"Rate: {processed/total_time:.1f} docs/sec")

Running HQF-DE expansion on 1,000 docs with batch_size=32...
Output: /content/output/expanded_hqfde.tsv


HQF-DE:   0%|          | 0/7 [00:00<?, ?it/s]

    Processing 160 docs...


HQF-DE:  14%|█▍        | 1/7 [04:05<24:35, 245.91s/it]

  Processed: 160 | Rate: 0.7 docs/sec | ETA: 0.4 hrs
    Processing 160 docs...


HQF-DE:  29%|██▊       | 2/7 [09:06<23:10, 278.18s/it]

    Processing 160 docs...


HQF-DE:  43%|████▎     | 3/7 [13:13<17:35, 263.92s/it]

    Processing 160 docs...


HQF-DE:  57%|█████▋    | 4/7 [17:41<13:16, 265.43s/it]

    Processing 160 docs...


HQF-DE:  71%|███████▏  | 5/7 [22:13<08:55, 267.67s/it]

    Processing 160 docs...


HQF-DE:  86%|████████▌ | 6/7 [26:46<04:29, 269.79s/it]

    Processing 40 docs...


HQF-DE: 100%|██████████| 7/7 [28:23<00:00, 243.41s/it]


Done! Processed 1,000 docs in 0.47 hours
Rate: 0.6 docs/sec





In [46]:
# Run Doc2Query baseline with batching (much faster)
import time
print(f"Running Doc2Query baseline on {len(docs):,} docs...")
print(f"Output: {OUTPUT_D2Q}")

start_time = time.time()
D2Q_BATCH = 32  # D2Q can handle larger batches

with open(OUTPUT_D2Q, 'w', newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    
    for i in tqdm(range(0, len(docs), mega_batch), desc="Doc2Query"):
        batch_doc_ids = doc_ids[i:i+mega_batch]
        batch_docs = docs[i:i+mega_batch]
        
        try:
            results = expand_d2q_batch(batch_doc_ids, batch_docs, batch_size=D2Q_BATCH)
            for doc_id, expanded in results:
                writer.writerow([doc_id, expanded])
            f.flush()
        except Exception as e:
            print(f"Error at batch {i}: {e}")
            for doc_id, doc in zip(batch_doc_ids, batch_docs):
                writer.writerow([doc_id, doc])

total_time = time.time() - start_time
print(f"\nDone! Processed {len(docs):,} docs in {total_time/3600:.2f} hours")
print(f"Rate: {len(docs)/total_time:.1f} docs/sec")

Running Doc2Query baseline on 1,000 docs...
Output: /content/output/expanded_d2q.tsv


Doc2Query: 100%|██████████| 7/7 [00:27<00:00,  3.96s/it]


Done! Processed 1,000 docs in 0.01 hours
Rate: 36.0 docs/sec





## Step 6: Verify Output

In [56]:
import os

print("Output files saved to Google Drive:")
print(f"  HQF-DE: {OUTPUT_HQFDE} ({os.path.getsize(OUTPUT_HQFDE)/1e6:.1f} MB)")
print(f"  D2Q:    {OUTPUT_D2Q} ({os.path.getsize(OUTPUT_D2Q)/1e6:.1f} MB)")
print()
print("Files are automatically synced to your Google Drive.")
print("Next: Use these with your C++ indexer for evaluation.")

Output files saved to Google Drive:
  HQF-DE: /content/output/expanded_hqfde.tsv (1.7 MB)
  D2Q:    /content/output/expanded_d2q.tsv (0.5 MB)

Files are automatically synced to your Google Drive.
Next: Use these with your C++ indexer for evaluation.


In [60]:
# Upload output files to get download links (VS Code compatible)
!pip install -q requests

import os
import requests

print("Uploading output files...\n")

for filename in ['expanded_hqfde.tsv', 'expanded_d2q.tsv']:
    filepath = f'/content/output/{filename}'
    if os.path.exists(filepath):
        size_mb = os.path.getsize(filepath) / 1e6
        print(f"{filename} ({size_mb:.1f} MB)")
        
        uploaded = False
        
        # Try 0x0.st
        try:
            with open(filepath, 'rb') as f:
                response = requests.post('https://0x0.st', files={'file': f}, timeout=60)
            if response.status_code == 200 and response.text.startswith('http'):
                print(f"  ✓ {response.text.strip()}")
                uploaded = True
        except Exception as e:
            print(f"  0x0.st failed: {e}")
        
        # Try transfer.sh
        if not uploaded:
            try:
                with open(filepath, 'rb') as f:
                    response = requests.put(f'https://transfer.sh/{filename}', data=f, timeout=60)
                if response.status_code == 200 and response.text.startswith('http'):
                    print(f"  ✓ {response.text.strip()}")
                    uploaded = True
            except Exception as e:
                print(f"  transfer.sh failed: {e}")
        
        # Try tmpfiles.org
        if not uploaded:
            try:
                with open(filepath, 'rb') as f:
                    response = requests.post('https://tmpfiles.org/api/v1/upload', files={'file': f}, timeout=60)
                if response.status_code == 200:
                    data = response.json()
                    if data.get('status') == 'success':
                        # Convert URL from tmpfiles.org/123/file to tmpfiles.org/dl/123/file
                        url = data['data']['url'].replace('tmpfiles.org/', 'tmpfiles.org/dl/')
                        print(f"  ✓ {url}")
                        uploaded = True
            except Exception as e:
                print(f"  tmpfiles.org failed: {e}")
        
        if not uploaded:
            print(f"  ✗ All upload methods failed")
        print()
    else:
        print(f"✗ {filepath} not found\n")

print("Download these files using the URLs above.")

Uploading output files...

expanded_hqfde.tsv (1.7 MB)
  transfer.sh failed: HTTPSConnectionPool(host='transfer.sh', port=443): Max retries exceeded with url: /expanded_hqfde.tsv (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x78819008d490>: Failed to establish a new connection: [Errno 101] Network is unreachable'))
  ✓ http://tmpfiles.org/dl/16360996/expanded_hqfde.tsv

expanded_d2q.tsv (0.5 MB)
  transfer.sh failed: HTTPSConnectionPool(host='transfer.sh', port=443): Max retries exceeded with url: /expanded_d2q.tsv (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x78819008f230>: Failed to establish a new connection: [Errno 101] Network is unreachable'))
  ✓ http://tmpfiles.org/dl/16361000/expanded_d2q.tsv

Download these files using the URLs above.
