In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup

**Dataset**

In [2]:
class DPRDataset(Dataset):
    def __init__(self, passages, questions, p_tokenizer, q_tokenizer):
        self.passages = passages
        self.questions = questions
        self.p_tokenizer = p_tokenizer
        self.q_tokenizer = q_tokenizer
    
    def __len__(self):
        return len(self.passages)
    
    def __getitem__(self, index):
        passage = self.passages[index]
        question = self.questions[index]
        return passage, question
    
    def collate_fn(self, batch):
        passages, questions = zip(*batch)
        passage_inputs = self.p_tokenizer.batch_encode_plus(passages, padding=True, truncation=True, return_tensors="pt")
        question_inputs = self.q_tokenizer.batch_encode_plus(questions, padding=True, truncation=True, return_tensors="pt")
        return passage_inputs, question_inputs

**Model & tokenizer**

In [None]:
p_encoder = BertModel.from_pretrained("bert-base-multilingual-cased")
q_encoder = BertModel.from_pretrained("bert-base-multilingual-cased")

p_tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
q_tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

**Training**

In [4]:
def train(train_loader, optimizer, scheduler, device, batch):
    p_encoder.train()
    q_encoder.train()
    total_loss = 0

    pbar = tqdm(train_loader, desc="Training", leave=False)
    for passage_inputs, question_inputs in pbar:
        passage_inputs = {k: v.to(device) for k, v in passage_inputs.items()}
        question_inputs = {k: v.to(device) for k, v in question_inputs.items()}

        optimizer.zero_grad()
        passage_embeddings = p_encoder(**passage_inputs).pooler_output
        question_embeddings = q_encoder(**question_inputs).pooler_output

        sim_scores = torch.matmul(question_embeddings, torch.transpose(passage_embeddings, 0, 1))

        targets = torch.arange(0, batch).long().to(device)

        sim_scores = torch.nn.functional.log_softmax(sim_scores, dim=1)
        loss = torch.nn.functional.nll_loss(sim_scores, targets)

        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

        pbar.set_postfix({"Loss" : loss.item()})
    return total_loss / len(train_loader)

In [5]:
batch_size = 2
num_epochs = 10
learning_rate = 1e-5
warmup_steps = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
import pandas as pd

traindata = pd.read_csv("train.csv")
question = pd.read_csv("question.csv")
collection = pd.read_csv("collection.csv")
evaldata = pd.read_csv("test.csv")

question.columns = ['0', 'x_id', 'content']
collection.columns = ['0', 'y_id', 'document']

traindata = traindata.merge(question, on="x_id", how="left")

traindata.dropna(inplace = True)
traindata.reset_index(inplace = True, drop=True)

traindata = traindata.merge(collection, on="y_id", how="left")
traindata.dropna(inplace=True)
traindata.reset_index(inplace=True, drop=True)

traindata = traindata[["content", "document"]]
traindata.columns = ["question", "context"]

evaldata = evaldata.merge(question, on="x_id", how="left")

evaldata.dropna(inplace = True)
evaldata.reset_index(inplace = True, drop=True)

evaldata = evaldata.merge(collection, on="y_id", how="left")
evaldata.dropna(inplace=True)
evaldata.reset_index(inplace=True, drop=True)

evaldata = evaldata[["content", "document"]]
evaldata.columns = ["question", "context"]

train_data = traindata[:50000]
eval_data = evaldata

train_passages = list(train_data["context"])
train_questions = list(train_data["question"])
train_dataset = DPRDataset(train_passages, train_questions, p_tokenizer, q_tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)

eval_passages = list(eval_data["context"])
eval_questions = list(eval_data["question"])
eval_dataset = DPRDataset(eval_passages, eval_questions, p_tokenizer, q_tokenizer)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, collate_fn=eval_dataset.collate_fn)

In [7]:
criterion = nn.NLLLoss()
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters =[
    {'params': [p for n, p in p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in p_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in q_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warmup_steps, num_training_steps= total_steps)

In [None]:
p_encoder.to(device)
q_encoder.to(device)

for epoch in range(num_epochs):
    train_loss = train(train_dataloader, optimizer, scheduler, device, batch_size)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss: .4f}")

**Inference**

In [None]:
passages = pd.read_csv("collection.csv")
documents = list(passages['document'])

In [None]:
with torch.no_grad():
    p_encoder.eval()

    # passage embeddings
    p_embs = []
    num_documents = len(documents)

    for p in tqdm(documents, desc="Computing Passage Embeddings", total=num_documents):
        p = p_tokenizer(p, padding=True, truncation=True, return_tensors="pt").to(device)
        p_emb = p_encoder(**p).pooler_output.to("cpu").numpy()
        p_embs.append(p_emb)

    p_embs = torch.Tensor(p_embs).squeeze().to(device)

In [None]:
with torch.no_grad():
    p_encoder.eval()

    # passage embeddings
    p_embs = []
    num_documents = len(documents)

    for p in tqdm(documents, desc="Computing Passage Embeddings", total=num_documents):
        p = p_tokenizer(p, padding=True, truncation=True, return_tensors="pt").to(device)
        p_emb = p_encoder(**p).pooler_output.to("cpu").numpy()
        p_embs.append(p_emb)

    p_embs = torch.Tensor(p_embs).squeeze().to(device)

In [None]:
with torch.no_grad():

    recall_1 = 0
    recall_10 = 0
    recall_20 = 0
    recall_100 = 0

    total_actual_positives = 0

    q_encoder.eval()

    for sample_idx in tqdm(range(len(eval_questions))):
        query = eval_questions[sample_idx]

        q_seqs_val = q_tokenizer([query], padding=True, truncation=True, return_tensors="pt")
        q_emb = q_encoder(**q_seqs_val).pooler_output.to(device)

        dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))

        rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()

        correct_passage = eval_questions[sample_idx]

        correct_idx = (documents.index(correct_passage) == rank).nonzero()

        if correct_idx.numel() > 0:
            correct_idx = correct_idx.item()

            if correct_idx < 1:
                recall_1 += 1
            if correct_idx < 10:
                recall_10 += 1
            if correct_idx < 20:
                recall_20 += 1
            if correct_idx < 100:
                recall_100 += 1
        
        total_actual_positives += 1
    
    recall_1 /= total_actual_positives
    recall_10 /= total_actual_positives
    recall_20 /= total_actual_positives
    recall_100 /= total_actual_positives

    print("Recall@1: ", recall_1)
    print("Recall@10: ", recall_10)
    print("Recall@20: ", recall_20)
    print("Recall@100", recall_100)

**Model save & load**

In [None]:
p_encoder.to("cpu")
q_encoder.to("cpu")

torch.save({
    "p_encoder_state_dict": p_encoder.state_dict(),
    "q_encoder_state_dict": q_encoder.state_dict(),
}, "encoder_new.pt")

In [None]:
model_state_dict = torch.load("encoder_new.pt", map_location="cpu")
p_encoder.load_state_dict(model_state_dict["p_encoder_state_dict"])
q_encoder.load_state_dict(model_state_dict['q_encoder_state_dict'])