# 2024 COMP90042 Project

# Readme

#### **Document Retreival using monoBERT Model**: 

In this notebook, we finetune our custom pre-trained custom BERT model to perform binary classification of $(claim, evidence)$ pairs into relevant and non-relevant classes. We then use this trained re-ranker model to rerank top-k documents retreived by our BM25 model.

*** **PLEASE NOTE**: The BM25 model implementation is contained in a separate python script called: `BM25.py`. We are importing the `BM25` class from this script. We also import helper function that we implemented for pre-processing/cleaning our data from the python script called `utils.py`. Our custom BERT model implementation is contained in the python script called `min_bert_multi.py` and pre-trained weights are loaded from a checkpoint file.

In [1]:
%load_ext autoreload
%autoreload 2

# install required packages
!pip install unidecode
!python -m nltk.downloader stopwords
!pip install wandb

from utils import *
from min_bert_multi import *
from BM25 import BM25_retriever, eval

from torch.utils.data import Dataset, DataLoader
import torch
import random
import numpy as np

from tqdm import tqdm
import pickle 
import wandb
import psutil
import random

#wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


# 1.DataSet Processing

#### Load the Claims Dataset with Knowledge Source and Clean the Text

In [2]:
# load dataset and prepare corpus
knowledge_source, train_data, val_data = load_dataset()      
print(f"Number of evidence passages: {len(knowledge_source)}")
print(f"Number of training instances: {len(train_data)}")  
print(f"Number of validation instances: {len(val_data)}")

# clean all senteneces in the dataset (this involves converting from unicode to asc-ii, removing URLS, removing repeating non-alphanumeric characters, removing latex equation blocks, discarding evidence passages that contain code or references to mathematical equations, etc. Just a bunch of thing that are most likely will not be useful for claim classification task)
cleaner = SentenceCleaner()
knowledge_source, train_data, val_data = cleaner.clean_dataset(knowledge_source, train_data, val_data)
print(f"\nNumber of evidence passages after cleaning: {len(knowledge_source)}")
print(f"Number of training instances after cleaning: {len(train_data)}")  
print(f"Number of validation instances after cleaning: {len(val_data)}")

# dictionary for mapping integer to document id 
int2docID = {i:evidence_id for i,evidence_id in enumerate(list(knowledge_source.keys()))}

claim_ids = [claim_id for claim_id in train_data.keys()]

# load trained wordpiece tokenizer from file
with open('tokenizer_worpiece_20000_aug.pkl', 'rb') as f:
    tokenizer = pickle.load(f)

# load hard negatives from file if available
#with open("hard_negatives_2.pkl", "rb") as file:
#    hard_negatives = pickle.load(file)  
hard_negatives = None


Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154

Number of evidence passages after cleaning: 1206800
Number of training instances after cleaning: 1228
Number of validation instances after cleaning: 154


#### Set up the Re-Ranking Task Dataset

In [3]:
class CrossEncoderDataset(Dataset):
    def __init__(self, claims_data, document_store, tokenizer, hard_negatives=None, num_negative=2, block_size=128):
        self.claims_data = claims_data
        self.document_store = document_store
        self.tokenizer = tokenizer
        self.hard_negatives = hard_negatives
        self.num_negative = num_negative
        self.block_size = block_size
        self.negative_label = 0
        self.positive_label = 1
        self.document_ids = list(document_store.keys())
        self.claim_pairs = self.create_pairs()        
        
    def create_pairs(self):
        claim_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                claim_pairs.append((claim_id, evidence_id, self.positive_label))  
                # for each positive evidence, sample `num_negative` "negative" evidences randomly from document store
                negative_ids = random.sample(self.document_ids, self.num_negative)
                for negative_id in negative_ids:
                    claim_pairs.append((claim_id, negative_id, self.negative_label))
                # also sample a hard negative
                if self.hard_negatives is not None:
                    hard_negative_id = random.choice(self.hard_negatives[claim_id])
                    claim_pairs.append((claim_id, hard_negative_id, self.negative_label))   
        # shuffle the pairs 
        random.shuffle(claim_pairs)                
        return claim_pairs

    def __len__(self):
        return len(self.claim_pairs)

    def __getitem__(self, idx):
        # get claim id and evidence id
        claim_id, evidence_id, target_label = self.claim_pairs[idx]
        # get the claim and evidence text
        claim_text = self.claims_data[claim_id]['claim_text']
        evidence_text = self.document_store[evidence_id]

        # encode and create tensors
        input_idx, input_attn_mask, token_type_idx = self.tokenize_and_encode(claim_text, evidence_text)
        target_label = torch.tensor(target_label)
        return input_idx, input_attn_mask, token_type_idx, target_label

    def tokenize_and_encode(self, claim_text, evidence_text):
        # tokenize the claim and evidence text  
        claim_idx = self.tokenizer.encode([claim_text])[0]
        evidence_idx = self.tokenizer.encode([evidence_text])[0]

        # select a random window from the evidence passage if it won't fit in block size
        max_evidence_size = self.block_size - len(claim_idx) - 3
        if len(evidence_idx) > max_evidence_size:
            # pick a random start position
            start_pos = random.randint(0, max(0,len(evidence_idx)-max_evidence_size))
            # select the window
            evidence_idx = evidence_idx[start_pos:start_pos+max_evidence_size]
 
        # concatenate the claim and evidence, add special tokens and padding
        input_idx = [self.tokenizer.cls_token_id()] + claim_idx + [self.tokenizer.sep_token_id()] + evidence_idx + [self.tokenizer.sep_token_id()]
        input_idx = input_idx + [self.tokenizer.pad_token_id()] * (self.block_size - len(input_idx))    

        # create segment ids
        claim_len = len(claim_idx) + 2
        evidence_len = len(evidence_idx) + 1
        token_type_idx = [0] * claim_len + [1] * evidence_len + [0] * (self.block_size - claim_len - evidence_len)

        # make sure the passage sequences and claim sequences are not longer than max_length
        if len(input_idx) > self.block_size:
            raise Exception(f"Input sequence length {len(input_idx)} is longer than max_length {self.block_size}!")
    
        # create attention masks
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id() else 0 for idx in input_idx]
        # convert to tensors
        input_idx = torch.tensor(input_idx)
        input_attn_mask = torch.tensor(input_attn_mask)
        token_type_idx = torch.tensor(token_type_idx) 

        return input_idx, input_attn_mask, token_type_idx

    def on_epoch_end(self):
        self.claim_pairs = self.create_pairs()

