In [1]:
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import os
from DPR_biencoder_simple import *
import wandb
from utils import *
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer

%load_ext autoreload
%autoreload 2

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


#### Instead of training the DPR bi-encoder with in-batch negatives, we will now train exclusively using hard-negatives which where mined by using the cross-encoder to rerank top-k passages retrieved by our old DPR. First, let's set up the dataset and batch creation.

In [2]:
# load the data
document_store, train_data, val_data = load_data(clean=True)

# load hard negatives from pickle file
with open("dpr_embeddings/train_hard_negatives_reranked.pkl", "rb") as f:
    train_hard_negatives = pickle.load(f)

with open("dpr_embeddings/val_hard_negatives_reranked.pkl", "rb") as f:
    val_hard_negatives = pickle.load(f)

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1190647


In [3]:
# set tokenizer parallelism to False
os.environ["TOKENIZERS_PARALLELISM"] = "false"  

# now let's create a pytroch dataset
class ClaimsDataset(Dataset):
    def __init__(self, claims_data, document_store, hard_negatives, hard_negative_topk=20, num_negatives=10, block_size=128):
        self.claims_data = claims_data
        self.document_store = document_store
        self.hard_negatives = hard_negatives
        assert num_negatives % 2 == 0, "num_negatives must be even"
        self.num_negatives = num_negatives
        self.hard_negative_topk = hard_negative_topk
        self.block_size = block_size
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        all_passages_ids = list(self.document_store.keys())
        all_positive_ids = set([ev for claim in self.claims_data.values() for ev in claim['evidences']])
        self.all_negatives_ids = list(set(all_passages_ids) - set(all_positive_ids))
        self.claim_pairs = self.create_pairs()

    def __len__(self):
        return len(self.claim_pairs)

    def create_pairs(self):
        claim_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                # for each positive evidence, sample 1/2 * num_negatives evidences from hard negatives list
                # and 1/2 * num_negatives evidences from all negatives list
                negative_ids = random.sample(self.hard_negatives[claim_id][:self.hard_negative_topk], self.num_negatives//2)
                negative_ids += random.sample(self.all_negatives_ids, self.num_negatives//2)
                claim_pairs.append((claim_id, evidence_id, negative_ids))      
        # shuffle the instances 
        random.shuffle(claim_pairs)                
        return claim_pairs
    
    def on_epoch_end(self):
        self.claim_pairs = self.create_pairs()

    def tokenize_and_encode_claim(self, claim_text, to_tensor=True):
        # tokenize  
        claim_encoding = self.tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        # add special tokens and padding
        claim_idx = [self.tokenizer.cls_token_id] + claim_idx + [self.tokenizer.sep_token_id]
        claim_idx = claim_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(claim_idx))
        # create attention masks
        claim_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in claim_idx]
        if to_tensor:
            # convert to tensors
            claim_idx = torch.tensor(claim_idx)
            claim_attn_mask = torch.tensor(claim_attn_mask)
        return claim_idx, claim_attn_mask

    def tokenize_and_encode_evidence(self, evidence_text, to_tensor=True):
        # tokenize  
        evidence_encoding = self.tokenizer.encode_plus(evidence_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        evidence_idx = evidence_encoding['input_ids']
        # select a window from the passage if it is longer than block size
        if len(evidence_idx) > (self.block_size-2):
            # pick a random start position
            start_pos = random.randint(0, max(0,len(evidence_idx) - (self.block_size-2)))
            # select the window
            evidence_idx = evidence_idx[start_pos:start_pos+self.block_size-2]
        # add special tokens and padding
        evidence_idx = [self.tokenizer.cls_token_id] + evidence_idx + [self.tokenizer.sep_token_id]
        evidence_idx = evidence_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(evidence_idx))
        # create attention mask
        evidence_attn_mask  = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in evidence_idx]
        if to_tensor:
            # convert to tensors
            evidence_idx = torch.tensor(evidence_idx)
            evidence_attn_mask = torch.tensor(evidence_attn_mask)
        return evidence_idx, evidence_attn_mask

    def __getitem__(self, idx, to_tensor=True):
        # get claim id and positive evidence id
        claim_id, positive_id, negative_ids = self.claim_pairs[idx]
        # get the claim, positive and negative text
        claim_text = self.claims_data[claim_id]['claim_text']
        positive_text = self.document_store[positive_id]
        negatives_text = [self.document_store[id] for id in negative_ids]
        # tokenize and encode the claim
        claim_idx, claim_attn_mask = self.tokenize_and_encode_claim(claim_text, to_tensor=to_tensor)
        # tokenize and encode the positive evidence
        positive_idx, positive_attn_mask = self.tokenize_and_encode_evidence(positive_text, to_tensor=to_tensor)
        # tokenize and encode the negative evidences
        negative_idx = []
        negative_attn_mask = []
        for negative_text in negatives_text:
            negative_idx_i, negative_attn_mask_i = self.tokenize_and_encode_evidence(negative_text)
            negative_idx.append(negative_idx_i)
            negative_attn_mask.append(negative_attn_mask_i) 
        if to_tensor:    
            negative_idx = torch.stack(negative_idx)
            negative_attn_mask = torch.stack(negative_attn_mask)  
        return claim_idx, claim_attn_mask, positive_idx, positive_attn_mask, negative_idx, negative_attn_mask
    

