# 2024 COMP90042 Project

# Readme

#### **BERT Classifier Model**: 

In this notebook, we implement our main classifier model to perform multi-class classification of encoded $(claim, concatenated\_relevant\_evidences)$ pairs into one of the 4 claim classes: `[SUPPORTS, REFUTED, NOT_ENOUGH_INFO, DISPUTED]` by performing softmax classification on the `[CLS]` embedding. 

Our classifier model consists of a classificaion head (linear layer) attached on top of our pre-trained custom BERT model initialized from pre-training checkpoint.

*** **PLEASE NOTE**: We import helper functions that we implemented for pre-processing/cleaning our data from the python script called `utils.py`. Our custom BERT model implementation is contained in the python script called `min_bert_multi.py` and pre-trained weights are loaded from a checkpoint file. Our custom WordPiece Tokenizer implementation is contained is the python script `wordpiece_tokenizer.py`.

In [1]:
%load_ext autoreload
%autoreload 2

# install required packages
!pip install unidecode
!python -m nltk.downloader stopwords
!pip install wandb

from utils import *
from wordpiece_tokenizer import *
from min_bert_multi import *

from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR

from tqdm import tqdm
import pickle 
import wandb
import psutil
import random

#wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


# 1.DataSet Processing

#### Load the Claims Dataset with Knowledge Source and Clean the Text

In [2]:
# load dataset and prepare corpus
knowledge_source, train_data, val_data = load_dataset()      
print(f"Number of evidence passages: {len(knowledge_source)}")
print(f"Number of training instances: {len(train_data)}")  
print(f"Number of validation instances: {len(val_data)}")

# clean all senteneces in the dataset (this involves converting from unicode to asc-ii, removing URLS, removing repeating non-alphanumeric characters, etc. Just a bunch of thing that are most likely will not be useful for claim classification task)
cleaner = SentenceCleaner()
knowledge_source, train_data, val_data = cleaner.clean_dataset(knowledge_source, train_data, val_data)
print(f"\nNumber of evidence passages after cleaning: {len(knowledge_source)}")
print(f"Number of training instances after cleaning: {len(train_data)}")  
print(f"Number of validation instances after cleaning: {len(val_data)}")

# dictionary for mapping integer to document id 
int2docID = {i:evidence_id for i,evidence_id in enumerate(list(knowledge_source.keys()))}

claim_ids = [claim_id for claim_id in train_data.keys()]

# load trained wordpiece tokenizer from file
with open('tokenizer_worpiece_20000_aug.pkl', 'rb') as f:
    tokenizer = pickle.load(f)


Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154

Number of evidence passages after cleaning: 1206800
Number of training instances after cleaning: 1228
Number of validation instances after cleaning: 154


#### Set up the Classification Task Dataset

