In [1]:
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import os
from DPR_biencoder_simple import *
import wandb
from utils import *

%load_ext autoreload
%autoreload 2

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


#### Instead of training the DPR bi-encoder with in-batch negatives, we will now train exclusively using hard-negatives which where mined by using the cross-encoder to rerank top-k passages retrieved by our old DPR. First, let's set up the dataset and batch creation.

In [2]:
# load the data
document_store, train_data, val_data = load_data(clean=True)

# load hard negatives from pickle file
with open("dpr_embeddings/train_hard_negatives_reranked.pkl", "rb") as f:
    train_hard_negatives = pickle.load(f)

with open("dpr_embeddings/val_hard_negatives_reranked.pkl", "rb") as f:
    val_hard_negatives = pickle.load(f)

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1190647


In [3]:
# set tokenizer parallelism to False
os.environ["TOKENIZERS_PARALLELISM"] = "false"  

# now let's create a pytroch dataset
class ClaimsDataset(Dataset):
    def __init__(self, claims_data, document_store, hard_negatives, num_negatives=10, block_size=128):
        self.claims_data = claims_data
        self.document_store = document_store
        self.hard_negatives = hard_negatives
        self.num_negatives = num_negatives
        self.block_size = block_size
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        self.claim_pairs = self.create_pairs()

    def __len__(self):
        return len(self.claim_pairs)

    def create_pairs(self):
        claim_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                # for each positive evidence, sample num_negatives evidences from hard negatives list
                negative_ids = random.sample(self.hard_negatives[claim_id], self.num_negatives)
                claim_pairs.append((claim_id, evidence_id, negative_ids))      
        # shuffle the instances 
        random.shuffle(claim_pairs)                
        return claim_pairs
    
    def on_epoch_end(self):
        self.claim_pairs = self.create_pairs()

    def tokenize_and_encode_claim(self, claim_text, to_tensor=True):
        # tokenize  
        claim_encoding = self.tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        # add special tokens and padding
        claim_idx = [self.tokenizer.cls_token_id] + claim_idx + [self.tokenizer.sep_token_id]
        claim_idx = claim_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(claim_idx))
        # create attention masks
        claim_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in claim_idx]
        if to_tensor:
            # convert to tensors
            claim_idx = torch.tensor(claim_idx)
            claim_attn_mask = torch.tensor(claim_attn_mask)
        return claim_idx, claim_attn_mask

    def tokenize_and_encode_evidence(self, evidence_text, to_tensor=True):
        # tokenize  
        evidence_encoding = self.tokenizer.encode_plus(evidence_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        evidence_idx = evidence_encoding['input_ids']
        # select a window from the passage if it is longer than block size
        if len(evidence_idx) > (self.block_size-2):
            # pick a random start position
            start_pos = random.randint(0, max(0,len(evidence_idx) - (self.block_size-2)))
            # select the window
            evidence_idx = evidence_idx[start_pos:start_pos+self.block_size-2]
        # add special tokens and padding
        evidence_idx = [self.tokenizer.cls_token_id] + evidence_idx + [self.tokenizer.sep_token_id]
        evidence_idx = evidence_idx + [self.tokenizer.pad_token_id]*(self.block_size-len(evidence_idx))
        # create attention mask
        evidence_attn_mask  = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in evidence_idx]
        if to_tensor:
            # convert to tensors
            evidence_idx = torch.tensor(evidence_idx)
            evidence_attn_mask = torch.tensor(evidence_attn_mask)
        return evidence_idx, evidence_attn_mask

    def __getitem__(self, idx, to_tensor=True):
        # get claim id and positive evidence id
        claim_id, positive_id, negative_ids = self.claim_pairs[idx]
        # get the claim, positive and negative text
        claim_text = self.claims_data[claim_id]['claim_text']
        positive_text = self.document_store[positive_id]
        negatives_text = [self.document_store[id] for id in negative_ids]
        # tokenize and encode the claim
        claim_idx, claim_attn_mask = self.tokenize_and_encode_claim(claim_text, to_tensor=to_tensor)
        # tokenize and encode the positive evidence
        positive_idx, positive_attn_mask = self.tokenize_and_encode_evidence(positive_text, to_tensor=to_tensor)
        # tokenize and encode the negative evidences
        negative_idx = []
        negative_attn_mask = []
        for negative_text in negatives_text:
            negative_idx_i, negative_attn_mask_i = self.tokenize_and_encode_evidence(negative_text)
            negative_idx.append(negative_idx_i)
            negative_attn_mask.append(negative_attn_mask_i) 
        if to_tensor:    
            negative_idx = torch.stack(negative_idx)
            negative_attn_mask = torch.stack(negative_attn_mask)  
        return claim_idx, claim_attn_mask, positive_idx, positive_attn_mask, negative_idx, negative_attn_mask
    

def collate_fn(batch):
    # Unzip the batch
    query_idx, query_attn_mask, pos_idx, pos_attn_mask, neg_idx, neg_attn_mask = zip(*batch)

    # Convert to tensors and reshape negatives
    query_idx = torch.stack(query_idx)
    query_attn_mask = torch.stack(query_attn_mask)
    pos_idx = torch.stack(pos_idx)
    pos_attn_mask = torch.stack(pos_attn_mask)
    # reshape: (batch_size, num_negatives, max_seq_len) ->  (batch_size*num_negatives, max_seq_len)
    neg_idx = torch.cat(neg_idx).view(-1, neg_idx[0].shape[-1])  
    neg_attn_mask = torch.cat(neg_attn_mask).view(-1, neg_attn_mask[0].shape[-1]) 
    
    return query_idx, query_attn_mask, pos_idx, pos_attn_mask, neg_idx, neg_attn_mask


In [4]:
block_size = 128
train_dataset = ClaimsDataset(train_data, document_store, train_hard_negatives, num_negatives=4, block_size=block_size)
val_dataset = ClaimsDataset(val_data, document_store, val_hard_negatives, num_negatives=4, block_size=block_size)
print(len(train_dataset), len(val_dataset)) 

4122 491


In [5]:
B = 16
DEVICE = "cuda"
learning_rate = 5e-6

train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=False, pin_memory=True, num_workers=2, collate_fn=collate_fn)

# model with finetuning disabled
model = BERTBiEncoder(out_of_batch_negs=True).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 132.72576 M
RAM used: 2927.67 MB


In [6]:
# create a W&B run
run = wandb.init(
    project="Automated Climate Fact Checker", 
    config={
        "bi-encoder model": "DistillBERT DPR",
        "learning_rate": learning_rate, 
        "epochs": 5,
        "batch_size": B, 
        "corpus": "COMP90042 2023 project"},)   

def log_metrics(metrics):
    wandb.log(metrics)

In [7]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=10, save_every=None, val_every=100, log_metrics=log_metrics) 

Train Epochs:   0%|          | 0/258 [00:00<?, ?it/s]

In [None]:
#wandb.finish()