# 2024 COMP90042 Project

# Readme

#### **BERT Multitask Pretraining**:

In this notebook, we pre-train out custom BERT model jointly on the three tasks: (i) Masked Language Modeling (ii) Sentence entailment (iii) Claim Classification 

We use claims paired with gold passage as positive entailment samples and claims paired with randomly chosen evidences from the knowledge source as negatvie samples.


*** **PLEASE NOTE**: We import helper functions that we implemented for pre-processing/cleaning our data from the python script called `utils.py`. Our ***custom BERT model implementation*** is contained in the python script called `min_bert_multi.py`. And, our **custom WordPiece Tokenizer** implementation is contained is the python script `wordpiece_tokenizer.py`.

In [1]:
%load_ext autoreload
%autoreload 2

# install required packages
!pip install unidecode
!python -m nltk.downloader stopwords
!pip install wandb

from utils import *
from wordpiece_tokenizer import *
from min_bert_multi import *

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn

from tqdm import tqdm
import pickle 
import wandb
import psutil
import random

#wandb.login()
print(torch.cuda.is_available())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True


# 1.DataSet Processing

#### Load the Claims Dataset with Knowledge Source and Clean the Text

In [2]:
# load dataset and prepare corpus
knowledge_source, train_data, val_data = load_dataset()      
print(f"Number of evidence passages: {len(knowledge_source)}")
print(f"Number of training instances: {len(train_data)}")  
print(f"Number of validation instances: {len(val_data)}")

# clean all senteneces in the dataset (this involves converting from unicode to asc-ii, removing URLS, removing repeating non-alphanumeric characters, etc. Just a bunch of thing that are most likely will not be useful for claim classification task)
cleaner = SentenceCleaner()
knowledge_source, train_data, val_data = cleaner.clean_dataset(knowledge_source, train_data, val_data)
print(f"\nNumber of evidence passages after cleaning: {len(knowledge_source)}")
print(f"Number of training instances after cleaning: {len(train_data)}")  
print(f"Number of validation instances after cleaning: {len(val_data)}")

# dictionary for mapping integer to document id 
int2docID = {i:evidence_id for i,evidence_id in enumerate(list(knowledge_source.keys()))}

claim_ids = [claim_id for claim_id in train_data.keys()]

# load trained wordpiece tokenizer from file
with open('tokenizer_worpiece_20000_aug.pkl', 'rb') as f:
    tokenizer = pickle.load(f)

# load hard negatives from file if available
#with open("hard_negatives_2.pkl", "rb") as file:
#    hard_negatives = pickle.load(file)  
hard_negatives = None


Number of evidence passages: 1208827
Number of training instances: 1228
Number of validation instances: 154

Number of evidence passages after cleaning: 1206800
Number of training instances after cleaning: 1228
Number of validation instances after cleaning: 154


#### Set up the Multi-Task Pre-Training Dataset.

We have two different types of datasets. In the first dataset `BERTDataset`, each instance is a claim paired with a single evidence passage. In the second dataset `BERTDataset_Long`, each claim is paired with a concatenation of multiple evidence passages.

