In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torchvision
print(torch.__version__)       # should be 2.8.0
print(torchvision.__version__)  # should be 0.15.2
print(torch.cuda.is_available())  # should be True if GPU available

2.8.0+cu128
0.23.0+cu128
True


From the dataset on the website, sample some data from triples.train for training & testing since dataset is large enough

This code samples some data for training

import pandas as pd

DATA_PATH = "/teamspace/studios/this_studio/triples_chunk1000.tsv"
OUTPUT_PATH = "/teamspace/studios/this_studio/triples_sample_10k.tsv"

# Load the full TSV
df = pd.read_csv(DATA_PATH, sep="\t", names=["query", "positive", "negative"])

# Sample 10,000 rows randomly (or fewer if the file has less)
sample_df = df.sample(n=min(100, len(df)), random_state=42).reset_index(drop=True)

# Save to a new TSV
sample_df.to_csv(OUTPUT_PATH, sep="\t", index=False)

print(f"Sampled {len(sample_df)} rows and saved to {OUTPUT_PATH}")

Viewing the training data

In [2]:
import pandas as pd
import os
from download import download_with_progress, extract_tar_with_progress
from load_corpus import read_triples_train_small

files = [
    {
        "url": "https://msmarco.z22.web.core.windows.net/msmarcoranking/triples.train.small.tar.gz",
        "name": "triples.train.small.tar.gz"
    },
]

for file in files:
    tsv_name = file["name"].replace(".tar.gz", ".tsv")
    if not os.path.exists(tsv_name):
        # download with progress
        download_with_progress(file["url"], file["name"])
        # extract with progress
        extract_tar_with_progress(file["name"], ".")
        # (optional) delete the archive to save space
        # os.remove(file["name"])

print("Files downloaded and extracted.")
# Load the TSV into a DataFrame (note: use the extracted .tsv file)
train_df = read_triples_train_small(limit=10000)
train_df.head()

Files downloaded and extracted.


Unnamed: 0,query,positive,negative
0,is a little caffeine ok during pregnancy,We donât know a lot about the effects of caf...,It is generally safe for pregnant women to eat...
1,what fruit is native to australia,Passiflora herbertiana. A rare passion fruit n...,"The kola nut is the fruit of the kola tree, a ..."
2,how large is the canadian military,The Canadian Armed Forces. 1 The first large-...,The Canadian Physician Health Institute (CPHI)...
3,types of fruit trees,Cherry. Cherry trees are found throughout the ...,"The kola nut is the fruit of the kola tree, a ..."
4,how many calories a day are lost breastfeeding,"Not only is breastfeeding better for the baby,...","However, you still need some niacin each day; ..."


Training the DPR model

In [3]:
# -------------------- 1️⃣ Imports --------------------
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from transformers import (
    DPRQuestionEncoder, DPRQuestionEncoderTokenizer,
    DPRContextEncoder, DPRContextEncoderTokenizer
)
from torch.optim import AdamW

# -------------------- 2️⃣ Config --------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16      # you can increase if GPU allows
EPOCHS = 10           # for testing
LEARNING_RATE = 4e-6

# -------------------- 3️⃣ Dataset --------------------
class DPRDataset(Dataset):
    def __init__(self, dataFrame):
        self.data = dataFrame
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        return {
            "query": row["query"],
            "positive_passage": row["positive"],
            "negative_passage": row["negative"]
        }

dataset = DPRDataset(train_df)

# -------------------- 4️⃣ Collate Function --------------------
Q_MAX_LEN = 64      # queries are short
CTX_MAX_LEN = 256   # MS MARCO passages fit well here (try 128–256)

def collate_fn(batch):
    queries   = [b["query"] for b in batch]
    positives = [b["positive_passage"] for b in batch]
    negatives = [b["negative_passage"] for b in batch]

    # tokenize (keep dynamic padding, align to 8 for Tensor Cores)
    q_enc = question_tokenizer(
        queries, padding=True, truncation=True, max_length=Q_MAX_LEN,
        pad_to_multiple_of=8, return_tensors="pt"
    )
    # one tokenization for all contexts
    ctx_enc = context_tokenizer(
        positives + negatives, padding=True, truncation=True, max_length=CTX_MAX_LEN,
        pad_to_multiple_of=8, return_tensors="pt"
    )
    return q_enc, ctx_enc

