## DPR(Dense Passage Retrieval) for Open-Domain Question Answering



In [None]:
import torch
import torch.nn as nn
import transformers
from transformers import BertModel,BertConfig,BertTokenizer
from transformers import DistilBertConfig,DistilBertModel,DistilBertTokenizer,AutoModel
from torch.utils.data import DataLoader,Dataset
import pandas as pd
import numpy as np
from typing import Any,Dict
import random
import logging
logging.disable(logging.WARNING)

## Initilize Tokenizer

In [None]:
from transformers import BertTokenizer,DistilBertTokenizer
class HFBertTokenizer():
    def __init__(self,tokenizer:BertTokenizer,max_length :int,pad_to_max:bool ) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.pad_to_max = pad_to_max
    def text_to_tensor(self,text,title=None,apply_max_len=True,add_special_tokens=True):
        if title:
            token_ids = self.tokenizer(title,text_pair=text,return_tensors='pt',max_length=self.max_length,truncation=True,padding="max_length")
        else:
            token_ids = self.tokenizer(text,return_tensors='pt',max_length=self.max_length,truncation=True,padding="max_length")

        return token_ids

    def get_tokenizer(self):
        return self.tokenizer
class HFDistilBertTokenizer(HFBertTokenizer):
    def __init__(self, tokenizer, max_length: int,pad_to_max:bool ) -> None:
        super().__init__(tokenizer, max_length,pad_to_max)

In [None]:
bert_tokenozer = BertTokenizer.from_pretrained('bert-base-uncased')
hf_tokenizer = HFBertTokenizer(tokenizer=bert_tokenozer,max_length=200,pad_to_max = True)

## Prepare Data

In [None]:
from datasets import load_dataset
DATA_PATH = 'biencoder-nq-train-sample.json'
dataset = load_dataset("json", data_files=DATA_PATH)

In [None]:
import random
hf_tokenizer = HFBertTokenizer(tokenizer=bert_tokenozer,max_length=200,pad_to_max = True)
# hf_tokenizer = HFDistilBertTokenizer(tokenizer=dist_tokenozer,max_length=200,pad_to_max = True)
def tokenize_data(x,num_neg = 5 ,num_hard_neg = 5):
    pos_ctxs = x['positive_ctxs']
    pos_ctx = pos_ctxs[np.random.choice(len(pos_ctxs))]
    neg_ctxs = x['negative_ctxs']# [:num_neg]
    hard_neg_ctxs = x['hard_negative_ctxs']# [:num_hard_neg]
    all_neg_ctxs = hard_neg_ctxs + neg_ctxs
#     random.shuffle(all_neg_ctxs
    all_neg_ctxs = all_neg_ctxs[:num_neg+num_hard_neg]
    
    q = x['question']
    q_tensor = hf_tokenizer.text_to_tensor(text=q)
    all_ctxs = [pos_ctx] + all_neg_ctxs
    neg_ctxs_title = ["" if i.get("title") is None else i.get("title") for i in all_ctxs]
    neg_ctxs_text = ["" if i.get("text") is None else i.get("text") for i in all_ctxs ]
    all_ctxs_tensor = hf_tokenizer.text_to_tensor(text=neg_ctxs_text,title=neg_ctxs_title)
    return {"q_input_ids":q_tensor['input_ids'],"q_attention_mask":q_tensor['attention_mask'],
            "all_ctxs_input_ids":all_ctxs_tensor['input_ids'],"all_ctxs_attention_mask":all_ctxs_tensor['attention_mask']}


num_hard_neg = 5
num_neg = 5

dataset = dataset.map(tokenize_data,
                            num_proc=2,
                            fn_kwargs={"num_hard_neg": num_hard_neg, "num_neg": num_neg},
                            remove_columns=[ 'hard_negative_ctxs', 'question', 'negative_ctxs', 'positive_ctxs'])

In [None]:
# save tokenize data
dataset.save_to_disk("biencoder-nq-train-tokenize-data-bert.hf")

In [None]:
from datasets import load_from_disk
from datasets import set_caching_enabled
set_caching_enabled(False)

dataset = load_from_disk("biencoder-nq-train-tokenize-data-bert.hf",keep_in_memory=True)
dataset

In [None]:
dataset = dataset.train_test_split(test_size=0.1)

In [None]:
train_dataset = dataset['train']
test_dataset = dataset['test']

## Initilize Model