In [3]:
# pre-training dataset 1
class BERTDataset(Dataset):
    def __init__(self, claims_data, document_store, tokenizer, block_size, hard_negatives=None, mlm_prob=0.20, mask_fraction=0.8, span_mask=False, max_span_len=2):
        self.claims_data = claims_data          # corpus sentences
        self.document_store = document_store    # document store
        self.hard_negatives = hard_negatives    # hard negatives
        self.tokenizer = tokenizer    # wordpiece tokenizer
        self.block_size = block_size  # truncation/max length of sentences
        self.mlm_prob = mlm_prob
        self.mask_fraction = mask_fraction
        self.vocab_size = tokenizer.vocab_size()
        self.span_mask = span_mask
        self.max_span_len = max_span_len
        self.negative_label = 0
        self.positive_label = 1
        self.claim_label2int = {'SUPPORTS':0, 'REFUTES':1, 'NOT_ENOUGH_INFO':2, 'DISPUTED':3}
        self.document_ids = list(document_store.keys())
        self.sent_pairs = self.create_pairs()        
        

    # create positive and negative sentence entailment pairs    
    def create_pairs(self):
        sent_pairs = []
        for claim_id in self.claims_data.keys():
            # get claim label
            claim_label = self.claim_label2int[self.claims_data[claim_id]['claim_label']]
            # shuffle the gold evidences list
            gold_evidences = self.claims_data[claim_id]['evidences']
            random.shuffle(gold_evidences)
            # create a positive entailment pair between each claim and its gold evidences evidence and each possible pair of gold evidences
            for i, evidence_id_1 in enumerate(gold_evidences):
                # randomize the order of claim and evidence in the pair
                if random.random() > 0.5:
                    sent_pairs.append((claim_id, evidence_id_1, self.positive_label, claim_label))  
                else:
                    sent_pairs.append((evidence_id_1, claim_id, self.positive_label, claim_label))
                for evidence_id_2 in gold_evidences[i+1:]:
                    if random.random() > 0.5:
                        sent_pairs.append((evidence_id_1, evidence_id_2, self.positive_label, claim_label))
                    else:
                        sent_pairs.append((evidence_id_2, evidence_id_1, self.positive_label, claim_label))
                
                # create a negative pair with the claim and a randomly chosen document, assign NOT_ENOUGH_INFO claim label to negative pairs
                negative_ev_id = random.choice(self.document_ids)
                if random.random() > 0.5:
                    sent_pairs.append((claim_id, negative_ev_id, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))  
                else:
                    sent_pairs.append((negative_ev_id, claim_id, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))
                
                # also sample a hard negative
                if self.hard_negatives is not None:
                    hard_negative_id = random.choice(self.hard_negatives[claim_id])
                    if random.random() > 0.5:
                        sent_pairs.append((claim_id, hard_negative_id, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))  
                    else:
                        sent_pairs.append((hard_negative_id, claim_id, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))

                # create a negative pair with each gold evidence with a randomly chosen document
                for evidence_id in gold_evidences:
                    negative_ev_id = random.choice(self.document_ids)
                    if random.random() > 0.5:
                        sent_pairs.append((evidence_id, negative_ev_id, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))  
                    else:
                        sent_pairs.append((negative_ev_id, evidence_id, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))

                # create a negative pair with two random documents
                negative_ev_id_1 = random.choice(self.document_ids)
                negative_ev_id_2 = random.choice(self.document_ids)
                sent_pairs.append((negative_ev_id_1, negative_ev_id_2, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))        

        # shuffle the sentence pairs 
        random.shuffle(sent_pairs)                
        return sent_pairs
    

    def __len__(self):
        return len(self.sent_pairs)

    def get_pair_text(self, idx):
         # get sentence ids and entailment label
        sent_1_id, sent_2_id, entailment_label, claim_label = self.sent_pairs[idx]
        # get the actual sentences based on ids
        sent_1_type = sent_1_id.split("-")[0]
        sent_2_type = sent_2_id.split("-")[0]
        if sent_1_type == "claim":
            sent_1 = self.claims_data[sent_1_id]['claim_text']
        else:
            sent_1 = self.document_store[sent_1_id]
        if sent_2_type == "claim":
            sent_2 = self.claims_data[sent_2_id]['claim_text']
        else:
            sent_2 = self.document_store[sent_2_id]            
        return sent_1, sent_2, entailment_label, claim_label


    def __getitem__(self, idx):
        # get sentence ids and entailment label
        sent_1_id, sent_2_id, entailment_label, claim_label = self.sent_pairs[idx]
        # get the actual sentences based on ids
        sent_1_type = sent_1_id.split("-")[0]
        sent_2_type = sent_2_id.split("-")[0]
        if sent_1_type == "claim":
            sent_1 = self.claims_data[sent_1_id]['claim_text']
        else:
            sent_1 = self.document_store[sent_1_id]
        if sent_2_type == "claim":
            sent_2 = self.claims_data[sent_2_id]['claim_text']
        else:
            sent_2 = self.document_store[sent_2_id]            

        # encode both sentence using the tokenizer
        s1_idx = self.tokenizer.encode([sent_1])[0]
        s2_idx = self.tokenizer.encode([sent_2])[0]

        # check if combined length is within block_size-2
        if len(s1_idx) + len(s2_idx) + 3 > self.block_size:
            # calculate the space available for each sentence
            available_space = (self.block_size - 3) // 2
            if len(s1_idx) < available_space:
                # if s1 is shorter than available space, allocate the remaining space to s2
                available_space_s2 = self.block_size - 3 - len(s1_idx)
                if len(s2_idx) > available_space_s2:
                    # if s2 is longer than the available space, crop it
                    start = random.randint(0, len(s2_idx) - available_space_s2)
                    s2_idx = s2_idx[start:start+available_space_s2]
            elif len(s2_idx) < available_space:
                # if s2 is shorter than available space, allocate the remaining space to s1
                available_space_s1 = self.block_size - 3 - len(s2_idx)
                if len(s1_idx) > available_space_s1:
                    # if s1 is longer than the available space, crop it
                    start = random.randint(0, len(s1_idx) - available_space_s1)
                    s1_idx = s1_idx[start:start+available_space_s1]
            else:
                # if both sentences are longer than available space, crop both
                start = random.randint(0, len(s1_idx) - available_space)
                s1_idx = s1_idx[start:start+available_space]
                start = random.randint(0, len(s2_idx) - available_space)
                s2_idx = s2_idx[start:start+available_space]

        # apply masking
        if self.span_mask:
            s1_idx, MLM_label_s1 = self.replace_tokens_span(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens_span(s2_idx)
        else:
            s1_idx, MLM_label_s1 = self.replace_tokens(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens(s2_idx)


        # combine the two sentences with a separator token into single sequence        
        s = torch.cat([torch.tensor([self.tokenizer.cls_token_id()]), s1_idx, torch.tensor([self.tokenizer.sep_token_id()]), s2_idx, torch.tensor([self.tokenizer.sep_token_id()])])
        MLM_label = torch.cat([torch.tensor([-100]), MLM_label_s1, torch.tensor([-100]), MLM_label_s2, torch.tensor([-100])])
        # apply padding
        pad_len = max(0,self.block_size-len(s))
        s = torch.cat([s, torch.full((pad_len,), self.tokenizer.pad_token_id())])
        MLM_label = torch.cat([MLM_label,  torch.full((pad_len,),-100)])    
        attention_mask = torch.cat([torch.ones(self.block_size-pad_len), torch.zeros(pad_len)])
        entailment_label = torch.tensor(entailment_label)
        claim_label = torch.tensor(claim_label)

        # create segment ids
        segment_ids = torch.zeros(self.block_size)
        sep_idx = (s == self.tokenizer.sep_token_id()).nonzero(as_tuple=False)
        segment_ids[sep_idx[0]+1:] = 1
        segment_ids = segment_ids.long()

        return {"masked_input" : s, "MLM_label" : MLM_label, "entailment_label" : entailment_label, "claim_label": claim_label, "attention_mask" : attention_mask, "segment_ids" : segment_ids}


    def on_epoch_end(self):
        self.sent_pairs = self.create_pairs()

    # function for encoding a custom out of corpus sentence
    def encode_custom(self, sent_1, sent_2):
       # encode both sentence using the tokenizer
        s1_idx = self.tokenizer.encode([sent_1])[0]
        s2_idx = self.tokenizer.encode([sent_2])[0]

        # check if combined length is within block_size-2
        if len(s1_idx) + len(s2_idx) + 3 > self.block_size:
            # calculate the space available for each sentence
            available_space = (self.block_size - 3) // 2
            if len(s1_idx) < available_space:
                # if s1 is shorter than available space, allocate the remaining space to s2
                available_space_s2 = self.block_size - 3 - len(s1_idx)
                if len(s2_idx) > available_space_s2:
                    # if s2 is longer than the available space, crop it
                    start = random.randint(0, len(s2_idx) - available_space_s2)
                    s2_idx = s2_idx[start:start+available_space_s2]
            elif len(s2_idx) < available_space:
                # if s2 is shorter than available space, allocate the remaining space to s1
                available_space_s1 = self.block_size - 3 - len(s2_idx)
                if len(s1_idx) > available_space_s1:
                    # if s1 is longer than the available space, crop it
                    start = random.randint(0, len(s1_idx) - available_space_s1)
                    s1_idx = s1_idx[start:start+available_space_s1]
            else:
                # if both sentences are longer than available space, crop both
                start = random.randint(0, len(s1_idx) - available_space)
                s1_idx = s1_idx[start:start+available_space]
                start = random.randint(0, len(s2_idx) - available_space)
                s2_idx = s2_idx[start:start+available_space]

        # apply masking
        if self.span_mask:
            s1_idx, MLM_label_s1 = self.replace_tokens_span(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens_span(s2_idx)
        else:
            s1_idx, MLM_label_s1 = self.replace_tokens(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens(s2_idx)


        # combine the two sentences with a separator token into single sequence        
        s = torch.cat([torch.tensor([self.tokenizer.cls_token_id()]), s1_idx, torch.tensor([self.tokenizer.sep_token_id()]), s2_idx, torch.tensor([self.tokenizer.sep_token_id()])])
        MLM_label = torch.cat([torch.tensor([-100]), MLM_label_s1, torch.tensor([-100]), MLM_label_s2, torch.tensor([-100])])
        # apply padding
        pad_len = max(0,self.block_size-len(s))
        s = torch.cat([s, torch.full((pad_len,), self.tokenizer.pad_token_id())])
        MLM_label = torch.cat([MLM_label,  torch.full((pad_len,),-100)])    
        attention_mask = torch.cat([torch.ones(self.block_size-pad_len), torch.zeros(pad_len)])

        # create segment ids
        segment_ids = torch.zeros(self.block_size)
        sep_idx = (s == self.tokenizer.sep_token_id()).nonzero(as_tuple=False)
        segment_ids[sep_idx[0]+1:] = 1
        segment_ids = segment_ids.long()

        return {"masked_input" : s, "MLM_label" : MLM_label, "attention_mask" : attention_mask, "segment_ids" : segment_ids}

    # randomly replace tokens with mlm_prob probability
    def replace_tokens(self, s):
        # the labels for a masked token is the original token index and -100 for non-masked tokens
        s = torch.tensor(s, dtype=torch.long)
        label = torch.full_like(s, -100)
        # select tokens to be masked with mlm_prob probability
        mask = torch.rand(len(s)) < self.mlm_prob # mask for replacing tokens
        selected_idx = mask.nonzero(as_tuple=False).flatten()
        
        # mask_fraction percenttage of the selected tokens are replaced with mask tokens
        num_masked = int(self.mask_fraction * len(selected_idx))
        # half of the remaining selected tokens are replaced with random tokens, remaining half will be untouched
        num_replaced = int(0.5*(len(selected_idx) - num_masked))

        # shuffle the selcted indices
        shuffled_idx = selected_idx[torch.randperm(len(selected_idx))]
        # make a copy of original tokens before masking
        s_original = s[shuffled_idx].clone()
        # replace with mask tokens
        mask_idx = shuffled_idx[:num_masked]
        s[mask_idx] = self.tokenizer.mask_token_id()
        # replace with random tokens
        replace_idx = shuffled_idx[num_masked:num_masked+num_replaced]
        s[replace_idx] = torch.randint(self.vocab_size, (num_replaced,))
        
        # set labels
        label[shuffled_idx] = s_original
        return s, label
    
    # span masking
    def replace_tokens_span(self, s):
        # the labels for a masked token is the original token index and -100 for non-masked tokens
        s = torch.tensor(s, dtype=torch.long)
        label = torch.full_like(s, -100)
        # determine the number of spans to mask, ensuring at least one span is selected
        avg_span_length = (1+self.max_span_len)/2 # expected span length
        num_spans = max(1, int(self.mlm_prob * len(s) / avg_span_length))
        # generate all possible start positions
        all_starts = torch.arange(0, len(s))
        # shuffle the start positions
        perm = torch.randperm(all_starts.nelement())
        all_starts = all_starts.view(-1)[perm].view(all_starts.size())
        # select the first num_spans positions
        span_starts = all_starts[:num_spans]
        # sort the selected positions
        span_starts, _ = torch.sort(span_starts)

        for i, start in enumerate(span_starts):
            # select a random span length between 1 and 3
            span_length = torch.randint(1, self.max_span_len+1, (1,)).item()
            end = start + span_length
            # if there's a next start position and the current end position is greater, adjust the span length
            if i < num_spans - 1 and end > span_starts[i+1]:
                end = span_starts[i+1]
                span_length = end - start
            # if the number of tokens is less than the span length, adjust the span length
            if end > len(s):
                end = len(s)
                span_length = end - start
            # make a copy of original tokens before masking
            s_original = s[start:end].clone()
            # generate a random number to decide whether to mask or replace the span
            rand_num = torch.rand(1).item()
            if rand_num < self.mask_fraction:
                # replace with mask tokens
                s[start:end] = self.tokenizer.mask_token_id()
            else:
                # replace with random tokens
                s[start:end] = torch.randint(self.vocab_size, (span_length,))
            # set labels
            label[start:end] = s_original

        return s, label


# pre-training dataset 2
class BERTDataset_Long(Dataset):
    def __init__(self, claims_data, document_store, tokenizer, block_size, hard_negatives=None, mlm_prob=0.20, mask_fraction=0.8, span_mask=False, max_span_len=2):
        self.claims_data = claims_data          # corpus sentences
        self.document_store = document_store    # document store
        self.hard_negatives = hard_negatives    # hard negatives
        self.tokenizer = tokenizer    # wordpiece tokenizer
        self.block_size = block_size  # truncation/max length of sentences
        self.mlm_prob = mlm_prob
        self.mask_fraction = mask_fraction
        self.vocab_size = tokenizer.vocab_size()
        self.span_mask = span_mask
        self.max_span_len = max_span_len
        self.negative_label = 0
        self.positive_label = 1
        self.claim_label2int = {'SUPPORTS':0, 'REFUTES':1, 'NOT_ENOUGH_INFO':2, 'DISPUTED':3}
        self.document_ids = list(document_store.keys())
        self.sent_pairs = self.create_pairs()        
        

    # create positive and negative sentence entailment pairs    
    def create_pairs(self):
        sent_pairs = []
        for claim_id in self.claims_data.keys():
            # get claim label
            claim_label = self.claim_label2int[self.claims_data[claim_id]['claim_label']]
            gold_evidences = self.claims_data[claim_id]['evidences']
            
            # create positive example with claim and all its gold evidences
            sent_pairs.append(([claim_id],gold_evidences, self.positive_label, claim_label))
            
            # create a positive example with only gold evidences if there are more than 1
            if len(gold_evidences) > 1:
                sent_pairs.append((gold_evidences[0:1], gold_evidences[1:], self.positive_label, claim_label))
            
            # create a negative example with claim and 3 hard negatives
            if self.hard_negatives is not None:
                hard_negatives = random.sample(self.hard_negatives[claim_id], 3)
                sent_pairs.append(([claim_id], hard_negatives, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))
            
            # create a negative example with each gold evidence and 3 random documents
            for evidence_id in gold_evidences:
                negs = random.sample(self.document_ids, 3)
                sent_pairs.append(([evidence_id], negs, self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))

            # create a negative example with 4 random documents
            negs = random.sample(self.document_ids, 4)
            sent_pairs.append((negs[0:1], negs[1:], self.negative_label, self.claim_label2int['NOT_ENOUGH_INFO']))

        # shuffle the sentence pairs 
        random.shuffle(sent_pairs)                
        return sent_pairs
    

    def __len__(self):
        return len(self.sent_pairs)

    def get_pair_text(self, idx):
        # get sentence ids and entailment label
        sent_1_id, sent_2_ids, entailment_label, claim_label = self.sent_pairs[idx]

        # get segment 1 text
        sent_1_id = sent_1_id[0]    
        sent_1_type = sent_1_id.split("-")[0]
        if sent_1_type == "claim":
            sent_1_text = self.claims_data[sent_1_id]['claim_text']
        else:
            sent_1_text = self.document_store[sent_1_id]

        # get segment 2 text
        sent_2_texts = [self.document_store[sent_2_id] for sent_2_id in sent_2_ids]
        sent_2_text = " ".join(sent_2_texts)      
       
        return sent_1_text, sent_2_text, entailment_label, claim_label


    def __getitem__(self, idx):
        # get sentence ids and entailment label
        sent_1_id, sent_2_ids, entailment_label, claim_label = self.sent_pairs[idx]

        # get segment 1 text
        sent_1_id = sent_1_id[0]    
        sent_1_type = sent_1_id.split("-")[0]
        if sent_1_type == "claim":
            sent_1_text = self.claims_data[sent_1_id]['claim_text']
        else:
            sent_1_text = self.document_store[sent_1_id]

        # get segment 2 text
        sent_2_texts = [self.document_store[sent_2_id] for sent_2_id in sent_2_ids]
        sent_2_text = " ".join(sent_2_texts)   

        # encode both sentence using the tokenizer
        s1_idx = self.tokenizer.encode([sent_1_text])[0]
        s2_idx = self.tokenizer.encode([sent_2_text])[0]

        # check if combined length is within block_size-2
        if len(s1_idx) + len(s2_idx) + 3 > self.block_size:
            # calculate the space available for each sentence
            available_space = (self.block_size - 3) // 2
            if len(s1_idx) < available_space:
                # if s1 is shorter than available space, allocate the remaining space to s2
                available_space_s2 = self.block_size - 3 - len(s1_idx)
                if len(s2_idx) > available_space_s2:
                    # if s2 is longer than the available space, crop it
                    start = random.randint(0, len(s2_idx) - available_space_s2)
                    s2_idx = s2_idx[start:start+available_space_s2]
            elif len(s2_idx) < available_space:
                # if s2 is shorter than available space, allocate the remaining space to s1
                available_space_s1 = self.block_size - 3 - len(s2_idx)
                if len(s1_idx) > available_space_s1:
                    # if s1 is longer than the available space, crop it
                    start = random.randint(0, len(s1_idx) - available_space_s1)
                    s1_idx = s1_idx[start:start+available_space_s1]
            else:
                # if both sentences are longer than available space, crop both
                start = random.randint(0, len(s1_idx) - available_space)
                s1_idx = s1_idx[start:start+available_space]
                start = random.randint(0, len(s2_idx) - available_space)
                s2_idx = s2_idx[start:start+available_space]

        # apply masking
        if self.span_mask:
            s1_idx, MLM_label_s1 = self.replace_tokens_span(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens_span(s2_idx)
        else:
            s1_idx, MLM_label_s1 = self.replace_tokens(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens(s2_idx)


        # combine the two sentences with a separator token into single sequence        
        s = torch.cat([torch.tensor([self.tokenizer.cls_token_id()]), s1_idx, torch.tensor([self.tokenizer.sep_token_id()]), s2_idx, torch.tensor([self.tokenizer.sep_token_id()])])
        MLM_label = torch.cat([torch.tensor([-100]), MLM_label_s1, torch.tensor([-100]), MLM_label_s2, torch.tensor([-100])])
        # apply padding
        pad_len = max(0,self.block_size-len(s))
        s = torch.cat([s, torch.full((pad_len,), self.tokenizer.pad_token_id())])
        MLM_label = torch.cat([MLM_label,  torch.full((pad_len,),-100)])    
        attention_mask = torch.cat([torch.ones(self.block_size-pad_len), torch.zeros(pad_len)])
        entailment_label = torch.tensor(entailment_label)
        claim_label = torch.tensor(claim_label)

        # create segment ids
        segment_ids = torch.zeros(self.block_size)
        sep_idx = (s == self.tokenizer.sep_token_id()).nonzero(as_tuple=False)
        segment_ids[sep_idx[0]+1:] = 1
        segment_ids = segment_ids.long()

        return {"masked_input" : s, "MLM_label" : MLM_label, "entailment_label" : entailment_label, "claim_label": claim_label, "attention_mask" : attention_mask, "segment_ids" : segment_ids}

    def on_epoch_end(self):
        self.sent_pairs = self.create_pairs()

    # function for encoding a custom out of corpus sentence
    def encode_custom(self, sent_1, sent_2):
       # encode both sentence using the tokenizer
        s1_idx = self.tokenizer.encode([sent_1])[0]
        s2_idx = self.tokenizer.encode([sent_2])[0]

        # check if combined length is within block_size-2
        if len(s1_idx) + len(s2_idx) + 3 > self.block_size:
            # calculate the space available for each sentence
            available_space = (self.block_size - 3) // 2
            if len(s1_idx) < available_space:
                # if s1 is shorter than available space, allocate the remaining space to s2
                available_space_s2 = self.block_size - 3 - len(s1_idx)
                if len(s2_idx) > available_space_s2:
                    # if s2 is longer than the available space, crop it
                    start = random.randint(0, len(s2_idx) - available_space_s2)
                    s2_idx = s2_idx[start:start+available_space_s2]
            elif len(s2_idx) < available_space:
                # if s2 is shorter than available space, allocate the remaining space to s1
                available_space_s1 = self.block_size - 3 - len(s2_idx)
                if len(s1_idx) > available_space_s1:
                    # if s1 is longer than the available space, crop it
                    start = random.randint(0, len(s1_idx) - available_space_s1)
                    s1_idx = s1_idx[start:start+available_space_s1]
            else:
                # if both sentences are longer than available space, crop both
                start = random.randint(0, len(s1_idx) - available_space)
                s1_idx = s1_idx[start:start+available_space]
                start = random.randint(0, len(s2_idx) - available_space)
                s2_idx = s2_idx[start:start+available_space]

        # apply masking
        if self.span_mask:
            s1_idx, MLM_label_s1 = self.replace_tokens_span(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens_span(s2_idx)
        else:
            s1_idx, MLM_label_s1 = self.replace_tokens(s1_idx)
            s2_idx, MLM_label_s2 = self.replace_tokens(s2_idx)


        # combine the two sentences with a separator token into single sequence        
        s = torch.cat([torch.tensor([self.tokenizer.cls_token_id()]), s1_idx, torch.tensor([self.tokenizer.sep_token_id()]), s2_idx, torch.tensor([self.tokenizer.sep_token_id()])])
        MLM_label = torch.cat([torch.tensor([-100]), MLM_label_s1, torch.tensor([-100]), MLM_label_s2, torch.tensor([-100])])
        # apply padding
        pad_len = max(0,self.block_size-len(s))
        s = torch.cat([s, torch.full((pad_len,), self.tokenizer.pad_token_id())])
        MLM_label = torch.cat([MLM_label,  torch.full((pad_len,),-100)])    
        attention_mask = torch.cat([torch.ones(self.block_size-pad_len), torch.zeros(pad_len)])

        # create segment ids
        segment_ids = torch.zeros(self.block_size)
        sep_idx = (s == self.tokenizer.sep_token_id()).nonzero(as_tuple=False)
        segment_ids[sep_idx[0]+1:] = 1
        segment_ids = segment_ids.long()

        return {"masked_input" : s, "MLM_label" : MLM_label, "attention_mask" : attention_mask, "segment_ids" : segment_ids}

    # randomly replace tokens with mlm_prob probability
    def replace_tokens(self, s):
        # the labels for a masked token is the original token index and -100 for non-masked tokens
        s = torch.tensor(s, dtype=torch.long)
        label = torch.full_like(s, -100)
        # select tokens to be masked with mlm_prob probability
        mask = torch.rand(len(s)) < self.mlm_prob # mask for replacing tokens
        selected_idx = mask.nonzero(as_tuple=False).flatten()
        
        # mask_fraction percenttage of the selected tokens are replaced with mask tokens
        num_masked = int(self.mask_fraction * len(selected_idx))
        # half of the remaining selected tokens are replaced with random tokens, remaining half will be untouched
        num_replaced = int(0.5*(len(selected_idx) - num_masked))

        # shuffle the selcted indices
        shuffled_idx = selected_idx[torch.randperm(len(selected_idx))]
        # make a copy of original tokens before masking
        s_original = s[shuffled_idx].clone()
        # replace with mask tokens
        mask_idx = shuffled_idx[:num_masked]
        s[mask_idx] = self.tokenizer.mask_token_id()
        # replace with random tokens
        replace_idx = shuffled_idx[num_masked:num_masked+num_replaced]
        s[replace_idx] = torch.randint(self.vocab_size, (num_replaced,))
        
        # set labels
        label[shuffled_idx] = s_original
        return s, label
    
    # span masking
    def replace_tokens_span(self, s):
        # the labels for a masked token is the original token index and -100 for non-masked tokens
        s = torch.tensor(s, dtype=torch.long)
        label = torch.full_like(s, -100)
        # determine the number of spans to mask, ensuring at least one span is selected
        avg_span_length = (1+self.max_span_len)/2 # expected span length
        num_spans = max(1, int(self.mlm_prob * len(s) / avg_span_length))
        # generate all possible start positions
        all_starts = torch.arange(0, len(s))
        # shuffle the start positions
        perm = torch.randperm(all_starts.nelement())
        all_starts = all_starts.view(-1)[perm].view(all_starts.size())
        # select the first num_spans positions
        span_starts = all_starts[:num_spans]
        # sort the selected positions
        span_starts, _ = torch.sort(span_starts)

        for i, start in enumerate(span_starts):
            # select a random span length between 1 and 3
            span_length = torch.randint(1, self.max_span_len+1, (1,)).item()
            end = start + span_length
            # if there's a next start position and the current end position is greater, adjust the span length
            if i < num_spans - 1 and end > span_starts[i+1]:
                end = span_starts[i+1]
                span_length = end - start
            # if the number of tokens is less than the span length, adjust the span length
            if end > len(s):
                end = len(s)
                span_length = end - start
            # make a copy of original tokens before masking
            s_original = s[start:end].clone()
            # generate a random number to decide whether to mask or replace the span
            rand_num = torch.rand(1).item()
            if rand_num < self.mask_fraction:
                # replace with mask tokens
                s[start:end] = self.tokenizer.mask_token_id()
            else:
                # replace with random tokens
                s[start:end] = torch.randint(self.vocab_size, (span_length,))
            # set labels
            label[start:end] = s_original

        return s, label


# 2. Model Implementation

#### Prepare the Training and Validation Dataloaders

In [4]:
# create dataset
block_size = 128
batch_size = 48

train_dataset = BERTDataset_Long(train_data, knowledge_source, tokenizer, block_size,  hard_negatives=hard_negatives, mlm_prob=0.15, span_mask=False)
val_dataset = BERTDataset_Long(val_data, knowledge_source, tokenizer, block_size, mlm_prob=0.15, span_mask=False)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)  # set pin_memory for faster pre-fetching
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1)  # set pin_memory for faster pre-fetching 
print(f"Total number of training batches: {len(train_dataloader)}")
print(f"Total number of validation batches: {len(val_dataloader)}")

Total number of training batches: 184
Total number of validation batches: 20


#### Instantiate the custom BERT model, set the hyperparameter values

In [7]:
# model hyperparameters
embedding_dim = 512
head_size = embedding_dim
num_heads = 16
num_layers = 8
dropout_rate = 0.1
num_epochs = 5
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

model = BERTModel(vocab_size=tokenizer.vocab_size(), block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads, num_layers=num_layers, pad_token_id=tokenizer.pad_token_id(), device=device)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# load from previous checkpoint
#model, optimizer, epoch, loss = load_bert_model_checkpoint(model, optimizer, name="BERT_multitask_checkpoint_entaiment_claims_long_600_epochs", device=device, strict=False)

# learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-6)

num_params = sum(p.numel() for p in model.parameters())
print(f"Device: {device}")
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

  _torch_pytree._register_pytree_node(


Loaded pretrained BERT model checkpoint at epoch 20 with loss 2.0593682737607804
Device: cuda
Total number of parameters in transformer network: 45.030221 M
RAM used: 3426.89 MB


In [8]:
"""
run = wandb.init(
    project="BERT Pretrain MLM", 
    config={
        "model": "BERT-like Transformer Encoder",
        "learning_rate": learning_rate, 
        "epochs": num_epochs,
        "batch_size": batch_size, 
        "corpus": "Climate Claims"},)   

def log_metrics(metrics):
    wandb.log(metrics)
"""

#### Start pre-training the model.

In [8]:
train(model, 200, train_dataloader, val_dataloader, optimizer, grad_accumulation_steps=20, val_every=1000, save_every=4, device=device, log_metrics=None, mixed_precision=True, checkpoint_name="BERT_multitask", scheduler=scheduler, include_claim_loss=False)

Epoch 1, Train Loss(Total, MLM, Entailment, Claim): (2.355, 2.199, 0.157, 1.247), Train Accuracy (MLM, Entailment, Claim): (0.603, 0.948, 0.463), Val Loss: 6.003, Val Accuracy (MLM, Entailment, Claim): (0.376, 0.864, 0.514): 100%|██████████| 1238/1238 [03:54<00:00,  5.29it/s]
Epoch 2, Train Loss(Total, MLM, Entailment, Claim): (2.259, 2.124, 0.135, 1.253), Train Accuracy (MLM, Entailment, Claim): (0.606, 0.952, 0.470), Val Loss: 5.963, Val Accuracy (MLM, Entailment, Claim): (0.376, 0.867, 0.594): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 3, Train Loss(Total, MLM, Entailment, Claim): (2.127, 2.028, 0.099, 1.252), Train Accuracy (MLM, Entailment, Claim): (0.608, 0.955, 0.472), Val Loss: 6.017, Val Accuracy (MLM, Entailment, Claim): (0.378, 0.868, 0.555): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 4, Train Loss(Total, MLM, Entailment, Claim): (2.171, 2.076, 0.094, 1.256), Train Accuracy (MLM, Entailment, Claim): (0.610, 0.957, 0.471), Val Loss: 6.032, Val Ac

Saved BERT model checkpoint!


Epoch 5, Train Loss(Total, MLM, Entailment, Claim): (2.190, 2.081, 0.109, 1.240), Train Accuracy (MLM, Entailment, Claim): (0.611, 0.959, 0.472), Val Loss: 6.054, Val Accuracy (MLM, Entailment, Claim): (0.376, 0.858, 0.557): 100%|██████████| 1238/1238 [04:01<00:00,  5.14it/s]
Epoch 6, Train Loss(Total, MLM, Entailment, Claim): (2.231, 2.124, 0.107, 1.272), Train Accuracy (MLM, Entailment, Claim): (0.612, 0.960, 0.472), Val Loss: 6.060, Val Accuracy (MLM, Entailment, Claim): (0.379, 0.858, 0.552): 100%|██████████| 1238/1238 [04:01<00:00,  5.12it/s]
Epoch 7, Train Loss(Total, MLM, Entailment, Claim): (2.152, 2.048, 0.103, 1.245), Train Accuracy (MLM, Entailment, Claim): (0.612, 0.961, 0.472), Val Loss: 6.042, Val Accuracy (MLM, Entailment, Claim): (0.378, 0.864, 0.557): 100%|██████████| 1238/1238 [04:01<00:00,  5.12it/s]
Epoch 8, Train Loss(Total, MLM, Entailment, Claim): (2.162, 2.092, 0.070, 1.222), Train Accuracy (MLM, Entailment, Claim): (0.612, 0.962, 0.474), Val Loss: 6.121, Val Ac

Saved BERT model checkpoint!


Epoch 9, Train Loss(Total, MLM, Entailment, Claim): (2.233, 2.147, 0.086, 1.206), Train Accuracy (MLM, Entailment, Claim): (0.611, 0.962, 0.477), Val Loss: 6.137, Val Accuracy (MLM, Entailment, Claim): (0.371, 0.865, 0.618): 100%|██████████| 1238/1238 [04:01<00:00,  5.12it/s]
Epoch 10, Train Loss(Total, MLM, Entailment, Claim): (2.250, 2.198, 0.052, 1.248), Train Accuracy (MLM, Entailment, Claim): (0.609, 0.963, 0.481), Val Loss: 6.014, Val Accuracy (MLM, Entailment, Claim): (0.368, 0.855, 0.685): 100%|██████████| 1238/1238 [04:04<00:00,  5.07it/s]
Epoch 11, Train Loss(Total, MLM, Entailment, Claim): (2.255, 2.190, 0.065, 1.289), Train Accuracy (MLM, Entailment, Claim): (0.606, 0.964, 0.477), Val Loss: 6.203, Val Accuracy (MLM, Entailment, Claim): (0.376, 0.849, 0.496): 100%|██████████| 1238/1238 [04:03<00:00,  5.09it/s]
Epoch 12, Train Loss(Total, MLM, Entailment, Claim): (2.154, 2.117, 0.037, 1.343), Train Accuracy (MLM, Entailment, Claim): (0.604, 0.965, 0.468), Val Loss: 6.573, Val

Saved BERT model checkpoint!


Epoch 13, Train Loss(Total, MLM, Entailment, Claim): (2.161, 2.125, 0.036, 1.285), Train Accuracy (MLM, Entailment, Claim): (0.604, 0.967, 0.460), Val Loss: 6.404, Val Accuracy (MLM, Entailment, Claim): (0.375, 0.842, 0.415): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 14, Train Loss(Total, MLM, Entailment, Claim): (2.063, 2.031, 0.033, 1.350), Train Accuracy (MLM, Entailment, Claim): (0.604, 0.968, 0.455), Val Loss: 6.260, Val Accuracy (MLM, Entailment, Claim): (0.382, 0.881, 0.424): 100%|██████████| 1238/1238 [04:02<00:00,  5.10it/s]
Epoch 15, Train Loss(Total, MLM, Entailment, Claim): (1.967, 1.945, 0.022, 1.386), Train Accuracy (MLM, Entailment, Claim): (0.605, 0.970, 0.448), Val Loss: 6.269, Val Accuracy (MLM, Entailment, Claim): (0.382, 0.879, 0.406): 100%|██████████| 1238/1238 [04:14<00:00,  4.86it/s]
Epoch 16, Train Loss(Total, MLM, Entailment, Claim): (2.002, 1.982, 0.020, 1.378), Train Accuracy (MLM, Entailment, Claim): (0.607, 0.972, 0.443), Val Loss: 6.314, Va

Saved BERT model checkpoint!


Epoch 17, Train Loss(Total, MLM, Entailment, Claim): (1.941, 1.922, 0.019, 1.378), Train Accuracy (MLM, Entailment, Claim): (0.608, 0.973, 0.437), Val Loss: 6.307, Val Accuracy (MLM, Entailment, Claim): (0.382, 0.877, 0.394): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 18, Train Loss(Total, MLM, Entailment, Claim): (1.977, 1.965, 0.012, 1.393), Train Accuracy (MLM, Entailment, Claim): (0.609, 0.974, 0.433), Val Loss: 6.456, Val Accuracy (MLM, Entailment, Claim): (0.379, 0.866, 0.381): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 19, Train Loss(Total, MLM, Entailment, Claim): (2.023, 2.006, 0.017, 1.363), Train Accuracy (MLM, Entailment, Claim): (0.609, 0.975, 0.430), Val Loss: 6.436, Val Accuracy (MLM, Entailment, Claim): (0.379, 0.861, 0.434): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 20, Train Loss(Total, MLM, Entailment, Claim): (2.146, 2.107, 0.039, 1.418), Train Accuracy (MLM, Entailment, Claim): (0.609, 0.976, 0.426), Val Loss: 6.388, Va

Saved BERT model checkpoint!


Epoch 21, Train Loss(Total, MLM, Entailment, Claim): (2.180, 2.125, 0.055, 1.406), Train Accuracy (MLM, Entailment, Claim): (0.608, 0.976, 0.421), Val Loss: 6.240, Val Accuracy (MLM, Entailment, Claim): (0.375, 0.874, 0.387): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 22, Train Loss(Total, MLM, Entailment, Claim): (2.085, 2.063, 0.022, 1.306), Train Accuracy (MLM, Entailment, Claim): (0.608, 0.976, 0.420), Val Loss: 6.341, Val Accuracy (MLM, Entailment, Claim): (0.372, 0.873, 0.436): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 23, Train Loss(Total, MLM, Entailment, Claim): (1.945, 1.931, 0.014, 1.345), Train Accuracy (MLM, Entailment, Claim): (0.608, 0.977, 0.417), Val Loss: 6.439, Val Accuracy (MLM, Entailment, Claim): (0.378, 0.855, 0.335): 100%|██████████| 1238/1238 [03:56<00:00,  5.22it/s]
Epoch 24, Train Loss(Total, MLM, Entailment, Claim): (1.901, 1.893, 0.008, 1.304), Train Accuracy (MLM, Entailment, Claim): (0.609, 0.978, 0.414), Val Loss: 6.221, Va

Saved BERT model checkpoint!


Epoch 25, Train Loss(Total, MLM, Entailment, Claim): (1.926, 1.859, 0.067, 1.333), Train Accuracy (MLM, Entailment, Claim): (0.610, 0.979, 0.413), Val Loss: 6.346, Val Accuracy (MLM, Entailment, Claim): (0.387, 0.876, 0.423): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 26, Train Loss(Total, MLM, Entailment, Claim): (1.870, 1.861, 0.009, 1.346), Train Accuracy (MLM, Entailment, Claim): (0.611, 0.979, 0.411), Val Loss: 6.362, Val Accuracy (MLM, Entailment, Claim): (0.387, 0.867, 0.423): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 27, Train Loss(Total, MLM, Entailment, Claim): (1.909, 1.900, 0.009, 1.364), Train Accuracy (MLM, Entailment, Claim): (0.612, 0.980, 0.410), Val Loss: 6.265, Val Accuracy (MLM, Entailment, Claim): (0.386, 0.883, 0.415): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 28, Train Loss(Total, MLM, Entailment, Claim): (1.919, 1.912, 0.007, 1.395), Train Accuracy (MLM, Entailment, Claim): (0.613, 0.981, 0.407), Val Loss: 6.373, Va

Saved BERT model checkpoint!


Epoch 29, Train Loss(Total, MLM, Entailment, Claim): (1.905, 1.877, 0.028, 1.357), Train Accuracy (MLM, Entailment, Claim): (0.614, 0.981, 0.404), Val Loss: 6.282, Val Accuracy (MLM, Entailment, Claim): (0.386, 0.871, 0.339): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 30, Train Loss(Total, MLM, Entailment, Claim): (2.018, 1.976, 0.043, 1.347), Train Accuracy (MLM, Entailment, Claim): (0.614, 0.981, 0.401), Val Loss: 6.398, Val Accuracy (MLM, Entailment, Claim): (0.376, 0.871, 0.291): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 31, Train Loss(Total, MLM, Entailment, Claim): (2.014, 1.996, 0.018, 1.316), Train Accuracy (MLM, Entailment, Claim): (0.614, 0.982, 0.400), Val Loss: 6.313, Val Accuracy (MLM, Entailment, Claim): (0.378, 0.852, 0.409): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 32, Train Loss(Total, MLM, Entailment, Claim): (2.048, 1.998, 0.050, 1.430), Train Accuracy (MLM, Entailment, Claim): (0.614, 0.982, 0.396), Val Loss: 6.433, Va

Saved BERT model checkpoint!


Epoch 33, Train Loss(Total, MLM, Entailment, Claim): (1.932, 1.916, 0.016, 1.403), Train Accuracy (MLM, Entailment, Claim): (0.614, 0.982, 0.393), Val Loss: 6.447, Val Accuracy (MLM, Entailment, Claim): (0.382, 0.850, 0.315): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 34, Train Loss(Total, MLM, Entailment, Claim): (1.884, 1.867, 0.016, 1.483), Train Accuracy (MLM, Entailment, Claim): (0.614, 0.982, 0.390), Val Loss: 6.409, Val Accuracy (MLM, Entailment, Claim): (0.387, 0.864, 0.282): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 35, Train Loss(Total, MLM, Entailment, Claim): (1.766, 1.760, 0.005, 1.441), Train Accuracy (MLM, Entailment, Claim): (0.615, 0.983, 0.388), Val Loss: 6.427, Val Accuracy (MLM, Entailment, Claim): (0.389, 0.873, 0.276): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 36, Train Loss(Total, MLM, Entailment, Claim): (1.741, 1.738, 0.003, 1.414), Train Accuracy (MLM, Entailment, Claim): (0.616, 0.983, 0.385), Val Loss: 6.373, Va

Saved BERT model checkpoint!


Epoch 37, Train Loss(Total, MLM, Entailment, Claim): (1.810, 1.807, 0.003, 1.414), Train Accuracy (MLM, Entailment, Claim): (0.617, 0.984, 0.382), Val Loss: 6.341, Val Accuracy (MLM, Entailment, Claim): (0.392, 0.883, 0.282): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 38, Train Loss(Total, MLM, Entailment, Claim): (1.854, 1.844, 0.010, 1.416), Train Accuracy (MLM, Entailment, Claim): (0.618, 0.984, 0.380), Val Loss: 6.400, Val Accuracy (MLM, Entailment, Claim): (0.392, 0.869, 0.326): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 39, Train Loss(Total, MLM, Entailment, Claim): (1.810, 1.799, 0.011, 1.393), Train Accuracy (MLM, Entailment, Claim): (0.619, 0.984, 0.378), Val Loss: 6.480, Val Accuracy (MLM, Entailment, Claim): (0.391, 0.846, 0.324): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 40, Train Loss(Total, MLM, Entailment, Claim): (1.881, 1.861, 0.020, 1.456), Train Accuracy (MLM, Entailment, Claim): (0.619, 0.985, 0.376), Val Loss: 6.277, Va

Saved BERT model checkpoint!


Epoch 41, Train Loss(Total, MLM, Entailment, Claim): (1.999, 1.974, 0.025, 1.322), Train Accuracy (MLM, Entailment, Claim): (0.619, 0.985, 0.375), Val Loss: 6.314, Val Accuracy (MLM, Entailment, Claim): (0.384, 0.853, 0.437): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 42, Train Loss(Total, MLM, Entailment, Claim): (1.926, 1.912, 0.014, 1.240), Train Accuracy (MLM, Entailment, Claim): (0.619, 0.985, 0.375), Val Loss: 6.171, Val Accuracy (MLM, Entailment, Claim): (0.383, 0.872, 0.607): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 43, Train Loss(Total, MLM, Entailment, Claim): (1.851, 1.830, 0.020, 1.320), Train Accuracy (MLM, Entailment, Claim): (0.619, 0.985, 0.375), Val Loss: 6.276, Val Accuracy (MLM, Entailment, Claim): (0.384, 0.875, 0.356): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 44, Train Loss(Total, MLM, Entailment, Claim): (1.771, 1.760, 0.011, 1.378), Train Accuracy (MLM, Entailment, Claim): (0.620, 0.985, 0.375), Val Loss: 6.268, Va

Saved BERT model checkpoint!


Epoch 45, Train Loss(Total, MLM, Entailment, Claim): (1.768, 1.754, 0.014, 1.378), Train Accuracy (MLM, Entailment, Claim): (0.621, 0.986, 0.374), Val Loss: 6.226, Val Accuracy (MLM, Entailment, Claim): (0.394, 0.881, 0.388): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 46, Train Loss(Total, MLM, Entailment, Claim): (1.778, 1.772, 0.006, 1.380), Train Accuracy (MLM, Entailment, Claim): (0.622, 0.986, 0.374), Val Loss: 6.292, Val Accuracy (MLM, Entailment, Claim): (0.394, 0.873, 0.398): 100%|██████████| 1238/1238 [03:53<00:00,  5.31it/s]
Epoch 47, Train Loss(Total, MLM, Entailment, Claim): (1.729, 1.725, 0.004, 1.373), Train Accuracy (MLM, Entailment, Claim): (0.622, 0.986, 0.373), Val Loss: 6.336, Val Accuracy (MLM, Entailment, Claim): (0.396, 0.872, 0.392): 100%|██████████| 1238/1238 [03:54<00:00,  5.28it/s]
Epoch 48, Train Loss(Total, MLM, Entailment, Claim): (1.713, 1.692, 0.021, 1.373), Train Accuracy (MLM, Entailment, Claim): (0.623, 0.986, 0.372), Val Loss: 6.360, Va

Saved BERT model checkpoint!


Epoch 49, Train Loss(Total, MLM, Entailment, Claim): (1.872, 1.855, 0.018, 1.318), Train Accuracy (MLM, Entailment, Claim): (0.624, 0.987, 0.371), Val Loss: 6.354, Val Accuracy (MLM, Entailment, Claim): (0.389, 0.880, 0.392): 100%|██████████| 1238/1238 [03:47<00:00,  5.44it/s]
Epoch 50, Train Loss(Total, MLM, Entailment, Claim): (1.855, 1.834, 0.020, 1.291), Train Accuracy (MLM, Entailment, Claim): (0.624, 0.987, 0.371), Val Loss: 6.451, Val Accuracy (MLM, Entailment, Claim): (0.387, 0.843, 0.436): 100%|██████████| 1238/1238 [03:49<00:00,  5.40it/s]
Epoch 51, Train Loss(Total, MLM, Entailment, Claim): (1.871, 1.850, 0.021, 1.378), Train Accuracy (MLM, Entailment, Claim): (0.624, 0.987, 0.370), Val Loss: 6.726, Val Accuracy (MLM, Entailment, Claim): (0.383, 0.835, 0.291): 100%|██████████| 1238/1238 [03:50<00:00,  5.38it/s]
Epoch 52, Train Loss(Total, MLM, Entailment, Claim): (1.901, 1.881, 0.020, 1.465), Train Accuracy (MLM, Entailment, Claim): (0.624, 0.987, 0.369), Val Loss: 6.348, Va

Saved BERT model checkpoint!


Epoch 53, Train Loss(Total, MLM, Entailment, Claim): (1.787, 1.776, 0.011, 1.476), Train Accuracy (MLM, Entailment, Claim): (0.625, 0.987, 0.367), Val Loss: 6.431, Val Accuracy (MLM, Entailment, Claim): (0.396, 0.873, 0.169): 100%|██████████| 1238/1238 [03:56<00:00,  5.25it/s]
Epoch 54, Train Loss(Total, MLM, Entailment, Claim): (1.719, 1.713, 0.006, 1.492), Train Accuracy (MLM, Entailment, Claim): (0.625, 0.987, 0.364), Val Loss: 6.452, Val Accuracy (MLM, Entailment, Claim): (0.393, 0.864, 0.179): 100%|██████████| 1238/1238 [03:55<00:00,  5.25it/s]
Epoch 55, Train Loss(Total, MLM, Entailment, Claim): (1.765, 1.760, 0.005, 1.492), Train Accuracy (MLM, Entailment, Claim): (0.626, 0.987, 0.362), Val Loss: 6.451, Val Accuracy (MLM, Entailment, Claim): (0.395, 0.870, 0.206): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 56, Train Loss(Total, MLM, Entailment, Claim): (1.701, 1.695, 0.007, 1.459), Train Accuracy (MLM, Entailment, Claim): (0.627, 0.988, 0.359), Val Loss: 6.487, Va

Saved BERT model checkpoint!


Epoch 57, Train Loss(Total, MLM, Entailment, Claim): (1.665, 1.663, 0.002, 1.512), Train Accuracy (MLM, Entailment, Claim): (0.627, 0.988, 0.357), Val Loss: 6.387, Val Accuracy (MLM, Entailment, Claim): (0.400, 0.879, 0.189): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 58, Train Loss(Total, MLM, Entailment, Claim): (1.678, 1.675, 0.003, 1.398), Train Accuracy (MLM, Entailment, Claim): (0.628, 0.988, 0.355), Val Loss: 6.487, Val Accuracy (MLM, Entailment, Claim): (0.398, 0.867, 0.234): 100%|██████████| 1238/1238 [03:54<00:00,  5.27it/s]
Epoch 59, Train Loss(Total, MLM, Entailment, Claim): (1.844, 1.828, 0.016, 1.449), Train Accuracy (MLM, Entailment, Claim): (0.629, 0.988, 0.354), Val Loss: 6.434, Val Accuracy (MLM, Entailment, Claim): (0.392, 0.883, 0.192): 100%|██████████| 1238/1238 [03:53<00:00,  5.31it/s]
Epoch 60, Train Loss(Total, MLM, Entailment, Claim): (1.842, 1.828, 0.014, 1.451), Train Accuracy (MLM, Entailment, Claim): (0.629, 0.988, 0.352), Val Loss: 6.356, Va

Saved BERT model checkpoint!


Epoch 61, Train Loss(Total, MLM, Entailment, Claim): (1.874, 1.837, 0.037, 1.394), Train Accuracy (MLM, Entailment, Claim): (0.629, 0.988, 0.351), Val Loss: 6.295, Val Accuracy (MLM, Entailment, Claim): (0.388, 0.884, 0.310): 100%|██████████| 1238/1238 [03:51<00:00,  5.36it/s]
Epoch 62, Train Loss(Total, MLM, Entailment, Claim): (1.782, 1.757, 0.025, 1.382), Train Accuracy (MLM, Entailment, Claim): (0.629, 0.988, 0.350), Val Loss: 6.600, Val Accuracy (MLM, Entailment, Claim): (0.391, 0.839, 0.276): 100%|██████████| 1238/1238 [03:51<00:00,  5.34it/s]
Epoch 63, Train Loss(Total, MLM, Entailment, Claim): (1.735, 1.727, 0.008, 1.353), Train Accuracy (MLM, Entailment, Claim): (0.629, 0.989, 0.349), Val Loss: 6.454, Val Accuracy (MLM, Entailment, Claim): (0.395, 0.857, 0.308): 100%|██████████| 1238/1238 [03:51<00:00,  5.35it/s]
Epoch 64, Train Loss(Total, MLM, Entailment, Claim): (1.711, 1.694, 0.016, 1.415), Train Accuracy (MLM, Entailment, Claim): (0.630, 0.989, 0.348), Val Loss: 6.334, Va

Saved BERT model checkpoint!


Epoch 65, Train Loss(Total, MLM, Entailment, Claim): (1.629, 1.627, 0.001, 1.437), Train Accuracy (MLM, Entailment, Claim): (0.631, 0.989, 0.348), Val Loss: 6.468, Val Accuracy (MLM, Entailment, Claim): (0.400, 0.873, 0.301): 100%|██████████| 1238/1238 [03:50<00:00,  5.36it/s]
Epoch 66, Train Loss(Total, MLM, Entailment, Claim): (1.605, 1.604, 0.002, 1.429), Train Accuracy (MLM, Entailment, Claim): (0.631, 0.989, 0.347), Val Loss: 6.462, Val Accuracy (MLM, Entailment, Claim): (0.401, 0.876, 0.285): 100%|██████████| 1238/1238 [03:51<00:00,  5.35it/s]
Epoch 67, Train Loss(Total, MLM, Entailment, Claim): (1.669, 1.668, 0.001, 1.411), Train Accuracy (MLM, Entailment, Claim): (0.632, 0.989, 0.346), Val Loss: 6.360, Val Accuracy (MLM, Entailment, Claim): (0.400, 0.877, 0.323): 100%|██████████| 1238/1238 [03:51<00:00,  5.35it/s]
Epoch 68, Train Loss(Total, MLM, Entailment, Claim): (1.653, 1.649, 0.004, 1.376), Train Accuracy (MLM, Entailment, Claim): (0.633, 0.989, 0.345), Val Loss: 6.447, Va

Saved BERT model checkpoint!


Epoch 69, Train Loss(Total, MLM, Entailment, Claim): (1.664, 1.654, 0.010, 1.321), Train Accuracy (MLM, Entailment, Claim): (0.633, 0.989, 0.345), Val Loss: 6.263, Val Accuracy (MLM, Entailment, Claim): (0.397, 0.882, 0.474): 100%|██████████| 1238/1238 [03:50<00:00,  5.37it/s]
Epoch 70, Train Loss(Total, MLM, Entailment, Claim): (1.675, 1.669, 0.005, 1.297), Train Accuracy (MLM, Entailment, Claim): (0.633, 0.989, 0.346), Val Loss: 6.301, Val Accuracy (MLM, Entailment, Claim): (0.390, 0.878, 0.444): 100%|██████████| 1238/1238 [03:51<00:00,  5.35it/s]
Epoch 71, Train Loss(Total, MLM, Entailment, Claim): (1.769, 1.741, 0.027, 1.341), Train Accuracy (MLM, Entailment, Claim): (0.634, 0.990, 0.347), Val Loss: 6.472, Val Accuracy (MLM, Entailment, Claim): (0.392, 0.854, 0.430): 100%|██████████| 1238/1238 [03:50<00:00,  5.38it/s]
Epoch 72, Train Loss(Total, MLM, Entailment, Claim): (1.737, 1.723, 0.014, 1.424), Train Accuracy (MLM, Entailment, Claim): (0.634, 0.990, 0.346), Val Loss: 6.374, Va

Saved BERT model checkpoint!


Epoch 73, Train Loss(Total, MLM, Entailment, Claim): (1.740, 1.728, 0.012, 1.411), Train Accuracy (MLM, Entailment, Claim): (0.634, 0.990, 0.346), Val Loss: 6.313, Val Accuracy (MLM, Entailment, Claim): (0.392, 0.869, 0.372): 100%|██████████| 1238/1238 [03:50<00:00,  5.37it/s]
Epoch 74, Train Loss(Total, MLM, Entailment, Claim): (1.599, 1.597, 0.002, 1.403), Train Accuracy (MLM, Entailment, Claim): (0.634, 0.990, 0.346), Val Loss: 6.299, Val Accuracy (MLM, Entailment, Claim): (0.402, 0.863, 0.343): 100%|██████████| 1238/1238 [03:52<00:00,  5.32it/s]
Epoch 75, Train Loss(Total, MLM, Entailment, Claim): (1.634, 1.632, 0.002, 1.413), Train Accuracy (MLM, Entailment, Claim): (0.635, 0.990, 0.345), Val Loss: 6.397, Val Accuracy (MLM, Entailment, Claim): (0.399, 0.876, 0.385): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 76, Train Loss(Total, MLM, Entailment, Claim): (1.667, 1.647, 0.020, 1.392), Train Accuracy (MLM, Entailment, Claim): (0.636, 0.990, 0.345), Val Loss: 6.407, Va

Saved BERT model checkpoint!


Epoch 77, Train Loss(Total, MLM, Entailment, Claim): (1.588, 1.584, 0.004, 1.385), Train Accuracy (MLM, Entailment, Claim): (0.636, 0.990, 0.345), Val Loss: 6.414, Val Accuracy (MLM, Entailment, Claim): (0.403, 0.874, 0.363): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 78, Train Loss(Total, MLM, Entailment, Claim): (1.634, 1.630, 0.004, 1.401), Train Accuracy (MLM, Entailment, Claim): (0.637, 0.990, 0.344), Val Loss: 6.523, Val Accuracy (MLM, Entailment, Claim): (0.402, 0.872, 0.316): 100%|██████████| 1238/1238 [03:53<00:00,  5.31it/s]
Epoch 79, Train Loss(Total, MLM, Entailment, Claim): (1.609, 1.604, 0.006, 1.394), Train Accuracy (MLM, Entailment, Claim): (0.637, 0.990, 0.344), Val Loss: 6.545, Val Accuracy (MLM, Entailment, Claim): (0.399, 0.854, 0.335): 100%|██████████| 1238/1238 [03:50<00:00,  5.38it/s]
Epoch 80, Train Loss(Total, MLM, Entailment, Claim): (1.781, 1.761, 0.020, 1.358), Train Accuracy (MLM, Entailment, Claim): (0.638, 0.990, 0.343), Val Loss: 6.418, Va

Saved BERT model checkpoint!


Epoch 81, Train Loss(Total, MLM, Entailment, Claim): (1.785, 1.726, 0.058, 1.471), Train Accuracy (MLM, Entailment, Claim): (0.638, 0.990, 0.343), Val Loss: 6.397, Val Accuracy (MLM, Entailment, Claim): (0.396, 0.843, 0.269): 100%|██████████| 1238/1238 [03:47<00:00,  5.44it/s]
Epoch 82, Train Loss(Total, MLM, Entailment, Claim): (1.765, 1.750, 0.016, 1.477), Train Accuracy (MLM, Entailment, Claim): (0.638, 0.990, 0.342), Val Loss: 6.469, Val Accuracy (MLM, Entailment, Claim): (0.397, 0.847, 0.222): 100%|██████████| 1238/1238 [03:46<00:00,  5.46it/s]
Epoch 83, Train Loss(Total, MLM, Entailment, Claim): (1.636, 1.625, 0.011, 1.393), Train Accuracy (MLM, Entailment, Claim): (0.638, 0.991, 0.341), Val Loss: 6.404, Val Accuracy (MLM, Entailment, Claim): (0.399, 0.858, 0.215): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 84, Train Loss(Total, MLM, Entailment, Claim): (1.617, 1.614, 0.002, 1.420), Train Accuracy (MLM, Entailment, Claim): (0.639, 0.991, 0.340), Val Loss: 6.385, Va

Saved BERT model checkpoint!


Epoch 85, Train Loss(Total, MLM, Entailment, Claim): (1.552, 1.549, 0.003, 1.434), Train Accuracy (MLM, Entailment, Claim): (0.639, 0.991, 0.339), Val Loss: 6.475, Val Accuracy (MLM, Entailment, Claim): (0.400, 0.870, 0.234): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 86, Train Loss(Total, MLM, Entailment, Claim): (1.604, 1.594, 0.010, 1.445), Train Accuracy (MLM, Entailment, Claim): (0.640, 0.991, 0.338), Val Loss: 6.455, Val Accuracy (MLM, Entailment, Claim): (0.410, 0.865, 0.233): 100%|██████████| 1238/1238 [04:02<00:00,  5.11it/s]
Epoch 87, Train Loss(Total, MLM, Entailment, Claim): (1.587, 1.586, 0.001, 1.440), Train Accuracy (MLM, Entailment, Claim): (0.640, 0.991, 0.337), Val Loss: 6.482, Val Accuracy (MLM, Entailment, Claim): (0.401, 0.876, 0.232): 100%|██████████| 1238/1238 [04:10<00:00,  4.95it/s]
Epoch 88, Train Loss(Total, MLM, Entailment, Claim): (1.537, 1.537, 0.001, 1.429), Train Accuracy (MLM, Entailment, Claim): (0.641, 0.991, 0.336), Val Loss: 6.389, Va

Saved BERT model checkpoint!


Epoch 89, Train Loss(Total, MLM, Entailment, Claim): (1.564, 1.560, 0.005, 1.443), Train Accuracy (MLM, Entailment, Claim): (0.641, 0.991, 0.336), Val Loss: 6.641, Val Accuracy (MLM, Entailment, Claim): (0.405, 0.863, 0.256): 100%|██████████| 1238/1238 [04:04<00:00,  5.07it/s]
Epoch 90, Train Loss(Total, MLM, Entailment, Claim): (1.728, 1.720, 0.008, 1.388), Train Accuracy (MLM, Entailment, Claim): (0.641, 0.991, 0.335), Val Loss: 6.415, Val Accuracy (MLM, Entailment, Claim): (0.403, 0.876, 0.255): 100%|██████████| 1238/1238 [04:23<00:00,  4.71it/s]
Epoch 91, Train Loss(Total, MLM, Entailment, Claim): (1.700, 1.691, 0.009, 1.326), Train Accuracy (MLM, Entailment, Claim): (0.642, 0.991, 0.335), Val Loss: 6.426, Val Accuracy (MLM, Entailment, Claim): (0.395, 0.860, 0.383): 100%|██████████| 1238/1238 [04:25<00:00,  4.66it/s]
Epoch 92, Train Loss(Total, MLM, Entailment, Claim): (1.659, 1.646, 0.013, 1.335), Train Accuracy (MLM, Entailment, Claim): (0.642, 0.991, 0.335), Val Loss: 6.427, Va

Saved BERT model checkpoint!


Epoch 93, Train Loss(Total, MLM, Entailment, Claim): (1.567, 1.563, 0.005, 1.358), Train Accuracy (MLM, Entailment, Claim): (0.642, 0.991, 0.335), Val Loss: 6.387, Val Accuracy (MLM, Entailment, Claim): (0.407, 0.858, 0.391): 100%|██████████| 1238/1238 [04:39<00:00,  4.43it/s]
Epoch 94, Train Loss(Total, MLM, Entailment, Claim): (1.560, 1.551, 0.009, 1.422), Train Accuracy (MLM, Entailment, Claim): (0.643, 0.991, 0.335), Val Loss: 6.474, Val Accuracy (MLM, Entailment, Claim): (0.404, 0.866, 0.292): 100%|██████████| 1238/1238 [04:30<00:00,  4.58it/s]
Epoch 95, Train Loss(Total, MLM, Entailment, Claim): (1.523, 1.522, 0.001, 1.404), Train Accuracy (MLM, Entailment, Claim): (0.643, 0.991, 0.334), Val Loss: 6.478, Val Accuracy (MLM, Entailment, Claim): (0.402, 0.875, 0.328): 100%|██████████| 1238/1238 [04:09<00:00,  4.97it/s]
Epoch 96, Train Loss(Total, MLM, Entailment, Claim): (1.567, 1.566, 0.001, 1.375), Train Accuracy (MLM, Entailment, Claim): (0.644, 0.991, 0.334), Val Loss: 6.458, Va

Saved BERT model checkpoint!


Epoch 97, Train Loss(Total, MLM, Entailment, Claim): (1.493, 1.492, 0.001, 1.411), Train Accuracy (MLM, Entailment, Claim): (0.644, 0.992, 0.334), Val Loss: 6.408, Val Accuracy (MLM, Entailment, Claim): (0.407, 0.871, 0.306): 100%|██████████| 1238/1238 [04:20<00:00,  4.76it/s]
Epoch 98, Train Loss(Total, MLM, Entailment, Claim): (1.546, 1.545, 0.001, 1.425), Train Accuracy (MLM, Entailment, Claim): (0.644, 0.992, 0.333), Val Loss: 6.427, Val Accuracy (MLM, Entailment, Claim): (0.409, 0.876, 0.352): 100%|██████████| 1238/1238 [04:34<00:00,  4.51it/s]
Epoch 99, Train Loss(Total, MLM, Entailment, Claim): (1.525, 1.523, 0.003, 1.407), Train Accuracy (MLM, Entailment, Claim): (0.645, 0.992, 0.333), Val Loss: 6.316, Val Accuracy (MLM, Entailment, Claim): (0.404, 0.878, 0.331): 100%|██████████| 1238/1238 [04:30<00:00,  4.58it/s]
Epoch 100, Train Loss(Total, MLM, Entailment, Claim): (1.605, 1.594, 0.010, 1.409), Train Accuracy (MLM, Entailment, Claim): (0.645, 0.992, 0.333), Val Loss: 6.435, V

Saved BERT model checkpoint!


Epoch 101, Train Loss(Total, MLM, Entailment, Claim): (1.652, 1.634, 0.017, 1.372), Train Accuracy (MLM, Entailment, Claim): (0.645, 0.992, 0.333), Val Loss: 6.362, Val Accuracy (MLM, Entailment, Claim): (0.397, 0.871, 0.354): 100%|██████████| 1238/1238 [04:05<00:00,  5.04it/s]
Epoch 102, Train Loss(Total, MLM, Entailment, Claim): (1.701, 1.691, 0.010, 1.365), Train Accuracy (MLM, Entailment, Claim): (0.646, 0.992, 0.333), Val Loss: 6.455, Val Accuracy (MLM, Entailment, Claim): (0.400, 0.860, 0.334): 100%|██████████| 1238/1238 [04:06<00:00,  5.01it/s]
Epoch 103, Train Loss(Total, MLM, Entailment, Claim): (1.519, 1.515, 0.003, 1.406), Train Accuracy (MLM, Entailment, Claim): (0.646, 0.992, 0.332), Val Loss: 6.596, Val Accuracy (MLM, Entailment, Claim): (0.401, 0.864, 0.278): 100%|██████████| 1238/1238 [04:25<00:00,  4.67it/s]
Epoch 104, Train Loss(Total, MLM, Entailment, Claim): (1.567, 1.547, 0.020, 1.414), Train Accuracy (MLM, Entailment, Claim): (0.646, 0.992, 0.332), Val Loss: 6.390

Saved BERT model checkpoint!


Epoch 105, Train Loss(Total, MLM, Entailment, Claim): (1.487, 1.485, 0.002, 1.411), Train Accuracy (MLM, Entailment, Claim): (0.647, 0.992, 0.331), Val Loss: 6.455, Val Accuracy (MLM, Entailment, Claim): (0.406, 0.875, 0.255): 100%|██████████| 1238/1238 [04:28<00:00,  4.61it/s]
Epoch 106, Train Loss(Total, MLM, Entailment, Claim): (1.513, 1.510, 0.003, 1.406), Train Accuracy (MLM, Entailment, Claim): (0.647, 0.992, 0.331), Val Loss: 6.457, Val Accuracy (MLM, Entailment, Claim): (0.410, 0.868, 0.267): 100%|██████████| 1238/1238 [04:08<00:00,  4.97it/s]
Epoch 107, Train Loss(Total, MLM, Entailment, Claim): (1.478, 1.477, 0.001, 1.412), Train Accuracy (MLM, Entailment, Claim): (0.648, 0.992, 0.331), Val Loss: 6.471, Val Accuracy (MLM, Entailment, Claim): (0.410, 0.872, 0.267): 100%|██████████| 1238/1238 [04:28<00:00,  4.62it/s]
Epoch 108, Train Loss(Total, MLM, Entailment, Claim): (1.439, 1.437, 0.002, 1.424), Train Accuracy (MLM, Entailment, Claim): (0.648, 0.992, 0.330), Val Loss: 6.539

Saved BERT model checkpoint!


Epoch 109, Train Loss(Total, MLM, Entailment, Claim): (1.425, 1.420, 0.004, 1.385), Train Accuracy (MLM, Entailment, Claim): (0.648, 0.992, 0.330), Val Loss: 6.412, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.870, 0.262): 100%|██████████| 1238/1238 [04:29<00:00,  4.59it/s]
Epoch 110, Train Loss(Total, MLM, Entailment, Claim): (1.588, 1.574, 0.014, 1.383), Train Accuracy (MLM, Entailment, Claim): (0.649, 0.992, 0.330), Val Loss: 6.445, Val Accuracy (MLM, Entailment, Claim): (0.402, 0.865, 0.311): 100%|██████████| 1238/1238 [04:32<00:00,  4.55it/s]
Epoch 111, Train Loss(Total, MLM, Entailment, Claim): (1.589, 1.569, 0.020, 1.346), Train Accuracy (MLM, Entailment, Claim): (0.649, 0.992, 0.329), Val Loss: 6.391, Val Accuracy (MLM, Entailment, Claim): (0.399, 0.881, 0.291): 100%|██████████| 1238/1238 [04:15<00:00,  4.85it/s]
Epoch 112, Train Loss(Total, MLM, Entailment, Claim): (1.591, 1.580, 0.010, 1.362), Train Accuracy (MLM, Entailment, Claim): (0.649, 0.992, 0.329), Val Loss: 6.642

Saved BERT model checkpoint!


Epoch 113, Train Loss(Total, MLM, Entailment, Claim): (1.500, 1.480, 0.020, 1.401), Train Accuracy (MLM, Entailment, Claim): (0.649, 0.992, 0.329), Val Loss: 6.505, Val Accuracy (MLM, Entailment, Claim): (0.403, 0.872, 0.254): 100%|██████████| 1238/1238 [04:33<00:00,  4.53it/s]
Epoch 114, Train Loss(Total, MLM, Entailment, Claim): (1.468, 1.467, 0.001, 1.383), Train Accuracy (MLM, Entailment, Claim): (0.650, 0.992, 0.328), Val Loss: 6.416, Val Accuracy (MLM, Entailment, Claim): (0.405, 0.874, 0.243): 100%|██████████| 1238/1238 [04:10<00:00,  4.94it/s]
Epoch 115, Train Loss(Total, MLM, Entailment, Claim): (1.446, 1.444, 0.001, 1.384), Train Accuracy (MLM, Entailment, Claim): (0.650, 0.992, 0.328), Val Loss: 6.496, Val Accuracy (MLM, Entailment, Claim): (0.410, 0.863, 0.264): 100%|██████████| 1238/1238 [04:03<00:00,  5.08it/s]
Epoch 116, Train Loss(Total, MLM, Entailment, Claim): (1.460, 1.460, 0.001, 1.416), Train Accuracy (MLM, Entailment, Claim): (0.651, 0.993, 0.327), Val Loss: 6.516

Saved BERT model checkpoint!


Epoch 117, Train Loss(Total, MLM, Entailment, Claim): (1.430, 1.421, 0.009, 1.410), Train Accuracy (MLM, Entailment, Claim): (0.651, 0.993, 0.327), Val Loss: 6.572, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.865, 0.261): 100%|██████████| 1238/1238 [04:03<00:00,  5.09it/s]
Epoch 118, Train Loss(Total, MLM, Entailment, Claim): (1.449, 1.442, 0.007, 1.375), Train Accuracy (MLM, Entailment, Claim): (0.651, 0.993, 0.326), Val Loss: 6.405, Val Accuracy (MLM, Entailment, Claim): (0.409, 0.878, 0.238): 100%|██████████| 1238/1238 [04:01<00:00,  5.12it/s]
Epoch 119, Train Loss(Total, MLM, Entailment, Claim): (1.512, 1.499, 0.013, 1.352), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.993, 0.326), Val Loss: 6.439, Val Accuracy (MLM, Entailment, Claim): (0.406, 0.860, 0.320): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 120, Train Loss(Total, MLM, Entailment, Claim): (1.564, 1.556, 0.008, 1.387), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.993, 0.326), Val Loss: 6.549

Saved BERT model checkpoint!


Epoch 121, Train Loss(Total, MLM, Entailment, Claim): (1.597, 1.593, 0.005, 1.331), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.993, 0.326), Val Loss: 6.438, Val Accuracy (MLM, Entailment, Claim): (0.403, 0.855, 0.332): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 122, Train Loss(Total, MLM, Entailment, Claim): (1.533, 1.519, 0.014, 1.325), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.993, 0.326), Val Loss: 6.496, Val Accuracy (MLM, Entailment, Claim): (0.399, 0.868, 0.304): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 123, Train Loss(Total, MLM, Entailment, Claim): (1.546, 1.541, 0.005, 1.307), Train Accuracy (MLM, Entailment, Claim): (0.653, 0.993, 0.326), Val Loss: 6.375, Val Accuracy (MLM, Entailment, Claim): (0.407, 0.862, 0.328): 100%|██████████| 1238/1238 [04:05<00:00,  5.05it/s]
Epoch 124, Train Loss(Total, MLM, Entailment, Claim): (1.430, 1.428, 0.002, 1.391), Train Accuracy (MLM, Entailment, Claim): (0.653, 0.993, 0.325), Val Loss: 6.493