# 2. Model Implementation

#### Define the Re-Ranking BERT Model

This is just a classifier head on top of the custom BERT.

In [4]:
class BERTReranker(torch.nn.Module):
    def __init__(self, bert_pretrained, dropout_rate=0.2):
        super().__init__()
        # load pretrained BERT model
        self.bert_encoder = bert_pretrained
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head which is a single linear layer
        self.classifier_head = torch.nn.Sequential(torch.nn.Linear(bert_pretrained.embedding_dim, 2))

        # make sure BERT parameters are trainable
        for param in self.bert_encoder.parameters():
            param.requires_grad = True

    def forward(self, input_idx, input_attn_mask, token_type_idx, targets=None):
        # compute BERT encodings, extract the pooler output (which is just the [CLS] embedding fed through a feedforward network or just the [CLS] embedding), apply dropout        
        #MLM_logits, entailment_logits, claim_class_logits = self.bert_encoder(input_idx, input_attn_mask, segment_idx=token_type_idx) # shape: (batch_size, hidden_size)
        bert_cls_encoding = self.bert_encoder(input_idx, input_attn_mask, segment_idx=token_type_idx, return_cls=True) # shape: (batch_size, hidden_size)
        bert_cls_encoding = self.dropout(bert_cls_encoding) # shape: (batch_size, hidden_size)
        # compute output logits
        logits = self.classifier_head(bert_cls_encoding) # shape: (batch_size, 2)  
        
        # compute cross-entropy loss on the entailment logits
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, accumulation_steps=1, val_every=100, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    train_precision = 0
    train_recall = 0
    val_loss = 0
    val_acc = 0
    val_precision = 0
    val_recall = 0
    model.train()
    # reset gradients
    optimizer.zero_grad()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        num_pos = 0
        num_true_pos = 0
        num_pred_pos = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for i, batch in enumerate(pbar):
            input_idx, input_attn_mask, token_type_idx, targets = batch
            # move batch to device
            input_idx, input_attn_mask, token_type_idx, targets = input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device), targets.to(device)
            # forward pass
            logits, loss = model(input_idx, input_attn_mask, token_type_idx, targets)
            # backward pass
            loss.backward()
            # apply gradient step 
            if (i+1) % accumulation_steps == 0:
                # optimizer step
                optimizer.step()
                # reset gradients
                optimizer.zero_grad()
   
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, _ = input_idx.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            train_acc = num_correct / num_total     

            # compute recall for positive class
            num_pos += targets.eq(1).sum().item()
            num_true_pos += (y_pred.eq(1) & targets.eq(1)).sum().item()
            train_recall = num_true_pos / num_pos

            # compute precision
            num_pred_pos += y_pred.eq(1).sum().item()
            train_precision = num_true_pos / max(1,num_pred_pos)
               
            if val_every is not None:
                if (i+1)%val_every == 0:
                    # compute validation loss
                    val_loss, val_acc, val_precision, val_recall = validation(model, val_dataloader, device=device)

            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Train Precision: {train_precision: .3f}, Train Recall: {train_recall: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}, Val Precision: {val_precision: .3f}, Val Recall: {val_recall: .3f}")  

            if log_metrics:
                metrics = {"Batch loss":loss.item(), "Moving Avg Loss":avg_loss, "Train Accuracy":train_acc, "Train Recall":train_recall, "Val Loss": val_loss, "Val Accuracy":val_acc, "Val Recall":val_precision, "Val Recall":val_recall}
                log_metrics(metrics)

        # run optimizer step for remainder batches
        if len(train_dataloader) % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        if scheduler is not None:
            scheduler.step()

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_ce_model_checkpoint(model, optimizer, epoch, avg_loss)


