# HQF-DE: Hybrid Query- and Fact-Guided Document Expansion

This notebook runs the HQF-DE pipeline on Google Colab with GPU support.

In [6]:
# Install dependencies
!pip install -q torch transformers sentence-transformers scikit-learn pydantic-settings sentencepiece protobuf accelerate bitsandbytes

In [None]:
# Login to HuggingFace (required for Llama access)
# Option 1: Use Colab secrets (recommended)
# from google.colab import userdata
# from huggingface_hub import login
# login(token=userdata.get('HF_TOKEN'))

# Option 2: Enter token manually when prompted
from huggingface_hub import login
login()  # Will prompt for token

In [8]:
# Mount Google Drive (optional - for saving results)
from google.colab import drive
drive.mount('/content/drive')

ValueError: mount failed

In [None]:
# Clone or upload the project
# Option 1: If uploaded to Drive
# import sys
# sys.path.append('/content/drive/MyDrive/wse-final-project')

# Option 2: Clone from GitHub (if you push it)
# !git clone https://github.com/yourusername/wse-final-project.git
# %cd wse-final-project

In [None]:
# Create the hqf_de package inline (if not using Drive/GitHub)
!mkdir -p hqf_de/models hqf_de/pipeline hqf_de/evaluation

In [None]:
%%writefile hqf_de/config.py
from pydantic_settings import BaseSettings
from pydantic import Field
from pathlib import Path

class HQFDEConfig(BaseSettings):
    project_root: Path = Field(default=Path("/content"))
    data_dir: Path = Field(default=Path("/content/data"))
    output_dir: Path = Field(default=Path("/content/output"))
    cache_dir: Path = Field(default=Path("/content/cache"))

    input_tsv: str = Field(default="collection.tsv")
    output_tsv: str = Field(default="expanded_passages.tsv")

    llm_model_name: str = Field(default="meta-llama/Meta-Llama-3-8B-Instruct")
    llm_max_new_tokens: int = Field(default=256)
    llm_temperature: float = Field(default=0.7)

    doc2query_model: str = Field(default="castorini/doc2query-t5-base-msmarco")
    num_queries_per_doc: int = Field(default=5)

    nli_model: str = Field(default="microsoft/deberta-v3-large-mnli")
    nli_entailment_threshold: float = Field(default=0.9)

    embedding_model: str = Field(default="sentence-transformers/all-MiniLM-L6-v2")
    dedup_similarity_threshold: float = Field(default=0.85)

    batch_size: int = Field(default=8)
    max_doc_length: int = Field(default=512)
    device: str = Field(default="cuda")

    class Config:
        env_prefix = "HQFDE_"

config = HQFDEConfig()

In [None]:
%%writefile hqf_de/__init__.py
__version__ = "0.1.0"
from .config import config, HQFDEConfig
__all__ = ["config", "HQFDEConfig"]

In [None]:
%%writefile hqf_de/models/__init__.py
from .llm import LLM
from .doc2query import Doc2Query
from .nli import NLI
from .embeddings import Embedder
__all__ = ["LLM", "Doc2Query", "NLI", "Embedder"]

In [None]:
%%writefile hqf_de/models/llm.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from typing import List
from dataclasses import dataclass
import logging
from ..config import config

logger = logging.getLogger(__name__)

@dataclass
class Expansion:
    text: str
    gaps: List[str]
    expansions: List[str]