In [12]:
class Classifier_Dataset(Dataset):
    def __init__(self, claims_data, document_store, tokenizer, block_size):
        self.claims_data = claims_data          
        self.document_store = document_store    # document store
        self.tokenizer = tokenizer    # wordpiece tokenizer
        self.block_size = block_size  # truncation/max length of sentences
        self.vocab_size = tokenizer.vocab_size()

        self.negative_label = 0
        self.positive_label = 1
        self.claim_label2int = {'SUPPORTS':0, 'REFUTES':1, 'NOT_ENOUGH_INFO':2, 'DISPUTED':3}
        self.document_ids = list(document_store.keys())
        self.claim_ids = list(claims_data.keys())        
        
    def __len__(self):
        return len(self.claim_ids)


    def __getitem__(self, idx):
        # get claim text
        claim_id = self.claim_ids[idx]
        claim_text = self.claims_data[claim_id]['claim_text']
        # get concatenated evidence texts
        evidence_ids = self.claims_data[claim_id]['evidences']
        evidence_texts = [self.document_store[evidence_id] for evidence_id in evidence_ids] 
        evidence_text = " ".join(evidence_texts)
        # get claim class
        claim_label = self.claim_label2int[self.claims_data[claim_id]['claim_label']]

        # encode both sentence using the tokenizer
        s1_idx = self.tokenizer.encode([claim_text])[0]
        s2_idx = self.tokenizer.encode([evidence_text])[0]

        # check if combined length is within block_size-2
        if len(s1_idx) + len(s2_idx) + 3 > self.block_size:
            # calculate the space available for each sentence
            available_space = (self.block_size - 3) // 2
            if len(s1_idx) < available_space:
                # if s1 is shorter than available space, allocate the remaining space to s2
                available_space_s2 = self.block_size - 3 - len(s1_idx)
                if len(s2_idx) > available_space_s2:
                    # if s2 is longer than the available space, crop it
                    start = random.randint(0, len(s2_idx) - available_space_s2)
                    s2_idx = s2_idx[start:start+available_space_s2]
            elif len(s2_idx) < available_space:
                # if s2 is shorter than available space, allocate the remaining space to s1
                available_space_s1 = self.block_size - 3 - len(s2_idx)
                if len(s1_idx) > available_space_s1:
                    # if s1 is longer than the available space, crop it
                    start = random.randint(0, len(s1_idx) - available_space_s1)
                    s1_idx = s1_idx[start:start+available_space_s1]
            else:
                # if both sentences are longer than available space, crop both
                start = random.randint(0, len(s1_idx) - available_space)
                s1_idx = s1_idx[start:start+available_space]
                start = random.randint(0, len(s2_idx) - available_space)
                s2_idx = s2_idx[start:start+available_space]

        # combine the two sentences with a separator token into single sequence        
        s = torch.cat([torch.tensor([self.tokenizer.cls_token_id()]), torch.tensor(s1_idx, dtype=torch.long), torch.tensor([self.tokenizer.sep_token_id()]), torch.tensor(s2_idx, dtype=torch.long), torch.tensor([self.tokenizer.sep_token_id()])])
        # apply padding
        pad_len = max(0,self.block_size-len(s))
        s = torch.cat([s, torch.full((pad_len,), self.tokenizer.pad_token_id())])
        attention_mask = torch.cat([torch.ones(self.block_size-pad_len), torch.zeros(pad_len)])
        claim_label = torch.tensor(claim_label)

        # create segment ids
        segment_ids = torch.zeros(self.block_size)
        sep_idx = (s == self.tokenizer.sep_token_id()).nonzero(as_tuple=False)
        segment_ids[sep_idx[0]+1:] = 1
        segment_ids = segment_ids.long()

        return s, attention_mask, segment_ids, claim_label

    def on_epoch_end(self):
        self.sent_pairs = self.create_pairs()

    # function for encoding a custom out of corpus sentence
    def encode_custom(self, sent_1, sent_2):
        # encode both sentence using the tokenizer
        s1_idx = self.tokenizer.encode([sent_1])[0]
        s2_idx = self.tokenizer.encode([sent_2])[0]

        # check if combined length is within block_size-2
        if len(s1_idx) + len(s2_idx) + 3 > self.block_size:
            # calculate the space available for each sentence
            available_space = (self.block_size - 3) // 2
            if len(s1_idx) < available_space:
                # if s1 is shorter than available space, allocate the remaining space to s2
                available_space_s2 = self.block_size - 3 - len(s1_idx)
                if len(s2_idx) > available_space_s2:
                    # if s2 is longer than the available space, crop it
                    start = random.randint(0, len(s2_idx) - available_space_s2)
                    s2_idx = s2_idx[start:start+available_space_s2]
            elif len(s2_idx) < available_space:
                # if s2 is shorter than available space, allocate the remaining space to s1
                available_space_s1 = self.block_size - 3 - len(s2_idx)
                if len(s1_idx) > available_space_s1:
                    # if s1 is longer than the available space, crop it
                    start = random.randint(0, len(s1_idx) - available_space_s1)
                    s1_idx = s1_idx[start:start+available_space_s1]
            else:
                # if both sentences are longer than available space, crop both
                start = random.randint(0, len(s1_idx) - available_space)
                s1_idx = s1_idx[start:start+available_space]
                start = random.randint(0, len(s2_idx) - available_space)
                s2_idx = s2_idx[start:start+available_space]


        # combine the two sentences with a separator token into single sequence        
        s = torch.cat([torch.tensor([self.tokenizer.cls_token_id()]), torch.tensor(s1_idx, dtype=torch.long), torch.tensor([self.tokenizer.sep_token_id()]), torch.tensor(s2_idx, dtype=torch.long), torch.tensor([self.tokenizer.sep_token_id()])])
        # apply padding
        pad_len = max(0,self.block_size-len(s))
        s = torch.cat([s, torch.full((pad_len,), self.tokenizer.pad_token_id())])
        attention_mask = torch.cat([torch.ones(self.block_size-pad_len), torch.zeros(pad_len)])

        # create segment ids
        segment_ids = torch.zeros(self.block_size)
        sep_idx = (s == self.tokenizer.sep_token_id()).nonzero(as_tuple=False)
        segment_ids[sep_idx[0]+1:] = 1
        segment_ids = segment_ids.long()

        return s, attention_mask, segment_ids