In [None]:
model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
class HFEncoderModel(nn.Module):
    def __init__(self, model_path = None,project_dim=768):
        super(HFEncoderModel,self).__init__()
        if model_path is None:
            self.model = BertModel.from_pretrained("bert-base-uncased")
            for param in self.model.parameters():
                param.requires_grad = False
        else:
            self.model = BertModel.from_pretrained(model_path)

        self.encoder_proj = nn.Linear(768,project_dim)


    def forward(self,
                input_ids,
                attention_mask,
                pooling = 'cls'):
        # print(input_ids.size(),attention_mask.size())
        out = self.model(input_ids = input_ids, attention_mask= attention_mask)
        ecnoder_out = out.last_hidden_state
        if pooling=='cls':
            ecnoder_out = ecnoder_out[:,0,:]
        else:
            NotImplementedError()

        ecnoder_out = self.encoder_proj(ecnoder_out)
        return ecnoder_out

In [None]:
class BIEncoder(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.query_encoder = HFEncoderModel()
        self.doc_encoder = HFEncoderModel()

    @staticmethod
    def get_token_representation(sub_model,tokenize):
        return sub_model(**tokenize)

    def forward(self,question_idx,ctx_idx):
        q_encoder = self.get_token_representation(self.query_encoder , question_idx)
        ctx_encoder = self.get_token_representation(self.doc_encoder , ctx_idx)
        return q_encoder,ctx_encoder


In [None]:
def calculate_total_num_parameter(model):
    # calculate total number of modle parameter and total number of trainable parameter
    total_num_param = 0
    total_num_trainable_param = 0

    for param in model.parameters():
        if param.requires_grad:
            total_num_trainable_param += param.numel()
        total_num_param += param.numel()
    return total_num_param,total_num_trainable_param


In [None]:
model = BIEncoder()
calculate_total_num_parameter(model)

## Optimizer

In [None]:
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import AdamW
def get_optimizer(model_param,learning_rate: float = 1e-5,adam_eps: float = 1e-8,) -> torch.optim.Optimizer:

    optimizer = AdamW(model_param, lr=learning_rate, eps=adam_eps)
    return optimizer

def get_schedule_linear(
    optimizer,
    warmup_steps,
    total_training_steps,
    steps_shift=0,
    last_epoch=-1,
):

    """Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """

    def lr_lambda(current_step):
        current_step += steps_shift
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(
            1e-7,
            float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)),
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)


## Loss function

In [None]:
import torch.nn.functional as F
class NegativeLogLikeHood(nn.Module):
    def __init__(self):
        super(NegativeLogLikeHood, self).__init__()
    def forward(self,query_vec,pos_neg_vec):
        batch_size,_ = query_vec.size()
        pos_vec = pos_neg_vec[:,0,:]
        neg_vec = pos_neg_vec[:,1:,:]
        pos_sim = torch.sum(query_vec * pos_vec, dim=-1) #[batch_size]
        neg_sim = torch.bmm(neg_vec, query_vec.unsqueeze(-1)).squeeze(-1)  # [batch_size,num_negative]
        
        pos_sim_expand = pos_sim.unsqueeze(1)
        all_sim = torch.cat([pos_sim_expand,neg_sim],dim=1)
        softmax_scores = F.log_softmax(all_sim, dim=1)
        
        target = torch.zeros(batch_size,dtype=torch.long).to(query_vec.device)
        loss = F.nll_loss(softmax_scores, target)
        return loss

## Accuracy

In [None]:
def top_k_accuracy(query_vec,pos_neg_vec,k=3):
    batch_size,_ = query_vec.size()
    pos_vec = pos_neg_vec[:,0,:]
    neg_vec = pos_neg_vec[:,1:,:]
    pos_sim = torch.sum(query_vec * pos_vec, dim=-1) #[batch_size]
    neg_sim = torch.bmm(neg_vec, query_vec.unsqueeze(-1)).squeeze(-1)  # [batch_size,num_negative]

    pos_sim_expand = pos_sim.unsqueeze(1)
    all_sim = torch.cat([pos_sim_expand,neg_sim],dim=1)
    target = torch.zeros(batch_size,dtype=torch.long).to(query_vec.device)
    
    top_k_ind = torch.topk(all_sim,k=k,dim=1).indices

    accuracy = (top_k_ind==target.unsqueeze(1)).sum().item()/batch_size
    return accuracy

## Dataloader