def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        num_pos = 0
        num_true_pos = 0
        num_pred_pos = 0
        for i,batch in enumerate(val_dataloader):
            input_idx, input_attn_mask, token_type_idx, targets = batch
            input_idx, input_attn_mask, token_type_idx, targets = input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device), targets.to(device)
            logits, loss = model(input_idx, input_attn_mask, token_type_idx, targets)
            B, _ = input_idx.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B

            # compute recall for positive class
            num_pos += targets.eq(1).sum().item()
            num_true_pos += (y_pred.eq(1) & targets.eq(1)).sum().item()
            val_losses[i] = loss.item()

            # compute precision
            num_pred_pos += y_pred.eq(1).sum().item()

    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    val_precision = num_true_pos / max(1,num_pred_pos)
    val_recall = num_true_pos / num_pos

    return val_loss, val_accuracy, val_precision, val_recall

def save_ce_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'cross_enc_checkpoint.pth')
    print(f"Saved cross-encoder model checkpoint!")


def load_ce_model_checkpoint(model, optimizer=None, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('cross_enc_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded cross-encoder model from checkpoint!")
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.train()
        return model, optimizer          
    else:
        return model

#### Prepare the Training and Validation Dataloaders

In [5]:
block_size = 128
batch_size = 32

# create dataset
train_dataset = CrossEncoderDataset(train_data, knowledge_source, tokenizer, hard_negatives=hard_negatives, block_size=block_size, num_negative=12)
val_dataset = CrossEncoderDataset(val_data, knowledge_source, tokenizer, block_size=block_size, num_negative=12)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)  # set pin_memory for faster pre-fetching
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)   
print(f"Total number of training batches: {len(train_dataloader)}, Total number of validation batches: {len(val_dataloader)}")
print(f"Total number of training instances: {len(train_dataset)}, Total number of validation instances: {len(val_dataset)}")

Total number of training batches: 1804, Total number of validation batches: 200
Total number of training instances: 57708, Total number of validation instances: 6383


#### Instantiate the RE-Ranking BERT model, load pre-trained custom BERT

In [6]:
# BERT model hyperparameters
embedding_dim = 512
head_size = embedding_dim
num_heads = 16
num_layers = 8
dropout_rate = 0.2
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

# instantiate custom BERT model
bert = BERTModel(vocab_size=tokenizer.vocab_size(), block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads, num_layers=num_layers, pad_token_id=tokenizer.pad_token_id(), device=device)
bert = bert.to(device)

# load pretrained custom BERT model from checkpoint
bert, _, _ =  load_bert_model_checkpoint(bert, name="BERT_multitask_checkpoint_entaiment_claims_long_600_epochs", device=device, strict=False)

# instantiate Re-ranker model and optimizer
model = BERTReranker(bert)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model = model.to(device)

# instantiate learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-8)

# load monoBERT cross encoder model from checkpoint
#model = load_ce_model_checkpoint(model, filename='cross_enc_checkpoint_4_epochs_12_negatives.pth')

num_params = sum(p.numel() for p in model.parameters())
print(f"Device: {device}")
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")