# 2. Model Implementation

#### Define the Classifier Model

This is just a classifier head on top of the custom BERT.

In [4]:
class BERTClassifier(torch.nn.Module):
    def __init__(self, bert_pretrained, dropout_rate=0.2):
        super().__init__()
        # load pretrained BERT model
        self.bert_encoder = bert_pretrained
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head which is a single linear layer
        self.classifier_head = torch.nn.Sequential(torch.nn.Linear(bert_pretrained.embedding_dim, 4))

        # make sure BERT parameters are trainable
        for param in self.bert_encoder.parameters():
            param.requires_grad = True

    def forward(self, input_idx, input_attn_mask, token_type_idx, targets=None):
        # compute BERT encodings, extract the pooler output (which is just the [CLS] embedding fed through a feedforward network or just the [CLS] embedding), apply dropout        
        #MLM_logits, entailment_logits, claim_class_logits = self.bert_encoder(input_idx, input_attn_mask, segment_idx=token_type_idx) # shape: (batch_size, hidden_size)
        bert_cls_encoding = self.bert_encoder(input_idx, input_attn_mask, segment_idx=token_type_idx, return_cls=True) # shape: (batch_size, hidden_size)
        bert_cls_encoding = self.dropout(bert_cls_encoding) # shape: (batch_size, hidden_size)
        # compute output logits
        logits = self.classifier_head(bert_cls_encoding) # shape: (batch_size, 2)        
        
        # compute cross-entropy loss on the entailment logits
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
# training loop
def train(model, optimizer, train_dataloader, val_dataloader, scheduler=None, device="cpu", num_epochs=10, accumulation_steps=1, val_every=100, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    # reset gradients
    optimizer.zero_grad()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for i, batch in enumerate(pbar):
            input_idx, input_attn_mask, token_type_idx, targets = batch
            # move batch to device
            input_idx, input_attn_mask, token_type_idx, targets = input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device), targets.to(device)
            # forward pass
            logits, loss = model(input_idx, input_attn_mask, token_type_idx, targets)
            # backward pass
            loss.backward()
            # apply gradient step 
            if (i+1) % accumulation_steps == 0:
                # optimizer step
                optimizer.step()
                # reset gradients
                optimizer.zero_grad()
   
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, _ = input_idx.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            train_acc = num_correct / num_total     
               
            if val_every is not None:
                if (i+1)%val_every == 0:
                    # compute validation loss
                    val_loss, val_acc = validation(model, val_dataloader, device=device)

            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss":loss.item(), "Moving Avg Loss":avg_loss, "Train Accuracy":train_acc, "Val Loss": val_loss, "Val Accuracy":val_acc}
                log_metrics(metrics)

        # run optimizer step for remainder batches
        if len(train_dataloader) % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        if scheduler is not None:
            scheduler.step()

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_classifier_model_checkpoint(model, optimizer, epoch, avg_loss)