Saved BERT model checkpoint!


Epoch 125, Train Loss(Total, MLM, Entailment, Claim): (1.442, 1.439, 0.003, 1.402), Train Accuracy (MLM, Entailment, Claim): (0.653, 0.993, 0.325), Val Loss: 6.525, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.870, 0.291): 100%|██████████| 1238/1238 [04:02<00:00,  5.10it/s]
Epoch 126, Train Loss(Total, MLM, Entailment, Claim): (1.435, 1.434, 0.001, 1.385), Train Accuracy (MLM, Entailment, Claim): (0.654, 0.993, 0.325), Val Loss: 6.511, Val Accuracy (MLM, Entailment, Claim): (0.414, 0.869, 0.294): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 127, Train Loss(Total, MLM, Entailment, Claim): (1.465, 1.463, 0.003, 1.402), Train Accuracy (MLM, Entailment, Claim): (0.654, 0.993, 0.325), Val Loss: 6.506, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.879, 0.279): 100%|██████████| 1238/1238 [04:00<00:00,  5.16it/s]
Epoch 128, Train Loss(Total, MLM, Entailment, Claim): (1.440, 1.439, 0.001, 1.379), Train Accuracy (MLM, Entailment, Claim): (0.655, 0.993, 0.324), Val Loss: 6.563