train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn,
                          num_workers=os.cpu_count(), pin_memory=torch.cuda.is_available()
                          )

# -------------------- 5️⃣ Load DPR Models --------------------
# Question encoder
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(DEVICE)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

# Context encoder
context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(DEVICE)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

# -------------------- 6️⃣ Optimizer --------------------
optimizer = AdamW(list(question_encoder.parameters()) + list(context_encoder.parameters()), lr=LEARNING_RATE)

# -------------------- 7️⃣ Training Loop --------------------
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

for epoch in range(EPOCHS):
    question_encoder.train(); context_encoder.train()
    total = 0.0

    for q_enc, ctx_enc in train_loader:
        optimizer.zero_grad(set_to_none=True)

        # to GPU
        q_enc   = {k: v.to(DEVICE, non_blocking=True) for k, v in q_enc.items()}
        ctx_enc = {k: v.to(DEVICE, non_blocking=True) for k, v in ctx_enc.items()}

        B = q_enc["input_ids"].size(0)

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            q_out   = question_encoder(**q_enc).pooler_output                  # [B, H]
            ctx_out = context_encoder(**ctx_enc).pooler_output                 # [2B, H]
            p_out, n_out = ctx_out.split(B, dim=0)                             # each [B, H]

            # cosine margin loss
            pos = torch.cosine_similarity(q_out, p_out, dim=1)
            neg = torch.cosine_similarity(q_out, n_out, dim=1)
            loss = (0.2 - pos + neg).clamp_min_(0).mean()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}  Loss: {total/len(train_loader):.4f}")


# -------------------- 8️⃣ Save Models --------------------
question_encoder.save_pretrained("./dpr_question_encoder")
context_encoder.save_pretrained("./dpr_context_encoder")
question_tokenizer.save_pretrained("./dpr_question_encoder")
context_tokenizer.save_pretrained("./dpr_context_encoder")

print("Training complete and models saved!")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected 

Epoch 2/10  Loss: 0.0258
Epoch 3/10  Loss: 0.0137
Epoch 4/10  Loss: 0.0073
Epoch 5/10  Loss: 0.0043
Epoch 6/10  Loss: 0.0024
Epoch 7/10  Loss: 0.0015
Epoch 8/10  Loss: 0.0010
Epoch 9/10  Loss: 0.0008
Epoch 10/10  Loss: 0.0006
Training complete and models saved!


Encode documents & queries into embeddings, retrieve top-k docs (using FAISS or
similar).

In [5]:
# -------------------- 1️⃣ Imports --------------------
import pandas as pd
import torch
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoder, DPRContextEncoderTokenizer
import faiss
import numpy as np
from sklearn.metrics import average_precision_score, ndcg_score

# -------------------- 2️⃣ Config --------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOP_K = 5
TEST_SAMPLE = 5  # number of test triples to sample
DATA_PATH = "sample.tsv"  # update path

# -------------------- 3️⃣ Load & Sample Test Data --------------------
df = pd.read_csv(DATA_PATH, sep="\t", names=["query", "positive", "negative"])
df = df.sample(TEST_SAMPLE, random_state=42).reset_index(drop=True)

queries = df["query"].tolist()
positive_passages = df["positive"].tolist()

# -------------------- 4️⃣ Load Trained DPR Models --------------------
question_encoder = DPRQuestionEncoder.from_pretrained("./dpr_question_encoder").to(DEVICE)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("./dpr_question_encoder")

context_encoder = DPRContextEncoder.from_pretrained("./dpr_context_encoder").to(DEVICE)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("./dpr_context_encoder")

# -------------------- 5️⃣ Encode Passages --------------------
context_encoder.eval()
passage_embeddings = []

with torch.no_grad():
    for i in range(0, len(positive_passages), 16):
        batch = positive_passages[i:i+16]
        enc = context_tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        enc = context_tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=512,   
                return_tensors="pt"
            ).to(DEVICE)

        emb = context_encoder(**enc).pooler_output
        passage_embeddings.append(emb.cpu())

passage_embeddings = torch.cat(passage_embeddings, dim=0).numpy()
faiss.normalize_L2(passage_embeddings)  # for cosine similarity

# -------------------- 6️⃣ Encode Queries --------------------
question_encoder.eval()
query_embeddings = []