In [None]:
def collate_fn(batch):
    batch_size = len(batch)
    num_ctx = len(batch[0]['all_ctxs_input_ids'])
    data = {
        'q_input_ids': torch.zeros(batch_size,200,dtype = torch.long),
        'q_attention_mask': torch.zeros(batch_size,200,dtype = torch.long),
        'all_ctxs_input_ids': torch.zeros(batch_size,num_ctx,200,dtype = torch.long),
        'all_ctxs_attention_mask': torch.zeros(batch_size,num_ctx,200,dtype = torch.long)
    }
    for ind,sample in enumerate(batch):
        for key,val in sample.items():
            if len(val)==1:
                data[key][ind,:] = torch.tensor(val)
            else:
                data[key][ind,:,:] = torch.tensor(val)
    data_final = {}
    for key,val in data.items():
        if len(val.size())==2:
            data_final[key] = val[:,torch.sum(val,axis=0).bool()]
        else:
            data_final[key] = val[:,:,torch.sum(val,axis=(0,1)).bool()]

#     data1 = {key:val[:,torch.sum(val,axis=0).bool()] for key,val in data.items()}
    return data_final

In [None]:
from torch.utils.data import DataLoader
batch_size = 32
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=batch_size,
    drop_last = True,
    pin_memory=True,
    num_workers= 2
)

test_dataloader = DataLoader(
    test_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=batch_size,
    drop_last = True,
    pin_memory=True,
    num_workers= 2
)


## Training

In [None]:
epoch = 2
total_training_steps = len(train_dataloader) * epoch
max_seq_len = 200
ctx_sample = 11

optimizer = get_optimizer(model.parameters())
scheduler = get_schedule_linear(optimizer=optimizer,warmup_steps=50,total_training_steps = total_training_steps)
loss_function = NegativeLogLikeHood()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# training
from tqdm import tqdm
def train_one_epoch(epoch_no):
    train_loss = 0
    c = 0
    progress_bar = tqdm(range(len(train_dataloader)))
    for tokenize_data in train_dataloader:
        b_size,num_ctx,max_seq_len = tokenize_data['all_ctxs_input_ids'].size()
        try:
            question_idx = {"input_ids":tokenize_data['q_input_ids'].to(device),
                            "attention_mask":tokenize_data['q_attention_mask'].to(device)}

            ctx_idx = {"input_ids":tokenize_data['all_ctxs_input_ids'].view(num_ctx*b_size,max_seq_len).to(device),
                    "attention_mask":tokenize_data['all_ctxs_attention_mask'].view(num_ctx*b_size,max_seq_len).to(device)}
            q_embedd,ctx_embedd = model(question_idx,ctx_idx)

            ctx_embedd = ctx_embedd.view(batch_size,-1,768)
            loss = loss_function(q_embedd,ctx_embedd)
            loss.backward()
            train_loss += loss.item()


            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            c+=1
            if c%200==0 and c>0:
                print(f"epoch is {epoch_no},batch is {c+1}, loss is {loss.item()}")

            progress_bar.update(1)
        except:
            print("error")
            pass
        
    return train_loss/len(train_dataloader)

In [None]:
# evaluation
best_vloss = 100_00
model.to(device)
for ep in range(epoch):
    model.train(True)
    avg_loss = train_one_epoch(ep)
    
    model.eval()
    total_val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for i, vdata in enumerate(test_dataloader):

            b_size,num_ctx,max_seq_len = vdata['all_ctxs_input_ids'].size()
            val_question_idx = {"input_ids":vdata['q_input_ids'].to(device),
                            "attention_mask":vdata['q_attention_mask'].to(device)}

            val_ctx_idx = {"input_ids":vdata['all_ctxs_input_ids'].view(num_ctx*b_size,max_seq_len).to(device),
                    "attention_mask":vdata['all_ctxs_attention_mask'].view(num_ctx*b_size,max_seq_len).to(device)}

            q_embedd,ctx_embedd = model(val_question_idx,val_ctx_idx)

            ctx_embedd = ctx_embedd.view(batch_size,-1,768)
            val_accuracy += top_k_accuracy(q_embedd, ctx_embedd)

            vloss = loss_function(q_embedd,ctx_embedd)
            total_val_loss += vloss

    avg_loss = total_val_loss/len(test_dataloader)
    
    print(f"validation loss = {avg_loss}, validation accuracy = {val_accuracy/len(test_dataloader)}")
    if avg_loss<best_vloss:
        best_vloss = best_vloss
        torch.save(model.state_dict(), 'biencoder-model.pt')
    