Saved BERT model checkpoint!


Epoch 129, Train Loss(Total, MLM, Entailment, Claim): (1.457, 1.450, 0.007, 1.424), Train Accuracy (MLM, Entailment, Claim): (0.655, 0.993, 0.324), Val Loss: 6.643, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.849, 0.243): 100%|██████████| 1238/1238 [04:00<00:00,  5.16it/s]
Epoch 130, Train Loss(Total, MLM, Entailment, Claim): (1.504, 1.493, 0.012, 1.332), Train Accuracy (MLM, Entailment, Claim): (0.655, 0.993, 0.324), Val Loss: 6.437, Val Accuracy (MLM, Entailment, Claim): (0.404, 0.863, 0.358): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 131, Train Loss(Total, MLM, Entailment, Claim): (1.581, 1.566, 0.015, 1.292), Train Accuracy (MLM, Entailment, Claim): (0.655, 0.993, 0.324), Val Loss: 6.399, Val Accuracy (MLM, Entailment, Claim): (0.397, 0.875, 0.359): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 132, Train Loss(Total, MLM, Entailment, Claim): (1.521, 1.509, 0.012, 1.346), Train Accuracy (MLM, Entailment, Claim): (0.656, 0.993, 0.325), Val Loss: 6.335

Saved BERT model checkpoint!


Epoch 133, Train Loss(Total, MLM, Entailment, Claim): (1.449, 1.443, 0.005, 1.306), Train Accuracy (MLM, Entailment, Claim): (0.656, 0.993, 0.325), Val Loss: 6.371, Val Accuracy (MLM, Entailment, Claim): (0.409, 0.871, 0.396): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 134, Train Loss(Total, MLM, Entailment, Claim): (1.434, 1.432, 0.002, 1.294), Train Accuracy (MLM, Entailment, Claim): (0.656, 0.993, 0.325), Val Loss: 6.316, Val Accuracy (MLM, Entailment, Claim): (0.414, 0.890, 0.436): 100%|██████████| 1238/1238 [03:59<00:00,  5.16it/s]
Epoch 135, Train Loss(Total, MLM, Entailment, Claim): (1.385, 1.380, 0.005, 1.315), Train Accuracy (MLM, Entailment, Claim): (0.657, 0.993, 0.326), Val Loss: 6.427, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.877, 0.422): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 136, Train Loss(Total, MLM, Entailment, Claim): (1.345, 1.341, 0.004, 1.307), Train Accuracy (MLM, Entailment, Claim): (0.657, 0.993, 0.326), Val Loss: 6.409

Saved BERT model checkpoint!


