#### In the previous notebook, we tried out a simple BM25 retreiveal system and saw that it performed poorly. The average F1 score for top-5 retreival was only about 9%. Now, we will train a Dense Passage Retrieval (DPR) model and see if we can do better.

In [1]:
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import os
from DPR_biencoder_simple import *
import wandb

%load_ext autoreload
%autoreload 2

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


First, let's load the data and do filter out some of the bad passages like we did before.

In [2]:
# load the evidence passages
with open("project-data/evidence.json", "r") as train_file:
    document_store = json.load(train_file)         
print(f"Number of evidence passages: {len(document_store)}")

# load the training data insttances
with open("project-data/train-claims.json", "r") as train_file:
    train_data = json.load(train_file)
print(f"Number of training instances: {len(train_data)}")

# load the validation data instances
with open("project-data/dev-claims.json", "r") as dev_file:
    val_data = json.load(dev_file)    
print(f"Number of validation instances: {len(val_data)}")

# we remove duplicate values from the document_store dictionary (we arbitrarily keep the first one)
seen = set()
document_store_no_duplicates = {}
for key, value in document_store.items():
    if value not in seen:
        document_store_no_duplicates[key] = value
        seen.add(value)

# remove all "bad" documents from the document store, except those that occur in claim gold evidence lists, we will define "bad" documents as ones that have less than 50 characters
claim_evidence_list = [claim['evidences'] for claim in train_data.values()]
claim_evidence_list = claim_evidence_list + [claim['evidences'] for claim in val_data.values()]
claim_evidence_list = list(set([evidence for evidence_list in claim_evidence_list for evidence in evidence_list]))

document_store_cleaned = {i: evidence_text for i, evidence_text in document_store_no_duplicates.items() if len(evidence_text) >= 30 or i in claim_evidence_list}
print(f"Number of evidence passages remaining after cleaning: {len(document_store_cleaned)}")

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1190647


In [3]:
claim_evidence_list_train = set([ev for claim in train_data.values() for ev in claim['evidences']])
claim_evidence_list_val = set([ev for claim in val_data.values() for ev in claim['evidences']])
print(f"Number of unique evidence passages in training data: {len(claim_evidence_list_train)}")
print(f"Number of unique evidence passages in validation data: {len(claim_evidence_list_val)}")

Number of unique evidence passages in training data: 3121
Number of unique evidence passages in validation data: 463


To train the DPR retriever, we need to create pairs of $(\text{claim}_i, \text{evidence passage}_{i,1})$. However each claim can have multiple evidence passages, so we will create multiples pairs: $(\text{claim}_i, \text{evidence passage}_{i,1})$, $(\text{claim}_i, \text{evidence passage}_{i,2})$, ...

Next, we prepare a minibatch of claims and corresponding passages. Then given a matrix $C$ of shape $(B,d)$ containing the batch of encoded claim vectors (where $d$ is the hidden dimensions of the encoded vectors) and a matrix $P$ of the same shape containing the batch of encoded passage vectors, we can compute the matrix $CP^T$ of shape $(B,B)$ whose $(i,j)th$ entry given us the dot product between the ith question and the jth passage. The elements along the diagonal of $CP^T$ are the scores for positive pairs and off-diagonal entries are for negative pairs. Then we can train a softmax classifier to classify the diagonal term in each row as the score for the positive class and the remaining B-1 terms as non-positive classes. This is the trick of `in-batch negatves` 

However, one issue is that out of the ~1M different passages, only ~3000 of them appear as positive evidences in (claim, evidence) pairs. Therefore, the in-batch negatives will also be restricted to these few passages. For good performance, we need to be able to select negatives from the all of the evidence passages from the document store. That's why in addition to passing a minibatch of claims $C$ and corresponding positive passages $P$, we will also pass in a batch of negatives $N$ which are selectled from the set of passgaes outside of the 3000 that appear as positives. Then we compute the matrix $CN^T$ which has shape $(B,B)$. All terms in the $ith$ row of this matrix are scores for the $ith$ claim with $B$ negatives. Then by horizontally concatenating: $[CP^T; CN^T]$, we get a matrix of shape $(B,2B)$ and we can just train a softmax classifier to classify the diagonal term in each row as the score for the positive class and the remaining 2B-1 terms as non-positive classes.

`Hard-negative mining`: We could simply just create a batch of N negatives by randomly sampling the set of all negatives. However, a slightly better option would be to select the "hard" negatives. These are passages which are very similar to the positive ones. We could either use the highest scoring non-positive documents from a BM25 retreiver as hard negatives. Or we could first train our model with random negative selection. Then use the highest scoring non-positives from our trained model as negatives and do some finetuning. 