Loaded pretrained BERT model checkpoint at epoch 40 with loss 1.8120101885718962
Device: cuda
Total number of parameters in transformer network: 45.031247 M
RAM used: 3429.38 MB


  _torch_pytree._register_pytree_node(


In [9]:
"""
run = wandb.init(
    project="BERT Pretrain MLM", 
    config={
        "model": "Cross Encoder",
        "learning_rate": learning_rate, 
        "batch_size": batch_size, 
        "corpus": "Climate Claims"},)   

def log_metrics(metrics):
    wandb.log(metrics)
"""

#### Finetune for 4 epochs

In [8]:
train(model, optimizer, train_dataloader, val_dataloader, scheduler=scheduler, device=device, num_epochs=4, accumulation_steps=20, val_every=1300, save_every=None, log_metrics=None)

Epochs:   0%|          | 0/1804 [00:00<?, ?it/s]

Epoch 1, EMA Train Loss: 0.055, Train Accuracy:  0.968, Train Precision:  0.763, Train Recall:  0.808, Val Loss:  0.187, Val Accuracy:  0.954, Val Precision:  0.939, Val Recall:  0.436: 100%|██████████| 1804/1804 [05:00<00:00,  6.00it/s]  
Epoch 2, EMA Train Loss: 0.050, Train Accuracy:  0.983, Train Precision:  0.867, Train Recall:  0.902, Val Loss:  0.148, Val Accuracy:  0.959, Val Precision:  0.883, Val Recall:  0.540: 100%|██████████| 1804/1804 [05:16<00:00,  5.70it/s]  
Epoch 3, EMA Train Loss: 0.012, Train Accuracy:  0.989, Train Precision:  0.920, Train Recall:  0.934, Val Loss:  0.225, Val Accuracy:  0.954, Val Precision:  0.933, Val Recall:  0.428: 100%|██████████| 1804/1804 [05:22<00:00,  5.59it/s]  
Epoch 4, EMA Train Loss: 0.006, Train Accuracy:  0.995, Train Precision:  0.961, Train Recall:  0.970, Val Loss:  0.254, Val Accuracy:  0.956, Val Precision:  0.944, Val Recall:  0.448: 100%|██████████| 1804/1804 [05:25<00:00,  5.54it/s]  


#### Save model checkpoint

In [21]:
save_ce_model_checkpoint(model, optimizer, filename='cross_enc_checkpoint_4_epochs_12_negatives.pth')

Saved cross-encoder model checkpoint!


# 3.Testing and Evaluation

#### **Re-ranking with the trained model**:

#### Define helper functions for computing document relevancy scores usong trained Re-ranker model

In [9]:
passage_ids = list(knowledge_source.keys())   

# compute score for a given claim and evidence pair
def compute_score(claim_text, evidence_text, model, device='cuda'):
    model.eval()
    input_idx, input_attn_mask, token_type_idx = val_dataset.tokenize_and_encode(claim_text, evidence_text)
    input_idx = input_idx.unsqueeze(0)
    input_attn_mask = input_attn_mask.unsqueeze(0)
    token_type_idx = token_type_idx.unsqueeze(0)
    with torch.no_grad():
        scores, _ = model(input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device))
    
    # apply sigmoid to get probability
    score = torch.sigmoid(scores[0,1]).item()
    return score


def compute_score_batch(claim_text, evidence_ids, model, device='cuda', batch_size=16):
    model.eval()
    input_idx_batch, input_attn_mask_batch, token_type_idx_batch = [], [], []
    scores = torch.zeros(len(evidence_ids))
    for i in range(0, len(evidence_ids), batch_size):
        input_idx_batch, input_attn_mask_batch, token_type_idx_batch = [], [], []
        for evidence_id in evidence_ids[i:i+batch_size]:
            input_idx, input_attn_mask, token_type_idx = val_dataset.tokenize_and_encode(claim_text, knowledge_source[evidence_id])
            input_idx_batch.append(input_idx)
            input_attn_mask_batch.append(input_attn_mask)
            token_type_idx_batch.append(token_type_idx)
        input_idx_batch = torch.stack(input_idx_batch)
        input_attn_mask_batch = torch.stack(input_attn_mask_batch)
        token_type_idx_batch = torch.stack(token_type_idx_batch)    
        with torch.no_grad():
            scores_batch, _ = model(input_idx_batch.to(device), input_attn_mask_batch.to(device), token_type_idx_batch.to(device))
        scores[i:i+batch_size] = torch.sigmoid(scores_batch[:,1]) 
    scores = scores.cpu().tolist()
    return scores