class LLM:
    GAP_PROMPT = """Analyze this document and list semantic gaps (max 5):
{document}

Gaps:"""

    EXPAND_PROMPT = """Generate brief factual expansions for this document:
{document}

Gaps: {gaps}

Expansions:"""

    def __init__(self, model: str = None, device: str = None, quantize: bool = True):
        self.model_name = model or config.llm_model_name
        self.device = device or config.device
        self.quantize = quantize
        self.model = None
        self.tokenizer = None
        self.pipe = None
        self._loaded = False

    def load(self):
        if self._loaded:
            return self
        logger.info(f"Loading LLM: {self.model_name}")
        quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) if self.quantize else None
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=config.cache_dir)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name, cache_dir=config.cache_dir, quantization_config=quant_config, torch_dtype=torch.float16, device_map="auto")
        self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, max_new_tokens=config.llm_max_new_tokens, temperature=config.llm_temperature, do_sample=True)
        self._loaded = True
        return self

    def _fmt(self, doc: str, template: str, **kw) -> str:
        content = template.format(document=doc, **kw)
        return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

    def _gen(self, prompt: str) -> str:
        if not self._loaded:
            self.load()
        result = self.pipe(prompt, return_full_text=False, pad_token_id=self.tokenizer.pad_token_id)
        return result[0]["generated_text"].strip()

    def _parse(self, text: str) -> List[str]:
        items = []
        for line in text.split("\n"):
            line = line.strip()
            if line and not line.startswith("#"):
                if line[0].isdigit():
                    line = line.split(".", 1)[-1].strip()
                if line.startswith("-"):
                    line = line[1:].strip()
                if line and len(line) > 5:
                    items.append(line)
        return items[:5]

    def gaps(self, doc: str) -> List[str]:
        return self._parse(self._gen(self._fmt(doc, self.GAP_PROMPT)))

    def expand(self, doc: str, gaps: List[str] = None) -> List[str]:
        gaps_text = "\n".join(gaps) if gaps else "none"
        return self._parse(self._gen(self._fmt(doc, self.EXPAND_PROMPT, gaps=gaps_text)))

    def run(self, doc: str) -> Expansion:
        g = self.gaps(doc)
        e = self.expand(doc, g)
        return Expansion(text=doc, gaps=g, expansions=e)

    def unload(self):
        if self.model:
            del self.model, self.tokenizer, self.pipe
            self.model = self.tokenizer = self.pipe = None
            self._loaded = False
            torch.cuda.empty_cache()

In [None]:
%%writefile hqf_de/models/doc2query.py
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from typing import List
import logging
from ..config import config

logger = logging.getLogger(__name__)

class Doc2Query:
    def __init__(self, model: str = None, device: str = None, n: int = 5):
        self.model_name = model or config.doc2query_model
        self.device = device or config.device
        self.n = n
        self.model = None
        self.tokenizer = None
        self._loaded = False

    def load(self):
        if self._loaded:
            return self
        logger.info(f"Loading doc2query: {self.model_name}")
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name, cache_dir=config.cache_dir)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name, cache_dir=config.cache_dir).to(self.device)
        self.model.eval()
        self._loaded = True
        return self

    def gen(self, doc: str, n: int = None) -> List[str]:
        if not self._loaded:
            self.load()
        n = n or self.n
        inputs = self.tokenizer(doc, max_length=config.max_doc_length, truncation=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_length=64, do_sample=True, top_k=10, num_return_sequences=n)
        queries = []
        for out in outputs:
            q = self.tokenizer.decode(out, skip_special_tokens=True).strip()
            if q and q not in queries:
                queries.append(q)
        return queries

    def unload(self):
        if self.model:
            del self.model, self.tokenizer
            self.model = self.tokenizer = None
            self._loaded = False
            torch.cuda.empty_cache()

In [None]:
%%writefile hqf_de/models/nli.py
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from typing import List, Tuple
from dataclasses import dataclass
import logging
from ..config import config

logger = logging.getLogger(__name__)

@dataclass
class NLIResult:
    hypothesis: str
    entailment: float
    valid: bool

class NLI:
    def __init__(self, model: str = None, device: str = None, threshold: float = 0.9):
        self.model_name = model or config.nli_model
        self.device = device or config.device
        self.threshold = threshold
        self.model = None
        self.tokenizer = None
        self.pipe = None
        self._loaded = False

    def load(self):
        if self._loaded:
            return self
        logger.info(f"Loading NLI: {self.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=config.cache_dir)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, cache_dir=config.cache_dir).to(self.device)
        self.model.eval()
        self.pipe = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer, device=0, top_k=None)
        self._loaded = True
        return self

    def check(self, premise: str, hypothesis: str) -> NLIResult:
        if not self._loaded:
            self.load()
        try:
            results = self.pipe(f"{premise} [SEP] {hypothesis}", truncation=True, max_length=512)
            scores = {r["label"].lower(): r["score"] for r in results}
            ent = scores.get("entailment", 0.0)
            return NLIResult(hypothesis=hypothesis, entailment=ent, valid=ent >= self.threshold)
        except:
            return NLIResult(hypothesis=hypothesis, entailment=0.0, valid=False)

    def validate(self, doc: str, expansions: List[str]) -> Tuple[List[str], List[NLIResult]]:
        if not self._loaded:
            self.load()
        valid = []
        results = []
        for exp in expansions:
            r = self.check(doc, exp)
            results.append(r)
            if r.valid:
                valid.append(exp)
        return valid, results

    def unload(self):
        if self.model:
            del self.model, self.tokenizer, self.pipe
            self.model = self.tokenizer = self.pipe = None
            self._loaded = False
            torch.cuda.empty_cache()