Note: Because we need to use fixed size passages for BERT, if a passage exceeds the block size, we will take a random window of it.

In [4]:
# let's separate out all ~3000 positive evidence passages from the document store and define the remaining as negatives.
all_passages_ids = list(document_store_cleaned.keys())
positives_ids_train = claim_evidence_list_train
negatives_ids_train = list(set(all_passages_ids) - set(positives_ids_train))
                         
# create claim-positive pairs
claim_positive_pairs_train = []
for claim_id in train_data.keys():
    for evidence_id in train_data[claim_id]['evidences']:
        claim_positive_pairs_train.append((claim_id, evidence_id))    


In [5]:
# now let's create a pytroch dataset
class ClaimsDataset(Dataset):
    def __init__(self, claims_data, document_store, block_size=128):
        self.claims_data = claims_data
        self.document_store = document_store
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        self.block_size = block_size
        self.claim_positive_pairs, self.negatives_ids = self.create_pairs()

    def create_pairs(self):
        claim_evidence_list = set([ev for claim in self.claims_data.values() for ev in claim['evidences']])
        all_passages_ids = list(self.document_store.keys())
        positives_ids = claim_evidence_list
        negatives_ids = list(set(all_passages_ids) - set(positives_ids))
        claim_positive_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                claim_positive_pairs.append((claim_id, evidence_id))    
        return claim_positive_pairs, negatives_ids

    def __len__(self):
        return len(self.claim_positive_pairs)

    def __getitem__(self, idx):
        # get claim id and positive evidence id
        claim_id, positive_id = self.claim_positive_pairs[idx]
        # randomly sample a negative evidence id
        negative_id = random.choice(self.negatives_ids)
        # get the claim, positive and negative text
        claim_text = self.claims_data[claim_id]['claim_text']
        positive_text = self.document_store[positive_id]
        negative_text = self.document_store[negative_id]

        # tokenize the claim, positive and negative text  
        claim_encoding = self.tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        
        positive_encoding = self.tokenizer.encode_plus(positive_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        positive_idx = positive_encoding['input_ids']
        
        negative_encoding = self.tokenizer.encode_plus(negative_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        negative_idx = negative_encoding['input_ids']

        # select a window from the positive passage if it is longer than block size
        if len(positive_idx) > (self.block_size-2):
            # pick a random start position
            start_pos = random.randint(0, max(0,len(positive_idx) - (self.block_size-2)))
            # select the window
            positive_idx = positive_idx[start_pos:start_pos+self.block_size-2]

        # select a window from the negative passage if it is longer than block size
        if len(negative_idx) > (self.block_size-2):
            # pick a random start position
            start_pos = random.randint(0, max(0,len(negative_idx) - (self.block_size-2)))
            # select the window
            negative_idx = negative_idx[start_pos:start_pos+self.block_size-2]    

        # add special tokens and padding
        claim_idx = [self.tokenizer.cls_token_id] + claim_idx + [self.tokenizer.sep_token_id]
        claim_idx = claim_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(claim_idx))
        positive_idx = [self.tokenizer.cls_token_id] + positive_idx + [self.tokenizer.sep_token_id]
        positive_idx = positive_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(positive_idx))
        negative_idx = [self.tokenizer.cls_token_id] + negative_idx + [self.tokenizer.sep_token_id]
        negative_idx = negative_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(negative_idx))

        # make sure the passage sequences and claim sequences are not longer than max_length
        if len(claim_idx) > self.block_size or len(positive_idx) > self.block_size or len(negative_idx) > self.block_size:
            raise Exception(f"Claim sequence length {len(claim_idx)} or positive sequence length {len(positive_idx)} or negative sequence length: {len(negative_idx)} is longer than max_length {self.block_size}!")
        
        # create attention masks
        claim_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in claim_idx]
        positive_attn_mask  = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in positive_idx]
        negative_attn_mask  = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in negative_idx]

        # convert to tensors
        claim_idx = torch.tensor(claim_idx)
        positive_idx = torch.tensor(positive_idx)
        negative_idx = torch.tensor(negative_idx)
        claim_attn_mask = torch.tensor(claim_attn_mask)
        positive_attn_mask = torch.tensor(positive_attn_mask)
        negative_attn_mask = torch.tensor(negative_attn_mask)

        return claim_idx, claim_attn_mask, positive_idx, positive_attn_mask, negative_idx, negative_attn_mask

In [6]:
block_size = 192
train_dataset = ClaimsDataset(train_data, document_store_cleaned, block_size=block_size)
val_dataset = ClaimsDataset(val_data, document_store_cleaned, block_size=block_size)

In [7]:
len(train_dataset), len(val_dataset)

(4122, 491)

#### Now that the dataset has been prepared, let's train a DPR model.