def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            input_idx, input_attn_mask, token_type_idx, targets = batch
            input_idx, input_attn_mask, token_type_idx, targets = input_idx.to(device), input_attn_mask.to(device), token_type_idx.to(device), targets.to(device)
            logits, loss = model(input_idx, input_attn_mask, token_type_idx, targets)
            B, _ = input_idx.shape
            y_pred = logits.argmax(dim=-1).view(-1) # shape (B,)
            num_correct += (y_pred.eq(targets.view(-1))).sum().item()      
            num_total += B
            val_losses[i] = loss.item()

    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total

    return val_loss, val_accuracy


def save_classifier_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'classifier_checkpoint.pth')
    print(f"Saved classifier model checkpoint!")


def load_classifier_model_checkpoint(model, optimizer=None, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('classifier_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded cross-encoder model from checkpoint!")
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        model.train()
        return model, optimizer          
    else:
        return model

#### Prepare the Training and Validation Dataloaders

In [13]:
# create dataset
block_size = 128
batch_size = 20

train_dataset = Classifier_Dataset(train_data, knowledge_source, tokenizer, block_size=block_size)
val_dataset = Classifier_Dataset(val_data, knowledge_source, tokenizer, block_size=block_size)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)  # set pin_memory for faster pre-fetching
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)  # set pin_memory for faster pre-fetching 
print(f"Total number of training batches: {len(train_dataloader)}")
print(f"Total number of validation batches: {len(val_dataloader)}")
print(len(train_dataset), len(val_dataset))

Total number of training batches: 62
Total number of validation batches: 8
1228 154


#### Instantiate the Classifier BERT model, load pre-trained custom BERT

In [57]:
# BERT model hyperparameters
embedding_dim = 512
head_size = embedding_dim
num_heads = 16
num_layers = 8
dropout_rate = 0.2
num_epochs = 10
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

bert = BERTModel(vocab_size=tokenizer.vocab_size(), block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads, num_layers=num_layers, pad_token_id=tokenizer.pad_token_id(), device=device)
bert = bert.to(device)

# load pretrained BERT model
bert, _, _ =  load_bert_model_checkpoint(bert, name="BERT_multitask_checkpoint_entaiment_claims_long_600_epochs", device=device, strict=False)

# create Classifier model
model = BERTClassifier(bert)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model = model.to(device)

# learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-8)

# load pretrained Classifier model
#model = load_classifier_model_checkpoint(model, filename='classifier_checkpoint_pretrain_basic.pth')

num_params = sum(p.numel() for p in model.parameters())
print(f"Device: {device}")
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")


Loaded pretrained BERT model checkpoint at epoch 40 with loss 1.8120101885718962
Device: cuda
Total number of parameters in transformer network: 45.032273 M
RAM used: 5411.65 MB


In [9]:
"""
run = wandb.init(
    project="BERT Pretrain MLM", 
    config={
        "model": "Cross Encoder",
        "learning_rate": learning_rate, 
        "epochs": num_epochs,
        "batch_size": batch_size, 
        "corpus": "Climate Claims"},)   


def log_metrics(metrics):
    wandb.log(metrics)
"""


#### Finetune for 3 epochs

In [61]:
train(model, optimizer, train_dataloader, val_dataloader, scheduler=scheduler, device=device, num_epochs=3, accumulation_steps=20, val_every=60, save_every=None, log_metrics=None)

Epochs:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 1, EMA Train Loss: 1.166, Train Accuracy:  0.384, Val Loss:  1.240, Val Accuracy:  0.448: 100%|██████████| 62/62 [00:15<00:00,  4.08it/s]
Epoch 2, EMA Train Loss: 0.944, Train Accuracy:  0.577, Val Loss:  1.124, Val Accuracy:  0.487: 100%|██████████| 62/62 [00:15<00:00,  4.00it/s]
Epoch 3, EMA Train Loss: 0.793, Train Accuracy:  0.658, Val Loss:  1.122, Val Accuracy:  0.565: 100%|██████████| 62/62 [00:15<00:00,  4.02it/s]