def collate_fn(batch):
    # Unzip the batch
    query_idx, query_attn_mask, pos_idx, pos_attn_mask, neg_idx, neg_attn_mask = zip(*batch)

    # Convert to tensors and reshape negatives
    query_idx = torch.stack(query_idx)
    query_attn_mask = torch.stack(query_attn_mask)
    pos_idx = torch.stack(pos_idx)
    pos_attn_mask = torch.stack(pos_attn_mask)
    # reshape: (batch_size, num_negatives, max_seq_len) ->  (batch_size*num_negatives, max_seq_len)
    neg_idx = torch.cat(neg_idx).view(-1, neg_idx[0].shape[-1])  
    neg_attn_mask = torch.cat(neg_attn_mask).view(-1, neg_attn_mask[0].shape[-1]) 
    
    return query_idx, query_attn_mask, pos_idx, pos_attn_mask, neg_idx, neg_attn_mask


In [4]:
block_size = 128
train_dataset = ClaimsDataset(train_data, document_store, train_hard_negatives, num_negatives=4, block_size=block_size)
val_dataset = ClaimsDataset(val_data, document_store, val_hard_negatives, num_negatives=4, block_size=block_size)
print(len(train_dataset), len(val_dataset)) 

4122 491


In [5]:
B = 16
DEVICE = "cuda"
learning_rate = 5e-6

train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)

# model with finetuning disabled
model = BERTBiEncoder(out_of_batch_negs=True).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_dpr_model_checkpoint(model, optimizer, filename='dpr_checkpoint_1.pth')

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 132.72576 M
RAM used: 2930.82 MB


In [6]:
# create a W&B run
run = wandb.init(
    project="Automated Climate Fact Checker", 
    config={
        "bi-encoder model": "DistillBERT DPR",
        "learning_rate": learning_rate, 
        "epochs": 5,
        "batch_size": B, 
        "corpus": "COMP90042 2023 project"},)   

def log_metrics(metrics):
    wandb.log(metrics)

In [9]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=5, save_every=None, val_every=100, log_metrics=log_metrics) 

Train Epochs:   0%|          | 0/258 [00:00<?, ?it/s]