In [8]:
B = 16
DEVICE = "cuda"
learning_rate = 5e-6

train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)

# model with finetuning disabled
model = BERTBiEncoder().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Loaded model from checkpoint!
Total number of parameters in transformer network: 132.72576 M
RAM used: 3680.88 MB


In [9]:
"""# create a W&B run
run = wandb.init(
    project="Automated Climate Fact Checker", 
    config={
        "bi-encoder model": "MobileBERT",
        "learning_rate": learning_rate, 
        "epochs": 5,
        "batch_size": B, 
        "corpus": "SQuAD v1"},)   

def log_metrics(metrics):
    wandb.log(metrics)"""

'# create a W&B run\nrun = wandb.init(\n    project="Automated Climate Fact Checker", \n    config={\n        "bi-encoder model": "MobileBERT",\n        "learning_rate": learning_rate, \n        "epochs": 5,\n        "batch_size": B, \n        "corpus": "SQuAD v1"},)   \n\ndef log_metrics(metrics):\n    wandb.log(metrics)'

In [10]:
#train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=8, save_every=2, val_every=100, log_metrics=log_metrics) 

In [11]:
#wandb.finish()

In [12]:
validation(model, val_dataloader, device=DEVICE)

(0.8016946315765381, 0.814663951120163)

#### Now that we have a trained bi-encoder model, we can precompute the dense embeddings for all the documents in the store.

In [13]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def tokenize_passage(passage_text, block_size):
        positive_encoding = tokenizer.encode_plus(passage_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        passage_idx = positive_encoding['input_ids']
        if len(passage_idx) > (block_size-2):
            start_pos = random.randint(0, max(0,len(passage_idx) - (block_size-2)))
            passage_idx = passage_idx[start_pos:start_pos+block_size-2]

        passage_idx = [tokenizer.cls_token_id] + passage_idx + [tokenizer.sep_token_id]    
        passage_idx = passage_idx + [tokenizer.pad_token_id]*(block_size-len(passage_idx))

        if len(passage_idx) > block_size:
            raise Exception(f"Sequence length {len(passage_idx)} is longer than max_length {block_size}!")

        passage_attn_mask  = [1 if idx != tokenizer.pad_token_id else 0 for idx in passage_idx]
        passage_idx = torch.tensor(passage_idx)
        passage_attn_mask = torch.tensor(passage_attn_mask)
        return passage_idx, passage_attn_mask

def tokenize_claim(claim_text):
        claim_encoding = tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        claim_idx = [tokenizer.cls_token_id] + claim_idx + [tokenizer.sep_token_id]    
        claim_idx = claim_idx + [tokenizer.pad_token_id]*(block_size-len(claim_idx))

        if len(claim_idx) > block_size:
            raise Exception(f"Sequence length {len(claim_idx)} is longer than max_length {block_size}!")

        claim_attn_mask  = [1 if idx != tokenizer.pad_token_id else 0 for idx in claim_idx]
        claim_idx = torch.tensor(claim_idx)
        claim_attn_mask = torch.tensor(claim_attn_mask)
        return claim_idx, claim_attn_mask

In [14]:
def precompute_passage_embeddings(document_store, block_size):
    # precompute the passage embeddings for all passages in the document store
    passage_embeddings_batches = []
    document_store_list = list(document_store.items())  # Convert the document_store slice into a list
    num_passages = len(document_store_list)
    passage_embeddings = torch.zeros((num_passages, 768), device=DEVICE)  # Preallocate memory
    for i in tqdm(range(0, len(document_store_list), 8)):
        # tokenize the passages in this batch
        passages_idx_batch = []
        passages_attn_mask_batch = []
        for _, passage_text in document_store_list[i:i+8]:
            passage_idx, passage_attn_mask = tokenize_passage(passage_text, block_size)
            passages_idx_batch.append(passage_idx)
            passages_attn_mask_batch.append(passage_attn_mask)

        passages_idx_batch = torch.stack(passages_idx_batch).to(DEVICE)
        passages_attn_mask_batch = torch.stack(passages_attn_mask_batch).to(DEVICE)    
        passage_embedding = model.encode_passages(passages_idx_batch, passages_attn_mask_batch)
        passage_embeddings[i:i+8] = passage_embedding
        del passage_embedding, passages_idx_batch, passages_attn_mask_batch  # Delete tensors to free up memory
        torch.cuda.empty_cache()  # Clear GPU memory

    return passage_embeddings

passage_ids = list(document_store_cleaned.keys())   


In [16]:
evidence_passage_embeds = precompute_passage_embeddings(document_store_cleaned, block_size)


  2%|▏         | 2275/148831 [00:56<1:04:52, 37.65it/s]