## Dense Retrieval
Implementation of dense passage retrieval using a SOTA sentence embedding model and FAISS for retrieval. Evaluated on the MS MARCO dataset using MRR and retrieval time.

In [12]:
!pip install -r requirements.txt -q # upload requirements.txt then uncomment for colab

from sentence_transformers import SentenceTransformer
import torch
import faiss
import time
import pandas as pd
import numpy as np

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[?25h

## Load MS MARCO Dataset + Data Preprocessing

In [13]:
from datasets import load_dataset

# load MS MARCO dataset (use 10% of the dataset for testing)
docs_dataset = load_dataset("sentence-transformers/msmarco", "corpus", split="train[:10%]")
queries_dataset = load_dataset("sentence-transformers/msmarco", "queries", split="train[:10%]")
qrels_dataset = load_dataset("sentence-transformers/msmarco", "labeled-list", split="train[:10%]")

print("corpus labels:", docs_dataset.column_names)
print("queries labels:", queries_dataset.column_names)
print("qrels labels:", qrels_dataset.column_names)

print("corpus:", docs_dataset[0])
print("queries:", queries_dataset[0])
print("qrels:", qrels_dataset[0])


corpus labels: ['passage_id', 'passage']
queries labels: ['query_id', 'query']
qrels labels: ['query_id', 'doc_ids', 'labels']
corpus: {'passage_id': '0', 'passage': 'The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.'}
queries: {'query_id': '121352', 'query': 'define extreme'}
qrels: {'query_id': '100', 'doc_ids': ['3837260', '7854412', '4778006', '7929416', '5833477', '2715823', '903728', '1418399', '2544108', '4592808', '3565885', '260356', '5885724', '2976754', '3530456', '903722', '5136237', '6166367', '5372728', '6166373', '1615726', '5909725', '3278290', '570067', '2628703', '3619930', '3282101', '570061', '1442855', '5293099', '3976606', '3542912', '4358422', '4729309', '3542156', '102825', '2141701', '58857

## Data Preprocessing

In [14]:
from collections import defaultdict

# convert to list for faster processing
docs = list(docs_dataset)
queries = list(queries_dataset)
qrels = list(qrels_dataset)

#  Normalize helper
def clean(text):
    return text.strip().lower()

# extract document and query IDs + texts
doc_ids, doc_texts = zip(*[(d["passage_id"], clean(d["passage"])) for d in docs])
query_ids, query_texts = zip(*[(q["query_id"], clean(q["query"])) for q in queries])

# prepare qrels dict for ranx
doc_id_set = set(doc_ids)
query_id_set = set(query_ids)

qrels_dict = defaultdict(dict)
for row in qrels:
    qid = row["query_id"]
    for doc_id, label in zip(row["doc_ids"], row["labels"]):
        if qid in query_id_set and doc_id in doc_id_set and label > 0:
            qrels_dict[qid][doc_id] = label

## Document Embedding

In [15]:
# load encoder model pretrained on MS MARCO
model_name = "msmarco-distilbert-base-v2"
model = SentenceTransformer(model_name)

# use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32 if device == "cuda" else 16
model.eval() # put in eval mode to speed up inference

# encode documents and queries
document_embeddings = model.encode(doc_texts, batch_size=batch_size, convert_to_tensor=True, device=device, show_progress_bar=True)
query_embeddings = model.encode(query_texts, batch_size=batch_size, convert_to_tensor=True, device=device, show_progress_bar=True)

# convert to numpy array for FAISS
doc_embs = document_embeddings.cpu().numpy()
query_embs = query_embeddings.cpu().numpy()

Batches:   0%|          | 0/27631 [00:00<?, ?it/s]

Batches:   0%|          | 0/2528 [00:00<?, ?it/s]

## Build FAISS Index

In [16]:
# normalize embeddings with faiss
faiss.normalize_L2(doc_embs)
faiss.normalize_L2(query_embs)

# build index
dim = doc_embs.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(doc_embs)


## Evaluation

In [17]:
from ranx import Qrels, Run, evaluate

# do KNN Search and return retrieval time
k = 5
start_time = time.time()
distances, knn = index.search(query_embs, k)
retrieval_time = (time.time() - start_time) / len(query_ids)
print(f"Retrieval time per query: {retrieval_time:.4f} seconds")


Retrieval time per query: 0.0121 seconds


In [18]:
# calculate MRR

# Run: stores the relevance scores estimated by the model under evaluation
# map each query_id -> { doc_id: score }
run = {
    query_ids[i]: {
        doc_ids[knn[i][j]]: -float(distances[i][j]) for j in range(k) # use negative distance as ranx interprets higher score = higher rank
    }
    for i in range(len(query_ids))
}

qrels_rx = Qrels(qrels_dict)
run_rx = Run(run)

# # measure MRR
mrr = evaluate(qrels_rx, run_rx, "mrr", make_comparable=True)
print(f"MRR: {mrr:.4f}")


MRR: 0.4686


In [20]:
# save run results to use in dense + sparse retrieval hybrid retriever
# run_rx.save("../dense_sparse_retrieval/dense_results.json")
run_rx.save("dense_results.json")