#### Save model checkpoint to file

In [70]:
#save_classifier_model_checkpoint(model, optimizer, filename='classifier_checkpoint_pretrain_basic.pth')

Saved classifier model checkpoint!


# 3.Testing and Evaluation

#### Our main validation metric for evaluating the classifier is the classification accuracy on the Validation Claims. We will run the following 3 different **experiments**:

(1) Evaluate the classification accuracy by classifying the input sequence containing a claim paired with a concatenation of all it's gold evidence passages: `[[CLS], claim, [SEP], concatenated gold evidences]`

(2) Evaluate the classification accuracy by classifying the input sequence containing a claim paired with a concatenation of the top-5 BM25 retreived evidence passages: `[[CLS], claim, [SEP], concatenated top-5 BM25 evidences]`

(3) Evaluate the classification accuracy by classifying the input sequence containing a claim paired with a concatenation of the top-5 monoBERT re-ranked retreived evidence passages: `[[CLS], claim, [SEP], concatenated top-5 monoBERT evidences]` 

We will loaded the pre-computed BM25 and monoBERT ranked document ids from pickle file.

In [27]:
claim_int2label = {0:'SUPPORTS', 1:'REFUTES', 2:'NOT_ENOUGH_INFO', 3:'DISPUTED'}

# define a helper function for predicting clasim classes for a given claim and evidence list
def predict_claim_class(val_claim_id, evidence_ids, model, device='cpu'):
    # get the claim sentence 
    claim_sentence = val_data[val_claim_id]['claim_text']
    # get the evidence sentences
    evidence_texts = [knowledge_source[evidence_id] for evidence_id in evidence_ids] 
    evidence_text = " ".join(evidence_texts)
    # encode the claim and evidence sentences
    input_idx, input_attn_mask, token_type_idx = val_dataset.encode_custom(claim_sentence, evidence_text)
    # classify the pair using the model
    model.eval()
    with torch.no_grad():
        logits, _ = model(input_idx.unsqueeze(0).to(device), input_attn_mask.unsqueeze(0).to(device), token_type_idx.unsqueeze(0).to(device))

    y_pred = logits.argmax(dim=-1).view(1) # shape (1,)
    # convert claim label from int to string
    claim_class = claim_int2label[y_pred.item()]
    return claim_class

#### Test the helper function to make sure it work correctly.

In [40]:
# pick a random claim from the validation set
val_claim_id = random.choice(list(val_data.keys()))
# get the evidence ids for the claim
evidence_ids = val_data[val_claim_id]['evidences']
# get the claim label
claim_label = val_data[val_claim_id]['claim_label']
# predict the claim class
predicted_claim_class = predict_claim_class(val_claim_id, evidence_ids, model, device=device)
# get the claim text
claim_text = val_data[val_claim_id]['claim_text']
# get the evidence texts
evidence_texts = [knowledge_source[evidence_id] for evidence_id in evidence_ids]
evidence_text = " ".join(evidence_texts)
print(f"Claim ID: {val_claim_id}")
print(f"Claim Text: {claim_text}")
print(f"Claim Label: {claim_label}")
print(f"Predicted Claim Class: {predicted_claim_class}")


Claim ID: claim-1292
Claim Text: Any reasonable person can recognize both positives and negatives among the policy proposals of both Tories and Labour.
Claim Label: NOT_ENOUGH_INFO
Predicted Claim Class: NOT_ENOUGH_INFO


#### **Experiment 1**: Evaluate the classification accuracy by classifying the input sequence containing a claim paired with a concatenation of all it's gold evidence passages: `[[CLS], claim, [SEP], concatenated gold evidences]`