Val Epochs: 100%|██████████| 31/31 [00:05<00:00,  6.04it/s] Loss:  0.000, Val Accuracy:  0.000:  38%|███▊      | 99/258 [00:48<01:16,  2.07it/s]
Val Epochs: 100%|██████████| 31/31 [00:05<00:00,  6.00it/s] Loss:  1.857, Val Accuracy:  0.532:  77%|███████▋  | 199/258 [01:42<00:28,  2.04it/s]
Epoch 1, EMA Train Loss: 0.422, Train Accuracy:  0.832, Val Loss:  1.916, Val Accuracy:  0.544: 100%|██████████| 258/258 [02:16<00:00,  1.89it/s]
Val Epochs: 100%|██████████| 31/31 [00:05<00:00,  5.95it/s] Loss:  1.916, Val Accuracy:  0.544:  38%|███▊      | 99/258 [00:48<01:17,  2.05it/s]
Val Epochs: 100%|██████████| 31/31 [00:05<00:00,  5.89it/s] Loss:  2.051, Val Accuracy:  0.521:  77%|███████▋  | 199/258 [01:42<00:28,  2.04it/s]
Epoch 2, EMA Train Loss: 0.387, Train Accuracy:  0.867, Val Loss:  2.077, Val Accuracy:  0.544: 100%|██████████| 258/258 [02:16<00:00,  1.88it/s]
Val Epochs: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s] Loss:  2.077, Val Accuracy:  0.544:  38%|███▊      | 99/258 [00:48

In [8]:
validation(model, val_dataloader, device=DEVICE)

Val Epochs: 100%|██████████| 31/31 [00:05<00:00,  5.62it/s]


(1.9114601612091064, 0.5315682281059063)

In [9]:
#save_dpr_model_checkpoint(model, optimizer, filename='dpr_checkpoint_1.pth')

In [10]:
def precompute_passage_embeddings(document_store):
    # precompute the passage embeddings for all passages in the document store
    document_store_list = list(document_store.items())  # Convert the document_store slice into a list
    num_passages = len(document_store_list)
    passage_embeddings = torch.zeros((num_passages, 768), device=DEVICE)  # Preallocate memory
    for i in tqdm(range(0, len(document_store_list), 16)):
        # tokenize the passages in this batch
        passages_idx_batch = []
        passages_attn_mask_batch = []
        for _, passage_text in document_store_list[i:i+16]:
            passage_idx, passage_attn_mask = train_dataset.tokenize_and_encode_evidence(passage_text)
            passages_idx_batch.append(passage_idx)
            passages_attn_mask_batch.append(passage_attn_mask)

        passages_idx_batch = torch.stack(passages_idx_batch).to(DEVICE)
        passages_attn_mask_batch = torch.stack(passages_attn_mask_batch).to(DEVICE)    
        passage_embedding = model.encode_passages(passages_idx_batch, passages_attn_mask_batch)
        passage_embeddings[i:i+16] = passage_embedding
        del passage_embedding, passages_idx_batch, passages_attn_mask_batch  # Delete tensors to free up memory
        torch.cuda.empty_cache()  # Clear GPU memory

    return passage_embeddings

passage_ids = list(document_store.keys()) 

In [11]:
evidence_passage_embeds = precompute_passage_embeddings(document_store)

# save precomputed embeddings
#torch.save(evidence_passage_embeds, "dpr_embeddings/evidence_passage_simple_dpr_embeds_1.pt")

# load embeddings from file
#evidence_passage_embeds = torch.load("dpr_embeddings/evidence_passage_simple_dpr_embeds_1.pt")

  3%|▎         | 2231/74416 [01:15<36:44, 32.74it/s] 

In [12]:
# Save the passage_ids list to the pickle file
#with open("dpr_embeddings/passage_ids_1.pkl", "wb") as f:
#    pickle.dump(passage_ids, f)

In [None]:
import numpy as np

def find_topk_evidence(claim_text, passage_ids, k=5):
    # tokenize claim text
    claim_idx, claim_attn_mask = train_dataset.tokenize_and_encode_claim(claim_text)
    claim_idx = claim_idx.unsqueeze(0).to(DEVICE)
    claim_attn_mask = claim_attn_mask.unsqueeze(0).to(DEVICE)
    # get BERT embedding of claim
    claim_embedding = model.encode_queries(claim_idx, claim_attn_mask)
    # find topk passages 
    scores = torch.mm(evidence_passage_embeds, claim_embedding.T)
    topk_scores, topk_ids = torch.topk(scores.squeeze(1), k=k)
    topk_scores = topk_scores.squeeze().tolist()
    topk_ids = topk_ids.squeeze().tolist()
    # get passage ids
    topk_passage_ids = [passage_ids[i] for i in topk_ids]
    return topk_passage_ids, topk_scores


def eval(claims_list, passage_ids, topk=[5]):
    precision_total = np.zeros(len(topk))
    recall_total = np.zeros(len(topk))
    f1_total = np.zeros(len(topk))

    for idx in tqdm(range(len(claims_list))):
        claim_text = claims_list[idx][1]['claim_text']
        gold_evidence_list = claims_list[idx][1]['evidences']
        # predict topk passages using model
        topk_passage_ids, topk_scores = find_topk_evidence(claim_text, passage_ids, k=max(topk))
        for i,k in enumerate(topk):
            topk_passage_ids_k = topk_passage_ids[:k]
            # evaluation (precision, recall, F1)
            intersection = set(topk_passage_ids_k).intersection(gold_evidence_list)
            precision = len(intersection) / len(topk_passage_ids_k)
            recall = len(intersection) / len(gold_evidence_list)
            f1 = (2*precision*recall/(precision + recall)) if (precision + recall) > 0 else 0 
            precision_total[i] += precision
            recall_total[i] += recall
            f1_total[i] += f1

    precision_avg = precision_total / len(claims_list)
    recall_avg = recall_total / len(claims_list)
    f1_avg = f1_total / len(claims_list)  

    # convert to dictionary
    precision_avg = {f"Precision@{k}":v for k,v in zip(topk, precision_avg)}
    recall_avg = {f"Recall@{k}":v for k,v in zip(topk, recall_avg)}
    f1_avg = {f"F1@{k}":v for k,v in zip(topk, f1_avg)}  

    print(f"\nAvg Precision: {precision_avg}, Avg Recall: {recall_avg}, Avg F1: {f1_avg}")
    return precision_avg, recall_avg, f1_avg


In [None]:
# now let's test the model on a few claims
claims_list_train = list(train_data.items()) 
claims_list_val = list(val_data.items()) 

In [None]:
# evaluation on train dataset for top-k reranked passages
print(f"Eval on training set:")
eval(claims_list_train, passage_ids, topk=[3,5,8,10,15,20,30,50,100,250,500])

print(f"\nEval on validation set:")
eval(claims_list_val, passage_ids, topk=[3,5,8,10,15,20,30,50,100,250,500])

Eval on training set:


  0%|          | 0/1228 [00:00<?, ?it/s]

100%|██████████| 1228/1228 [00:26<00:00, 45.57it/s]



Avg Precision: {'Precision@3': 0.1921824104234536, 'Precision@5': 0.15781758957654662, 'Precision@8': 0.12347312703583062, 'Precision@10': 0.11009771986970632, 'Precision@15': 0.08566775244299726, 'Precision@20': 0.07219055374592778, 'Precision@30': 0.05483170466883844, 'Precision@50': 0.038387622149837104, 'Precision@100': 0.02237785016286649, 'Precision@250': 0.010576547231270187, 'Precision@500': 0.00572801302931587}, Avg Recall: {'Recall@3': 0.18813789359391947, 'Recall@5': 0.25578175895765415, 'Recall@8': 0.32005971769815356, 'Recall@10': 0.3551845819761122, 'Recall@15': 0.40754614549402746, 'Recall@20': 0.4576954397394132, 'Recall@30': 0.5147122692725297, 'Recall@50': 0.5997692725298588, 'Recall@100': 0.6955890336590655, 'Recall@250': 0.8146579804560249, 'Recall@500': 0.8741313789359376}, Avg F1: {'F1@3': 0.17770280750736822, 'F1@5': 0.18370042914016857, 'F1@8': 0.16939439496442865, 'F1@10': 0.16075740220691445, 'F1@15': 0.13689085454743155, 'F1@20': 0.12130079039121686, 'F1@30'

100%|██████████| 154/154 [00:03<00:00, 46.02it/s]


Avg Precision: {'Precision@3': 0.07359307359307361, 'Precision@5': 0.062337662337662345, 'Precision@8': 0.05113636363636364, 'Precision@10': 0.04740259740259738, 'Precision@15': 0.03766233766233767, 'Precision@20': 0.030844155844155827, 'Precision@30': 0.02402597402597403, 'Precision@50': 0.018441558441558457, 'Precision@100': 0.011688311688311696, 'Precision@250': 0.006233766233766238, 'Precision@500': 0.003714285714285717}, Avg Recall: {'Recall@3': 0.07337662337662339, 'Recall@5': 0.11699134199134197, 'Recall@8': 0.15108225108225104, 'Recall@10': 0.1713203463203463, 'Recall@15': 0.20357142857142851, 'Recall@20': 0.218073593073593, 'Recall@30': 0.25541125541125537, 'Recall@50': 0.3254329004329004, 'Recall@100': 0.4024891774891775, 'Recall@250': 0.5054112554112555, 'Recall@500': 0.6012987012987012}, Avg F1: {'F1@3': 0.0701762523191095, 'F1@5': 0.076149247577819, 'F1@8': 0.0725274725274725, 'F1@10': 0.070787653904537, 'F1@15': 0.061344537815126006, 'F1@20': 0.05258230852527861, 'F1@30'




({'Precision@3': 0.07359307359307361,
  'Precision@5': 0.062337662337662345,
  'Precision@8': 0.05113636363636364,
  'Precision@10': 0.04740259740259738,
  'Precision@15': 0.03766233766233767,
  'Precision@20': 0.030844155844155827,
  'Precision@30': 0.02402597402597403,
  'Precision@50': 0.018441558441558457,
  'Precision@100': 0.011688311688311696,
  'Precision@250': 0.006233766233766238,
  'Precision@500': 0.003714285714285717},
 {'Recall@3': 0.07337662337662339,
  'Recall@5': 0.11699134199134197,
  'Recall@8': 0.15108225108225104,
  'Recall@10': 0.1713203463203463,
  'Recall@15': 0.20357142857142851,
  'Recall@20': 0.218073593073593,
  'Recall@30': 0.25541125541125537,
  'Recall@50': 0.3254329004329004,
  'Recall@100': 0.4024891774891775,
  'Recall@250': 0.5054112554112555,
  'Recall@500': 0.6012987012987012},
 {'F1@3': 0.0701762523191095,
  'F1@5': 0.076149247577819,
  'F1@8': 0.0725274725274725,
  'F1@10': 0.070787653904537,
  'F1@15': 0.061344537815126006,
  'F1@20': 0.052582308

In [8]:
#wandb.finish()