## Sparse Retrieval
Implementation of sparse passage retrieval using TF-IDF and BM25. Evaluated on the MS MARCO dataset using MRR and retrieval time.

In [1]:
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from rank_bm25 import BM25Okapi
import numpy as np
import time
from ranx import Run, Qrels, evaluate

## Load Dataset

In [2]:
# load MS MARCO dataset (use 10% of the dataset for testing)
docs_dataset = load_dataset("sentence-transformers/msmarco", "corpus", split="train[:10%]")
queries_dataset = load_dataset("sentence-transformers/msmarco", "queries", split="train[:10%]")
qrels_dataset = load_dataset("sentence-transformers/msmarco", "labeled-list", split="train[:10%]")

print("corpus labels:", docs_dataset.column_names)
print("queries labels:", queries_dataset.column_names)
print("qrels labels:", qrels_dataset.column_names)

print("corpus:", docs_dataset[0])
print("queries:", queries_dataset[0])
print("qrels:", qrels_dataset[0])


corpus labels: ['passage_id', 'passage']
queries labels: ['query_id', 'query']
qrels labels: ['query_id', 'doc_ids', 'labels']
corpus: {'passage_id': '0', 'passage': 'The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.'}
queries: {'query_id': '121352', 'query': 'define extreme'}
qrels: {'query_id': '100', 'doc_ids': ['3837260', '7854412', '4778006', '7929416', '5833477', '2715823', '903728', '1418399', '2544108', '4592808', '3565885', '260356', '5885724', '2976754', '3530456', '903722', '5136237', '6166367', '5372728', '6166373', '1615726', '5909725', '3278290', '570067', '2628703', '3619930', '3282101', '570061', '1442855', '5293099', '3976606', '3542912', '4358422', '4729309', '3542156', '102825', '2141701', '58857

## Data Preprocessing

In [3]:
from collections import defaultdict

# convert to list for faster processing
docs = list(docs_dataset)
queries = list(queries_dataset)
qrels = list(qrels_dataset)

#  Normalize helper
def clean(text):
    return text.strip().lower()

# extract document and query IDs + texts
doc_ids, doc_texts = zip(*[(d["passage_id"], clean(d["passage"])) for d in docs])
query_ids, query_texts = zip(*[(q["query_id"], clean(q["query"])) for q in queries])

# prepare qrels dict for ranx
doc_id_set = set(doc_ids)
query_id_set = set(query_ids)

qrels_dict = defaultdict(dict)
for row in qrels:
    qid = row["query_id"]
    for doc_id, label in zip(row["doc_ids"], row["labels"]):
        if qid in query_id_set and doc_id in doc_id_set and label > 0:
            qrels_dict[qid][doc_id] = label

## Build Retriever

In [4]:
from sklearn.preprocessing import normalize

# Build indices
tokenized_passages = [p.lower().split() for p in doc_texts]

# TF-IDF
tfidf = TfidfVectorizer(lowercase=True, stop_words='english', ngram_range=(1, 2))
tfidf_vectors = tfidf.fit_transform(doc_texts)

# convert sparse to dense matrix for FAISS
tfidf_dense = tfidf_matrix.toarray().astype('float32')
tfidf_dense = normalize(tfidf_dense, norm='l2', axis=1)

# Build FAISS index using inner product (dot product = cosine after normalization)
index = faiss.IndexFlatIP(tfidf_dense.shape[1])
index.add(tfidf_dense)

# BM25
bm25 = BM25Okapi(tokenized_passages)


In [5]:
# Retrieval
def run_sparse_retrieval(method="tfidf", top_k=10):
    run_dict = {}
    start_time = time.time()

    # loop through all queries
    for i, query in enumerate(query_texts):
        
        query_id = query_ids[i]

        if method == "tfidf":
            query_vec = tfidf.transform([query])
            scores, indices = index.search(query, k)
        
        elif method == "bm25":
            scores = bm25.get_scores(query.lower().split())
        
        else:
            raise ValueError("Unknown method")

        top_indices = np.argsort(scores)[-top_k:][::-1]
        run_dict[query_id] = {doc_ids[idx]: float(scores[idx]) for idx in top_indices}

    avg_time = (time.time() - start_time) / len(query_ids)
    print(f"{method.upper()} retrieval time per query: {avg_time:.4f} seconds")
    return Run(run_dict)

## Evaluation

In [None]:
# Evaluate
qrels_rx = Qrels(qrels_dict)

tfidf_run = run_sparse_retrieval("tfidf")
bm25_run = run_sparse_retrieval("bm25")

tfidf_mrr = evaluate(qrels_rx, tfidf_run, "mrr", make_comparable=True)
bm25_mrr = evaluate(qrels_rx, bm25_run, "mrr", make_comparable=True)

print(f"TF-IDF MRR: {tfidf_mrr:.4f}")
print(f"BM25 MRR:  {bm25_mrr:.4f}")

In [None]:
# save runs
tfidf_run.save("tfidf_run.json")
bm25_run.save("bm25_run.json")