In [None]:
%%writefile hqf_de/models/embeddings.py
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple
import logging
from ..config import config

logger = logging.getLogger(__name__)

class Embedder:
    def __init__(self, model: str = None, device: str = None):
        self.model_name = model or config.embedding_model
        self.device = device or config.device
        self.model = None
        self._loaded = False

    def load(self):
        if self._loaded:
            return self
        logger.info(f"Loading embedder: {self.model_name}")
        self.model = SentenceTransformer(self.model_name, cache_folder=str(config.cache_dir), device=self.device)
        self._loaded = True
        return self

    def encode(self, texts: List[str]) -> np.ndarray:
        if not self._loaded:
            self.load()
        return self.model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)

    def sim(self, texts1: List[str], texts2: List[str] = None) -> np.ndarray:
        e1 = self.encode(texts1)
        if texts2 is None:
            return cosine_similarity(e1)
        return cosine_similarity(e1, self.encode(texts2))

    def dedup(self, texts: List[str], threshold: float = 0.85) -> Tuple[List[str], List[int]]:
        if len(texts) <= 1:
            return texts, list(range(len(texts)))
        sim = self.sim(texts)
        kept, indices, removed = [], [], set()
        for i in range(len(texts)):
            if i in removed:
                continue
            kept.append(texts[i])
            indices.append(i)
            for j in range(i + 1, len(texts)):
                if sim[i, j] >= threshold:
                    removed.add(j)
        return kept, indices

    def dedup_vs_doc(self, doc: str, expansions: List[str], threshold: float = 0.85) -> List[str]:
        if not expansions:
            return []
        sims = self.sim(expansions, [doc]).flatten()
        return [e for e, s in zip(expansions, sims) if s < threshold]

    def select(self, expansions: List[str], n: int = 5, doc: str = None) -> List[str]:
        if len(expansions) <= n:
            return expansions
        embs = self.encode(expansions)
        rel = self.sim(expansions, [doc]).flatten() if doc else np.ones(len(expansions))
        selected, indices = [], []
        for _ in range(n):
            best_idx, best_score = -1, -float('inf')
            for i in range(len(expansions)):
                if i in indices:
                    continue
                div = 1 - max(cosine_similarity(embs[i:i+1], embs[indices]).flatten()) if indices else 1.0
                score = 0.5 * rel[i] + 0.5 * div
                if score > best_score:
                    best_score, best_idx = score, i
            if best_idx >= 0:
                selected.append(expansions[best_idx])
                indices.append(best_idx)
        return selected

    def unload(self):
        if self.model:
            del self.model
            self.model = None
            self._loaded = False
            torch.cuda.empty_cache()

In [None]:
%%writefile hqf_de/pipeline/__init__.py
from .expander import Expander
from .combiner import Combiner
__all__ = ["Expander", "Combiner"]

In [None]:
%%writefile hqf_de/pipeline/combiner.py
from typing import List, Dict, Any
from dataclasses import dataclass, field
import logging
from ..models.embeddings import Embedder

logger = logging.getLogger(__name__)
GENERIC = {"information", "details", "things", "stuff", "content", "topic", "subject", "matter", "example", "case", "way", "method", "people", "time", "place", "thing"}

@dataclass
class Combined:
    original: str
    semantic: List[str] = field(default_factory=list)
    queries: List[str] = field(default_factory=list)
    final: List[str] = field(default_factory=list)
    text: str = ""
    meta: Dict[str, Any] = field(default_factory=dict)

