#### Dense Passage Retrieval (DPR)

We saw how how to use the TFIDF representation for passages to perform retreival. One major problem with this kind of sparse vector representation is that if the query words don't exactly match any words from the relavant passages, then the retreival system will not be able to find those passages (because of zero cosine similarity between the query and passage vectors). 

In DPR, we instead have `bi-encoders`, i.e. two separate BERT networks, a `query encoder` and a `pasage encoder`, which learn to map queries and passages respectively into a dense vector space in which the similarity between a query vector and it's corresponding relevant passage(s) is maximized. We use the output for the `[CLS]` token from each encoder as the dense vector representation. 

The bi-encoders are jointly trained using a supervised classification task where each input instance is a tuple $(q_i, p_i^{+}, p_{i,1}^{-}, ...,p_{i,n}^{-})$ where $q_i$ is a query, $p_i^{+}$ is a rlevant/positive passage and each of the $n$ $p_{i,j}^{-}$ are irrelevant/negative documents. Then we use the query encoder to compute the dense vector representation for the query $E_{Q}(q)$ and use the passage encoder for all the passages $E_P(p)$. Then we compute similarity scores between the query vector and each passage vector: $sim(q_i, p)$ for $p \in \{p_i^{+}, p_{i,1}^{-}, ...,p_{i,n}^{-}\}$. We can interpret these similarity scores as unnormalized logits for $(n+1)$ different class labels. With this interpretaion, we can define $sim(q_i, p_i^{+})$ as the logit for the "correct\ground truth class" and then simply use the `softmax cross-entropy/negative log-likelihood loss` function:

$L(q_i, p_i^{+}, p_{i,1}^{-}, ...,p_{i,n}^{-}) = -\log \frac{exp(sim(q_i, p_i^{+}))}{exp(sim(q_i, p_i^{+})) + \sum_{j=1}^n exp(sim(q_i, p_{i,j}^{-}))}$