In [69]:
# predict claim classes for all claims in the validation set with gold evidences and compute accuracy
num_correct = 0
num_total = 0
pbar = tqdm(val_data.keys(), desc="Validation")
for val_claim_id in pbar:
    gold_evidence_ids = val_data[val_claim_id]['evidences']
    claim_label = val_data[val_claim_id]['claim_label']
    predicted_claim_class = predict_claim_class(val_claim_id, gold_evidence_ids, model, device=device)
    num_correct += (predicted_claim_class == claim_label)
    num_total += 1

accuracy = num_correct / num_total

print(f"\nExperiment 1 --> Validation Accuracy = {accuracy:.3f}")

Validation: 100%|██████████| 154/154 [00:05<00:00, 28.89it/s]


Experiment 1 --> Validation Accuracy = 0.506





#### **Experiment 2**: Evaluate the classification accuracy by classifying the input sequence containing a claim paired with a concatenation of the top-5 BM25 retreived evidence passages: `[[CLS], claim, [SEP], concatenated top-5 BM25 evidences]`

In [30]:
# load precomputed topk BM25 and reranked documents from pickle file
with open("topk_reranker.pkl", "rb") as file:
    topk = pickle.load(file)

In [64]:
# get the top-5 BM25 evidences for each validation claim
top5_ev_ids_bm25 = {claim_id: ev_dict["bm25"][0][:5] for claim_id, ev_dict in topk.items()}
top5_ev_scores_bm25 = {claim_id: ev_dict["bm25"][1][:5] for claim_id, ev_dict in topk.items()}

# predict claim classes for all claims in the validation set with top-5 BM25 evidences and compute accuracy
num_correct = 0
num_total = 0
pbar = tqdm(val_data.keys(), desc="Validation")
for val_claim_id in pbar:
    bm25_evidence_ids = top5_ev_ids_bm25[val_claim_id]
    claim_label = val_data[val_claim_id]['claim_label']
    predicted_claim_class = predict_claim_class(val_claim_id, bm25_evidence_ids, model, device=device)
    num_correct += (predicted_claim_class == claim_label)
    num_total += 1

accuracy = num_correct / num_total

print(f"\nExperiment 2 --> Validation Accuracy = {accuracy:.3f}")


Validation:   0%|          | 0/154 [00:00<?, ?it/s]

Validation: 100%|██████████| 154/154 [00:03<00:00, 42.01it/s]


Experiment 2 --> Validation Accuracy = 0.422





#### **Experiment 3**: Evaluate the classification accuracy by classifying the input sequence containing a claim paired with a concatenation of the top-5 monoBERT re-ranked retreived evidence passages: `[[CLS], claim, [SEP], concatenated top-5 monoBERT evidences]` 

In [65]:
# get the top-5 monoBERT evidences for each validation claim
top5_ev_ids_monoBERT = {claim_id: ev_dict["ce"][0][:5] for claim_id, ev_dict in topk.items()}
top5_ev_scores_monoBERT = {claim_id: ev_dict["ce"][1][:5] for claim_id, ev_dict in topk.items()}

# predict claim classes for all claims in the validation set with top-5 BM25 evidences and compute accuracy
num_correct = 0
num_total = 0
pbar = tqdm(val_data.keys(), desc="Validation")
for val_claim_id in pbar:
    monoBERT_evidence_ids = top5_ev_ids_monoBERT[val_claim_id]
    claim_label = val_data[val_claim_id]['claim_label']
    predicted_claim_class = predict_claim_class(val_claim_id, monoBERT_evidence_ids, model, device=device)
    num_correct += (predicted_claim_class == claim_label)
    num_total += 1

accuracy = num_correct / num_total

print(f"\nExperiment 3 --> Validation Accuracy = {accuracy:.3f}")


Validation:   0%|          | 0/154 [00:00<?, ?it/s]

Validation: 100%|██████████| 154/154 [00:06<00:00, 23.19it/s]


Experiment 3 --> Validation Accuracy = 0.435