class Combiner:
    def __init__(self, embedder: Embedder = None, threshold: float = 0.85, max_exp: int = 10):
        self.embedder = embedder or Embedder()
        self.threshold = threshold
        self.max_exp = max_exp
        self._loaded = False

    def _load(self):
        if not self._loaded:
            self.embedder.load()
            self._loaded = True

    def _filter(self, expansions: List[str], doc: str = None) -> List[str]:
        out = []
        for e in expansions:
            e = e.strip()
            if not e or len(e.split()) < 3 or len(e.split()) > 50:
                continue
            words = e.lower().split()
            if sum(1 for w in words if w in GENERIC) / len(words) > 0.5:
                continue
            if doc and e.lower() in doc.lower():
                continue
            out.append(e)
        return out

    def _dedup(self, expansions: List[str], doc: str = None) -> List[str]:
        if len(expansions) <= 1:
            return expansions
        self._load()
        deduped, _ = self.embedder.dedup(expansions, self.threshold)
        if doc:
            deduped = self.embedder.dedup_vs_doc(doc, deduped, self.threshold)
        return deduped

    def combine(self, doc: str, semantic: List[str], queries: List[str]) -> Combined:
        self._load()
        sem = self._filter(semantic, doc)
        q = self._filter(queries, doc)
        all_exp = sem + q
        deduped = self._dedup(all_exp, doc)
        final = self.embedder.select(deduped, self.max_exp, doc) if len(deduped) > self.max_exp else deduped[:self.max_exp]
        text = f"{doc} {' '.join(final)}"
        return Combined(original=doc, semantic=sem, queries=q, final=final, text=text, meta={"n_sem": len(sem), "n_q": len(q), "n_final": len(final)})

    def unload(self):
        if self._loaded:
            self.embedder.unload()
            self._loaded = False

In [None]:
%%writefile hqf_de/pipeline/expander.py
from typing import List, Dict, Any
from dataclasses import dataclass, field
import logging
from ..models.llm import LLM
from ..models.nli import NLI
from ..models.doc2query import Doc2Query
from ..models.embeddings import Embedder
from .combiner import Combiner
from ..config import config

logger = logging.getLogger(__name__)

@dataclass
class Result:
    doc_id: str
    original: str
    expanded: str
    gaps: List[str] = field(default_factory=list)
    raw: List[str] = field(default_factory=list)
    valid: List[str] = field(default_factory=list)
    rejected: List[str] = field(default_factory=list)
    queries: List[str] = field(default_factory=list)
    final: List[str] = field(default_factory=list)
    meta: Dict[str, Any] = field(default_factory=dict)

class Expander:
    def __init__(self, use_llm: bool = True, use_nli: bool = True, use_d2q: bool = True, device: str = None):
        self.device = device or config.device
        self.llm = LLM(device=self.device) if use_llm else None
        self.nli = NLI(device=self.device) if use_nli else None
        self.d2q = Doc2Query(device=self.device) if use_d2q else None
        self.combiner = Combiner(Embedder(device=self.device))
        self._loaded = False

    def load(self):
        if self._loaded:
            return self
        if self.llm:
            self.llm.load()
        if self.nli:
            self.nli.load()
        if self.d2q:
            self.d2q.load()
        self._loaded = True
        return self

    def expand(self, doc_id: str, doc: str) -> Result:
        if not self._loaded:
            self.load()
        result = Result(doc_id=doc_id, original=doc, expanded=doc)
        raw = []
        if self.llm:
            try:
                exp = self.llm.run(doc)
                result.gaps = exp.gaps
                raw = exp.expansions
                result.raw = raw
            except Exception as e:
                logger.error(f"LLM error: {e}")
        valid = raw
        if self.nli and raw:
            try:
                valid, _ = self.nli.validate(doc, raw)
                result.valid = valid
                result.rejected = [e for e in raw if e not in valid]
            except Exception as e:
                logger.error(f"NLI error: {e}")
        queries = []
        if self.d2q:
            try:
                queries = self.d2q.gen(doc)
                result.queries = queries
            except Exception as e:
                logger.error(f"D2Q error: {e}")
        try:
            combined = self.combiner.combine(doc, valid, queries)
            result.final = combined.final
            result.expanded = combined.text
            result.meta = combined.meta
        except Exception as e:
            logger.error(f"Combiner error: {e}")
            all_exp = valid + queries
            result.final = all_exp[:10]
            result.expanded = f"{doc} {' '.join(all_exp[:10])}"
        return result

    def d2q_only(self, doc_id: str, doc: str) -> Result:
        if self.d2q and not self.d2q._loaded:
            self.d2q.load()
        queries = self.d2q.gen(doc) if self.d2q else []
        return Result(doc_id=doc_id, original=doc, expanded=f"{doc} {' '.join(queries)}", queries=queries, final=queries, meta={"method": "d2q_only"})

    def unload(self):
        if self.llm:
            self.llm.unload()
        if self.nli:
            self.nli.unload()
        if self.d2q:
            self.d2q.unload()
        self.combiner.unload()
        self._loaded = False

    def __enter__(self):
        return self.load()

    def __exit__(self, *args):
        self.unload()