Epoch 137, Train Loss(Total, MLM, Entailment, Claim): (1.400, 1.399, 0.001, 1.307), Train Accuracy (MLM, Entailment, Claim): (0.657, 0.993, 0.326), Val Loss: 6.406, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.875, 0.398): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 138, Train Loss(Total, MLM, Entailment, Claim): (1.340, 1.339, 0.001, 1.285), Train Accuracy (MLM, Entailment, Claim): (0.658, 0.993, 0.327), Val Loss: 6.345, Val Accuracy (MLM, Entailment, Claim): (0.415, 0.885, 0.419): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 139, Train Loss(Total, MLM, Entailment, Claim): (1.429, 1.411, 0.018, 1.258), Train Accuracy (MLM, Entailment, Claim): (0.658, 0.993, 0.327), Val Loss: 6.593, Val Accuracy (MLM, Entailment, Claim): (0.412, 0.845, 0.401): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 140, Train Loss(Total, MLM, Entailment, Claim): (1.517, 1.500, 0.017, 1.255), Train Accuracy (MLM, Entailment, Claim): (0.658, 0.993, 0.328), Val Loss: 6.520

Saved BERT model checkpoint!


Epoch 141, Train Loss(Total, MLM, Entailment, Claim): (1.555, 1.549, 0.006, 1.287), Train Accuracy (MLM, Entailment, Claim): (0.658, 0.993, 0.329), Val Loss: 6.672, Val Accuracy (MLM, Entailment, Claim): (0.407, 0.862, 0.419): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 142, Train Loss(Total, MLM, Entailment, Claim): (1.554, 1.545, 0.010, 1.227), Train Accuracy (MLM, Entailment, Claim): (0.659, 0.993, 0.329), Val Loss: 6.411, Val Accuracy (MLM, Entailment, Claim): (0.404, 0.849, 0.409): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 143, Train Loss(Total, MLM, Entailment, Claim): (1.474, 1.468, 0.006, 1.271), Train Accuracy (MLM, Entailment, Claim): (0.659, 0.994, 0.330), Val Loss: 6.342, Val Accuracy (MLM, Entailment, Claim): (0.410, 0.865, 0.546): 100%|██████████| 1238/1238 [03:59<00:00,  5.16it/s]
Epoch 144, Train Loss(Total, MLM, Entailment, Claim): (1.307, 1.305, 0.002, 1.278), Train Accuracy (MLM, Entailment, Claim): (0.659, 0.994, 0.331), Val Loss: 6.345

Saved BERT model checkpoint!


Epoch 145, Train Loss(Total, MLM, Entailment, Claim): (1.373, 1.369, 0.004, 1.305), Train Accuracy (MLM, Entailment, Claim): (0.660, 0.994, 0.331), Val Loss: 6.330, Val Accuracy (MLM, Entailment, Claim): (0.414, 0.872, 0.463): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 146, Train Loss(Total, MLM, Entailment, Claim): (1.388, 1.387, 0.002, 1.302), Train Accuracy (MLM, Entailment, Claim): (0.660, 0.994, 0.332), Val Loss: 6.408, Val Accuracy (MLM, Entailment, Claim): (0.415, 0.864, 0.448): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 147, Train Loss(Total, MLM, Entailment, Claim): (1.404, 1.398, 0.005, 1.286), Train Accuracy (MLM, Entailment, Claim): (0.660, 0.994, 0.332), Val Loss: 6.411, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.868, 0.436): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 148, Train Loss(Total, MLM, Entailment, Claim): (1.320, 1.316, 0.004, 1.254), Train Accuracy (MLM, Entailment, Claim): (0.661, 0.994, 0.333), Val Loss: 6.366

Saved BERT model checkpoint!


Epoch 149, Train Loss(Total, MLM, Entailment, Claim): (1.385, 1.379, 0.006, 1.284), Train Accuracy (MLM, Entailment, Claim): (0.661, 0.994, 0.333), Val Loss: 6.407, Val Accuracy (MLM, Entailment, Claim): (0.415, 0.865, 0.432): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 150, Train Loss(Total, MLM, Entailment, Claim): (1.476, 1.471, 0.006, 1.282), Train Accuracy (MLM, Entailment, Claim): (0.661, 0.994, 0.333), Val Loss: 6.417, Val Accuracy (MLM, Entailment, Claim): (0.404, 0.869, 0.405): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 151, Train Loss(Total, MLM, Entailment, Claim): (1.488, 1.474, 0.013, 1.364), Train Accuracy (MLM, Entailment, Claim): (0.661, 0.994, 0.334), Val Loss: 6.423, Val Accuracy (MLM, Entailment, Claim): (0.405, 0.870, 0.388): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 152, Train Loss(Total, MLM, Entailment, Claim): (1.447, 1.436, 0.011, 1.357), Train Accuracy (MLM, Entailment, Claim): (0.662, 0.994, 0.334), Val Loss: 6.437

Saved BERT model checkpoint!


Epoch 153, Train Loss(Total, MLM, Entailment, Claim): (1.383, 1.376, 0.007, 1.374), Train Accuracy (MLM, Entailment, Claim): (0.662, 0.994, 0.333), Val Loss: 6.494, Val Accuracy (MLM, Entailment, Claim): (0.408, 0.860, 0.272): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 154, Train Loss(Total, MLM, Entailment, Claim): (1.390, 1.389, 0.001, 1.402), Train Accuracy (MLM, Entailment, Claim): (0.662, 0.994, 0.333), Val Loss: 6.367, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.879, 0.264): 100%|██████████| 1238/1238 [03:59<00:00,  5.18it/s]
Epoch 155, Train Loss(Total, MLM, Entailment, Claim): (1.277, 1.273, 0.003, 1.389), Train Accuracy (MLM, Entailment, Claim): (0.662, 0.994, 0.333), Val Loss: 6.562, Val Accuracy (MLM, Entailment, Claim): (0.412, 0.872, 0.255): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 156, Train Loss(Total, MLM, Entailment, Claim): (1.339, 1.333, 0.007, 1.443), Train Accuracy (MLM, Entailment, Claim): (0.663, 0.994, 0.332), Val Loss: 6.616

Saved BERT model checkpoint!


Epoch 157, Train Loss(Total, MLM, Entailment, Claim): (1.332, 1.331, 0.001, 1.413), Train Accuracy (MLM, Entailment, Claim): (0.663, 0.994, 0.332), Val Loss: 6.640, Val Accuracy (MLM, Entailment, Claim): (0.410, 0.864, 0.242): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 158, Train Loss(Total, MLM, Entailment, Claim): (1.329, 1.328, 0.001, 1.450), Train Accuracy (MLM, Entailment, Claim): (0.663, 0.994, 0.332), Val Loss: 6.531, Val Accuracy (MLM, Entailment, Claim): (0.416, 0.876, 0.282): 100%|██████████| 1238/1238 [03:59<00:00,  5.18it/s]
Epoch 159, Train Loss(Total, MLM, Entailment, Claim): (1.374, 1.371, 0.003, 1.428), Train Accuracy (MLM, Entailment, Claim): (0.664, 0.994, 0.331), Val Loss: 6.514, Val Accuracy (MLM, Entailment, Claim): (0.415, 0.871, 0.270): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 160, Train Loss(Total, MLM, Entailment, Claim): (1.428, 1.423, 0.004, 1.396), Train Accuracy (MLM, Entailment, Claim): (0.664, 0.994, 0.331), Val Loss: 6.551

Saved BERT model checkpoint!


Epoch 161, Train Loss(Total, MLM, Entailment, Claim): (1.483, 1.468, 0.015, 1.276), Train Accuracy (MLM, Entailment, Claim): (0.664, 0.994, 0.331), Val Loss: 6.153, Val Accuracy (MLM, Entailment, Claim): (0.406, 0.875, 0.565): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 162, Train Loss(Total, MLM, Entailment, Claim): (1.388, 1.385, 0.003, 1.308), Train Accuracy (MLM, Entailment, Claim): (0.664, 0.994, 0.331), Val Loss: 6.345, Val Accuracy (MLM, Entailment, Claim): (0.406, 0.865, 0.368): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 163, Train Loss(Total, MLM, Entailment, Claim): (1.403, 1.395, 0.008, 1.341), Train Accuracy (MLM, Entailment, Claim): (0.665, 0.994, 0.331), Val Loss: 6.349, Val Accuracy (MLM, Entailment, Claim): (0.409, 0.869, 0.394): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 164, Train Loss(Total, MLM, Entailment, Claim): (1.365, 1.361, 0.004, 1.390), Train Accuracy (MLM, Entailment, Claim): (0.665, 0.994, 0.331), Val Loss: 6.359

Saved BERT model checkpoint!


Epoch 165, Train Loss(Total, MLM, Entailment, Claim): (1.318, 1.317, 0.001, 1.397), Train Accuracy (MLM, Entailment, Claim): (0.665, 0.994, 0.331), Val Loss: 6.585, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.872, 0.315): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 166, Train Loss(Total, MLM, Entailment, Claim): (1.317, 1.316, 0.001, 1.342), Train Accuracy (MLM, Entailment, Claim): (0.665, 0.994, 0.331), Val Loss: 6.449, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.877, 0.314): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 167, Train Loss(Total, MLM, Entailment, Claim): (1.364, 1.363, 0.001, 1.371), Train Accuracy (MLM, Entailment, Claim): (0.666, 0.994, 0.331), Val Loss: 6.529, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.876, 0.323): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 168, Train Loss(Total, MLM, Entailment, Claim): (1.302, 1.301, 0.000, 1.403), Train Accuracy (MLM, Entailment, Claim): (0.666, 0.994, 0.331), Val Loss: 6.467

Saved BERT model checkpoint!


Epoch 169, Train Loss(Total, MLM, Entailment, Claim): (1.307, 1.304, 0.003, 1.354), Train Accuracy (MLM, Entailment, Claim): (0.666, 0.994, 0.330), Val Loss: 6.526, Val Accuracy (MLM, Entailment, Claim): (0.416, 0.877, 0.281): 100%|██████████| 1238/1238 [03:57<00:00,  5.20it/s]
Epoch 170, Train Loss(Total, MLM, Entailment, Claim): (1.411, 1.410, 0.001, 1.349), Train Accuracy (MLM, Entailment, Claim): (0.667, 0.994, 0.330), Val Loss: 6.349, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.875, 0.262): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 171, Train Loss(Total, MLM, Entailment, Claim): (1.452, 1.437, 0.015, 1.328), Train Accuracy (MLM, Entailment, Claim): (0.667, 0.994, 0.330), Val Loss: 6.631, Val Accuracy (MLM, Entailment, Claim): (0.407, 0.872, 0.263): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 172, Train Loss(Total, MLM, Entailment, Claim): (1.408, 1.392, 0.016, 1.284), Train Accuracy (MLM, Entailment, Claim): (0.667, 0.994, 0.330), Val Loss: 6.402

Saved BERT model checkpoint!


Epoch 173, Train Loss(Total, MLM, Entailment, Claim): (1.320, 1.312, 0.008, 1.292), Train Accuracy (MLM, Entailment, Claim): (0.667, 0.994, 0.331), Val Loss: 6.347, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.863, 0.423): 100%|██████████| 1238/1238 [03:57<00:00,  5.20it/s]
Epoch 174, Train Loss(Total, MLM, Entailment, Claim): (1.322, 1.318, 0.004, 1.317), Train Accuracy (MLM, Entailment, Claim): (0.667, 0.994, 0.331), Val Loss: 6.393, Val Accuracy (MLM, Entailment, Claim): (0.412, 0.885, 0.441): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 175, Train Loss(Total, MLM, Entailment, Claim): (1.275, 1.274, 0.001, 1.297), Train Accuracy (MLM, Entailment, Claim): (0.668, 0.994, 0.331), Val Loss: 6.473, Val Accuracy (MLM, Entailment, Claim): (0.420, 0.865, 0.429): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 176, Train Loss(Total, MLM, Entailment, Claim): (1.308, 1.305, 0.003, 1.281), Train Accuracy (MLM, Entailment, Claim): (0.668, 0.994, 0.332), Val Loss: 6.451

Saved BERT model checkpoint!


Epoch 177, Train Loss(Total, MLM, Entailment, Claim): (1.256, 1.255, 0.001, 1.309), Train Accuracy (MLM, Entailment, Claim): (0.668, 0.994, 0.332), Val Loss: 6.466, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.867, 0.420): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 178, Train Loss(Total, MLM, Entailment, Claim): (1.223, 1.219, 0.004, 1.303), Train Accuracy (MLM, Entailment, Claim): (0.669, 0.994, 0.332), Val Loss: 6.324, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.885, 0.443): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 179, Train Loss(Total, MLM, Entailment, Claim): (1.269, 1.268, 0.001, 1.350), Train Accuracy (MLM, Entailment, Claim): (0.669, 0.994, 0.332), Val Loss: 6.447, Val Accuracy (MLM, Entailment, Claim): (0.412, 0.887, 0.324): 100%|██████████| 1238/1238 [03:57<00:00,  5.20it/s]
Epoch 180, Train Loss(Total, MLM, Entailment, Claim): (1.325, 1.319, 0.006, 1.335), Train Accuracy (MLM, Entailment, Claim): (0.669, 0.994, 0.332), Val Loss: 6.537

Saved BERT model checkpoint!


Epoch 181, Train Loss(Total, MLM, Entailment, Claim): (1.467, 1.458, 0.009, 1.360), Train Accuracy (MLM, Entailment, Claim): (0.669, 0.995, 0.332), Val Loss: 6.511, Val Accuracy (MLM, Entailment, Claim): (0.407, 0.864, 0.298): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 182, Train Loss(Total, MLM, Entailment, Claim): (1.403, 1.394, 0.009, 1.311), Train Accuracy (MLM, Entailment, Claim): (0.670, 0.995, 0.332), Val Loss: 6.281, Val Accuracy (MLM, Entailment, Claim): (0.416, 0.870, 0.335): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 183, Train Loss(Total, MLM, Entailment, Claim): (1.363, 1.358, 0.005, 1.344), Train Accuracy (MLM, Entailment, Claim): (0.670, 0.995, 0.332), Val Loss: 6.565, Val Accuracy (MLM, Entailment, Claim): (0.411, 0.873, 0.336): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 184, Train Loss(Total, MLM, Entailment, Claim): (1.243, 1.242, 0.001, 1.389), Train Accuracy (MLM, Entailment, Claim): (0.670, 0.995, 0.332), Val Loss: 6.505

Saved BERT model checkpoint!


Epoch 185, Train Loss(Total, MLM, Entailment, Claim): (1.255, 1.254, 0.002, 1.375), Train Accuracy (MLM, Entailment, Claim): (0.670, 0.995, 0.332), Val Loss: 6.551, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.879, 0.290): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 186, Train Loss(Total, MLM, Entailment, Claim): (1.295, 1.294, 0.001, 1.375), Train Accuracy (MLM, Entailment, Claim): (0.671, 0.995, 0.332), Val Loss: 6.533, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.876, 0.301): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 187, Train Loss(Total, MLM, Entailment, Claim): (1.262, 1.261, 0.001, 1.370), Train Accuracy (MLM, Entailment, Claim): (0.671, 0.995, 0.332), Val Loss: 6.504, Val Accuracy (MLM, Entailment, Claim): (0.420, 0.873, 0.278): 100%|██████████| 1238/1238 [04:03<00:00,  5.09it/s]
Epoch 188, Train Loss(Total, MLM, Entailment, Claim): (1.359, 1.351, 0.008, 1.358), Train Accuracy (MLM, Entailment, Claim): (0.671, 0.995, 0.332), Val Loss: 6.515

Saved BERT model checkpoint!


Epoch 189, Train Loss(Total, MLM, Entailment, Claim): (1.330, 1.329, 0.002, 1.418), Train Accuracy (MLM, Entailment, Claim): (0.671, 0.995, 0.331), Val Loss: 6.554, Val Accuracy (MLM, Entailment, Claim): (0.416, 0.875, 0.273): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 190, Train Loss(Total, MLM, Entailment, Claim): (1.364, 1.357, 0.007, 1.359), Train Accuracy (MLM, Entailment, Claim): (0.672, 0.995, 0.331), Val Loss: 6.465, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.876, 0.229): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 191, Train Loss(Total, MLM, Entailment, Claim): (1.390, 1.377, 0.012, 1.344), Train Accuracy (MLM, Entailment, Claim): (0.672, 0.995, 0.331), Val Loss: 6.530, Val Accuracy (MLM, Entailment, Claim): (0.408, 0.857, 0.285): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 192, Train Loss(Total, MLM, Entailment, Claim): (1.373, 1.362, 0.011, 1.359), Train Accuracy (MLM, Entailment, Claim): (0.672, 0.995, 0.331), Val Loss: 6.408

Saved BERT model checkpoint!


Epoch 193, Train Loss(Total, MLM, Entailment, Claim): (1.332, 1.326, 0.007, 1.387), Train Accuracy (MLM, Entailment, Claim): (0.672, 0.995, 0.331), Val Loss: 6.373, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.871, 0.293): 100%|██████████| 1238/1238 [03:55<00:00,  5.25it/s]
Epoch 194, Train Loss(Total, MLM, Entailment, Claim): (1.237, 1.237, 0.000, 1.373), Train Accuracy (MLM, Entailment, Claim): (0.672, 0.995, 0.331), Val Loss: 6.432, Val Accuracy (MLM, Entailment, Claim): (0.420, 0.872, 0.311): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 195, Train Loss(Total, MLM, Entailment, Claim): (1.305, 1.303, 0.002, 1.357), Train Accuracy (MLM, Entailment, Claim): (0.673, 0.995, 0.331), Val Loss: 6.571, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.870, 0.300): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 196, Train Loss(Total, MLM, Entailment, Claim): (1.215, 1.214, 0.001, 1.379), Train Accuracy (MLM, Entailment, Claim): (0.673, 0.995, 0.331), Val Loss: 6.514

Saved BERT model checkpoint!


Epoch 197, Train Loss(Total, MLM, Entailment, Claim): (1.237, 1.237, 0.000, 1.346), Train Accuracy (MLM, Entailment, Claim): (0.673, 0.995, 0.331), Val Loss: 6.507, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.867, 0.314): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 198, Train Loss(Total, MLM, Entailment, Claim): (1.220, 1.218, 0.001, 1.420), Train Accuracy (MLM, Entailment, Claim): (0.674, 0.995, 0.330), Val Loss: 6.529, Val Accuracy (MLM, Entailment, Claim): (0.416, 0.885, 0.293): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 199, Train Loss(Total, MLM, Entailment, Claim): (1.247, 1.246, 0.002, 1.435), Train Accuracy (MLM, Entailment, Claim): (0.674, 0.995, 0.330), Val Loss: 6.753, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.855, 0.229): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 200, Train Loss(Total, MLM, Entailment, Claim): (1.280, 1.277, 0.002, 1.394), Train Accuracy (MLM, Entailment, Claim): (0.674, 0.995, 0.330), Val Loss: 6.395

Saved BERT model checkpoint!
Training done!


In [10]:
train(model, 150, train_dataloader, val_dataloader, optimizer, grad_accumulation_steps=20, val_every=1000, save_every=4, device=device, log_metrics=log_metrics, mixed_precision=True, checkpoint_name="BERT_multitask", scheduler=scheduler, include_claim_loss=False)

Epoch 1, Train Loss(Total, MLM, Entailment, Claim): (1.376, 1.368, 0.008, 1.306), Train Accuracy (MLM, Entailment, Claim): (0.707, 0.997, 0.288), Val Loss: 6.558, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.872, 0.278): 100%|██████████| 1238/1238 [03:54<00:00,  5.27it/s]
Epoch 2, Train Loss(Total, MLM, Entailment, Claim): (1.319, 1.314, 0.005, 1.350), Train Accuracy (MLM, Entailment, Claim): (0.706, 0.997, 0.297), Val Loss: 6.599, Val Accuracy (MLM, Entailment, Claim): (0.413, 0.855, 0.332): 100%|██████████| 1238/1238 [03:55<00:00,  5.27it/s]
Epoch 3, Train Loss(Total, MLM, Entailment, Claim): (1.305, 1.303, 0.002, 1.397), Train Accuracy (MLM, Entailment, Claim): (0.709, 0.997, 0.294), Val Loss: 6.578, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.871, 0.290): 100%|██████████| 1238/1238 [03:57<00:00,  5.22it/s]
Epoch 4, Train Loss(Total, MLM, Entailment, Claim): (1.266, 1.264, 0.002, 1.353), Train Accuracy (MLM, Entailment, Claim): (0.713, 0.998, 0.297), Val Loss: 6.434, Val Ac

