#### Claim Label Classification

Recall that the goal of this project is to design and implement a system that takes a given `claim` and then retrieves (one or more) `evidence passages` from a document store, then using these evidence passages classifies the claim as one of these four labels: `[SUPPORTS, REFUTES, NOT_ENOUGH_INFO, DISPUTED]. 

Our first step will be to use train a simple BERT based clasifier which takes as input a sequence containing a `(claim, single evidence passage)` pair in the format: `[CLS] claim text [SEP] evidence text [SEP]`and classifies it by passing the `[CLS]` output embedding to a softmax classifier. If for a given claim `c`, if we have multiple evidence passages `[e_1, e_2, .., e_n]`, then we will have separate input pairs `(c, e_1)`, .., `(c, e_n)` all of which are assigned the same label. Then during inference time, given that we have multiple evidence passages and a claim, we classify every pair `(c,e_i)` and take a majority vote of the label. Note that this model assumes that the multiple evidences passages for a single claim independently determine the class label, which may not be true in some cases where the different evidences can interact in some complex way to determine the label. 

The next step will be to aggregate all the evidence passgaes into a single passage and classify input sequence containing `(claim, aggregated evidence passages)`. The simplest form of aggregation would be to directly concatenate all the input evidence passages into a single passage. Howevere, due to the maximum input sequence length imposed by BERT, we may need to do some truncation of these passages. This model will be able to learn the interactions between the different evidence passages which may lead to improved performance.


In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, DistilBertTokenizerFast
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import os
from DPR_biencoder_simple import *
import wandb
from utils import *
import matplotlib.pyplot as plt


%load_ext autoreload
%autoreload 2

wandb.login()
print(torch.cuda.is_available())

In [None]:
# first let's load the data and look at some examples and the label distribution
label_dict = {'SUPPORTS':0, 'REFUTES':1, 'NOT_ENOUGH_INFO':2, 'DISPUTED':3}
document_store, train_data, val_data = load_data(clean=True, clean_threshold=40)
train_labels = [label_dict[claim['claim_label']] for claim_id, claim in train_data.items()]

# plot label distribution histogram
plt.hist(train_labels, bins=4)
plt.xlabel('Label')
plt.ylabel('Frequency')
plt.title('Training Label Distribution')
plt.show()

In [None]:
# show some examples of claims, evidence, and labels
c = random.sample(train_data.items(), 10)
for claim_id, claim in c:
    print(f"{claim_id} --> {claim['claim_text']}")
    print(f"Evidences:")
    for ev in claim['evidences']:
        print(f"\t{document_store[ev]}")
    print(f"Claim Label: {claim['claim_label']}")

#### Pytorch Dataset

In [None]:
# set tokenizer parallelism to False
os.environ["TOKENIZERS_PARALLELISM"] = "false"  

class ClaimsDatasetSingle(Dataset):
    def __init__(self, claims_data, document_store, label_dict, block_size=192):
        self.claims_data = claims_data
        self.hard_negatives = hard_negatives
        self.document_store = document_store
        self.label_dict = label_dict
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.block_size = block_size
        self.claim_pairs = self.create_pairs()

    def create_pairs(self):
        claim_pairs = []
        for claim_id in self.claims_data.keys():
            for evidence_id in self.claims_data[claim_id]['evidences']:
                claim_pairs.append((claim_id, evidence_id))  
        return claim_pairs

    def __len__(self):
        return len(self.claim_pairs)

    def __getitem__(self, idx):
        # get claim id and evidence id
        claim_id, evidence_id = self.claim_pairs[idx]
        target_label = self.label_dict[self.claims_data[claim_id]['claim_label']]
        # get the claim and evidence text
        claim_text = self.claims_data[claim_id]['claim_text']
        evidence_text = self.document_store[evidence_id]
        # encode and create tensors
        input_idx, input_attn_mask, token_type_idx = self.tokenize_and_encode(claim_text, evidence_text)
        target_label = torch.tensor(target_label)
        return input_idx, input_attn_mask, token_type_idx, target_label

    def tokenize_and_encode(self, claim_text, evidence_text):
        # tokenize the claim and evidence text  
        claim_encoding = self.tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        evidence_encoding = self.tokenizer.encode_plus(evidence_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        evidence_idx = evidence_encoding['input_ids']

        # select a random window from the evidence passage if it won't fit in block size
        max_evidence_size = self.block_size - len(claim_idx) - 3
        if len(evidence_idx) > max_evidence_size:
            # pick a random start position
            start_pos = random.randint(0, max(0,len(evidence_idx)-max_evidence_size))
            # select the window
            evidence_idx = evidence_idx[start_pos:start_pos+max_evidence_size]
 
        # concatenate the claim and evidence, add special tokens and padding
        input_idx = [self.tokenizer.cls_token_id] + claim_idx + [self.tokenizer.sep_token_id] + evidence_idx + [self.tokenizer.sep_token_id]
        input_idx = input_idx + [self.tokenizer.pad_token_id] * (self.block_size - len(input_idx))    

        # create segment ids
        claim_len = len(claim_idx) + 2
        evidence_len = len(evidence_idx) + 1
        token_type_idx = [0] * claim_len + [1] * evidence_len + [0] * (self.block_size - claim_len - evidence_len)

        # make sure the passage sequences and claim sequences are not longer than max_length
        if len(input_idx) > self.block_size:
            raise Exception(f"Input sequence length {len(input_idx)} is longer than max_length {self.block_size}!")
    
        # create attention masks
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in input_idx]
        # convert to tensors
        input_idx = torch.tensor(input_idx)
        input_attn_mask = torch.tensor(input_attn_mask)
        token_type_idx = torch.tensor(token_type_idx)  # don't need this for roberta

        return input_idx, input_attn_mask, token_type_idx



class ClaimsDatasetAggregate(Dataset):
    def __init__(self, claims_data, document_store, label_dict, block_size=192):
        self.claims_data = claims_data.items()
        self.hard_negatives = hard_negatives
        self.document_store = document_store
        self.label_dict = label_dict
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.block_size = block_size
        self.claim_pairs = self.create_pairs()


    def __len__(self):
        return len(self.claims_data)

    def __getitem__(self, idx):
        # get claim and evidence texts
        claim_id, claim = self.claims_data[idx]
        claim_text = claim['claim_text']
        evidences_text = [self.document_store[evidence_id] for evidence_id in claim['evidences']]
        target_label = self.label_dict[claim['claim_label']]
        # encode and create tensors
        input_idx, input_attn_mask, token_type_idx = self.tokenize_and_encode(claim_text, evidences_text)
        target_label = torch.tensor(target_label)
        return input_idx, input_attn_mask, token_type_idx, target_label

    def tokenize_and_encode(self, claim_text, evidences_text):
        # tokenize the claim and evidence text  
        claim_encoding = self.tokenizer.encode_plus(claim_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        claim_idx = claim_encoding['input_ids']
        evidence_encoding = self.tokenizer.batch_encode_plus(evidences_text, add_special_tokens=False, return_offsets_mapping=False, return_attention_mask=False, return_token_type_ids=False)
        evidence_idx = evidence_encoding['input_ids']

        total_evidence_length = sum([len(evidence) for evidence in evidence_idx])
        max_evidence_size = self.block_size - len(claim_idx) - 3
        # take a separate random window from each evidence passage, make sure the proportion of evidence is the same, and that it fits in max_evidence_size
        if total_evidence_length > max_evidence_size:
            windowed_evidence_idx = []
            for evidence in evidence_idx:
                desired_length = max_evidence_size*len(evidence)//total_evidence_length
                # pick a random start position
                start_pos = random.randint(0, max(0,len(evidence)-desired_length))
                # select the window
                evidence = evidence[start_pos:start_pos+desired_length]
                windowed_evidence_idx.append(evidence)
            evidence_idx = windowed_evidence_idx

        # concatenate the evidences
        evidence_idx = [idx for evidence in evidence_idx for idx in evidence] 
                  
        # concatenate the claim and evidence, add special tokens and padding
        input_idx = [self.tokenizer.cls_token_id] + claim_idx + [self.tokenizer.sep_token_id] + evidence_idx + [self.tokenizer.sep_token_id]
        input_idx = input_idx + [self.tokenizer.pad_token_id] * (self.block_size - len(input_idx))    

        # create segment ids
        claim_len = len(claim_idx) + 2
        evidence_len = len(evidence_idx) + 1
        token_type_idx = [0] * claim_len + [1] * evidence_len + [0] * (self.block_size - claim_len - evidence_len)

        # make sure the passage sequences and claim sequences are not longer than max_length
        if len(input_idx) > self.block_size:
            raise Exception(f"Input sequence length {len(input_idx)} is longer than max_length {self.block_size}!")
    
        # create attention masks
        input_attn_mask = [1 if idx != self.tokenizer.pad_token_id else 0 for idx in input_idx]
        # convert to tensors
        input_idx = torch.tensor(input_idx)
        input_attn_mask = torch.tensor(input_attn_mask)
        token_type_idx = torch.tensor(token_type_idx)  # don't need this for roberta

        return input_idx, input_attn_mask, token_type_idx

#### Classifier Model

In [None]:
class ClaimClassifier(torch.nn.Module):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        # load pretrained BERT model
        self.bert_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)
        # define classifier head
        self.classifier_head = torch.nn.Linear(768, 4)
        # make sure BERT parameters are trainable
        for param in self.bert_encoder.parameters():
            param.requires_grad = True

    def forward(self, input_idx, input_attn_mask, token_type_idx, targets=None):
        # compute BERT encodings, extract the pooler output (which is just the [CLS] embedding fed through a feedforward network or just the [CLS] embedding), apply dropout        
        bert_output = self.bert_encoder(input_idx, attention_mask=input_attn_mask, token_type_ids=token_type_idx)
        pooled_output = self.dropout(bert_output.last_hidden_state[:,0]) # shape: (batch_size, hidden_size)
        # compute output logits
        logits = self.classifier_head(pooled_output) # shape: (batch_size, 4)
        # compute cross-entropy loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss