# Semantic Retrieval for Scientific Documents

This notebook implements a deep learning-based semantic retrieval system, trained and evaluated on the SciFact dataset.

## Project Overview
- Fine-tune embedding models using sentence-transformers
- Use MultipleNegativesRankingLoss loss function
- Train and evaluate on BEIR-format datasets
- Compare with traditional methods like BM25

## 1. Environment Setup and Dependencies

In [None]:
# If running in Colab, uncomment the line below to install dependencies
!pip install sentence-transformers datasets pandas scikit-learn torch accelerate beir

In [None]:
import logging
import json
import math
from pathlib import Path
from collections import defaultdict
from typing import Dict, Set, List, Tuple, Optional, Any, Union
import pandas as pd
from sklearn.model_selection import train_test_split

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

## 2. Data Loading and Preprocessing

We use the BeIR/scifact-generated-queries dataset, a scientific literature retrieval dataset.

In [None]:
def load_scifact_raw(output_dir: str = "./data/raw"):
    """
    Load BeIR/scifact-generated-queries dataset from HuggingFace
    """
    from datasets import load_dataset

    raw_dir = Path(output_dir)
    raw_dir.mkdir(parents=True, exist_ok=True)

    # Load dataset
    logger.info("Loading SciFact dataset from HuggingFace...")
    ds = load_dataset("BeIR/scifact-generated-queries")
    df = ds["train"].to_pandas()  # type: ignore

    # Save raw data
    raw_path = raw_dir / "scifact_raw.csv"
    df.to_csv(raw_path, index=False)  # type: ignore
    logger.info(f"Raw data saved to {raw_path}")
    logger.info(f"Dataset size: {len(df)} rows")  # type: ignore

    return df  # type: ignore

In [None]:
def preprocess_scifact(raw_df: pd.DataFrame, output_dir: str = "./data/processed"):
    """
    Preprocess SciFact dataset and generate BEIR-format training data

    Output files:
    - scifact_pairs.csv: Query-document pairs
    - scifact_corpus.csv: Deduplicated document corpus
    - beir_format/corpus.jsonl: BEIR-format corpus
    - beir_format/queries.jsonl: BEIR-format queries
    - beir_format/qrels/train.tsv: Training set relevance labels (70%)
    - beir_format/qrels/dev.tsv: Validation set relevance labels (10%)
    - beir_format/qrels/test.tsv: Test set relevance labels (20%)
    """
    processed_dir = Path(output_dir)
    processed_dir.mkdir(parents=True, exist_ok=True)

    # Process data
    df = raw_df[["_id", "title", "text", "query"]].copy()
    df = df.dropna(subset=["text", "query"])  # type: ignore
    df = df.rename(columns={"_id": "doc_id"})

    # Data cleaning
    df["title"] = df["title"].fillna("").astype(str).str.strip()
    df["text"] = df["text"].fillna("").astype(str).str.strip()
    df["query"] = df["query"].astype(str).str.strip()

    # Combine title and text
    df["content"] = df["title"] + ". " + df["text"]

    # Assign ID to each unique query
    df = df.reset_index(drop=True)
    df["query_id"] = pd.factorize(df["query"])[0]

    # Save query-document pairs
    pairs_path = processed_dir / "scifact_pairs.csv"
    df.to_csv(pairs_path, index=False)
    logger.info(f"Query-doc pairs saved to {pairs_path}")

    # Save deduplicated corpus
    corpus = (
        df[["doc_id", "content"]]
        .drop_duplicates(subset=["doc_id"])  # type: ignore
        .reset_index(drop=True)
    )
    corpus_path = processed_dir / "scifact_corpus.csv"
    corpus.to_csv(corpus_path, index=False)
    logger.info(f"Corpus saved to {corpus_path} ({len(corpus)} documents)")

    # Create BEIR-format data
    beir_dir = processed_dir / "beir_format"
    beir_dir.mkdir(parents=True, exist_ok=True)
    qrels_dir = beir_dir / "qrels"
    qrels_dir.mkdir(parents=True, exist_ok=True)

    # Write corpus.jsonl
    corpus_jsonl_path = beir_dir / "corpus.jsonl"
    with open(corpus_jsonl_path, "w", encoding="utf-8") as f:
        for _, row in corpus.iterrows():
            doc = {
                "_id": str(row["doc_id"]),
                "text": row["content"],
                "title": "",
            }
            f.write(json.dumps(doc, ensure_ascii=False) + "\n")
    logger.info(f"BEIR corpus saved to {corpus_jsonl_path}")

    # Write queries.jsonl
    queries_jsonl_path = beir_dir / "queries.jsonl"
    unique_queries = df[["query_id", "query"]].drop_duplicates(subset=["query_id"])  # type: ignore
    with open(queries_jsonl_path, "w", encoding="utf-8") as f:
        for _, row in unique_queries.iterrows():
            query = {"_id": str(row["query_id"]), "text": row["query"]}
            f.write(json.dumps(query, ensure_ascii=False) + "\n")
    logger.info(f"BEIR queries saved to {queries_jsonl_path}")

    # Split into train/dev/test sets (70/10/20)
    unique_query_ids = df["query_id"].unique()
    train_dev_query_ids, test_query_ids = train_test_split(
        unique_query_ids, test_size=0.2, random_state=42
    )
    train_query_ids, dev_query_ids = train_test_split(
        train_dev_query_ids, test_size=0.125, random_state=42
    )

    train_df = df[df["query_id"].isin(train_query_ids)]
    dev_df = df[df["query_id"].isin(dev_query_ids)]
    test_df = df[df["query_id"].isin(test_query_ids)]

    # Write qrels files
    for split_name, split_df in [
        ("train", train_df),
        ("dev", dev_df),
        ("test", test_df),
    ]:
        qrels_path = qrels_dir / f"{split_name}.tsv"
        with open(qrels_path, "w", encoding="utf-8") as f:
            for _, row in split_df.iterrows():
                f.write(f"{row['query_id']}\t{row['doc_id']}\t1\n")
        logger.info(f"{split_name} qrels saved to {qrels_path} ({len(split_df)} pairs)")

    logger.info(f"\nData preprocessing completed!")
    logger.info(f"Train: {len(train_query_ids)} queries, {len(train_df)} pairs")
    logger.info(f"Dev: {len(dev_query_ids)} queries, {len(dev_df)} pairs")
    logger.info(f"Test: {len(test_query_ids)} queries, {len(test_df)} pairs")

    return beir_dir

## 3. Data Loading Utilities

In [None]:
def load_qrels(qrels_path: Path) -> Dict[str, Set[str]]:
    """Load qrels file and return query_id -> {doc_ids} mapping"""
    query_to_docs = defaultdict(set)

    with open(qrels_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) >= 2:
                query_id = parts[0]
                doc_id = parts[1]
                query_to_docs[query_id].add(doc_id)

    return dict(query_to_docs)


def load_queries(queries_path: Path) -> Dict[str, str]:
    """Load queries.jsonl file and return query_id -> query_text mapping"""
    queries = {}

    with open(queries_path, "r", encoding="utf-8") as f:
        for line in f:
            query = json.loads(line)
            queries[query["_id"]] = query["text"]

    return queries


def load_corpus(corpus_path: Path) -> Dict[str, str]:
    """Load corpus.jsonl file and return doc_id -> doc_text mapping"""
    corpus = {}

    with open(corpus_path, "r", encoding="utf-8") as f:
        for line in f:
            doc = json.loads(line)
            corpus[doc["_id"]] = doc["text"]

    return corpus

## 4. Model Training

Fine-tune embedding models using sentence-transformers library with MultipleNegativesRankingLoss.

In [None]:
# Training configuration
MODEL_NAME = "BAAI/bge-small-en-v1.5"  # Base model
BATCH_SIZE = 32
EPOCHS = 6
WARMUP_RATIO = 0.1
LEARNING_RATE = 3e-5
OUTPUT_DIR = "./models/finetuned-mnrl"

In [None]:
def load_training_data(data_dir: Path) -> list:
    """
    Load training data and convert to InputExample format

    For MultipleNegativesRankingLoss, we only need (query, positive_doc) pairs.
    Negative samples are drawn from other examples in the same batch.
    """
    from sentence_transformers import InputExample

    qrels_path = data_dir / "qrels" / "train.tsv"
    queries_path = data_dir / "queries.jsonl"
    corpus_path = data_dir / "corpus.jsonl"

    logger.info("Loading training data...")
    query_to_docs = load_qrels(qrels_path)
    queries = load_queries(queries_path)
    corpus = load_corpus(corpus_path)

    # Create training examples
    examples = []
    for query_id, doc_ids in query_to_docs.items():
        if query_id not in queries:
            continue

        query_text = queries[query_id]

        for doc_id in doc_ids:
            if doc_id not in corpus:
                continue

            doc_text = corpus[doc_id]
            examples.append(InputExample(texts=[query_text, doc_text]))

    logger.info(f"Created {len(examples)} training examples")
    return examples


def create_evaluator(data_dir: Path, split: str = "dev"):
    """
    Create InformationRetrievalEvaluator for validation during training
    """
    from sentence_transformers.evaluation import InformationRetrievalEvaluator

    qrels_path = data_dir / "qrels" / f"{split}.tsv"
    queries_path = data_dir / "queries.jsonl"
    corpus_path = data_dir / "corpus.jsonl"

    logger.info(f"Loading {split} data for evaluation...")
    query_to_docs = load_qrels(qrels_path)
    queries = load_queries(queries_path)
    corpus = load_corpus(corpus_path)

    # Filter queries and documents
    eval_queries = {qid: queries[qid] for qid in query_to_docs if qid in queries}
    eval_corpus = corpus

    # Convert qrels format
    eval_qrels: Dict[str, Set[str]] = {
        qid: set(doc_ids)
        for qid, doc_ids in query_to_docs.items()
        if qid in eval_queries
    }

    logger.info(f"Evaluator: {len(eval_queries)} queries, {len(eval_corpus)} documents")

    return InformationRetrievalEvaluator(
        queries=eval_queries,
        corpus=eval_corpus,
        relevant_docs=eval_qrels,
        name=split,
        ndcg_at_k=[10, 100],
        precision_recall_at_k=[10, 100],
        map_at_k=[100],
        mrr_at_k=[10],
        show_progress_bar=True,
    )

In [None]:
def train_model(data_dir: Path, output_dir: str = OUTPUT_DIR):
    """
    Train the embedding model
    """
    from sentence_transformers import SentenceTransformer, losses
    from torch.utils.data import DataLoader

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Load model
    logger.info(f"Loading model: {MODEL_NAME}")
    model = SentenceTransformer(MODEL_NAME)

    # Load training data
    train_examples = load_training_data(data_dir)
    train_dataloader = DataLoader(
        train_examples,  # type: ignore
        shuffle=True,
        batch_size=BATCH_SIZE,
    )

    # Setup loss function
    train_loss = losses.MultipleNegativesRankingLoss(model)

    # Create validation evaluator
    dev_evaluator = create_evaluator(data_dir, split="dev")

    # Calculate training steps
    total_steps = len(train_dataloader) * EPOCHS
    warmup_steps = int(total_steps * WARMUP_RATIO)

    logger.info(f"\nTraining configuration:")
    logger.info(f"  Model: {MODEL_NAME}")
    logger.info(f"  Batch size: {BATCH_SIZE}")
    logger.info(f"  Epochs: {EPOCHS}")
    logger.info(f"  Total steps: {total_steps}")
    logger.info(f"  Warmup steps: {warmup_steps}")
    logger.info(f"  Learning rate: {LEARNING_RATE}")
    logger.info(f"  Output directory: {output_path}")

    # Start training
    logger.info("\nStarting training...")
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=dev_evaluator,
        epochs=EPOCHS,
        warmup_steps=warmup_steps,
        optimizer_params={"lr": LEARNING_RATE},
        output_path=str(output_path),
        evaluation_steps=len(train_dataloader) // 2,  # Evaluate twice per epoch
        save_best_model=True,
        show_progress_bar=True,
    )

    logger.info(f"\nTraining completed! Model saved to: {output_path}")

    # Final evaluation on test set
    logger.info("\nRunning final evaluation on test set...")
    test_evaluator = create_evaluator(data_dir, split="test")
    test_results = test_evaluator(model, output_path=str(output_path))

    logger.info("\n" + "=" * 50)
    logger.info("Final Test Results:")
    for metric, value in test_results.items():
        logger.info(f"  {metric}: {value:.4f}")
    logger.info("=" * 50)

    return model

## 5. Evaluation Metrics

Implement standard information retrieval evaluation metrics.

In [None]:
def traditional_eval(samples: List[Dict[str, Any]], k: int = 10) -> Dict[str, float]:
    """
    Evaluate retrieval results

    Args:
        samples: List of samples containing queries and retrieval results
            [{
                "question": str,
                "contexts": List[str],  # Ranked top-K doc_ids
                "ground_truth": str | list | set  # Ground truth doc_id(s)
            }]
        k: Cutoff position for evaluation

    Returns:
        Dictionary containing various metrics
    """

    def to_set(gt: Union[str, List[str], Set[str]]) -> Set[str]:
        if isinstance(gt, set):
            return set(str(x) for x in gt)
        if isinstance(gt, list):
            return set(str(x) for x in gt)
        return {str(gt)}

    def hit_at_k(gt_set: Set[str], results: List[str], k: int) -> float:
        return 1.0 if any(r in gt_set for r in results[:k]) else 0.0

    def precision_at_k(gt_set: Set[str], results: List[str], k: int) -> float:
        if k <= 0:
            return 0.0
        hits = sum(1 for r in results[:k] if r in gt_set)
        return hits / k

    def recall_at_k(gt_set: Set[str], results: List[str], k: int) -> float:
        if not gt_set:
            return 0.0
        hits = sum(1 for r in results[:k] if r in gt_set)
        return hits / len(gt_set)

    def mrr(gt_set: Set[str], results: List[str]) -> float:
        for rank, r in enumerate(results, start=1):
            if r in gt_set:
                return 1.0 / rank
        return 0.0

    def average_precision_at_k(gt_set: Set[str], results: List[str], k: int) -> float:
        if not gt_set:
            return 0.0

        hits = 0
        s = 0.0
        for i, r in enumerate(results[:k], start=1):
            if r in gt_set:
                hits += 1
                s += hits / i
        return s / min(len(gt_set), k)

    def ndcg_at_k(gt_set: Set[str], results: List[str], k: int) -> float:
        # DCG
        dcg = 0.0
        for i, r in enumerate(results[:k], start=1):
            rel = 1.0 if r in gt_set else 0.0
            dcg += rel / math.log2(i + 1)

        # IDCG
        ideal_rels = [1.0] * min(len(gt_set), k)
        idcg = 0.0
        for i, rel in enumerate(ideal_rels, start=1):
            idcg += rel / math.log2(i + 1)

        return dcg / idcg if idcg > 0 else 0.0

    # Calculate metrics for all samples
    hits, precisions, recalls, mrrs, aps, ndcgs = [], [], [], [], [], []

    for s in samples:
        gt_set = to_set(s["ground_truth"])
        results = [str(x) for x in s["contexts"]]

        hits.append(hit_at_k(gt_set, results, k))
        precisions.append(precision_at_k(gt_set, results, k))
        recalls.append(recall_at_k(gt_set, results, k))
        mrrs.append(mrr(gt_set, results))
        aps.append(average_precision_at_k(gt_set, results, k))
        ndcgs.append(ndcg_at_k(gt_set, results, k))

    n = len(samples) if samples else 1
    return {
        f"Hit@{k}": sum(hits) / n,
        f"Precision@{k}": sum(precisions) / n,
        f"Recall@{k}": sum(recalls) / n,
        "MRR": sum(mrrs) / n,
        f"MAP@{k}": sum(aps) / n,
        f"NDCG@{k}": sum(ndcgs) / n,
        "N": float(len(samples)),
    }

## 6. Complete Training and Evaluation Pipeline

Run the code below to execute the complete workflow.

In [None]:
# Step 1: Load raw data
logger.info("=" * 60)
logger.info("Step 1: Loading raw data")
logger.info("=" * 60)
raw_df = load_scifact_raw()

# Display sample data
logger.info("\nSample data:")
print(raw_df.head())  # type: ignore

In [None]:
# Step 2: Preprocess data
logger.info("\n" + "=" * 60)
logger.info("Step 2: Preprocessing data")
logger.info("=" * 60)
beir_dir = preprocess_scifact(raw_df)  # type: ignore

In [None]:
# Step 3: Train model
logger.info("\n" + "=" * 60)
logger.info("Step 3: Training model")
logger.info("=" * 60)

# Comment out the line below if you want to skip training
model = train_model(beir_dir)