Saved BERT model checkpoint!


Epoch 5, Train Loss(Total, MLM, Entailment, Claim): (1.227, 1.227, 0.001, 1.360), Train Accuracy (MLM, Entailment, Claim): (0.716, 0.998, 0.301), Val Loss: 6.525, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.864, 0.303): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 6, Train Loss(Total, MLM, Entailment, Claim): (1.260, 1.259, 0.001, 1.351), Train Accuracy (MLM, Entailment, Claim): (0.719, 0.998, 0.303), Val Loss: 6.553, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.872, 0.302): 100%|██████████| 1238/1238 [03:56<00:00,  5.24it/s]
Epoch 7, Train Loss(Total, MLM, Entailment, Claim): (1.227, 1.223, 0.003, 1.359), Train Accuracy (MLM, Entailment, Claim): (0.721, 0.998, 0.303), Val Loss: 6.562, Val Accuracy (MLM, Entailment, Claim): (0.420, 0.871, 0.293): 100%|██████████| 1238/1238 [03:55<00:00,  5.25it/s]
Epoch 8, Train Loss(Total, MLM, Entailment, Claim): (1.277, 1.275, 0.003, 1.384), Train Accuracy (MLM, Entailment, Claim): (0.722, 0.999, 0.305), Val Loss: 6.542, Val Ac

Saved BERT model checkpoint!


Epoch 9, Train Loss(Total, MLM, Entailment, Claim): (1.271, 1.269, 0.001, 1.434), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.999, 0.303), Val Loss: 6.404, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.878, 0.317): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 10, Train Loss(Total, MLM, Entailment, Claim): (1.298, 1.291, 0.008, 1.375), Train Accuracy (MLM, Entailment, Claim): (0.722, 0.999, 0.300), Val Loss: 6.570, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.864, 0.295): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 11, Train Loss(Total, MLM, Entailment, Claim): (1.347, 1.323, 0.023, 1.426), Train Accuracy (MLM, Entailment, Claim): (0.721, 0.998, 0.298), Val Loss: 6.765, Val Accuracy (MLM, Entailment, Claim): (0.414, 0.860, 0.209): 100%|██████████| 1238/1238 [03:58<00:00,  5.18it/s]
Epoch 12, Train Loss(Total, MLM, Entailment, Claim): (1.347, 1.342, 0.005, 1.438), Train Accuracy (MLM, Entailment, Claim): (0.720, 0.998, 0.294), Val Loss: 6.706, Val

Saved BERT model checkpoint!


Epoch 13, Train Loss(Total, MLM, Entailment, Claim): (1.263, 1.258, 0.005, 1.393), Train Accuracy (MLM, Entailment, Claim): (0.720, 0.998, 0.291), Val Loss: 6.555, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.864, 0.236): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 14, Train Loss(Total, MLM, Entailment, Claim): (1.310, 1.306, 0.004, 1.445), Train Accuracy (MLM, Entailment, Claim): (0.720, 0.998, 0.288), Val Loss: 6.527, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.886, 0.222): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 15, Train Loss(Total, MLM, Entailment, Claim): (1.173, 1.172, 0.000, 1.435), Train Accuracy (MLM, Entailment, Claim): (0.721, 0.998, 0.284), Val Loss: 6.729, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.863, 0.215): 100%|██████████| 1238/1238 [03:58<00:00,  5.20it/s]
Epoch 16, Train Loss(Total, MLM, Entailment, Claim): (1.214, 1.212, 0.003, 1.443), Train Accuracy (MLM, Entailment, Claim): (0.722, 0.998, 0.281), Val Loss: 6.663, Va

Saved BERT model checkpoint!


Epoch 17, Train Loss(Total, MLM, Entailment, Claim): (1.186, 1.186, 0.000, 1.449), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.998, 0.278), Val Loss: 6.719, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.870, 0.197): 100%|██████████| 1238/1238 [03:56<00:00,  5.23it/s]
Epoch 18, Train Loss(Total, MLM, Entailment, Claim): (1.206, 1.206, 0.000, 1.439), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.998, 0.276), Val Loss: 6.650, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.892, 0.209): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 19, Train Loss(Total, MLM, Entailment, Claim): (1.251, 1.250, 0.001, 1.385), Train Accuracy (MLM, Entailment, Claim): (0.724, 0.999, 0.275), Val Loss: 6.701, Val Accuracy (MLM, Entailment, Claim): (0.420, 0.871, 0.260): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 20, Train Loss(Total, MLM, Entailment, Claim): (1.312, 1.310, 0.003, 1.407), Train Accuracy (MLM, Entailment, Claim): (0.724, 0.999, 0.275), Val Loss: 6.668, Va

Saved BERT model checkpoint!


Epoch 21, Train Loss(Total, MLM, Entailment, Claim): (1.353, 1.337, 0.016, 1.413), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.998, 0.272), Val Loss: 6.532, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.866, 0.209): 100%|██████████| 1238/1238 [03:57<00:00,  5.20it/s]
Epoch 22, Train Loss(Total, MLM, Entailment, Claim): (1.333, 1.326, 0.006, 1.459), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.998, 0.271), Val Loss: 6.678, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.854, 0.183): 100%|██████████| 1238/1238 [03:59<00:00,  5.16it/s]
Epoch 23, Train Loss(Total, MLM, Entailment, Claim): (1.226, 1.219, 0.007, 1.479), Train Accuracy (MLM, Entailment, Claim): (0.722, 0.998, 0.268), Val Loss: 6.776, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.851, 0.176): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 24, Train Loss(Total, MLM, Entailment, Claim): (1.245, 1.240, 0.004, 1.467), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.998, 0.266), Val Loss: 6.605, Va

Saved BERT model checkpoint!


Epoch 25, Train Loss(Total, MLM, Entailment, Claim): (1.182, 1.180, 0.001, 1.461), Train Accuracy (MLM, Entailment, Claim): (0.723, 0.998, 0.264), Val Loss: 6.770, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.863, 0.164): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 26, Train Loss(Total, MLM, Entailment, Claim): (1.253, 1.251, 0.002, 1.470), Train Accuracy (MLM, Entailment, Claim): (0.724, 0.998, 0.262), Val Loss: 6.771, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.866, 0.179): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 27, Train Loss(Total, MLM, Entailment, Claim): (1.188, 1.186, 0.002, 1.455), Train Accuracy (MLM, Entailment, Claim): (0.724, 0.999, 0.261), Val Loss: 6.697, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.871, 0.173): 100%|██████████| 1238/1238 [04:00<00:00,  5.16it/s]
Epoch 28, Train Loss(Total, MLM, Entailment, Claim): (1.226, 1.220, 0.006, 1.424), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.999, 0.260), Val Loss: 6.563, Va

Saved BERT model checkpoint!


Epoch 29, Train Loss(Total, MLM, Entailment, Claim): (1.160, 1.159, 0.001, 1.426), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.999, 0.259), Val Loss: 6.609, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.873, 0.231): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 30, Train Loss(Total, MLM, Entailment, Claim): (1.271, 1.262, 0.009, 1.373), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.999, 0.260), Val Loss: 6.736, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.849, 0.243): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 31, Train Loss(Total, MLM, Entailment, Claim): (1.349, 1.323, 0.026, 1.384), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.999, 0.261), Val Loss: 6.641, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.853, 0.260): 100%|██████████| 1238/1238 [03:59<00:00,  5.16it/s]
Epoch 32, Train Loss(Total, MLM, Entailment, Claim): (1.246, 1.242, 0.005, 1.403), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.998, 0.261), Val Loss: 6.469, Va

Saved BERT model checkpoint!


Epoch 33, Train Loss(Total, MLM, Entailment, Claim): (1.200, 1.197, 0.003, 1.452), Train Accuracy (MLM, Entailment, Claim): (0.724, 0.998, 0.261), Val Loss: 6.526, Val Accuracy (MLM, Entailment, Claim): (0.420, 0.863, 0.232): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 34, Train Loss(Total, MLM, Entailment, Claim): (1.190, 1.182, 0.008, 1.407), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.998, 0.261), Val Loss: 6.611, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.887, 0.242): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 35, Train Loss(Total, MLM, Entailment, Claim): (1.205, 1.205, 0.001, 1.412), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.998, 0.261), Val Loss: 6.719, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.870, 0.229): 100%|██████████| 1238/1238 [03:57<00:00,  5.21it/s]
Epoch 36, Train Loss(Total, MLM, Entailment, Claim): (1.118, 1.117, 0.002, 1.395), Train Accuracy (MLM, Entailment, Claim): (0.725, 0.999, 0.261), Val Loss: 6.673, Va

Saved BERT model checkpoint!


Epoch 37, Train Loss(Total, MLM, Entailment, Claim): (1.141, 1.140, 0.001, 1.416), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.999, 0.261), Val Loss: 6.697, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.863, 0.232): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 38, Train Loss(Total, MLM, Entailment, Claim): (1.181, 1.179, 0.002, 1.432), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.999, 0.262), Val Loss: 6.649, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.888, 0.223): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 39, Train Loss(Total, MLM, Entailment, Claim): (1.168, 1.166, 0.001, 1.412), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.999, 0.261), Val Loss: 6.732, Val Accuracy (MLM, Entailment, Claim): (0.428, 0.857, 0.255): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 40, Train Loss(Total, MLM, Entailment, Claim): (1.289, 1.276, 0.013, 1.381), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.999, 0.261), Val Loss: 6.731, Va

Saved BERT model checkpoint!


Epoch 41, Train Loss(Total, MLM, Entailment, Claim): (1.305, 1.300, 0.005, 1.413), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.999, 0.261), Val Loss: 6.655, Val Accuracy (MLM, Entailment, Claim): (0.417, 0.851, 0.253): 100%|██████████| 1238/1238 [03:59<00:00,  5.18it/s]
Epoch 42, Train Loss(Total, MLM, Entailment, Claim): (1.230, 1.227, 0.003, 1.419), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.998, 0.262), Val Loss: 6.378, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.884, 0.304): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 43, Train Loss(Total, MLM, Entailment, Claim): (1.203, 1.201, 0.002, 1.387), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.998, 0.262), Val Loss: 6.593, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.858, 0.256): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 44, Train Loss(Total, MLM, Entailment, Claim): (1.152, 1.151, 0.002, 1.456), Train Accuracy (MLM, Entailment, Claim): (0.726, 0.999, 0.262), Val Loss: 6.572, Va

Saved BERT model checkpoint!


Epoch 45, Train Loss(Total, MLM, Entailment, Claim): (1.136, 1.135, 0.001, 1.429), Train Accuracy (MLM, Entailment, Claim): (0.727, 0.999, 0.262), Val Loss: 6.747, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.861, 0.223): 100%|██████████| 1238/1238 [03:59<00:00,  5.16it/s]
Epoch 46, Train Loss(Total, MLM, Entailment, Claim): (1.186, 1.183, 0.003, 1.411), Train Accuracy (MLM, Entailment, Claim): (0.727, 0.999, 0.261), Val Loss: 6.734, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.874, 0.225): 100%|██████████| 1238/1238 [03:59<00:00,  5.18it/s]
Epoch 47, Train Loss(Total, MLM, Entailment, Claim): (1.130, 1.129, 0.001, 1.381), Train Accuracy (MLM, Entailment, Claim): (0.727, 0.999, 0.261), Val Loss: 6.702, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.875, 0.237): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 48, Train Loss(Total, MLM, Entailment, Claim): (1.185, 1.185, 0.000, 1.413), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.261), Val Loss: 6.641, Va

Saved BERT model checkpoint!


Epoch 49, Train Loss(Total, MLM, Entailment, Claim): (1.198, 1.197, 0.001, 1.384), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.262), Val Loss: 6.653, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.868, 0.292): 100%|██████████| 1238/1238 [04:04<00:00,  5.06it/s]
Epoch 50, Train Loss(Total, MLM, Entailment, Claim): (1.237, 1.225, 0.012, 1.443), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.262), Val Loss: 6.803, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.859, 0.224): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 51, Train Loss(Total, MLM, Entailment, Claim): (1.270, 1.266, 0.004, 1.382), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.263), Val Loss: 6.688, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.850, 0.268): 100%|██████████| 1238/1238 [03:59<00:00,  5.17it/s]
Epoch 52, Train Loss(Total, MLM, Entailment, Claim): (1.285, 1.267, 0.018, 1.435), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.263), Val Loss: 6.699, Va

Saved BERT model checkpoint!


Epoch 53, Train Loss(Total, MLM, Entailment, Claim): (1.241, 1.238, 0.003, 1.404), Train Accuracy (MLM, Entailment, Claim): (0.727, 0.999, 0.263), Val Loss: 6.755, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.869, 0.260): 100%|██████████| 1238/1238 [04:04<00:00,  5.07it/s]
Epoch 54, Train Loss(Total, MLM, Entailment, Claim): (1.148, 1.148, 0.000, 1.398), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.263), Val Loss: 6.681, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.872, 0.279): 100%|██████████| 1238/1238 [04:04<00:00,  5.06it/s]
Epoch 55, Train Loss(Total, MLM, Entailment, Claim): (1.169, 1.166, 0.003, 1.439), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.263), Val Loss: 6.777, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.865, 0.267): 100%|██████████| 1238/1238 [04:00<00:00,  5.15it/s]
Epoch 56, Train Loss(Total, MLM, Entailment, Claim): (1.122, 1.119, 0.003, 1.426), Train Accuracy (MLM, Entailment, Claim): (0.728, 0.999, 0.263), Val Loss: 6.742, Va

Saved BERT model checkpoint!


Epoch 57, Train Loss(Total, MLM, Entailment, Claim): (1.161, 1.161, 0.000, 1.427), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.696, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.872, 0.258): 100%|██████████| 1238/1238 [04:15<00:00,  4.85it/s]
Epoch 58, Train Loss(Total, MLM, Entailment, Claim): (1.123, 1.122, 0.001, 1.457), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.644, Val Accuracy (MLM, Entailment, Claim): (0.428, 0.874, 0.249): 100%|██████████| 1238/1238 [04:13<00:00,  4.89it/s]
Epoch 59, Train Loss(Total, MLM, Entailment, Claim): (1.194, 1.192, 0.003, 1.403), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.709, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.865, 0.253): 100%|██████████| 1238/1238 [04:19<00:00,  4.78it/s]
Epoch 60, Train Loss(Total, MLM, Entailment, Claim): (1.263, 1.261, 0.002, 1.410), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.593, Va

Saved BERT model checkpoint!


Epoch 61, Train Loss(Total, MLM, Entailment, Claim): (1.221, 1.193, 0.029, 1.415), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.627, Val Accuracy (MLM, Entailment, Claim): (0.412, 0.874, 0.270): 100%|██████████| 1238/1238 [04:21<00:00,  4.74it/s]
Epoch 62, Train Loss(Total, MLM, Entailment, Claim): (1.232, 1.224, 0.007, 1.332), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.387, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.872, 0.314): 100%|██████████| 1238/1238 [04:18<00:00,  4.79it/s]
Epoch 63, Train Loss(Total, MLM, Entailment, Claim): (1.199, 1.192, 0.007, 1.341), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.264), Val Loss: 6.661, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.875, 0.303): 100%|██████████| 1238/1238 [04:13<00:00,  4.88it/s]
Epoch 64, Train Loss(Total, MLM, Entailment, Claim): (1.141, 1.140, 0.001, 1.354), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.265), Val Loss: 6.474, Va

Saved BERT model checkpoint!


Epoch 65, Train Loss(Total, MLM, Entailment, Claim): (1.188, 1.188, 0.000, 1.362), Train Accuracy (MLM, Entailment, Claim): (0.729, 0.999, 0.265), Val Loss: 6.686, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.870, 0.289): 100%|██████████| 1238/1238 [04:21<00:00,  4.73it/s]
Epoch 66, Train Loss(Total, MLM, Entailment, Claim): (1.106, 1.105, 0.001, 1.365), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.266), Val Loss: 6.663, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.872, 0.295): 100%|██████████| 1238/1238 [04:17<00:00,  4.81it/s]
Epoch 67, Train Loss(Total, MLM, Entailment, Claim): (1.086, 1.086, 0.000, 1.365), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.267), Val Loss: 6.679, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.874, 0.303): 100%|██████████| 1238/1238 [04:15<00:00,  4.84it/s]
Epoch 68, Train Loss(Total, MLM, Entailment, Claim): (1.164, 1.162, 0.002, 1.332), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.267), Val Loss: 6.611, Va

Saved BERT model checkpoint!


Epoch 69, Train Loss(Total, MLM, Entailment, Claim): (1.210, 1.209, 0.002, 1.356), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.268), Val Loss: 6.528, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.876, 0.319): 100%|██████████| 1238/1238 [04:16<00:00,  4.83it/s]
Epoch 70, Train Loss(Total, MLM, Entailment, Claim): (1.256, 1.243, 0.012, 1.387), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.268), Val Loss: 6.595, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.877, 0.284): 100%|██████████| 1238/1238 [04:11<00:00,  4.91it/s]
Epoch 71, Train Loss(Total, MLM, Entailment, Claim): (1.239, 1.233, 0.005, 1.377), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.269), Val Loss: 6.583, Val Accuracy (MLM, Entailment, Claim): (0.418, 0.870, 0.365): 100%|██████████| 1238/1238 [04:06<00:00,  5.03it/s]
Epoch 72, Train Loss(Total, MLM, Entailment, Claim): (1.188, 1.184, 0.004, 1.380), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.269), Val Loss: 6.484, Va

Saved BERT model checkpoint!


Epoch 73, Train Loss(Total, MLM, Entailment, Claim): (1.162, 1.160, 0.002, 1.360), Train Accuracy (MLM, Entailment, Claim): (0.730, 0.999, 0.270), Val Loss: 6.490, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.863, 0.286): 100%|██████████| 1238/1238 [04:08<00:00,  4.99it/s]
Epoch 74, Train Loss(Total, MLM, Entailment, Claim): (1.072, 1.071, 0.000, 1.330), Train Accuracy (MLM, Entailment, Claim): (0.731, 0.999, 0.270), Val Loss: 6.549, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.875, 0.386): 100%|██████████| 1238/1238 [04:08<00:00,  4.98it/s]
Epoch 75, Train Loss(Total, MLM, Entailment, Claim): (1.110, 1.109, 0.002, 1.320), Train Accuracy (MLM, Entailment, Claim): (0.731, 0.999, 0.271), Val Loss: 6.589, Val Accuracy (MLM, Entailment, Claim): (0.430, 0.873, 0.372): 100%|██████████| 1238/1238 [04:04<00:00,  5.06it/s]
Epoch 76, Train Loss(Total, MLM, Entailment, Claim): (1.092, 1.091, 0.001, 1.327), Train Accuracy (MLM, Entailment, Claim): (0.731, 0.999, 0.272), Val Loss: 6.527, Va

Saved BERT model checkpoint!