# find top-k passages for given claim and compare with gold evidences
def topk_passages(data, claim_id, topk=100, device='cuda', batch_size=16):
    claim_text = data[claim_id]["claim_text"]
    model.eval()
    scores = torch.zeros(len(passage_ids))
    for i in tqdm(range(0, len(passage_ids), batch_size)):
        input_idx_batch, input_attn_mask_batch, token_type_idx_batch = [], [], []
        for evidence in passage_ids[i:i+batch_size]:
            evidence_text = knowledge_source[evidence]
            input_idx, input_attn_mask, token_type_idx = val_dataset.tokenize_and_encode(claim_text, evidence_text)
            input_idx_batch.append(input_idx)
            input_attn_mask_batch.append(input_attn_mask)
            token_type_idx_batch.append(token_type_idx)
        input_idx_batch = torch.stack(input_idx_batch)
        input_attn_mask_batch = torch.stack(input_attn_mask_batch)
        token_type_idx_batch = torch.stack(token_type_idx_batch)    
        with torch.no_grad():
            scores_batch, _ = model(input_idx_batch.to(device), input_attn_mask_batch.to(device), token_type_idx_batch.to(device))
        scores[i:i+batch_size] = scores_batch[:,1] 

    # sort the scores 
    sorted_scores, sorted_indices = torch.sort(scores, descending=True)    
    # sort the passage ids
    ranked_passage_ids = [passage_ids[i] for i in sorted_indices]
    # get top-k passages
    topk_passage_ids = ranked_passage_ids[:topk]
    topk_scores = sorted_scores[:topk]

    # get gold evidences
    gold_evidence_list = data[claim_id]["evidences"]
    print(f"Claim: {claim_text}")
    print(f"\nGold evidence:")
    for evidence_id in gold_evidence_list:
        print(f"{evidence_id} --> {knowledge_source[evidence_id]}")  

    print(f"\nPredicted evidence:")
    for i, evidence_id in enumerate(topk_passage_ids):
        print(f"{evidence_id} --> {knowledge_source[evidence_id]} --> {topk_scores[i]}")

    # evaluation (precision, recall, F1)
    intersection = set(topk_passage_ids).intersection(gold_evidence_list)
    print(f"\nMatching evidence passages: {intersection}")
    precision = len(intersection) / len(topk_passage_ids)
    recall = len(intersection) / len(gold_evidence_list)
    f1 = (2*precision*recall/(precision + recall)) if (precision + recall) > 0 else 0 
    print(f"\nPrecision: {precision}, Recall: {recall}, F1: {f1}")

    return topk_passage_ids, topk_scores

#### Run some tests for validation claims to compare yhe relevenacy scores assigned to gold documents vs non-gold documents selected randomly from knowledge scource. Observe that the gold documents will get much higher scores. This will indicate the re-ranking model is working correctly.

In [15]:
# pick a random claim from the validation set
claim_id = random.choice(list(val_data.keys())) 
claim = val_data[claim_id]
claim_text = claim["claim_text"]
print(f"{claim_id} --> {claim_text}")

# get gold evidence passages
gold_evidence_list = claim["evidences"]
# get an twice as many random passages from knowledge source
random_passages = random.sample(passage_ids, len(gold_evidence_list)*4)

# now score the claim against the gold evidence passages and random passages
scores_gold = [compute_score(claim_text, knowledge_source[evidence_id], model) for evidence_id in gold_evidence_list]
scores_random = [compute_score(claim_text, knowledge_source[evidence_id], model) for evidence_id in random_passages]

print("\nGold evidences and scores:")
for i, evidence_id in enumerate(gold_evidence_list):
    print(f"{evidence_id} --> {knowledge_source[evidence_id]} --> {scores_gold[i]}")

print("\nRandom evidences and scores:")
for i, evidence_id in enumerate(random_passages):
    print(f"{evidence_id} --> {knowledge_source[evidence_id]} --> {scores_random[i]}")    


claim-2060 --> Global warming is causing more hurricanes and stronger hurricanes.

Gold evidences and scores:
evidence-553897 --> The maximum rainfall and wind speed from hurricanes and typhoons are likely increasing. --> 0.9698548316955566
evidence-8304 --> As the Earth's climate warms, we are seeing many changes: stronger, more destructive hurricanes; heavier rainfall; more disastrous flooding; more areas of the world experiencing severe drought; and more heat waves." --> 0.7158042788505554

Random evidences and scores:
evidence-938605 --> On 23 September 2014, Al-Turki was killed during a series of a U.S.-led anti-Khorasan Group coalition airstrikes over Syria. --> 0.00016451627016067505
evidence-951564 --> He had written about his journeys through Afghanistan, once at 19 and again, as described in the book, An Unexpected Light : Travels in Afghanistan, for which he received the Thomas Cook Travel Book Award in 2000 and the ALA Notable Books for Adults in 2002. --> 9.351714106742293

#### Load the trained BM25 model.

In [10]:
# load trained BM25 model
with open("bm25_b=0.3_k=0.5.pkl", "rb") as file:
    bm25 = pickle.load(file)

#### Run some tests comparing BM25 retreived document rankings before and after Re-ranking