Note that $[exp(sim(q_i, p_i^{+}), exp(sim(q_i, p_{i,1}^{-}),..., exp(sim(q_i, p_{i,n}^{-})]$ represents a probability distrbution and minimizing the loss function pushes $exp(sim(q_i, p_i^{+}))$ towards 1 and pushes the $exp(sim(q_i, p_{i,j}^{-}))$ towards zero, which allows us to achieve the dense vector space in which a query vector is maximally similar to the positive passage vector and dis-similar to the negative passages. We also use the simple `dot product` as our similarity metric.

For the SQuAD dataset, we already have given question, context passage pairs. Now we need to somehow choose negative passages for each pair. For training efficieny, we can use a simple trick. Given that we have a minibatch of $B$ such (question, context passage) pairs, then for each pair, we can simply just assign the passages from the other $B-1$ pairs as the negatives. Then we can compute the pair-wise dot product between every question-passgae pair with a single matrix multiplication. So given a matrix $Q$ of shape $(B,d)$ containing the batch of query vectors (where $d$ is the hidden dimensions of the encoded vectors) and a matrix $P$ of the same shape containing the batch of passages, we can compute the matrix $QP^T$ whose $(i,j)th$ entry given us the dot product between the ith question and the jth passage. So the $ith$ diagonal entry in this matrix is the dot product between $ith$ question and its corresponding positive passage and all other elements from that row are dot products with the negative passages. Then by taking the softmax of each row of this matrix, we can compute the total loss for the batch by just summing up the negative log of the terms along the diagonal. In addition to training efficiency, the other huge advantage of this technique is that the dataset will be shuffled before each epoch so that each question-positive passage pair will always get a different sample set of negative passages and therefore we effectively get a very large set of negatives per pair.

We wil use two MobileBERT models for our bi-encoders (two BERTs probably won't fit on my GPU and MobileBERT is less than half the size and performs just as well as BERT anyway).


In [2]:
import torch
from transformers import BertTokenizerFast, MobileBertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import wandb
import os

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


#### Let's define our Bi-encoder model first.

In [6]:
class BERTBiEncoder(torch.nn.Module):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        # load pretrained BERT model
        self.query_encoder = MobileBertModel.from_pretrained('google/mobilebert-uncased')
        self.passage_encoder = MobileBertModel.from_pretrained('google/mobilebert-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)

        for param in self.query_encoder.parameters():
            param.requires_grad = True
        
        for param in self.passage_encoder.parameters():
            param.requires_grad = True

    def forward(self, query_idx, query_attn_mask, passage_idx, passage_attn_mask):
        # compute BERT encodings
        query_output = self.query_encoder(query_idx, attention_mask=query_attn_mask)
        passage_output = self.passage_encoder(passage_idx, attention_mask=passage_attn_mask)
        # extract the `[CLS]` encoding (first element of the sequence), apply dropout
        query_enc = self.dropout(query_output.last_hidden_state[:, 0]) # shape: (batch_size, hidden_size)
        passage_enc = self.dropout(passage_output.last_hidden_state[:,0]) # shape: (batch_size, hidden_size)
        # compute similarity score matrix
        scores = torch.mm(query_enc, passage_enc.transpose(0, 1)) # shape: (batch_size, batch_size)
        # take row-wise softmax
        scores = F.softmax(scores, dim=1) # shape: (batch_size, batch_size)
        # compute negtive log likelihood loss
        loss = -torch.log(scores.diag()).mean()
    
        return scores, loss
    

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, val_every=1, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            query_idx, query_attn_mask, passage_idx, passage_attn_mask = batch
            # move batch to device
            query_idx, query_attn_mask, passage_idx, passage_attn_mask = query_idx.to(device), query_attn_mask.to(device), passage_idx.to(device), passage_attn_mask.to(device)
            # forward pass
            scores, loss = model(query_idx, query_attn_mask, passage_idx, passage_attn_mask )
            # reset gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, _ = query_idx.shape
            y_pred = scores.argmax(dim=-1).view(-1) # shape (B,)
            targets = torch.arange(B).to(device) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            train_acc = num_correct / num_total        
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        if scheduler is not None:
            scheduler.step()
        
        if val_every is not None:
            if epoch%val_every == 0:
                # compute validation loss
                val_loss, val_acc = validation(model, val_dataloader, device=device)
                pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)


def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            query_idx, query_attn_mask, passage_idx, passage_attn_mask = batch
            query_idx, query_attn_mask, passage_idx, passage_attn_mask = query_idx.to(device), query_attn_mask.to(device), passage_idx.to(device), passage_attn_mask.to(device)
            scores, loss = model(query_idx, query_attn_mask, passage_idx, passage_attn_mask )
            B, _ = query_idx.shape
            y_pred = scores.argmax(dim=-1).view(-1) # shape (B,)
            targets = torch.arange(B).to(device) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            num_total += B
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy

def save_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'dpr_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('dpr_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer          

In [7]:
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

RAM used: 1190.82 MB


#### Let's load up the SQuAD v1 data

In [10]:
# load the train and dev JSON documents
with open("train.json", "r") as train_file:
    squad_train = json.load(train_file)         
with open("dev.json", "r") as dev_file:
    squad_dev = json.load(dev_file) 

def get_passages(squad, num_titles=None):
    if num_titles is None:
        num_titles = len(squad['data'])
    # for each title, get passages and all corresponding questions from SQuAD train set
    passages = []
    questions = []
    num_questions = 0
    j = 0
    for i in range(num_titles):
        #print(f"Title# {i}: {squad['data'][i]['title']}, Number of passages: {len(squad['data'][i]['paragraphs'])}")
        for p in squad['data'][i]['paragraphs']:
            passages.append(p['context'])
            for q in p['qas']:
                if not q['is_impossible']:
                    questions.append((q,j))    
                    num_questions += 1
            j += 1
    print(f"Number of passages: {len(passages)}")
    print(f"Number of questions: {num_questions}")
    return passages, questions

passages_train, questions_train = get_passages(squad_train)
passages_val, questions_val = get_passages(squad_dev)

Number of passages: 19035
Number of questions: 86821
Number of passages: 1204
Number of questions: 5928


In [8]:
# need this because huggingface tokenizer is not thread safe when using it inside __getitem__ instead of dataloader collatefunction
os.environ["TOKENIZERS_PARALLELISM"] = "false"  #

class SquadDataset(Dataset):
    def __init__(self, passages, questions, block_size=128):
        self.passages = passages
        self.questions = questions
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.block_size = block_size
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        # get the question and context passage
        q = self.questions[idx][0]
        passage_idx = self.questions[idx][1]
        question = q['question']
        passage = self.passages[passage_idx]
        # tokenize the context passage
        passage_encoding = self.tokenizer.encode_plus(passage, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        passage_idx = passage_encoding['input_ids']
        # tokenize the question
        question_encoding = self.tokenizer.encode_plus(question, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        question_idx = question_encoding['input_ids']

        # get answer span start and end character positions, for multiple answers, we will only use the first answer
        first_answer_idx = 0
        answer_start_char = q['answers'][first_answer_idx]['answer_start']
        answer_end_char = answer_start_char + len(q['answers'][first_answer_idx]['text'])
        # convert char positions to token positions
        answer_start_token = passage_encoding.char_to_token(answer_start_char)
        answer_end_token = passage_encoding.char_to_token(answer_end_char-1)

        # select a window size so that the passage sequence will be no longer than the block size
        window_size_tokens = self.block_size - 2 # 2 special tokens ([CLS], [SEP])
        # now create a window around the answer span, pick the window start position randomly
        window_start_min = max(0, answer_end_token - window_size_tokens + 1)
        window_start_max = min(answer_start_token, max(0,len(passage_idx) - window_size_tokens)) # we want to make the window as large as possible, but not go over the end of the context
        window_start = random.randint(window_start_min, window_start_max)
        window_end = window_start + window_size_tokens        
        # select window of passage tokens
        passage_window_tokens = passage_idx[window_start:window_end]

        # create padded query and passage sequences
        question_idx = [self.tokenizer.cls_token_id] + question_idx + [self.tokenizer.sep_token_id]
        question_idx = question_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(question_idx))
        passage_idx = [self.tokenizer.cls_token_id] + passage_window_tokens + [self.tokenizer.sep_token_id]
        passage_idx = passage_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(passage_idx))

        # make sure the passage sequence and query sequences are not longer than max_length
        if len(question_idx) > self.block_size or len(passage_idx) > self.block_size:
            raise Exception(f"Passage sequence length {len(passage_idx)} or question sequence length {len(question_idx)} is longer than max_length {self.block_size}!")
        
        # create attention masks
        question_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in question_idx]
        passage_attn_mask  = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in passage_idx]

        # convert to tensors
        question_idx = torch.tensor(question_idx)
        question_attn_mask = torch.tensor(question_attn_mask)
        passage_idx = torch.tensor(passage_idx)
        passage_attn_mask = torch.tensor(passage_attn_mask)

        return question_idx, question_attn_mask, passage_idx, passage_attn_mask

In [15]:
block_size = 192
train_dataset = SquadDataset(passages_train, questions_train, block_size=block_size)
val_dataset = SquadDataset(passages_val, questions_val, block_size=block_size)