with torch.no_grad():
    for i in range(0, len(queries), 16):
        batch = queries[i:i+16]
        # enc = question_tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        enc = question_tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=512,   
            return_tensors="pt"
        ).to(DEVICE)

        emb = question_encoder(**enc).pooler_output
        query_embeddings.append(emb.cpu())

query_embeddings = torch.cat(query_embeddings, dim=0).numpy()
faiss.normalize_L2(query_embeddings)

# -------------------- 7️⃣ FAISS Retrieval --------------------
d = passage_embeddings.shape[1]
index = faiss.IndexFlatIP(d)  # inner product = cosine similarity
index.add(passage_embeddings)
D, I = index.search(query_embeddings, TOP_K)  # top-K indices

# -------------------- 8️⃣ Evaluation --------------------
# Define relevance: only the matching positive passage is relevant
qrels = {i: [i] for i in range(len(queries))}

# Precision@K
def precision_at_k(retrieved_idx, relevant_idx, k):
    retrieved_k = retrieved_idx[:k]
    return len(set(retrieved_k) & set(relevant_idx)) / k

precisions = [precision_at_k(I[q], qrels[q], TOP_K) for q in range(len(queries))]
print("Mean Precision@5:", np.mean(precisions))

# MAP
def mean_average_precision(retrieved_indices, qrels):
    all_ap = []
    for qid, retrieved in enumerate(retrieved_indices):
        y_true = [1 if i in qrels[qid] else 0 for i in retrieved]
        y_score = list(range(len(retrieved), 0, -1))  # simple ranking scores
        if sum(y_true) > 0:
            ap = average_precision_score(y_true, y_score)
            all_ap.append(ap)
    return np.mean(all_ap)

map_score = mean_average_precision(I, qrels)
print("MAP:", map_score)

# nDCG@K
y_true = np.zeros((len(queries), len(positive_passages)))
for qid, rel in qrels.items():
    for pid in rel:
        y_true[qid, pid] = 1

y_score = np.zeros_like(y_true)
for qid, retrieved in enumerate(I):
    for rank, pid in enumerate(retrieved):
        y_score[qid, pid] = 1 / (rank + 1)

ndcg = ndcg_score(y_true, y_score, k=TOP_K)
print("nDCG@5:", ndcg)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Mean Precision@5: 0.2
MAP: 1.0
nDCG@5: 1.0


Comparing the truth value to the retrieved one

In [6]:
from IPython.display import display
import pandas as pd

# Make sure corpus exists (the same list used to build FAISS index)
corpus = df["positive"].tolist() + df["negative"].tolist()

# -------------------- Show Queries, Positives & Retrievals --------------------
def show_retrieval_results(queries, positive_passages, index, question_encoder, question_tokenizer, TOP_K=5):
    results = []

    for i, query in enumerate(queries[:5]):   # just show first 5 queries
        # Encode query
        q_enc = question_tokenizer([query], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        q_emb = question_encoder(**q_enc).pooler_output.cpu().detach().numpy()
        
        # Search FAISS
        D, I = index.search(q_emb, TOP_K)

        retrieved = []
        for rank, idx in enumerate(I[0]):
            passage = corpus[idx][:200] + "..."  # truncate for display
            retrieved.append(f"{rank+1}. {passage}")

        # Add to results table
        results.append({
            "Query": query[:200] + "...",
            "Ground Truth Positive": positive_passages[i][:200] + "...",
            "Top-k Retrieved": "\n".join(retrieved)
        })

    # Pretty display
    df_results = pd.DataFrame(results)
    display(df_results)

# Call function
show_retrieval_results(queries, positive_passages, index, question_encoder, question_tokenizer, TOP_K=5)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Unnamed: 0,Query,Ground Truth Positive,Top-k Retrieved
0,weather in austin tx in march...,"Average Weather in March in Austin Texas, Unit...","1. Average Weather in March in Austin Texas, U..."
1,what is the navy leave instruction...,"July 2, Secretary of the Navy Ray Mabus announ...","1. July 2, Secretary of the Navy Ray Mabus ann..."
2,what is a retrovirus and name an example....,"Retrovirus, any of a group of viruses that bel...","1. Retrovirus, any of a group of viruses that ..."
3,what is considered a normal dosage for xanax...,The usual starting Xanax dosage is 0.25 mg to ...,1. The usual starting Xanax dosage is 0.25 mg ...
4,personal tax rates 2015 us...,The United States federal government taxes per...,1. The United States federal government taxes ...