In [11]:
# evaluate retreival results for a given claim text
def claim_retreive_eval(claim, topk=5000):
    print(f"Claim: {claim['claim_text']}")
    claim_text = claim['claim_text']

    # get top k relevant evidence passages and their scores
    topk_ev_indices, topk_scores = bm25.retrieve_docs(query=claim_text, topk=topk)
    # convert indices to document ids
    topk_ev_ids = [int2docID[i] for i in topk_ev_indices]
    # get the gold evidence ids
    gold_ev_ids = claim['evidences']
    # compute precision, recall, and F1
    intersection = set(topk_ev_ids).intersection(gold_ev_ids)
    precision = len(intersection) / len(topk_ev_ids)
    recall = len(intersection) / len(gold_ev_ids)
    f1 = (2*precision*recall/(precision + recall)) if (precision + recall) > 0 else 0 

    print(f"\nGold evidence passages: {gold_ev_ids}")
    for ev in gold_ev_ids:
        print(f"{ev} --> {knowledge_source[ev]}")

    print(f"\nMatching evidence passages: {intersection}")

    # get cross-encoder scores for retireved passages
    ce_scores = compute_score_batch(claim_text, topk_ev_ids, model) #[compute_score(claim_text, knowledge_source[evidence_id], model) for evidence_id in topk_ev_ids]
    # sort the passages by cross-encoder scores
    sorted_indices = sorted(range(len(ce_scores)), key=lambda i: ce_scores[i], reverse=True)
    sorted_ce_scores = [ce_scores[i] for i in sorted_indices] 
    topk_ev_ids_ce = [topk_ev_ids[i] for i in sorted_indices]

    print(f"Top 10 BM-25 Passages:")
    for i, ev in enumerate(topk_ev_ids[:15]):
        print(f"{ev} --> {knowledge_source[ev]} --> {topk_scores[i]}")

    print(f"Top 10 Cross-Encoder Passages:")
    for i, ev in enumerate(topk_ev_ids_ce[:15]):
        print(f"{ev} --> {knowledge_source[ev]} --> {sorted_ce_scores[i]}")    


    # get BM25 and cross-encoder rankings of matching evidence passages
    bm25_rankings = [topk_ev_ids.index(ev) for ev in intersection]
    ce_rankings = [topk_ev_ids_ce.index(ev) for ev in intersection]

    print(f"\nBM25 and Cross-Encoder rankings of matching evidence passages:")
    for i, ev in enumerate(intersection):
        print(f"{ev} --> BM25 Rank: {bm25_rankings[i]}, CE Rank: {ce_rankings[i]}")
    
    print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")

    return precision, recall, f1

In [32]:
# pick a random claim from validation set
id = random.choice(list(val_data.keys()))   # random.choice(list(train_data.keys()))
claim = val_data[id]                        # train_data[id] 
print(f"Claim id --> {id}")
# retreive with BM25 and then perform re-ranking
claim_retreive_eval(claim, topk=1000)

Claim id --> claim-1219
Claim: Global warming is driving major melting on the surface of Greenland's glaciers and is speeding up their travel into the sea."

Gold evidence passages: ['evidence-88825', 'evidence-1127398', 'evidence-1115033']
evidence-88825 --> Surface temperature increases are greatest in the Arctic, which has contributed to the retreat of glaciers, permafrost, and sea ice.
evidence-1127398 --> This acceleration is due mostly to human-caused global warming, which is driving thermal expansion of seawater and the melting of land-based ice sheets and glaciers.
evidence-1115033 --> Global warming could lead to an increase in freshwater in the northern oceans, by melting glaciers in Greenland, and by increasing precipitation, especially through Siberian rivers.

Matching evidence passages: {'evidence-88825', 'evidence-1115033', 'evidence-1127398'}
Top 10 BM-25 Passages:
evidence-797505 --> Recent global warming has caused mountain glaciers and the ice sheets in Greenland and

(0.003, 1.0, 0.005982053838484547)

#### Perform Retreival on Validation Claims and Evaluate Performance by computing the following metrics at specified range of k-values: Mean Average Precision (mAP), Average Recall, Average F1-Score.

