#### Dense Passage Retrieval (DPR)

We saw how how to use the TFIDF representation for passages to perform retreival. One major problem with this kind of sparse vector representation is that if the query words don't exactly match any words from the relavant passages, then the retreival system will not be able to find those passages (because of zero cosine similarity between the query and passage vectors). 

In DPR, we instead have `bi-encoders`, i.e. two separate BERT networks, a `query encoder` and a `pasage encoder`, which learn to map queries and passages respectively into a dense vector space in which the similarity between a query vector and it's corresponding relevant passage(s) is maximized. We use the output for the `[CLS]` token from each encoder as the dense vector representation. 

The bi-encoders are jointly trained using a supervised classification task where each input instance is a tuple $(q_i, p_i^{+}, p_{i,1}^{-}, ...,p_{i,n}^{-})$ where $q_i$ is a query, $p_i^{+}$ is a rlevant/positive passage and each of the $n$ $p_{i,j}^{-}$ are irrelevant/negative documents. Then we use the query encoder to compute the dense vector representation for the query $E_{Q}(q)$ and use the passage encoder for all the passages $E_P(p)$. Then we compute similarity scores between the query vector and each passage vector: $sim(q_i, p)$ for $p \in \{p_i^{+}, p_{i,1}^{-}, ...,p_{i,n}^{-}\}$. We can interpret these similarity scores as unnormalized logits for $(n+1)$ different class labels. With this interpretaion, we can define $sim(q_i, p_i^{+})$ as the logit for the "correct\ground truth class" and then simply use the `softmax cross-entropy/negative log-likelihood loss` function:

$L(q_i, p_i^{+}, p_{i,1}^{-}, ...,p_{i,n}^{-}) = -\log \frac{exp(sim(q_i, p_i^{+}))}{exp(sim(q_i, p_i^{+})) + \sum_{j=1}^n exp(sim(q_i, p_{i,j}^{-}))}$

Note that $[exp(sim(q_i, p_i^{+}), exp(sim(q_i, p_{i,1}^{-}),..., exp(sim(q_i, p_{i,n}^{-})]$ represents a probability distrbution and minimizing the loss function pushes $exp(sim(q_i, p_i^{+}))$ towards 1 and pushes the $exp(sim(q_i, p_{i,j}^{-}))$ towards zero, which allows us to achieve the dense vector space in which a query vector is maximally similar to the positive passage vector and dis-similar to the negative passages. We also use the simple `dot product` as our similarity metric.

For the SQuAD dataset, we already have given question, context passage pairs. Now we need to somehow choose negative passages for each pair. For training efficieny, we can use a simple trick. Given that we have a minibatch of $B$ such (question, context passage) pairs, then for each pair, we can simply just assign the passages from the other $B-1$ pairs as the negatives. Then we can compute the pair-wise dot product between every question-passgae pair with a single matrix multiplication. So given a matrix $Q$ of shape $(B,d)$ containing the batch of query vectors (where $d$ is the hidden dimensions of the encoded vectors) and a matrix $P$ of the same shape containing the batch of passages, we can compute the matrix $QP^T$ whose $(i,j)th$ entry given us the dot product between the ith question and the jth passage. So the $ith$ diagonal entry in this matrix is the dot product between $ith$ question and its corresponding positive passage and all other elements from that row are dot products with the negative passages. Then by taking the softmax of each row of this matrix, we can compute the total loss for the batch by just summing up the negative log of the terms along the diagonal. In addition to training efficiency, the other huge advantage of this technique is that the dataset will be shuffled before each epoch so that each question-positive passage pair will always get a different sample set of negative passages and therefore we effectively get a very large set of negatives per pair.

We wil use two MobileBERT models for our bi-encoders (two BERTs probably won't fit on my GPU and MobileBERT is less than half the size and performs just as well as BERT anyway).


In [2]:
import torch
from transformers import BertTokenizerFast, MobileBertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import wandb
import os

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


#### Let's define our Bi-encoder model first.

In [5]:
class BERTBiEncoder(torch.nn.Module):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        # load pretrained BERT model
        self.query_encoder = MobileBertModel.from_pretrained('google/mobilebert-uncased')
        self.passage_encoder = MobileBertModel.from_pretrained('google/mobilebert-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)

        for param in self.query_encoder.parameters():
            param.requires_grad = True
        
        for param in self.passage_encoder.parameters():
            param.requires_grad = True

    def forward(self, query_idx, query_attn_mask, passage_idx, passage_attn_mask):
        # compute BERT encodings
        query_output = self.query_encoder(query_idx, attention_mask=query_attn_mask)
        passage_output = self.passage_encoder(passage_idx, attention_mask=passage_attn_mask)
        # extract the `[CLS]` encoding (first element of the sequence), apply dropout
        query_output = self.dropout(query_output.last_hidden_state[:, 0]) # shape: (batch_size, hidden_size)
        passage_output = self.dropout(passage_output.last_hidden_state[:,0]) # shape: (batch_size, hidden_size)
        # compute similarity score matrix
        scores = torch.mm(query_output, passage_output.transpose(0, 1))
        # take row-wise softmax
        scores = F.softmax(scores, dim=1)
        # compute negtive log likelihood loss
        loss = -torch.log(scores.diag()).mean()
    
        return scores, loss
    

In [4]:
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

RAM used: 1188.15 MB