In [None]:
%%writefile hqf_de/evaluation/__init__.py
# Evaluation module

In [None]:
# Create data and output directories
!mkdir -p /content/data /content/output /content/cache

## Run HQF-DE Demo

In [None]:
import logging
logging.basicConfig(level=logging.INFO)

from hqf_de.pipeline.expander import Expander

# Initialize expander
exp = Expander()
exp.load()

In [None]:
# Run on a sample document
doc = "The Eiffel Tower is a famous landmark in Paris, France. It was built in 1889."

result = exp.expand("demo", doc)

print("=" * 60)
print("ORIGINAL:")
print(result.original)
print()
print("SEMANTIC GAPS:")
for g in result.gaps:
    print(f"  - {g}")
print()
print("VALID EXPANSIONS (passed NLI):")
for e in result.valid:
    print(f"  + {e}")
print()
print("REJECTED (failed NLI):")
for e in result.rejected:
    print(f"  x {e}")
print()
print("SYNTHETIC QUERIES:")
for q in result.queries:
    print(f"  ? {q}")
print()
print("=" * 60)
print("FINAL EXPANDED:")
print(result.expanded)

In [None]:
# Clean up
exp.unload()

## Batch Processing

In [None]:
# Upload collection.tsv from local machine
from google.colab import files
uploaded = files.upload()  # Select your collection.tsv file

# Move to data directory
import shutil
for filename in uploaded.keys():
    shutil.move(filename, f'/content/data/{filename}')
    print(f"Moved {filename} to /content/data/")

In [None]:
# Batch expand documents and save to TSV
import csv
from tqdm import tqdm

def load_docs(path, limit=None):
    docs = []
    with open(path, 'r', encoding='utf-8') as f:
        for row in csv.reader(f, delimiter='\t'):
            if len(row) >= 2:
                docs.append((row[0], row[1]))
                if limit and len(docs) >= limit:
                    break
    return docs

def save_expanded(results, path):
    with open(path, 'w', encoding='utf-8', newline='') as f:
        w = csv.writer(f, delimiter='\t')
        for r in results:
            w.writerow([r.doc_id, r.expanded])
    print(f"Saved {len(results)} documents to {path}")

# Load documents
LIMIT = 1000  # Adjust as needed
docs = load_docs('/content/data/collection.tsv', limit=LIMIT)
print(f"Loaded {len(docs)} documents")

# Expand with full HQF-DE pipeline
exp = Expander()
exp.load()

results = []
for doc_id, text in tqdm(docs, desc="Expanding"):
    result = exp.expand(doc_id, text)
    results.append(result)

exp.unload()

# Save results
save_expanded(results, '/content/output/expanded_hqfde.tsv')

# Also save d2q-only baseline for comparison
exp_d2q = Expander(use_llm=False, use_nli=False, use_d2q=True)
exp_d2q.load()

results_d2q = []
for doc_id, text in tqdm(docs, desc="D2Q only"):
    result = exp_d2q.d2q_only(doc_id, text)
    results_d2q.append(result)

exp_d2q.unload()

save_expanded(results_d2q, '/content/output/expanded_d2q.tsv')

In [None]:
# Download expanded files to local machine
from google.colab import files

files.download('/content/output/expanded_hqfde.tsv')
files.download('/content/output/expanded_d2q.tsv')

print("Download complete! Use these files locally for indexing and evaluation.")

In [None]:
# Optional: Save to Google Drive for persistence
!cp /content/output/expanded_hqfde.tsv /content/drive/MyDrive/
!cp /content/output/expanded_d2q.tsv /content/drive/MyDrive/
print("Saved to Google Drive!")