In [12]:
# compute average precision, recall, f1 scores and mean average precision (mAP) at every k value
def reranking_eval(claims_dataset, bm25, cross_encoder, k_values=[10], mixed=False):
    k_values = sorted(k_values)
    if mixed:
        num_vars = 3
    else:
        num_vars = 2    
    precision = np.zeros(shape=(len(k_values), num_vars))
    recall = np.zeros(shape=(len(k_values), num_vars))
    f1 = np.zeros(shape=(len(k_values), num_vars))
    ap_values = np.zeros((len(k_values), num_vars))
    topk = {}

    for claim_id, claim in tqdm(claims_dataset.items(), total=len(claims_dataset)):
        query = claim["claim_text"]
        gold_evidence_list = claim["evidences"]
        # retrieve bm25 topk documents
        topk_ev_indices, topk_scores_bm25 = bm25.retrieve_docs(query, topk=max(k_values))
        if topk_ev_indices == []:
            continue
        # convert document indices to document id strings
        topk_ev_ids_bm25 = [int2docID[i] for i in topk_ev_indices]

        # compute cross-encoder scores for topk documents
        ce_scores = compute_score_batch(query, topk_ev_ids_bm25, cross_encoder)
        # sort the scores
        sorted_indices = sorted(range(len(ce_scores)), key=lambda i: ce_scores[i], reverse=True)
        sorted_ce_scores = [ce_scores[i] for i in sorted_indices] 
        # compute cross-encoder re-ranking
        topk_ev_ids_ce = [topk_ev_ids_bm25[i] for i in sorted_indices]
        # save topk ranked docs with scores
        topk[claim_id] = {"bm25": (topk_ev_ids_bm25, topk_scores_bm25), "ce": (topk_ev_ids_ce, sorted_ce_scores)}

        for i, k in enumerate(k_values):
            # (precision, recall, F1) for bm25
            intersection = set(topk_ev_ids_bm25[:k]).intersection(gold_evidence_list)
            p = len(intersection) / len(topk_ev_ids_bm25[:k])
            r = len(intersection) / len(gold_evidence_list)
            precision[i,0] += p
            recall[i,0] += r
            f1[i,0] += (2*p*r/(p + r)) if (p+r) > 0 else 0

            # Compute Average Precision (AP) for bm25
            precisions_bm25 = []
            for j in range(k):
                intersection = set(topk_ev_ids_bm25[:j+1]).intersection(gold_evidence_list)
                if intersection:
                    precisions_bm25.append(len(intersection)/(j+1))
            ap_values[i, 0] += sum(precisions_bm25) / len(precisions_bm25) if precisions_bm25 else 0

            # (precision, recall, F1) for cross-encoder
            intersection = set(topk_ev_ids_ce[:k]).intersection(gold_evidence_list)
            p = len(intersection) / len(topk_ev_ids_ce[:k])
            r = len(intersection) / len(gold_evidence_list)
            precision[i,1] += p
            recall[i,1] += r
            f1[i,1] += (2*p*r/(p + r)) if (p+r) > 0 else 0

            # Compute Average Precision (AP) for cross-encoder
            precisions_ce = []
            for j in range(k):
                intersection = set(topk_ev_ids_ce[:j+1]).intersection(gold_evidence_list)
                if intersection:
                    precisions_ce.append(len(intersection)/(j+1))
            ap_values[i, 1] += sum(precisions_ce) / len(precisions_ce) if precisions_ce else 0

            if mixed:
                # (precision, recall, F1) for 1/2 bm25 + 1/2 cross-encoder
                #topk_ev_ids_mixed = topk_ev_ids_bm25[:max(1,k//2)] + topk_ev_ids_ce[:max(1,k//2)]
                topk_ev_ids_mixed = []
                for ii in range(k):
                    topk_ev_ids_mixed.append(topk_ev_ids_bm25[ii])
                    topk_ev_ids_mixed.append(topk_ev_ids_ce[ii])
                # remove duplicates while maintaining order
                topk_ev_ids_mixed = list(dict.fromkeys(topk_ev_ids_mixed))[:k]           

                intersection = set(topk_ev_ids_mixed).intersection(gold_evidence_list)
                p = len(intersection) / len(topk_ev_ids_mixed)
                r = len(intersection) / len(gold_evidence_list)
                precision[i,2] += p
                recall[i,2] += r
                f1[i,2] += (2*p*r/(p + r)) if (p+r) > 0 else 0

                # Compute Average Precision (AP) for 1/2 bm25 + 1/2 cross-encoder
                precisions_mix = []
                for j in range(k):
                    #intersection = set(topk_ev_ids_bm25[:max(1,j//2)] + topk_ev_ids_ce[:max(1,j//2)]).intersection(gold_evidence_list)
                    intersection = set(topk_ev_ids_mixed[:j+1]).intersection(gold_evidence_list)
                    if intersection:
                        precisions_mix.append(len(intersection)/(j+1))
                ap_values[i, 2] += sum(precisions_mix) / len(precisions_mix) if precisions_mix else 0        

    # Compute Mean Average Precision (MAP)
    map_values = ap_values / len(claims_dataset)
    # average over all claims
    precision = precision / len(claims_dataset) 
    recall = recall / len(claims_dataset)
    f1 = f1 / len(claims_dataset)

    # convert to dictionary
    if mixed:
        metrics = {k:{"precision":{"bm25":precision[i,0], "cross_encoder":precision[i,1], "mixed":precision[i,2]}, "recall":{"bm25":recall[i,0], "cross_encoder":recall[i,1], "mixed":recall[i,2]}, "f1": {"bm25":f1[i,0], "cross_encoder":f1[i,1], "mixed":f1[i,2]}, "map": {"bm25":map_values[i,0], "cross_encoder":map_values[i,1], "mixed":map_values[i,2]}} for i,k in enumerate(k_values)}
    else:
        metrics = {k:{"precision":{"bm25":precision[i,0], "cross_encoder":precision[i,1]}, "recall":{"bm25":recall[i,0], "cross_encoder":recall[i,1]}, "f1": {"bm25":f1[i,0], "cross_encoder":f1[i,1]}, "map": {"bm25":map_values[i,0], "cross_encoder":map_values[i,1]}} for i,k in enumerate(k_values)}

    return metrics, topk