In [None]:
# Step 4: Load pre-trained model (if already trained)
# Comment out this section if you already trained above
# from sentence_transformers import SentenceTransformer
# model = SentenceTransformer(OUTPUT_DIR)
# logger.info(f"Model loaded from {OUTPUT_DIR}")

## 7. Semantic Retrieval with Trained Model

Demonstration of how to use the trained model for semantic retrieval.

In [None]:
def semantic_search_demo(model, data_dir: Path, num_queries: int = 5):
    """
    Demonstrate semantic search functionality
    """
    import numpy as np
    from sentence_transformers import util

    # Load test data
    queries_path = data_dir / "queries.jsonl"
    corpus_path = data_dir / "corpus.jsonl"
    qrels_path = data_dir / "qrels" / "test.tsv"

    queries = load_queries(queries_path)
    corpus = load_corpus(corpus_path)
    query_to_docs = load_qrels(qrels_path)

    # Encode corpus
    logger.info("Encoding corpus...")
    corpus_ids = list(corpus.keys())
    corpus_texts = [corpus[doc_id] for doc_id in corpus_ids]
    corpus_embeddings = model.encode(
        corpus_texts, convert_to_tensor=True, show_progress_bar=True
    )

    # Select test queries
    test_queries = list(query_to_docs.keys())[:num_queries]

    logger.info(f"\nRunning semantic search on {num_queries} test queries...")

    for query_id in test_queries:
        query_text = queries[query_id]
        ground_truth = query_to_docs[query_id]

        # Encode query
        query_embedding = model.encode(query_text, convert_to_tensor=True)

        # Compute similarity and retrieve
        cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
        top_results = np.argsort(-cos_scores.cpu().numpy())[:10]

        # Display results
        print("\n" + "=" * 80)
        print(f"Query: {query_text}")
        print(f"Ground truth docs: {ground_truth}")
        print("\nTop 10 retrieved documents:")

        for rank, idx in enumerate(top_results, 1):
            doc_id = corpus_ids[idx]
            score = cos_scores[idx].item()
            is_relevant = "✓" if doc_id in ground_truth else "✗"
            doc_preview = corpus[doc_id][:100] + "..."
            print(f"\n{rank}. [{is_relevant}] Doc {doc_id} (score: {score:.4f})")
            print(f"   {doc_preview}")

In [None]:
# Run retrieval demo
# semantic_search_demo(model, beir_dir, num_queries=3)

## 8. Results Visualization

Visualize training results and evaluation metrics.

In [None]:
def plot_evaluation_results(results_csv_path: str):
    """
    Plot evaluation results
    """
    import matplotlib.pyplot as plt

    if not Path(results_csv_path).exists():
        logger.warning(f"Results file not found: {results_csv_path}")
        return

    df = pd.read_csv(results_csv_path)

    # Plot main metrics
    metrics = ["ndcg_at_10", "map_at_100", "recall_at_10", "precision_at_10"]

    plt.figure(figsize=(12, 8))

    for i, metric in enumerate(metrics, 1):
        if metric in df.columns:
            plt.subplot(2, 2, i)
            plt.plot(df["epoch"], df[metric], marker="o")
            plt.title(metric.replace("_", " ").title())
            plt.xlabel("Epoch")
            plt.ylabel("Score")
            plt.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
# Visualize results (if training logs exist)
# results_path = f"{OUTPUT_DIR}/eval/Information-Retrieval_evaluation_dev_results.csv"
# plot_evaluation_results(results_path)

## 9. Summary

This notebook implements a complete semantic retrieval system, including:

1. **Data Loading and Preprocessing**: Load SciFact dataset from HuggingFace and convert to BEIR format
2. **Model Training**: Fine-tune sentence-transformers model using MultipleNegativesRankingLoss
3. **Evaluation**: Evaluate model performance on test set with multiple retrieval metrics
4. **Inference**: Perform semantic retrieval using the trained model

### Custom Configuration

You can modify the following parameters to customize training:

- `MODEL_NAME`: Base model (default: "BAAI/bge-small-en-v1.5")
- `BATCH_SIZE`: Batch size (default: 32)
- `EPOCHS`: Number of epochs (default: 6)
- `LEARNING_RATE`: Learning rate (default: 3e-5)
- `OUTPUT_DIR`: Model save path