Epoch 77, Train Loss(Total, MLM, Entailment, Claim): (1.124, 1.113, 0.011, 1.363), Train Accuracy (MLM, Entailment, Claim): (0.731, 0.999, 0.273), Val Loss: 6.637, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.876, 0.342): 100%|██████████| 1238/1238 [04:02<00:00,  5.11it/s]
Epoch 78, Train Loss(Total, MLM, Entailment, Claim): (1.090, 1.090, 0.000, 1.326), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.274), Val Loss: 6.768, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.869, 0.370): 100%|██████████| 1238/1238 [04:07<00:00,  5.00it/s]
Epoch 79, Train Loss(Total, MLM, Entailment, Claim): (1.101, 1.101, 0.000, 1.336), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.275), Val Loss: 6.524, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.895, 0.348): 100%|██████████| 1238/1238 [04:19<00:00,  4.78it/s]
Epoch 80, Train Loss(Total, MLM, Entailment, Claim): (1.165, 1.151, 0.015, 1.358), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.275), Val Loss: 6.525, Va

Saved BERT model checkpoint!


Epoch 81, Train Loss(Total, MLM, Entailment, Claim): (1.200, 1.197, 0.002, 1.323), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.276), Val Loss: 6.815, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.856, 0.268): 100%|██████████| 1238/1238 [04:08<00:00,  4.97it/s]
Epoch 82, Train Loss(Total, MLM, Entailment, Claim): (1.220, 1.217, 0.003, 1.397), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.276), Val Loss: 6.546, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.859, 0.326): 100%|██████████| 1238/1238 [04:05<00:00,  5.05it/s]
Epoch 83, Train Loss(Total, MLM, Entailment, Claim): (1.169, 1.166, 0.003, 1.353), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.276), Val Loss: 6.637, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.865, 0.315): 100%|██████████| 1238/1238 [04:08<00:00,  4.97it/s]
Epoch 84, Train Loss(Total, MLM, Entailment, Claim): (1.142, 1.137, 0.004, 1.347), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.277), Val Loss: 6.611, Va

Saved BERT model checkpoint!


Epoch 85, Train Loss(Total, MLM, Entailment, Claim): (1.123, 1.117, 0.006, 1.350), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.277), Val Loss: 6.695, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.862, 0.331): 100%|██████████| 1238/1238 [04:13<00:00,  4.88it/s]
Epoch 86, Train Loss(Total, MLM, Entailment, Claim): (1.064, 1.061, 0.003, 1.364), Train Accuracy (MLM, Entailment, Claim): (0.732, 0.999, 0.278), Val Loss: 6.701, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.864, 0.332): 100%|██████████| 1238/1238 [04:15<00:00,  4.85it/s]
Epoch 87, Train Loss(Total, MLM, Entailment, Claim): (1.062, 1.053, 0.009, 1.325), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.278), Val Loss: 6.800, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.859, 0.325): 100%|██████████| 1238/1238 [04:12<00:00,  4.91it/s]
Epoch 88, Train Loss(Total, MLM, Entailment, Claim): (1.100, 1.097, 0.003, 1.371), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.278), Val Loss: 6.630, Va

Saved BERT model checkpoint!


Epoch 89, Train Loss(Total, MLM, Entailment, Claim): (1.148, 1.135, 0.013, 1.402), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.279), Val Loss: 6.714, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.879, 0.331): 100%|██████████| 1238/1238 [04:07<00:00,  5.00it/s]
Epoch 90, Train Loss(Total, MLM, Entailment, Claim): (1.279, 1.166, 0.113, 1.406), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.279), Val Loss: 6.560, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.871, 0.364): 100%|██████████| 1238/1238 [04:07<00:00,  5.01it/s]
Epoch 91, Train Loss(Total, MLM, Entailment, Claim): (1.197, 1.175, 0.022, 1.350), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.279), Val Loss: 6.900, Val Accuracy (MLM, Entailment, Claim): (0.422, 0.842, 0.254): 100%|██████████| 1238/1238 [04:10<00:00,  4.95it/s]
Epoch 92, Train Loss(Total, MLM, Entailment, Claim): (1.226, 1.216, 0.010, 1.326), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.279), Val Loss: 6.529, Va

Saved BERT model checkpoint!


Epoch 93, Train Loss(Total, MLM, Entailment, Claim): (1.106, 1.102, 0.004, 1.327), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.279), Val Loss: 6.585, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.868, 0.260): 100%|██████████| 1238/1238 [04:08<00:00,  4.98it/s]
Epoch 94, Train Loss(Total, MLM, Entailment, Claim): (1.107, 1.106, 0.001, 1.361), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.279), Val Loss: 6.710, Val Accuracy (MLM, Entailment, Claim): (0.430, 0.871, 0.316): 100%|██████████| 1238/1238 [04:08<00:00,  4.98it/s]
Epoch 95, Train Loss(Total, MLM, Entailment, Claim): (1.050, 1.050, 0.001, 1.385), Train Accuracy (MLM, Entailment, Claim): (0.733, 0.999, 0.280), Val Loss: 6.841, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.869, 0.278): 100%|██████████| 1238/1238 [04:17<00:00,  4.82it/s]
Epoch 96, Train Loss(Total, MLM, Entailment, Claim): (1.094, 1.093, 0.001, 1.365), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.809, Va

Saved BERT model checkpoint!


Epoch 97, Train Loss(Total, MLM, Entailment, Claim): (1.090, 1.088, 0.002, 1.404), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.800, Val Accuracy (MLM, Entailment, Claim): (0.430, 0.861, 0.287): 100%|██████████| 1238/1238 [04:15<00:00,  4.85it/s]
Epoch 98, Train Loss(Total, MLM, Entailment, Claim): (1.100, 1.099, 0.001, 1.409), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.551, Val Accuracy (MLM, Entailment, Claim): (0.431, 0.892, 0.269): 100%|██████████| 1238/1238 [04:12<00:00,  4.90it/s]
Epoch 99, Train Loss(Total, MLM, Entailment, Claim): (1.095, 1.093, 0.002, 1.408), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.676, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.892, 0.267): 100%|██████████| 1238/1238 [04:09<00:00,  4.96it/s]
Epoch 100, Train Loss(Total, MLM, Entailment, Claim): (1.149, 1.147, 0.002, 1.375), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.818, V

Saved BERT model checkpoint!


Epoch 101, Train Loss(Total, MLM, Entailment, Claim): (1.180, 1.165, 0.014, 1.372), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.523, Val Accuracy (MLM, Entailment, Claim): (0.419, 0.885, 0.394): 100%|██████████| 1238/1238 [04:12<00:00,  4.89it/s]
Epoch 102, Train Loss(Total, MLM, Entailment, Claim): (1.143, 1.135, 0.008, 1.315), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.280), Val Loss: 6.642, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.851, 0.262): 100%|██████████| 1238/1238 [04:11<00:00,  4.92it/s]
Epoch 103, Train Loss(Total, MLM, Entailment, Claim): (1.151, 1.150, 0.001, 1.343), Train Accuracy (MLM, Entailment, Claim): (0.734, 0.999, 0.281), Val Loss: 6.575, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.876, 0.319): 100%|██████████| 1238/1238 [04:10<00:00,  4.95it/s]
Epoch 104, Train Loss(Total, MLM, Entailment, Claim): (1.096, 1.095, 0.001, 1.328), Train Accuracy (MLM, Entailment, Claim): (0.735, 0.999, 0.281), Val Loss: 6.631

Saved BERT model checkpoint!


Epoch 105, Train Loss(Total, MLM, Entailment, Claim): (1.038, 1.036, 0.001, 1.324), Train Accuracy (MLM, Entailment, Claim): (0.735, 0.999, 0.281), Val Loss: 6.771, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.865, 0.298): 100%|██████████| 1238/1238 [04:09<00:00,  4.96it/s]
Epoch 106, Train Loss(Total, MLM, Entailment, Claim): (1.044, 1.044, 0.000, 1.359), Train Accuracy (MLM, Entailment, Claim): (0.735, 0.999, 0.281), Val Loss: 6.712, Val Accuracy (MLM, Entailment, Claim): (0.432, 0.869, 0.300): 100%|██████████| 1238/1238 [04:05<00:00,  5.04it/s]
Epoch 107, Train Loss(Total, MLM, Entailment, Claim): (1.061, 1.061, 0.001, 1.374), Train Accuracy (MLM, Entailment, Claim): (0.735, 0.999, 0.282), Val Loss: 6.763, Val Accuracy (MLM, Entailment, Claim): (0.428, 0.871, 0.316): 100%|██████████| 1238/1238 [03:58<00:00,  5.19it/s]
Epoch 108, Train Loss(Total, MLM, Entailment, Claim): (1.113, 1.098, 0.015, 1.356), Train Accuracy (MLM, Entailment, Claim): (0.735, 0.999, 0.282), Val Loss: 6.678

Saved BERT model checkpoint!


Epoch 109, Train Loss(Total, MLM, Entailment, Claim): (1.082, 1.082, 0.000, 1.344), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.282), Val Loss: 6.658, Val Accuracy (MLM, Entailment, Claim): (0.428, 0.880, 0.360): 100%|██████████| 1238/1238 [04:06<00:00,  5.02it/s]
Epoch 110, Train Loss(Total, MLM, Entailment, Claim): (1.117, 1.104, 0.013, 1.351), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.282), Val Loss: 6.856, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.872, 0.280): 100%|██████████| 1238/1238 [04:32<00:00,  4.54it/s]
Epoch 111, Train Loss(Total, MLM, Entailment, Claim): (1.128, 1.121, 0.007, 1.393), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.282), Val Loss: 6.746, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.861, 0.263): 100%|██████████| 1238/1238 [04:26<00:00,  4.65it/s]
Epoch 112, Train Loss(Total, MLM, Entailment, Claim): (1.175, 1.170, 0.005, 1.403), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.282), Val Loss: 6.753

Saved BERT model checkpoint!


Epoch 113, Train Loss(Total, MLM, Entailment, Claim): (1.154, 1.152, 0.002, 1.414), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.282), Val Loss: 6.814, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.862, 0.258): 100%|██████████| 1238/1238 [04:24<00:00,  4.69it/s]
Epoch 114, Train Loss(Total, MLM, Entailment, Claim): (1.076, 1.070, 0.006, 1.380), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.282), Val Loss: 6.728, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.885, 0.232): 100%|██████████| 1238/1238 [04:26<00:00,  4.65it/s]
Epoch 115, Train Loss(Total, MLM, Entailment, Claim): (1.070, 1.045, 0.026, 1.405), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.281), Val Loss: 6.927, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.864, 0.243): 100%|██████████| 1238/1238 [04:24<00:00,  4.69it/s]
Epoch 116, Train Loss(Total, MLM, Entailment, Claim): (1.020, 1.020, 0.000, 1.391), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.281), Val Loss: 6.847

Saved BERT model checkpoint!


Epoch 117, Train Loss(Total, MLM, Entailment, Claim): (1.005, 0.999, 0.006, 1.425), Train Accuracy (MLM, Entailment, Claim): (0.736, 0.999, 0.281), Val Loss: 6.827, Val Accuracy (MLM, Entailment, Claim): (0.432, 0.869, 0.230): 100%|██████████| 1238/1238 [04:26<00:00,  4.65it/s]
Epoch 118, Train Loss(Total, MLM, Entailment, Claim): (1.019, 1.018, 0.001, 1.382), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.281), Val Loss: 6.707, Val Accuracy (MLM, Entailment, Claim): (0.431, 0.876, 0.294): 100%|██████████| 1238/1238 [04:33<00:00,  4.52it/s]
Epoch 119, Train Loss(Total, MLM, Entailment, Claim): (1.027, 1.026, 0.001, 1.398), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.281), Val Loss: 6.800, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.878, 0.290): 100%|██████████| 1238/1238 [04:15<00:00,  4.84it/s]
Epoch 120, Train Loss(Total, MLM, Entailment, Claim): (1.126, 1.112, 0.015, 1.423), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.281), Val Loss: 6.850

Saved BERT model checkpoint!


Epoch 121, Train Loss(Total, MLM, Entailment, Claim): (1.137, 1.126, 0.012, 1.421), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.281), Val Loss: 6.732, Val Accuracy (MLM, Entailment, Claim): (0.425, 0.866, 0.244): 100%|██████████| 1238/1238 [04:19<00:00,  4.76it/s]
Epoch 122, Train Loss(Total, MLM, Entailment, Claim): (1.189, 1.179, 0.009, 1.368), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.280), Val Loss: 6.600, Val Accuracy (MLM, Entailment, Claim): (0.423, 0.881, 0.275): 100%|██████████| 1238/1238 [04:17<00:00,  4.81it/s]
Epoch 123, Train Loss(Total, MLM, Entailment, Claim): (1.121, 1.114, 0.007, 1.339), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.280), Val Loss: 6.619, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.874, 0.248): 100%|██████████| 1238/1238 [04:15<00:00,  4.84it/s]
Epoch 124, Train Loss(Total, MLM, Entailment, Claim): (1.045, 1.045, 0.000, 1.379), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.280), Val Loss: 6.783

Saved BERT model checkpoint!


Epoch 125, Train Loss(Total, MLM, Entailment, Claim): (1.023, 1.021, 0.001, 1.379), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.280), Val Loss: 6.803, Val Accuracy (MLM, Entailment, Claim): (0.431, 0.867, 0.251): 100%|██████████| 1238/1238 [04:09<00:00,  4.97it/s]
Epoch 126, Train Loss(Total, MLM, Entailment, Claim): (1.031, 1.030, 0.000, 1.404), Train Accuracy (MLM, Entailment, Claim): (0.737, 0.999, 0.280), Val Loss: 6.776, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.867, 0.253): 100%|██████████| 1238/1238 [04:16<00:00,  4.83it/s]
Epoch 127, Train Loss(Total, MLM, Entailment, Claim): (1.029, 1.027, 0.002, 1.358), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.821, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.868, 0.272): 100%|██████████| 1238/1238 [04:12<00:00,  4.90it/s]
Epoch 128, Train Loss(Total, MLM, Entailment, Claim): (1.017, 1.017, 0.000, 1.397), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.871

Saved BERT model checkpoint!


Epoch 129, Train Loss(Total, MLM, Entailment, Claim): (1.047, 1.028, 0.019, 1.405), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.839, Val Accuracy (MLM, Entailment, Claim): (0.432, 0.879, 0.223): 100%|██████████| 1238/1238 [04:12<00:00,  4.89it/s]
Epoch 130, Train Loss(Total, MLM, Entailment, Claim): (1.121, 1.117, 0.004, 1.356), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.712, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.893, 0.234): 100%|██████████| 1238/1238 [04:12<00:00,  4.91it/s]
Epoch 131, Train Loss(Total, MLM, Entailment, Claim): (1.082, 1.080, 0.002, 1.356), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.596, Val Accuracy (MLM, Entailment, Claim): (0.421, 0.884, 0.258): 100%|██████████| 1238/1238 [04:08<00:00,  4.98it/s]
Epoch 132, Train Loss(Total, MLM, Entailment, Claim): (1.106, 1.104, 0.002, 1.341), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.842

Saved BERT model checkpoint!


Epoch 133, Train Loss(Total, MLM, Entailment, Claim): (1.075, 1.063, 0.012, 1.352), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.280), Val Loss: 6.887, Val Accuracy (MLM, Entailment, Claim): (0.427, 0.847, 0.237): 100%|██████████| 1238/1238 [04:15<00:00,  4.85it/s]
Epoch 134, Train Loss(Total, MLM, Entailment, Claim): (1.043, 1.042, 0.001, 1.394), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.279), Val Loss: 6.810, Val Accuracy (MLM, Entailment, Claim): (0.424, 0.876, 0.255): 100%|██████████| 1238/1238 [04:04<00:00,  5.07it/s]
Epoch 135, Train Loss(Total, MLM, Entailment, Claim): (1.010, 1.005, 0.005, 1.396), Train Accuracy (MLM, Entailment, Claim): (0.738, 0.999, 0.279), Val Loss: 6.872, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.872, 0.256): 100%|██████████| 1238/1238 [04:06<00:00,  5.02it/s]
Epoch 136, Train Loss(Total, MLM, Entailment, Claim): (0.994, 0.992, 0.002, 1.381), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.863

Saved BERT model checkpoint!


Epoch 137, Train Loss(Total, MLM, Entailment, Claim): (1.010, 1.010, 0.000, 1.405), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.864, Val Accuracy (MLM, Entailment, Claim): (0.432, 0.865, 0.262): 100%|██████████| 1238/1238 [04:03<00:00,  5.09it/s]
Epoch 138, Train Loss(Total, MLM, Entailment, Claim): (1.032, 1.029, 0.003, 1.351), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.820, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.873, 0.246): 100%|██████████| 1238/1238 [04:05<00:00,  5.05it/s]
Epoch 139, Train Loss(Total, MLM, Entailment, Claim): (1.017, 1.015, 0.002, 1.358), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.680, Val Accuracy (MLM, Entailment, Claim): (0.429, 0.893, 0.236): 100%|██████████| 1238/1238 [04:04<00:00,  5.06it/s]
Epoch 140, Train Loss(Total, MLM, Entailment, Claim): (1.090, 1.080, 0.010, 1.340), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.777

Saved BERT model checkpoint!


Epoch 141, Train Loss(Total, MLM, Entailment, Claim): (1.171, 1.161, 0.011, 1.330), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.997, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.848, 0.280): 100%|██████████| 1238/1238 [04:02<00:00,  5.11it/s]
Epoch 142, Train Loss(Total, MLM, Entailment, Claim): (1.087, 1.085, 0.002, 1.301), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.279), Val Loss: 6.840, Val Accuracy (MLM, Entailment, Claim): (0.426, 0.862, 0.302): 100%|██████████| 1238/1238 [04:01<00:00,  5.12it/s]
Epoch 143, Train Loss(Total, MLM, Entailment, Claim): (1.042, 1.023, 0.019, 1.285), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.280), Val Loss: 6.543, Val Accuracy (MLM, Entailment, Claim): (0.428, 0.876, 0.379): 100%|██████████| 1238/1238 [04:02<00:00,  5.11it/s]
Epoch 144, Train Loss(Total, MLM, Entailment, Claim): (1.079, 1.078, 0.001, 1.339), Train Accuracy (MLM, Entailment, Claim): (0.739, 0.999, 0.280), Val Loss: 6.753

Saved BERT model checkpoint!


Epoch 145, Train Loss(Total, MLM, Entailment, Claim): (0.991, 0.990, 0.000, 1.310), Train Accuracy (MLM, Entailment, Claim): (0.740, 0.999, 0.280), Val Loss: 6.761, Val Accuracy (MLM, Entailment, Claim): (0.432, 0.870, 0.334): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 146, Train Loss(Total, MLM, Entailment, Claim): (0.997, 0.994, 0.002, 1.313), Train Accuracy (MLM, Entailment, Claim): (0.740, 0.999, 0.281), Val Loss: 6.696, Val Accuracy (MLM, Entailment, Claim): (0.430, 0.876, 0.329): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 147, Train Loss(Total, MLM, Entailment, Claim): (0.975, 0.972, 0.003, 1.340), Train Accuracy (MLM, Entailment, Claim): (0.740, 0.999, 0.281), Val Loss: 6.816, Val Accuracy (MLM, Entailment, Claim): (0.430, 0.863, 0.304): 100%|██████████| 1238/1238 [04:01<00:00,  5.13it/s]
Epoch 148, Train Loss(Total, MLM, Entailment, Claim): (1.015, 1.015, 0.001, 1.369), Train Accuracy (MLM, Entailment, Claim): (0.740, 0.999, 0.281), Val Loss: 6.804

Saved BERT model checkpoint!


Epoch 149, Train Loss(Total, MLM, Entailment, Claim): (0.996, 0.995, 0.001, 1.395), Train Accuracy (MLM, Entailment, Claim): (0.740, 0.999, 0.281), Val Loss: 6.865, Val Accuracy (MLM, Entailment, Claim): (0.431, 0.868, 0.340): 100%|██████████| 1238/1238 [04:02<00:00,  5.11it/s]
Epoch 150, Train Loss(Total, MLM, Entailment, Claim): (1.053, 1.051, 0.002, 1.340), Train Accuracy (MLM, Entailment, Claim): (0.740, 0.999, 0.281), Val Loss: 6.682, Val Accuracy (MLM, Entailment, Claim): (0.431, 0.853, 0.385): 100%|██████████| 1238/1238 [04:02<00:00,  5.10it/s]