In [36]:
k_values = [2, 4, 6, 10, 20, 50, 100, 200, 300, 500, 1000]
metrics, topk = reranking_eval(val_data, bm25, model, k_values, mixed=True)

  0%|          | 0/154 [00:00<?, ?it/s]

100%|██████████| 154/154 [24:03<00:00,  9.37s/it]


In [22]:
for k, k_metrics in metrics.items():
    #print(k_metrics)
    print(f"k = {k}")
    bm25_precision = k_metrics["precision"]["bm25"]
    bm25_recall = k_metrics["recall"]["bm25"]
    bm25_f1 = k_metrics["f1"]["bm25"]
    bm25_map = k_metrics["map"]["bm25"]
    ce_precision = k_metrics["precision"]["cross_encoder"]
    ce_recall = k_metrics["recall"]["cross_encoder"]
    ce_f1 = k_metrics["f1"]["cross_encoder"]
    ce_map = k_metrics["map"]["cross_encoder"]
    mixed_precision = k_metrics["precision"]["mixed"]
    mixed_recall = k_metrics["recall"]["mixed"]
    mixed_f1 = k_metrics["f1"]["mixed"]
    mixed_map = k_metrics["map"]["mixed"]
    print(f"\t bm25: precision --> {bm25_precision:.4f}, recall -> {bm25_recall:.4f}, f1 -> {bm25_f1:.4f}, mAP --> {bm25_map:.4f}")
    print(f"\t c-e:  precision --> {ce_precision:.4f}, recall -> {ce_recall:.4f}, f1 -> {ce_f1:.4f}, mAP --> {ce_map:.4f}")
    print(f"\t mixed:  precision --> {mixed_precision:.4f}, recall -> {mixed_recall:.4f}, f1 -> {mixed_f1:.4f}, mAP --> {mixed_map:.4f}")

k = 2
	 bm25: precision --> 0.0882, recall -> 0.1176, f1 -> 0.0980, mAP --> 0.0882
	 c-e:  precision --> 0.1176, recall -> 0.0529, f1 -> 0.0728, mAP --> 0.1471
	 mixed:  precision --> 0.0882, recall -> 0.0559, f1 -> 0.0658, mAP --> 0.1029
k = 4
	 bm25: precision --> 0.0735, recall -> 0.1441, f1 -> 0.0905, mAP --> 0.0997
	 c-e:  precision --> 0.1324, recall -> 0.1324, f1 -> 0.1291, mAP --> 0.1442
	 mixed:  precision --> 0.1029, recall -> 0.1706, f1 -> 0.1183, mAP --> 0.1283
k = 6
	 bm25: precision --> 0.0490, recall -> 0.1441, f1 -> 0.0687, mAP --> 0.0814
	 c-e:  precision --> 0.1176, recall -> 0.2176, f1 -> 0.1427, mAP --> 0.1442
	 mixed:  precision --> 0.0980, recall -> 0.2088, f1 -> 0.1243, mAP --> 0.1167
k = 10
	 bm25: precision --> 0.0353, recall -> 0.1735, f1 -> 0.0564, mAP --> 0.0675
	 c-e:  precision --> 0.0941, recall -> 0.2931, f1 -> 0.1363, mAP --> 0.1365
	 mixed:  precision --> 0.0941, recall -> 0.3471, f1 -> 0.1393, mAP --> 0.1148
k = 20
	 bm25: precision --> 0.0235, recall

#### Save Retreival results to file, will need these later for classifier

In [39]:
# save metrics dict to pickle file
#with open("reranking_metrics_mAP_1.pkl", "wb") as file:
#    pickle.dump(metrics, file)

# save topk ranked documents to pickle file
#with open("topk_reranker.pkl", "wb") as file:
#    pickle.dump(topk, file)

In [42]:
# load topk ranked documents from pickle file
#with open("topk_reranker.pkl", "rb") as file:
#    topk = pickle.load(file)

#### Also compute evaluation metrics on Training Claims

In [26]:
"""
k_values = [2, 4, 6, 10, 20, 50, 100, 200, 300, 500]
metrics_train, topk_train = reranking_eval(train_data, bm25, model, k_values)
"""

'\nk_values = [2, 4, 6, 10, 20, 50, 100, 200, 300, 500]\nmetrics_train, topk_train = reranking_eval(train_data, bm25, model, k_values)\n'