#### Cross-Encoder Passage Retrieval Model

We will now explore an alternative mnodel for document retreival, called a `cross-encoder`. Unlike the bi-encoder, we will use a single BERT model and feed it input sequence which is a concatenation of a claim-passage pair. Then using the output embedding of the [CLS] token, we perform a `binary classification` of whether or not, the passage is relevant to this claim or not. We can set up training instances of both positive/relevant and negative/non-relevant pairs and train using binary cross-entropy loss. Then the sigmoid of the output logit can be interpreted as a relevancy score between $[0,1]$. 

Since each claim can have multiple relevant evidence passages, we can create multiple positive pairs. Then to have a balanced distribution of the two classes, we would also create the same number of negative pairs. However, we need to figure out a way to select the negative passages. The simplest way is to just `randomly sample` a passage from the document store which is also not in the list of positive passages. Howerever, these random samples may be too easy for our model to detect, and so ideally we would want to use some form of `hard-negative mining` to select good negative passages. 

So first, lets set up our dataset and implement hard-negative mining. 

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast
from collections import Counter
import random
from tqdm import tqdm
import psutil
from utils import *
from DPR_biencoder_simple import *
import wandb
import pickle

%load_ext autoreload
%autoreload 2

wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


In [2]:
# load the data
document_store, train_data, val_data = load_data(clean=True)

Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154
Number of evidence passages remaining after cleaning: 1190647


Hard-negative Mining: We will use our pre-trained dpr model to mine for hard negatives. We will retrieve the top-k passages for each training query and then sample hard negatives from there.

In [3]:
# load pretrained dpr model
"""
DEVICE = "cuda"
dpr_model = BERTBiEncoder().to(DEVICE)
dpr_model = load_dpr_model_checkpoint(dpr_model)

# load dpr passage embeddings
evidence_passage_embeds, passage_ids = load_dpr_passage_embeddings()
"""

'\nDEVICE = "cuda"\ndpr_model = BERTBiEncoder().to(DEVICE)\ndpr_model = load_dpr_model_checkpoint(dpr_model)\n\n# load dpr passage embeddings\nevidence_passage_embeds, passage_ids = load_dpr_passage_embeddings()\n'

In [4]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
block_size = 196

# now get the hard negatives for each question
def get_hard_negatives(dpr_model, tokenizer, data, passage_ids, block_size, k=100):
    """
    Get the k hard negatives for each claim in the dataset
    """
    claims_list = list(data.items()) 
    hard_negatives = {}
    for claim_id, claim in tqdm(claims_list):
        claim_text = claim["claim_text"]
        gold_evidence_list = claim["evidences"]
        topk_passage_ids, topk_scores = find_topk_evidence_dpr(dpr_model, tokenizer, claim_text, evidence_passage_embeds, passage_ids, block_size, k=k)
        # remove the gold evidence from the topk passages
        topk_passage_ids = [p_id for p_id in topk_passage_ids if p_id not in gold_evidence_list]
        hard_negatives[claim_id] = topk_passage_ids

    return hard_negatives

We will generate ~100 hard negatives for each claim and save them to a file.

In [5]:
#train_hard_negatives = get_hard_negatives(dpr_model, tokenizer, train_data, passage_ids, block_size)

In [6]:
#val_hard_negatives = get_hard_negatives(dpr_model, tokenizer, val_data, passage_ids, block_size)

In [7]:
"""
# Save the hard negatives to the pickle file
with open("dpr_embeddings/train_hard_negatives.pkl", "wb") as f:
    pickle.dump(train_hard_negatives, f)

# Save the hard negatives to the pickle file
with open("dpr_embeddings/val_hard_negatives.pkl", "wb") as f:
    pickle.dump(val_hard_negatives, f)
"""

# load from pickle file
with open("dpr_embeddings/train_hard_negatives.pkl", "rb") as f:
    train_hard_negatives = pickle.load(f)

with open("dpr_embeddings/val_hard_negatives.pkl", "rb") as f:
    val_hard_negatives = pickle.load(f)

In [8]:
# show some examples of the hard negatives
claim_id = random.choice(list(train_hard_negatives.keys()))
print("Claim:", train_data[claim_id]["claim_text"])
print(f"\nGold evidence: ")
for evidence in train_data[claim_id]["evidences"]:
    print(f"\t{document_store[evidence]}")
print(f"\nTop-5 Hard negatives: ")
for passage_id in train_hard_negatives[claim_id][:10]:
    print(f"\t{document_store[passage_id]}")

Claim: Empirical measurements of the Earth's heat content show the planet is still accumulating heat and global warming is still happening.

Gold evidence: 
	"Evidence is now 'unequivocal' that humans are causing global warming – UN report".
	This is predicted to produce changes such as the melting of glaciers and ice sheets, more extreme temperature ranges, significant changes in weather and a global rise in average sea levels.
	Since the pre-industrial period, global average land temperatures have increased almost twice as fast as global average temperatures.
	This is much colder than the conditions that actually exist at the Earth's surface (the global mean surface temperature is about 14 °C).
	The global average and combined land and ocean surface temperature, show a warming of 0.85 [0.65 to 1.06] °C, in the period 1880 to 2012, based on multiple independently produced datasets.

Top-5 Hard negatives: 
	Infrared Thermography is the science of measuring and mapping surface temperatu

Now let's create a pytroch dataset for cross-encoder training.

In [21]:
# now let's create a pytroch dataset
class CrossEncoderDataset(Dataset):
    def __init__(self, claims_data, hard_negatives, document_store, block_size=192):
        self.claims_data = claims_data
        self.hard_negatives = hard_negatives
        self.document_store = document_store
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        self.block_size = block_size
        self.positive_label = 1
        self.negative_label = 0
        self.claim_pairs = self.create_pairs()

    def create_pairs(self):
        claim_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                claim_pairs.append((claim_id, evidence_id, self.positive_label))  
                # for each positive evidence, sample a negative evidence from hard negatives list
                negative_id = random.choice(self.hard_negatives[claim_id])
                claim_pairs.append((claim_id, negative_id, self.negative_label)) 
        # shuffle the pairs 
        random.shuffle(claim_pairs)                
        return claim_pairs

    def __len__(self):
        return len(self.claim_pairs)

    def __getitem__(self, idx):
        # get claim id and evidence id
        claim_id, evidence_id, target_label = self.claim_pairs[idx]
        # get the claim and evidence text
        claim_text = self.claims_data[claim_id]['claim_text']
        evidence_text = self.document_store[evidence_id]
        # tokenize the claim and evidence text  
        claim_encoding = self.tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        evidence_encoding = self.tokenizer.encode_plus(evidence_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        evidence_idx = evidence_encoding['input_ids']

        # select a random window from the evidence passage if it won't fit in block size
        max_evidence_size = self.block_size - len(claim_idx) - 3
        if len(evidence_idx) > max_evidence_size:
            # pick a random start position
            start_pos = random.randint(0, max(0,len(evidence_idx)-max_evidence_size))
            # select the window
            evidence_idx = evidence_idx[start_pos:start_pos+max_evidence_size]
 
        # concatenate the claim and evidence, add special tokens and padding
        input_idx = [self.tokenizer.cls_token_id] + claim_idx + [self.tokenizer.sep_token_id] + evidence_idx + [self.tokenizer.sep_token_id]
        input_idx = input_idx + [self.tokenizer.pad_token_id] * (self.block_size - len(input_idx))    

        # create segment ids
        claim_len = len(claim_idx) + 2
        evidence_len = len(evidence_idx) + 1
        token_type_idx = [0] * claim_len + [1] * evidence_len + [0] * (self.block_size - claim_len - evidence_len)

        # make sure the passage sequences and claim sequences are not longer than max_length
        if len(input_idx) > self.block_size:
            raise Exception(f"Input sequence length {len(input_idx)} is longer than max_length {self.block_size}!")
    
        # create attention masks
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in input_idx]
        # convert to tensors
        input_idx = torch.tensor(input_idx)
        target_label = torch.tensor(target_label)
        input_attn_mask = torch.tensor(input_attn_mask)
        token_type_idx = torch.tensor(token_type_idx)

        return input_idx, input_attn_mask, token_type_idx, target_label

    def on_epoch_end(self):
        self.claim_pairs = self.create_pairs()

In [22]:
block_size = 192
train_dataset = CrossEncoderDataset(train_data, train_hard_negatives, document_store, block_size=block_size)
val_dataset = CrossEncoderDataset(val_data, val_hard_negatives, document_store, block_size=block_size)

Now let's create the cross-encoder model. This is just a single BERT model and sentence level binary classification using the [CLS] token.

In [None]:
class BERTCrossEncoder(torch.nn.Module):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        # load pretrained BERT model
        self.bert_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head
        self.classifier_head = torch.nn.Linear(768, 2)
        # make sure BERT parameters are trainable
        for param in self.bert_encoder.parameters():
            param.requires_grad = True

    def forward(self, input_idx, input_attn_mask, token_type_idx, targets=None):
        # compute BERT encodings, extract the `[CLS]` encoding (first element of the sequence), apply dropout        
        bert_output = self.bert_encoder(input_idx, attention_mask=input_attn_mask, token_type_ids=token_type_idx)
        cls_enc = self.dropout(bert_output.last_hidden_state[:, 0]) # shape: (batch_size, hidden_size)
        # compute output logits
        logits = self.classifier_head(cls_enc)
        # compute cross-entropy loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, accumulation_steps=1, val_every=100, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    # reset gradients
    optimizer.zero_grad()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for i, batch in enumerate(pbar):
            input_idx, input_attn_mask, token_type_idx, targets = batch
            # move batch to device
            input_idx, input_attn_mask, token_type_idx, targets = input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device), targets.to(device)
            # forward pass
            logits, loss = model(input_idx, input_attn_mask, token_type_idx, targets)
            # backward pass
            loss.backward()
            # apply gradient step 
            if (i+1) % accumulation_steps == 0:
                # optimizer step
                optimizer.step()
                # reset gradients
                optimizer.zero_grad()
   
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, _ = input_idx.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            train_acc = num_correct / num_total        

            if val_every is not None:
                if i%val_every == 0:
                    # compute validation loss
                    val_loss, val_acc = validation(model, val_dataloader, device=device)
                    pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        # run optimizer step for remainder batches
        if len(train_dataloader) % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        if scheduler is not None:
            scheduler.step()

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_dpr_model_checkpoint(model, optimizer, epoch, avg_loss)


def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            input_idx, input_attn_mask, token_type_idx, targets = batch
            input_idx, input_attn_mask, token_type_idx, targets = input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device), targets.to(device)
            logits, loss = model(input_idx, input_attn_mask, token_type_idx, targets)
            B, _ = input_idx.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy

def save_dpr_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'dpr_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_dpr_model_checkpoint(model, optimizer=None, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('dpr_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded model from checkpoint!")
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.train()
        return model, optimizer          
    else:
        return model