Training done!





In [9]:
train(model, 40, train_dataloader, val_dataloader, optimizer, grad_accumulation_steps=20, val_every=180, save_every=4, device=device, log_metrics=log_metrics, mixed_precision=True, checkpoint_name="BERT_multitask", scheduler=scheduler, include_claim_loss=True, claim_loss_weight=0.5)

Epoch 1, Train Loss(Total, MLM, Entailment, Claim): (2.106, 2.069, 0.033, 0.041), Train Accuracy (MLM, Entailment, Claim): (0.636, 0.982, 0.988), Val Loss: 4.384, Val Accuracy (MLM, Entailment, Claim): (0.510, 0.870, 0.844): 100%|██████████| 184/184 [01:41<00:00,  1.82it/s]
Epoch 2, Train Loss(Total, MLM, Entailment, Claim): (2.117, 2.091, 0.024, 0.028), Train Accuracy (MLM, Entailment, Claim): (0.636, 0.986, 0.989), Val Loss: 4.405, Val Accuracy (MLM, Entailment, Claim): (0.512, 0.839, 0.835): 100%|██████████| 184/184 [01:39<00:00,  1.85it/s]
Epoch 3, Train Loss(Total, MLM, Entailment, Claim): (2.077, 2.068, 0.011, 0.008), Train Accuracy (MLM, Entailment, Claim): (0.637, 0.988, 0.991), Val Loss: 4.420, Val Accuracy (MLM, Entailment, Claim): (0.521, 0.859, 0.836): 100%|██████████| 184/184 [01:39<00:00,  1.86it/s]
Epoch 4, Train Loss(Total, MLM, Entailment, Claim): (2.133, 2.126, 0.009, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.638, 0.990, 0.993), Val Loss: 4.539, Val Accuracy

Saved BERT model checkpoint!


Epoch 5, Train Loss(Total, MLM, Entailment, Claim): (2.061, 2.052, 0.008, 0.010), Train Accuracy (MLM, Entailment, Claim): (0.639, 0.991, 0.994), Val Loss: 4.677, Val Accuracy (MLM, Entailment, Claim): (0.513, 0.844, 0.843): 100%|██████████| 184/184 [01:37<00:00,  1.89it/s]
Epoch 6, Train Loss(Total, MLM, Entailment, Claim): (2.046, 2.038, 0.012, 0.005), Train Accuracy (MLM, Entailment, Claim): (0.639, 0.992, 0.994), Val Loss: 4.653, Val Accuracy (MLM, Entailment, Claim): (0.514, 0.846, 0.843): 100%|██████████| 184/184 [01:40<00:00,  1.83it/s]
Epoch 7, Train Loss(Total, MLM, Entailment, Claim): (2.049, 2.043, 0.007, 0.005), Train Accuracy (MLM, Entailment, Claim): (0.640, 0.993, 0.995), Val Loss: 4.700, Val Accuracy (MLM, Entailment, Claim): (0.512, 0.850, 0.837): 100%|██████████| 184/184 [01:45<00:00,  1.75it/s]
Epoch 8, Train Loss(Total, MLM, Entailment, Claim): (2.073, 2.061, 0.018, 0.006), Train Accuracy (MLM, Entailment, Claim): (0.640, 0.993, 0.995), Val Loss: 4.749, Val Accuracy

Saved BERT model checkpoint!


Epoch 9, Train Loss(Total, MLM, Entailment, Claim): (2.013, 2.007, 0.005, 0.006), Train Accuracy (MLM, Entailment, Claim): (0.641, 0.994, 0.995), Val Loss: 4.911, Val Accuracy (MLM, Entailment, Claim): (0.515, 0.854, 0.834): 100%|██████████| 184/184 [01:38<00:00,  1.87it/s]
Epoch 10, Train Loss(Total, MLM, Entailment, Claim): (2.013, 2.002, 0.007, 0.014), Train Accuracy (MLM, Entailment, Claim): (0.641, 0.994, 0.996), Val Loss: 5.021, Val Accuracy (MLM, Entailment, Claim): (0.516, 0.848, 0.835): 100%|██████████| 184/184 [01:37<00:00,  1.88it/s]
Epoch 11, Train Loss(Total, MLM, Entailment, Claim): (2.012, 2.004, 0.007, 0.008), Train Accuracy (MLM, Entailment, Claim): (0.642, 0.994, 0.996), Val Loss: 5.098, Val Accuracy (MLM, Entailment, Claim): (0.509, 0.855, 0.836): 100%|██████████| 184/184 [01:43<00:00,  1.78it/s]
Epoch 12, Train Loss(Total, MLM, Entailment, Claim): (2.003, 1.992, 0.012, 0.009), Train Accuracy (MLM, Entailment, Claim): (0.642, 0.995, 0.996), Val Loss: 4.899, Val Accur

Saved BERT model checkpoint!


Epoch 13, Train Loss(Total, MLM, Entailment, Claim): (1.974, 1.960, 0.015, 0.011), Train Accuracy (MLM, Entailment, Claim): (0.643, 0.995, 0.996), Val Loss: 4.891, Val Accuracy (MLM, Entailment, Claim): (0.523, 0.858, 0.845): 100%|██████████| 184/184 [01:38<00:00,  1.86it/s]
Epoch 14, Train Loss(Total, MLM, Entailment, Claim): (1.961, 1.957, 0.007, 0.003), Train Accuracy (MLM, Entailment, Claim): (0.643, 0.995, 0.996), Val Loss: 4.860, Val Accuracy (MLM, Entailment, Claim): (0.524, 0.857, 0.847): 100%|██████████| 184/184 [01:38<00:00,  1.86it/s]
Epoch 15, Train Loss(Total, MLM, Entailment, Claim): (1.981, 1.976, 0.005, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.644, 0.995, 0.996), Val Loss: 4.830, Val Accuracy (MLM, Entailment, Claim): (0.524, 0.861, 0.844): 100%|██████████| 184/184 [01:38<00:00,  1.88it/s]
Epoch 16, Train Loss(Total, MLM, Entailment, Claim): (1.938, 1.936, 0.001, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.644, 0.995, 0.996), Val Loss: 4.838, Val Accu

Saved BERT model checkpoint!


Epoch 17, Train Loss(Total, MLM, Entailment, Claim): (1.976, 1.973, 0.003, 0.002), Train Accuracy (MLM, Entailment, Claim): (0.645, 0.996, 0.996), Val Loss: 4.883, Val Accuracy (MLM, Entailment, Claim): (0.522, 0.868, 0.843): 100%|██████████| 184/184 [01:38<00:00,  1.86it/s]
Epoch 18, Train Loss(Total, MLM, Entailment, Claim): (1.947, 1.941, 0.004, 0.007), Train Accuracy (MLM, Entailment, Claim): (0.645, 0.996, 0.996), Val Loss: 4.913, Val Accuracy (MLM, Entailment, Claim): (0.525, 0.851, 0.836): 100%|██████████| 184/184 [01:37<00:00,  1.89it/s]
Epoch 19, Train Loss(Total, MLM, Entailment, Claim): (1.948, 1.943, 0.003, 0.008), Train Accuracy (MLM, Entailment, Claim): (0.646, 0.996, 0.997), Val Loss: 5.008, Val Accuracy (MLM, Entailment, Claim): (0.524, 0.855, 0.834): 100%|██████████| 184/184 [01:34<00:00,  1.94it/s]
Epoch 20, Train Loss(Total, MLM, Entailment, Claim): (1.965, 1.960, 0.005, 0.005), Train Accuracy (MLM, Entailment, Claim): (0.646, 0.996, 0.997), Val Loss: 4.783, Val Accu

Saved BERT model checkpoint!


Epoch 21, Train Loss(Total, MLM, Entailment, Claim): (1.953, 1.949, 0.002, 0.005), Train Accuracy (MLM, Entailment, Claim): (0.647, 0.996, 0.997), Val Loss: 5.095, Val Accuracy (MLM, Entailment, Claim): (0.516, 0.846, 0.844): 100%|██████████| 184/184 [01:39<00:00,  1.85it/s]
Epoch 22, Train Loss(Total, MLM, Entailment, Claim): (1.922, 1.918, 0.007, 0.003), Train Accuracy (MLM, Entailment, Claim): (0.647, 0.996, 0.997), Val Loss: 4.962, Val Accuracy (MLM, Entailment, Claim): (0.524, 0.858, 0.849): 100%|██████████| 184/184 [01:40<00:00,  1.84it/s]
Epoch 23, Train Loss(Total, MLM, Entailment, Claim): (1.908, 1.905, 0.002, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.648, 0.996, 0.997), Val Loss: 5.004, Val Accuracy (MLM, Entailment, Claim): (0.520, 0.845, 0.846): 100%|██████████| 184/184 [01:40<00:00,  1.84it/s]
Epoch 24, Train Loss(Total, MLM, Entailment, Claim): (1.912, 1.907, 0.005, 0.005), Train Accuracy (MLM, Entailment, Claim): (0.648, 0.996, 0.997), Val Loss: 5.029, Val Accu

Saved BERT model checkpoint!


Epoch 25, Train Loss(Total, MLM, Entailment, Claim): (1.911, 1.907, 0.005, 0.002), Train Accuracy (MLM, Entailment, Claim): (0.648, 0.997, 0.997), Val Loss: 5.066, Val Accuracy (MLM, Entailment, Claim): (0.522, 0.860, 0.839): 100%|██████████| 184/184 [01:42<00:00,  1.80it/s]
Epoch 26, Train Loss(Total, MLM, Entailment, Claim): (1.854, 1.852, 0.001, 0.003), Train Accuracy (MLM, Entailment, Claim): (0.649, 0.997, 0.997), Val Loss: 4.976, Val Accuracy (MLM, Entailment, Claim): (0.519, 0.860, 0.849): 100%|██████████| 184/184 [01:40<00:00,  1.83it/s]
Epoch 27, Train Loss(Total, MLM, Entailment, Claim): (1.875, 1.872, 0.002, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.649, 0.997, 0.997), Val Loss: 5.057, Val Accuracy (MLM, Entailment, Claim): (0.519, 0.855, 0.838): 100%|██████████| 184/184 [01:39<00:00,  1.85it/s]
Epoch 28, Train Loss(Total, MLM, Entailment, Claim): (1.924, 1.922, 0.003, 0.001), Train Accuracy (MLM, Entailment, Claim): (0.650, 0.997, 0.997), Val Loss: 4.940, Val Accu

Saved BERT model checkpoint!


Epoch 29, Train Loss(Total, MLM, Entailment, Claim): (1.889, 1.888, 0.001, 0.001), Train Accuracy (MLM, Entailment, Claim): (0.650, 0.997, 0.997), Val Loss: 5.176, Val Accuracy (MLM, Entailment, Claim): (0.519, 0.859, 0.838): 100%|██████████| 184/184 [01:40<00:00,  1.84it/s]
Epoch 30, Train Loss(Total, MLM, Entailment, Claim): (1.903, 1.901, 0.003, 0.001), Train Accuracy (MLM, Entailment, Claim): (0.650, 0.997, 0.997), Val Loss: 5.252, Val Accuracy (MLM, Entailment, Claim): (0.521, 0.852, 0.842): 100%|██████████| 184/184 [01:40<00:00,  1.82it/s]
Epoch 31, Train Loss(Total, MLM, Entailment, Claim): (1.890, 1.887, 0.003, 0.002), Train Accuracy (MLM, Entailment, Claim): (0.651, 0.997, 0.997), Val Loss: 5.421, Val Accuracy (MLM, Entailment, Claim): (0.519, 0.838, 0.841): 100%|██████████| 184/184 [01:41<00:00,  1.82it/s]
Epoch 32, Train Loss(Total, MLM, Entailment, Claim): (1.886, 1.884, 0.001, 0.002), Train Accuracy (MLM, Entailment, Claim): (0.651, 0.997, 0.997), Val Loss: 5.166, Val Accu

Saved BERT model checkpoint!


Epoch 33, Train Loss(Total, MLM, Entailment, Claim): (1.867, 1.858, 0.008, 0.010), Train Accuracy (MLM, Entailment, Claim): (0.651, 0.997, 0.997), Val Loss: 5.087, Val Accuracy (MLM, Entailment, Claim): (0.529, 0.860, 0.843): 100%|██████████| 184/184 [01:40<00:00,  1.84it/s]
Epoch 34, Train Loss(Total, MLM, Entailment, Claim): (1.862, 1.859, 0.001, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.997, 0.997), Val Loss: 5.004, Val Accuracy (MLM, Entailment, Claim): (0.521, 0.864, 0.845): 100%|██████████| 184/184 [01:39<00:00,  1.85it/s]
Epoch 35, Train Loss(Total, MLM, Entailment, Claim): (1.861, 1.860, 0.001, 0.003), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.997, 0.997), Val Loss: 5.125, Val Accuracy (MLM, Entailment, Claim): (0.527, 0.863, 0.838): 100%|██████████| 184/184 [01:41<00:00,  1.82it/s]
Epoch 36, Train Loss(Total, MLM, Entailment, Claim): (1.831, 1.828, 0.002, 0.004), Train Accuracy (MLM, Entailment, Claim): (0.652, 0.997, 0.997), Val Loss: 5.104, Val Accu

Saved BERT model checkpoint!


Epoch 37, Train Loss(Total, MLM, Entailment, Claim): (1.818, 1.817, 0.001, 0.002), Train Accuracy (MLM, Entailment, Claim): (0.653, 0.997, 0.998), Val Loss: 5.217, Val Accuracy (MLM, Entailment, Claim): (0.522, 0.854, 0.839): 100%|██████████| 184/184 [01:40<00:00,  1.83it/s]
Epoch 38, Train Loss(Total, MLM, Entailment, Claim): (1.835, 1.833, 0.002, 0.001), Train Accuracy (MLM, Entailment, Claim): (0.653, 0.997, 0.998), Val Loss: 5.307, Val Accuracy (MLM, Entailment, Claim): (0.528, 0.852, 0.839): 100%|██████████| 184/184 [01:40<00:00,  1.83it/s]
Epoch 39, Train Loss(Total, MLM, Entailment, Claim): (1.824, 1.823, 0.002, 0.001), Train Accuracy (MLM, Entailment, Claim): (0.653, 0.997, 0.998), Val Loss: 5.020, Val Accuracy (MLM, Entailment, Claim): (0.521, 0.866, 0.838): 100%|██████████| 184/184 [01:42<00:00,  1.79it/s]
Epoch 40, Train Loss(Total, MLM, Entailment, Claim): (1.812, 1.809, 0.002, 0.003), Train Accuracy (MLM, Entailment, Claim): (0.654, 0.997, 0.998), Val Loss: 5.229, Val Accu

Saved BERT model checkpoint!
Training done!


In [None]:
save_bert_model_checkpoint(model, optimizer, name='BERT_multitask_checkpoint_entaiment_claims_long_600_epochs')

# 3.Testing and Evaluation

#### We now run some quick tests to check the performance of the pretrained language model on each of the three tasks: (i) Masked token prediction, (ii) Sentence Entailment Prediction and (iii) Claim Class Prediction

In [22]:
# get pair of sentences from training dataset
idx = 275
sentence_1, sentence_2, entailment_label, claim_label = val_dataset.get_pair_text(idx)
print(sentence_1)
print(sentence_2)
print(entailment_label, claim_label)

# encode the sentence
input = train_dataset.encode_custom(sentence_1, sentence_2)
encoded_idx = input['masked_input'].tolist()

# convert to subword tokens
sentence_1_subwords = tokenizer.encode([sentence_1], return_subwords=True)[1][0]
sentence_2_subwords = tokenizer.encode([sentence_2], return_subwords=True)[1][0]
combined_sentence_subwords =  ['[CLS]']  + sentence_1_subwords + ['[SEP]'] + sentence_2_subwords + ['[SEP]']
masked_subwords = [tokenizer.int2word[i] for i in encoded_idx]

"Dominance hierarchies, diversity and species richness of vascular plants in an alpine meadow: contrasting short and medium term responses to simulated global change".
In 2017, they held seven Tag Team Championships from seven different promotions at the same time, including the ROH World Tag Team Championship. Local elections are to be held for Wyre Borough Council on 7 May 2015, the same day as the United Kingdom general election, 2015 and other United Kingdom local elections, 2015. It was declared as ` sanctuary ' under the National Parks and Wildlife Act 1972 by the Government of South Australia on 26 October 1995.
0 2


In [23]:
print(f"Original Sequence Subword tokens: {combined_sentence_subwords}")
print(f"Maksed Sentence Subword tokens:   {masked_subwords}")
print(f"Entailment label: {entailment_label}\n")

# move to device and unsqueeze to add batch dimension
x = {k: v.clone().unsqueeze(0).to(device) for k, v in input.items()}

# forward pass
model.eval()
with torch.no_grad():   
    MLM_logits, entailment_logits, claim_class_logits = model(x['masked_input'], x['attention_mask'], x['segment_ids'])
    logits = MLM_logits.squeeze(0)


# argmax logits
logits_argmax_idx = logits.argmax(dim=-1).tolist()
logits_argmax_subwords = [tokenizer.int2word[i] for i in logits_argmax_idx]

#print(f"Prediction Subword tokens: {logits_argmax_subwords}\n")

for i in range(len(logits)):
    if x['MLM_label'][0][i] != -100:
        # show top 10 predictions for each masked/replaced token
        topk_idx = torch.topk(logits[i], k=10)
        topk_subwords = [tokenizer.int2word[idx] for idx in topk_idx.indices.tolist()]
        print(f"Actual token --> {combined_sentence_subwords[i]}, Masked/Replaced Token --> {masked_subwords[i]}, Predicted Token Top 10 --> {topk_subwords}")
print("")

# compute argmax predictions for masked/replaced tokens
preds = logits.argmax(dim=-1)
mask = (x['MLM_label'][0] != -100)
masked_preds = preds[mask]

# compute argmax entailment prediction
entailment_pred = entailment_logits.argmax(dim=-1).item()

actual_tokens = [tokenizer.int2word[idx] for idx in x['MLM_label'][0][mask].tolist()] 
predicted_tokens = [tokenizer.int2word[idx] for idx in masked_preds.tolist()]

print(f"Actual tokens: {actual_tokens}")
print(f"Predicted tokens: {predicted_tokens}")
print(f"Entailment Logits --> {entailment_logits.tolist()[0]}, Prediction: {entailment_pred}")



Original Sequence Subword tokens: ['[CLS]', '``', 'do', '##min', '##ance', 'hier', '##arch', '##i', '##es', ',', 'diversity', 'and', 'species', 'rich', '##ne', '##s', '##s', 'of', 'va', '##scu', '##lar', 'plants', 'in', 'an', 'alpine', 'me', '##a', '##d', '##ow', ':', 'contrast', '##ing', 'short', 'and', 'medium', 'term', 'responses', 'to', 'si', '##mu', '##l', '##ated', 'global', 'change', "''", '.', '[SEP]', 'in', '2017', ',', 'they', 'held', 'seven', 'tag', 'team', 'championships', 'from', 'seven', 'different', 'promotion', '##s', 'at', 'the', 'same', 'time', ',', 'including', 'the', 'ro', '##h', 'world', 'tag', 'team', 'championship', '.', 'local', 'elections', 'are', 'to', 'be', 'held', 'for', 'wy', '##r', '##e', 'borough', 'council', 'on', '7', 'may', '2015', ',', 'the', 'same', 'day', 'as', 'the', 'united', 'kingdom', 'general', 'election', ',', '2015', 'and', 'other', 'united', 'kingdom', 'local', 'elections', ',', '2015', '##.', 'it', 'was', 'declared', 'as', 'sanctuary', "'",