In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import re
import string
import json
import sys
#import classificationreport
from sklearn.metrics import classification_report, confusion_matrix
from spacy.lang.en import English
from google.colab import drive
drive.mount('/content/gdrive')

eng = English()
tok = eng.tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

Mounted at /content/gdrive


device(type='cuda')

In [2]:
#NLI_LSTM with 3 class classification with dense layer
class NLI_LSTM(nn.Module):
    def __init__(self,embedding_matrix, vocab_size, embedding_dim, hidden_dim, output_dim,n_fc_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
        self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,batch_first=True)
        self.fc=nn.ModuleList([nn.Linear(hidden_dim * 2,hidden_dim * 2) for i in range(n_fc_layers)])
        self.linear=nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    
    # forward pass
    # NLI with 2 text inputs
    def forward(self, text1, text2):
        #text = [sent len, batch size]
        embedded1 = self.dropout(self.embedding(text1))
        embedded2 = self.dropout(self.embedding(text2))
        #embedded = [sent len, batch size, emb dim]
        output1, (hidden1, cell1) = self.lstm(embedded1)
        output2, (hidden2, cell2) = self.lstm(embedded2)

        hidden = torch.cat((hidden1,hidden2),dim=2)
        hidden = torch.squeeze(hidden)

        for i in range(len(self.fc)-1):
            hidden = self.fc[i](hidden)
            hidden = F.relu(hidden)
            hidden = self.dropout(hidden)
        
        hidden = self.fc[len(self.fc)-1](hidden)
        hidden = F.log_softmax(hidden,dim=1)
        hidden = self.dropout(hidden)

        hidden = self.linear(hidden)
        #hidden = [batch size, output dim]
        return hidden

In [3]:
#load glove vectors
def load_glovevectors():
  embeddings_index = {}
  f = open('/content/gdrive/MyDrive/NLPProject/data/glove.42B.300d.txt', encoding="utf-8")
  lines=f.readlines()
  # for line in f:
  #     split = line.split()
  #     embeddings_index[split[0]] = np.array([float(x) for x in split[1:]])
  for line in tqdm(lines):
      values = line.split()
      word = values[0]
      coefs = np.asarray(values[1:], dtype='float32')
      embeddings_index[word] = coefs
  f.close()
  return embeddings_index

embeddings_index=load_glovevectors()

100%|██████████| 1917494/1917494 [01:48<00:00, 17672.32it/s]


In [4]:
# tokenise input and remove punctuation and numbers using only regex
def tokenize(text):
    # text = re.sub(r'[^a-zA-Z ]', '', text)
    # #convert to lowercase
    # text = text.lower()
    # return text.split()
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    regex = re.compile('[' + re.escape(string.punctuation) + '0-9\\r\\t\\n]') # remove punctuation and numbers
    nopunctext = regex.sub(" ", text.lower())
    return [token.text for token in tok(nopunctext)]

UNK="<UNK>"
PAD="<PAD>"

#import data
def getDataset(dataset_name="mnli"):
    if dataset_name=="mnli":
        filepath_train="/content/gdrive/MyDrive/NLPProject/data/multinli_1.0/multinli_1.0/multinli_1.0_train.jsonl"
        filepath_dev="/content/gdrive/MyDrive/NLPProject/data/multinli_1.0/multinli_1.0/multinli_1.0_dev_matched.jsonl"
        filepath_test="/content/gdrive/MyDrive/NLPProject/data/multinli_1.0/multinli_1.0/multinli_1.0_dev_mismatched.jsonl"
    elif dataset_name=="snli":
        filepath_train="/content/gdrive/MyDrive/NLPProject/data/snli_1.0/snli_1.0/snli_1.0_train.jsonl"
        filepath_dev="/content/gdrive/MyDrive/NLPProject/data/snli_1.0/snli_1.0/snli_1.0_dev.jsonl"
        filepath_test="/content/gdrive/MyDrive/NLPProject/data/snli_1.0/snli_1.0/snli_1.0_test.jsonl"
    else:
        print("Invalid dataset name")
        return None
    
    #read train,dev and test data
    labels = ["contradiction", "entailment", "neutral"]
    f= open(filepath_train, "r")
    data = list(f)
    train_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("train data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        train_dataset["premise"].append(line['sentence1'])
        train_dataset["hypothesis"].append(line['sentence2'])
        train_dataset["label"].append(line['gold_label'])
    f.close()

    f= open(filepath_dev, "r")
    data = list(f)
    dev_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("dev data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        dev_dataset["premise"].append(line['sentence1'])
        dev_dataset["hypothesis"].append(line['sentence2'])
        dev_dataset["label"].append(line['gold_label'])
    f.close()

    f= open(filepath_test, "r")
    data = list(f)
    test_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("test data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        test_dataset["premise"].append(line['sentence1'])
        test_dataset["hypothesis"].append(line['sentence2'])
        test_dataset["label"].append(line['gold_label'])
    f.close()
    
    return train_dataset,dev_dataset,test_dataset

def getWord2index(dataset):
    word2index = {"":0,UNK:1,PAD:2}
    for sentence in dataset["premise"]:
        for word in tokenize(sentence):
            if word not in word2index:
                word2index[word] = len(word2index)
    for sentence in dataset["hypothesis"]:
        for word in tokenize(sentence):
            if word not in word2index:
                word2index[word] = len(word2index)
    return word2index

def getEmbeddingMatrix(word2index,emb_size=300):
    embedding_matrix = np.zeros((len(word2index),emb_size),dtype=np.float32)
    for word, i in word2index.items():
        if i==0:
            embedding_matrix[i] = np.zeros(emb_size)
        elif word in embeddings_index:
            embedding_matrix[i] = embeddings_index[word]
        else:
            embedding_matrix[i] = np.random.uniform(-0.25,0.25,emb_size)
    return embedding_matrix

def getLabel2index(dataset):
    label2index = {"entailment":0,"neutral":1,"contradiction":2}
    return label2index

def getSentence2vector(sentence,word2index,padLength=32):
    sentence = tokenize(sentence)
    vector = []
    for word in sentence:
        if word in word2index:
            vector.append(word2index[word])
        else:
            vector.append(word2index[UNK])
    
    if len(vector)>padLength:
        vector=vector[:padLength]
    else:
        for i in range(padLength-len(vector)):
            vector.append(word2index[PAD])
    
    if(len(vector)!=padLength):
        print("Error in vector length")
    return np.array(vector)   

In [5]:
def preprocess(dataset,word2index,label2index):
    dataset["premise"] = [getSentence2vector(sentence,word2index) for sentence in tqdm(dataset["premise"])]
    dataset["hypothesis"] = [getSentence2vector(sentence,word2index) for sentence in tqdm(dataset["hypothesis"])]
    dataset["label"] = [label2index[label] for label in dataset["label"]]
    return dataset

def getDataloader(dataset,batch_size=32):
    # print(len(dataset["premise"]))
    premise = torch.tensor(dataset["premise"],dtype=torch.long)
    hypothesis = torch.tensor(dataset["hypothesis"],dtype=torch.long)
    labels = torch.tensor(dataset["label"],dtype=torch.long)
    dataset = torch.utils.data.TensorDataset(premise,hypothesis,labels)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def prepData(dataset_name="mnli"):
    train_dataset,dev_dataset,test_dataset = getDataset(dataset_name)
    word2index = getWord2index(train_dataset)
    label2index = getLabel2index(train_dataset)
    embedding_matrix = getEmbeddingMatrix(word2index)
    #preprocess datasets
    print("preprocess")
    train_dataset = preprocess(train_dataset,word2index,label2index)
    dev_dataset = preprocess(dev_dataset,word2index,label2index)
    test_dataset = preprocess(test_dataset,word2index,label2index)
    #dataloader
    train_dataloader = getDataloader(train_dataset)
    dev_dataloader = getDataloader(dev_dataset)
    test_dataloader = getDataloader(test_dataset)

    return train_dataloader,dev_dataloader,test_dataloader,embedding_matrix,word2index


In [6]:
def train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,epochs=5):
    
    
    f=open(f"/content/gdrive/MyDrive/NLPProject/reports_and_results/report_lstm_nli_{datasetName}_foreachepoch.txt",'w')
    
    total_loss=[]
    total_acc=[]
    for epoch in range(1,1+epochs):
        ep_loss=0
        ep_acc=0
        train_total=0
        val_total=0
        test_total=0
        print("\n\nEpoch: ",epoch)
        #training
        print("Training")
        for batch in tqdm(train_dataloader):
            prem, hyp, label = batch
            # print(len(prem),len(hyp),len(label))
            optimizer.zero_grad()
            prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
            output = model(prem, hyp)
            # print(len(output))
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            ep_loss+=loss.item()
            total_loss.append(loss.item())
            acc = (output.argmax(1) == label).sum().item()
            train_total+=label.size(0)
            ep_acc+=acc
        total_acc.append((ep_acc/train_total))
        
        print("Validation")
        #validation
        val_y_true = []
        val_y_pred = []
        with torch.no_grad():
            ep_val_loss=0
            ep_val_acc=0
            for batch in tqdm(dev_dataloader):
                prem, hyp, label = batch
                prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
                output = model(prem, hyp)
                loss = criterion(output, label)
                val_y_true.extend(label.cpu())
                val_y_pred.extend(output.cpu().argmax(1))
                ep_val_loss+=loss.item()
                acc = (output.argmax(1) == label).sum().item()
                val_total+=label.size(0)
                ep_val_acc+=acc
        
        print("Test")
        y_true = []
        y_pred = []
        with torch.no_grad():
            ep_test_loss=0
            ep_test_acc=0
            for batch in tqdm(test_dataloader):
                prem, hyp, label = batch
                prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
                output = model(prem, hyp)
                y_true.extend(label.cpu())
                y_pred.extend(output.cpu().argmax(1))
                loss = criterion(output, label)
                ep_test_loss+=loss.item()
                acc = (output.argmax(1) == label).sum().item()
                test_total+=label.size(0)
                ep_test_acc+=acc
                
        print("Epoch: ",epoch," Train Loss: ",ep_loss/len(train_dataloader),"Train Accuracy: ",ep_acc/train_total)

        print("Epoch: ",epoch," Val Loss: ",ep_val_loss/len(dev_dataloader)," Val Accuracy: ",ep_val_acc/val_total)

        print("Epoch: ",epoch,"Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)

        print()
        print("Validation Classification report")

        print(classification_report(val_y_true, val_y_pred, target_names=["entailment","neutral","contradiction"]))
        print("Test Classification report")
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
                
        # orginal_stdout=sys.stdout
        # sys.stdout=f 
        print("\n\nEpoch: ",epoch, file=f)

        
        print("Epoch: ",epoch," Train Loss: ",ep_loss/len(train_dataloader),"Train Accuracy: ",ep_acc/train_total,file=f)

        print("Epoch: ",epoch," Val Loss: ",ep_val_loss/len(dev_dataloader)," Val Accuracy: ",ep_val_acc/val_total,file=f)

        print("Epoch: ",epoch,"Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total,file=f)

        print(file=f)
        print("Validation Classification report",file=f)

        print(classification_report(val_y_true, val_y_pred, target_names=["entailment","neutral","contradiction"]),file=f)
        print("Test Classification report",file=f)
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]),file=f)
        #save model
        torch.save(model.state_dict(), f"/content/gdrive/MyDrive/NLPProject/models/model_lstm_nli_{datasetName}_ep_{epoch}.pt")
        # sys.stdout=orginal_stdout
        
    # sys.stdout=f
    print("Total Train Loss",sum(total_loss)/len(total_loss)," Total Train Accuracy: ",sum(total_acc)/len(total_acc))
              
    # sys.stdout=orginal_stdout
    print("Total Train Loss",sum(total_loss)/len(total_loss)," Total Train Accuracy: ",sum(total_acc)/len(total_acc),file=f)
    f.close()
    return model
    

In [7]:
def test(model,test_dataloader,criterion):
#test
    y_true = []
    y_pred = []
    test_total=0
    with torch.no_grad():
        ep_test_loss=0
        ep_test_acc=0
        for batch in tqdm(test_dataloader):
            prem, hyp, label = batch
            prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
            output = model(prem, hyp)
            y_true.extend(label.cpu())
            y_pred.extend(output.cpu().argmax(1))
            loss = criterion(output, label)
            ep_test_loss+=loss.item()
            acc = (output.argmax(1) == label).sum().item()
            ep_test_acc+=acc
            test_total+=label.size(0)

        print("Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
        print(confusion_matrix(y_true, y_pred))
        


In [8]:
EMBEDDING_DIM = 300
HIDDEN_DIM = 100
OUTPUT_DIM = 3
N_FC_LAYERS = 2
DROPOUT = 0.3
EPOCHS = 30
lr=0.001

In [9]:
datasetName="mnli"

mtrain_dataloader,mdev_dataloader,mtest_dataloader,membedding_matrix,mword2index = prepData(datasetName)

INPUT_DIM = len(mword2index)

train data


100%|██████████| 392702/392702 [00:02<00:00, 133877.92it/s]


dev data


100%|██████████| 10000/10000 [00:00<00:00, 131521.64it/s]


test data


100%|██████████| 10000/10000 [00:00<00:00, 127812.38it/s]


preprocess


100%|██████████| 392702/392702 [00:30<00:00, 13030.34it/s]
100%|██████████| 392702/392702 [00:18<00:00, 21090.42it/s]
100%|██████████| 9815/9815 [00:00<00:00, 13947.42it/s]
100%|██████████| 9815/9815 [00:00<00:00, 23576.51it/s]
100%|██████████| 9832/9832 [00:00<00:00, 12975.99it/s]
100%|██████████| 9832/9832 [00:00<00:00, 22929.62it/s]
  premise = torch.tensor(dataset["premise"],dtype=torch.long)


In [10]:
# %%capture captured_output


#initialize the model with above parameters

mnli_model = NLI_LSTM(membedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
#adam optimizer
mnli_model.to(device)
optimizer = optim.Adam(mnli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
#train
mnli_model = train(mnli_model,mtrain_dataloader,mdev_dataloader,mtest_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(mnli_model,mtest_dataloader,criterion)

# with open(f"/content/gdrive/MyDrive/NLPProject/reports_and_results/report_lstm_nli_mnli_foreachepoch.txt",'w') as f:
#     f.write(captured_output.stdout)




Epoch:  1
Training


100%|██████████| 12272/12272 [00:47<00:00, 256.90it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 568.50it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 532.67it/s]


Epoch:  1  Train Loss:  1.1339195525902657 Train Accuracy:  0.38280171733273577
Epoch:  1  Val Loss:  0.9871561915556072  Val Accuracy:  0.5001528273051452
Epoch:  1 Test Loss:  0.9650772415198289  Test Accuracy:  0.5075264442636289

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.48      0.56      0.52      3479
      neutral       0.44      0.45      0.44      3123
contradiction       0.61      0.48      0.54      3213

     accuracy                           0.50      9815
    macro avg       0.51      0.50      0.50      9815
 weighted avg       0.51      0.50      0.50      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.49      0.53      0.51      3463
      neutral       0.43      0.49      0.46      3129
contradiction       0.65      0.50      0.56      3240

     accuracy                           0.51      9832
    macro avg       0.52      0.51      0

100%|██████████| 12272/12272 [00:48<00:00, 254.00it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 476.33it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 389.95it/s]


Epoch:  2  Train Loss:  0.9939353745892119 Train Accuracy:  0.5131244556941396
Epoch:  2  Val Loss:  0.9949678136005464  Val Accuracy:  0.5303107488537953
Epoch:  2 Test Loss:  0.9953602877530184  Test Accuracy:  0.5309194467046379

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.45      0.91      0.61      3479
      neutral       0.62      0.20      0.30      3123
contradiction       0.78      0.44      0.56      3213

     accuracy                           0.53      9815
    macro avg       0.61      0.52      0.49      9815
 weighted avg       0.61      0.53      0.49      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.45      0.92      0.61      3463
      neutral       0.60      0.19      0.29      3129
contradiction       0.80      0.45      0.58      3240

     accuracy                           0.53      9832
    macro avg       0.62      0.52      0.

100%|██████████| 12272/12272 [00:48<00:00, 253.76it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 554.69it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 390.90it/s]


Epoch:  3  Train Loss:  0.955123092804342 Train Accuracy:  0.5448508028988902
Epoch:  3  Val Loss:  0.9843822036193326  Val Accuracy:  0.5352012226184412
Epoch:  3 Test Loss:  0.9571623554477444  Test Accuracy:  0.536106590724166

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.64      0.35      0.45      3479
      neutral       0.48      0.53      0.50      3123
contradiction       0.54      0.74      0.62      3213

     accuracy                           0.54      9815
    macro avg       0.55      0.54      0.53      9815
 weighted avg       0.55      0.54      0.52      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.34      0.45      3463
      neutral       0.46      0.55      0.50      3129
contradiction       0.56      0.73      0.63      3240

     accuracy                           0.54      9832
    macro avg       0.55      0.54      0.53

100%|██████████| 12272/12272 [00:48<00:00, 252.94it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 556.91it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 388.47it/s]


Epoch:  4  Train Loss:  0.9359817954686117 Train Accuracy:  0.5592714068173832
Epoch:  4  Val Loss:  0.8927540680096282  Val Accuracy:  0.5917473255221599
Epoch:  4 Test Loss:  0.885173562285188  Test Accuracy:  0.5914361269324654

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.69      0.62      3479
      neutral       0.55      0.47      0.51      3123
contradiction       0.68      0.60      0.64      3213

     accuracy                           0.59      9815
    macro avg       0.60      0.59      0.59      9815
 weighted avg       0.60      0.59      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.70      0.62      3463
      neutral       0.54      0.47      0.50      3129
contradiction       0.71      0.60      0.65      3240

     accuracy                           0.59      9832
    macro avg       0.60      0.59      0.5

100%|██████████| 12272/12272 [00:49<00:00, 250.39it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 419.19it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 382.28it/s]


Epoch:  5  Train Loss:  0.921299452348242 Train Accuracy:  0.5690039775707788
Epoch:  5  Val Loss:  0.9964945828875812  Val Accuracy:  0.5242995415180846
Epoch:  5 Test Loss:  0.979661915789951  Test Accuracy:  0.5203417412530512

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.64      0.41      0.50      3479
      neutral       0.60      0.31      0.41      3123
contradiction       0.46      0.86      0.60      3213

     accuracy                           0.52      9815
    macro avg       0.57      0.53      0.50      9815
 weighted avg       0.57      0.52      0.50      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.40      0.49      3463
      neutral       0.56      0.30      0.39      3129
contradiction       0.47      0.86      0.61      3240

     accuracy                           0.52      9832
    macro avg       0.55      0.52      0.50

100%|██████████| 12272/12272 [00:48<00:00, 251.82it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 452.96it/s]


Test


100%|██████████| 308/308 [00:01<00:00, 286.73it/s]


Epoch:  6  Train Loss:  0.91221976779751 Train Accuracy:  0.5760653116103305
Epoch:  6  Val Loss:  0.8821630868150667  Val Accuracy:  0.6015282730514518
Epoch:  6 Test Loss:  0.8750765414594056  Test Accuracy:  0.605166802278275

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.76      0.64      3479
      neutral       0.56      0.46      0.50      3123
contradiction       0.72      0.57      0.64      3213

     accuracy                           0.60      9815
    macro avg       0.61      0.60      0.60      9815
 weighted avg       0.61      0.60      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.76      0.64      3463
      neutral       0.56      0.47      0.51      3129
contradiction       0.74      0.57      0.65      3240

     accuracy                           0.61      9832
    macro avg       0.62      0.60      0.60 

100%|██████████| 12272/12272 [00:53<00:00, 227.71it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 379.62it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 565.07it/s]


Epoch:  7  Train Loss:  0.9060540228810118 Train Accuracy:  0.5804630483165352
Epoch:  7  Val Loss:  0.8875146482977107  Val Accuracy:  0.5939887926642894
Epoch:  7 Test Loss:  0.8751668974563673  Test Accuracy:  0.5944873881204231

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.69      0.62      3479
      neutral       0.58      0.42      0.49      3123
contradiction       0.64      0.66      0.65      3213

     accuracy                           0.59      9815
    macro avg       0.59      0.59      0.59      9815
 weighted avg       0.59      0.59      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.70      0.62      3463
      neutral       0.57      0.41      0.48      3129
contradiction       0.66      0.66      0.66      3240

     accuracy                           0.59      9832
    macro avg       0.60      0.59      0.

100%|██████████| 12272/12272 [00:49<00:00, 246.81it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 381.43it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 552.60it/s]


Epoch:  8  Train Loss:  0.8956930224177813 Train Accuracy:  0.5872289929768628
Epoch:  8  Val Loss:  0.8725680763247735  Val Accuracy:  0.6012226184411615
Epoch:  8 Test Loss:  0.8703588063453699  Test Accuracy:  0.5961147274206672

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.58      0.64      0.61      3479
      neutral       0.55      0.54      0.55      3123
contradiction       0.69      0.61      0.65      3213

     accuracy                           0.60      9815
    macro avg       0.60      0.60      0.60      9815
 weighted avg       0.60      0.60      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.64      0.60      3463
      neutral       0.52      0.54      0.53      3129
contradiction       0.72      0.61      0.66      3240

     accuracy                           0.60      9832
    macro avg       0.60      0.59      0.

100%|██████████| 12272/12272 [00:49<00:00, 247.22it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 560.25it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 547.59it/s]


Epoch:  9  Train Loss:  0.8909064949607911 Train Accuracy:  0.5909340925179908
Epoch:  9  Val Loss:  0.8667466397782491  Val Accuracy:  0.6124299541518085
Epoch:  9 Test Loss:  0.8597485552747528  Test Accuracy:  0.6041497152156224

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.77      0.65      3479
      neutral       0.58      0.47      0.52      3123
contradiction       0.74      0.58      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.63      0.61      0.61      9815
 weighted avg       0.63      0.61      0.61      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.75      0.64      3463
      neutral       0.55      0.48      0.51      3129
contradiction       0.76      0.57      0.65      3240

     accuracy                           0.60      9832
    macro avg       0.62      0.60      0.

100%|██████████| 12272/12272 [00:49<00:00, 246.13it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 545.46it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 508.60it/s]


Epoch:  10  Train Loss:  0.8859390498805583 Train Accuracy:  0.5942826876359173
Epoch:  10  Val Loss:  0.9174848884247025  Val Accuracy:  0.5981660723382578
Epoch:  10 Test Loss:  0.914926820567676  Test Accuracy:  0.5958096013018714

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.53      0.86      0.66      3479
      neutral       0.69      0.28      0.40      3123
contradiction       0.69      0.62      0.65      3213

     accuracy                           0.60      9815
    macro avg       0.64      0.59      0.57      9815
 weighted avg       0.63      0.60      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.52      0.86      0.65      3463
      neutral       0.66      0.27      0.39      3129
contradiction       0.71      0.62      0.66      3240

     accuracy                           0.60      9832
    macro avg       0.63      0.59      

100%|██████████| 12272/12272 [00:50<00:00, 243.75it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 549.57it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 544.47it/s]


Epoch:  11  Train Loss:  0.8806241412696852 Train Accuracy:  0.59656176948424
Epoch:  11  Val Loss:  0.9099871780274357  Val Accuracy:  0.5736118186449313
Epoch:  11 Test Loss:  0.9102115344691586  Test Accuracy:  0.5725183075671277

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.42      0.51      3479
      neutral       0.49      0.62      0.55      3123
contradiction       0.61      0.70      0.65      3213

     accuracy                           0.57      9815
    macro avg       0.58      0.58      0.57      9815
 weighted avg       0.59      0.57      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.43      0.51      3463
      neutral       0.48      0.60      0.54      3129
contradiction       0.64      0.70      0.66      3240

     accuracy                           0.57      9832
    macro avg       0.58      0.58      0

100%|██████████| 12272/12272 [00:49<00:00, 249.62it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 575.44it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 551.87it/s]


Epoch:  12  Train Loss:  0.8829293975720699 Train Accuracy:  0.5975447031082093
Epoch:  12  Val Loss:  0.8864622888813578  Val Accuracy:  0.6041772796739684
Epoch:  12 Test Loss:  0.8808493918025648  Test Accuracy:  0.5992676973148902

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.64      0.62      3479
      neutral       0.64      0.41      0.50      3123
contradiction       0.59      0.75      0.66      3213

     accuracy                           0.60      9815
    macro avg       0.61      0.60      0.59      9815
 weighted avg       0.61      0.60      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.66      0.63      3463
      neutral       0.62      0.40      0.48      3129
contradiction       0.59      0.73      0.66      3240

     accuracy                           0.60      9832
    macro avg       0.60      0.60     

100%|██████████| 12272/12272 [00:49<00:00, 249.44it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 576.08it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 554.57it/s]


Epoch:  13  Train Loss:  0.873731958756352 Train Accuracy:  0.6013593004364632
Epoch:  13  Val Loss:  0.859783525381492  Val Accuracy:  0.612633723892002
Epoch:  13 Test Loss:  0.858691183100273  Test Accuracy:  0.6119812855980472

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.79      0.65      3479
      neutral       0.62      0.42      0.50      3123
contradiction       0.72      0.61      0.66      3213

     accuracy                           0.61      9815
    macro avg       0.63      0.61      0.60      9815
 weighted avg       0.63      0.61      0.61      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.80      0.65      3463
      neutral       0.61      0.42      0.49      3129
contradiction       0.74      0.60      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.61      0.6

100%|██████████| 12272/12272 [00:49<00:00, 249.26it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 539.56it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 557.98it/s]


Epoch:  14  Train Loss:  0.8705957472892636 Train Accuracy:  0.6045245504224578
Epoch:  14  Val Loss:  0.9950636530155468  Val Accuracy:  0.580030565461029
Epoch:  14 Test Loss:  0.9675192703287323  Test Accuracy:  0.5899104963384866

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.74      0.63      3479
      neutral       0.73      0.23      0.34      3123
contradiction       0.57      0.75      0.65      3213

     accuracy                           0.58      9815
    macro avg       0.62      0.57      0.54      9815
 weighted avg       0.62      0.58      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.77      0.65      3463
      neutral       0.72      0.22      0.34      3129
contradiction       0.59      0.75      0.66      3240

     accuracy                           0.59      9832
    macro avg       0.63      0.58      

100%|██████████| 12272/12272 [00:48<00:00, 250.75it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 561.17it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 549.77it/s]


Epoch:  15  Train Loss:  0.8699438528144414 Train Accuracy:  0.6065515327143738
Epoch:  15  Val Loss:  0.8895130881657429  Val Accuracy:  0.6009169638308711
Epoch:  15 Test Loss:  0.8879072542314406  Test Accuracy:  0.5990642799023597

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.83      0.65      3479
      neutral       0.65      0.35      0.46      3123
contradiction       0.69      0.59      0.64      3213

     accuracy                           0.60      9815
    macro avg       0.63      0.59      0.58      9815
 weighted avg       0.62      0.60      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.53      0.83      0.65      3463
      neutral       0.64      0.34      0.44      3129
contradiction       0.71      0.60      0.65      3240

     accuracy                           0.60      9832
    macro avg       0.63      0.59     

100%|██████████| 12272/12272 [00:49<00:00, 247.59it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 544.12it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 553.67it/s]


Epoch:  16  Train Loss:  0.8681552492645527 Train Accuracy:  0.606439488467082
Epoch:  16  Val Loss:  0.8645022587201494  Val Accuracy:  0.615078960774325
Epoch:  16 Test Loss:  0.8646043844811329  Test Accuracy:  0.6139137510170871

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.67      0.63      3479
      neutral       0.55      0.58      0.56      3123
contradiction       0.73      0.59      0.66      3213

     accuracy                           0.62      9815
    macro avg       0.62      0.61      0.62      9815
 weighted avg       0.62      0.62      0.62      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.67      0.63      3463
      neutral       0.54      0.58      0.56      3129
contradiction       0.74      0.59      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.61      0

100%|██████████| 12272/12272 [00:49<00:00, 249.50it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 527.00it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 574.93it/s]


Epoch:  17  Train Loss:  0.8627858251804362 Train Accuracy:  0.6101369486277126
Epoch:  17  Val Loss:  0.9293595719415124  Val Accuracy:  0.5771777890983188
Epoch:  17 Test Loss:  0.9160511333059955  Test Accuracy:  0.580553295362083

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.44      0.53      3479
      neutral       0.55      0.53      0.54      3123
contradiction       0.55      0.77      0.64      3213

     accuracy                           0.58      9815
    macro avg       0.59      0.58      0.57      9815
 weighted avg       0.59      0.58      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.45      0.53      3463
      neutral       0.55      0.54      0.54      3129
contradiction       0.57      0.76      0.65      3240

     accuracy                           0.58      9832
    macro avg       0.59      0.58      

100%|██████████| 12272/12272 [00:49<00:00, 248.24it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 553.50it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 379.69it/s]


Epoch:  18  Train Loss:  0.8623975596112603 Train Accuracy:  0.6098721167704774
Epoch:  18  Val Loss:  0.915097862190843  Val Accuracy:  0.596128374936322
Epoch:  18 Test Loss:  0.8962891634021487  Test Accuracy:  0.6049633848657445

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.81      0.66      3479
      neutral       0.69      0.27      0.39      3123
contradiction       0.63      0.68      0.65      3213

     accuracy                           0.60      9815
    macro avg       0.62      0.59      0.57      9815
 weighted avg       0.62      0.60      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.83      0.66      3463
      neutral       0.68      0.28      0.40      3129
contradiction       0.66      0.68      0.67      3240

     accuracy                           0.60      9832
    macro avg       0.63      0.60      0

100%|██████████| 12272/12272 [00:49<00:00, 246.21it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 370.78it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 548.78it/s]


Epoch:  19  Train Loss:  0.860902876438964 Train Accuracy:  0.6105723933160514
Epoch:  19  Val Loss:  0.8645170593106397  Val Accuracy:  0.6173204279164544
Epoch:  19 Test Loss:  0.8597949400737688  Test Accuracy:  0.6176769731489016

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.63      0.62      3479
      neutral       0.55      0.58      0.57      3123
contradiction       0.69      0.64      0.66      3213

     accuracy                           0.62      9815
    macro avg       0.62      0.62      0.62      9815
 weighted avg       0.62      0.62      0.62      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.65      0.63      3463
      neutral       0.55      0.59      0.57      3129
contradiction       0.71      0.61      0.66      3240

     accuracy                           0.62      9832
    macro avg       0.62      0.62      

100%|██████████| 12272/12272 [00:49<00:00, 247.81it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 551.70it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 555.56it/s]


Epoch:  20  Train Loss:  0.858694066275324 Train Accuracy:  0.6123090791490748
Epoch:  20  Val Loss:  0.8617330313893794  Val Accuracy:  0.6109016811003566
Epoch:  20 Test Loss:  0.8586033729763775  Test Accuracy:  0.6098454027664768

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.64      0.62      3479
      neutral       0.57      0.52      0.55      3123
contradiction       0.65      0.67      0.66      3213

     accuracy                           0.61      9815
    macro avg       0.61      0.61      0.61      9815
 weighted avg       0.61      0.61      0.61      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.65      0.62      3463
      neutral       0.55      0.51      0.53      3129
contradiction       0.68      0.66      0.67      3240

     accuracy                           0.61      9832
    macro avg       0.61      0.61      

100%|██████████| 12272/12272 [00:48<00:00, 252.12it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 561.02it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 562.35it/s]


Epoch:  21  Train Loss:  0.8590463239413041 Train Accuracy:  0.6131519574639294
Epoch:  21  Val Loss:  0.9491038180717816  Val Accuracy:  0.6018339276617423
Epoch:  21 Test Loss:  0.9577938624984258  Test Accuracy:  0.5950976403580146

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.52      0.88      0.65      3479
      neutral       0.70      0.31      0.43      3123
contradiction       0.75      0.58      0.65      3213

     accuracy                           0.60      9815
    macro avg       0.65      0.59      0.58      9815
 weighted avg       0.65      0.60      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.51      0.89      0.65      3463
      neutral       0.68      0.29      0.41      3129
contradiction       0.76      0.57      0.65      3240

     accuracy                           0.60      9832
    macro avg       0.65      0.58     

100%|██████████| 12272/12272 [00:48<00:00, 251.37it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 559.06it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 546.36it/s]


Epoch:  22  Train Loss:  0.8558580754021233 Train Accuracy:  0.6147613202886667
Epoch:  22  Val Loss:  0.8665680767853019  Val Accuracy:  0.6197656647987774
Epoch:  22 Test Loss:  0.8521216298852649  Test Accuracy:  0.6222538649308381

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.71      0.64      3479
      neutral       0.58      0.52      0.55      3123
contradiction       0.71      0.62      0.66      3213

     accuracy                           0.62      9815
    macro avg       0.63      0.62      0.62      9815
 weighted avg       0.62      0.62      0.62      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.72      0.65      3463
      neutral       0.57      0.52      0.55      3129
contradiction       0.73      0.62      0.67      3240

     accuracy                           0.62      9832
    macro avg       0.63      0.62     

100%|██████████| 12272/12272 [00:48<00:00, 252.31it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 551.60it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 527.82it/s]


Epoch:  23  Train Loss:  0.8549515115026739 Train Accuracy:  0.6154692362147379
Epoch:  23  Val Loss:  0.9203246346514078  Val Accuracy:  0.5904228222109017
Epoch:  23 Test Loss:  0.9062424605930006  Test Accuracy:  0.5859438567941416

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.50      0.57      3479
      neutral       0.57      0.52      0.54      3123
contradiction       0.57      0.75      0.65      3213

     accuracy                           0.59      9815
    macro avg       0.60      0.59      0.59      9815
 weighted avg       0.60      0.59      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.50      0.56      3463
      neutral       0.56      0.53      0.55      3129
contradiction       0.57      0.73      0.64      3240

     accuracy                           0.59      9832
    macro avg       0.59      0.59     

100%|██████████| 12272/12272 [00:48<00:00, 251.93it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 552.28it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 547.14it/s]


Epoch:  24  Train Loss:  0.8529767571976397 Train Accuracy:  0.6161287693976603
Epoch:  24  Val Loss:  0.8938137153072544  Val Accuracy:  0.6123280692817117
Epoch:  24 Test Loss:  0.9002456533444392  Test Accuracy:  0.6094385679414158

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.83      0.66      3479
      neutral       0.62      0.42      0.50      3123
contradiction       0.77      0.56      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.64      0.60      0.60      9815
 weighted avg       0.64      0.61      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.84      0.66      3463
      neutral       0.61      0.43      0.50      3129
contradiction       0.79      0.54      0.64      3240

     accuracy                           0.61      9832
    macro avg       0.65      0.60     

100%|██████████| 12272/12272 [00:48<00:00, 253.93it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 563.53it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 544.50it/s]


Epoch:  25  Train Loss:  0.8525959626957288 Train Accuracy:  0.616767930899257
Epoch:  25  Val Loss:  0.916709728660335  Val Accuracy:  0.5801324503311258
Epoch:  25 Test Loss:  0.9056701356327379  Test Accuracy:  0.5783157038242474

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.41      0.51      3479
      neutral       0.55      0.58      0.56      3123
contradiction       0.56      0.77      0.65      3213

     accuracy                           0.58      9815
    macro avg       0.59      0.58      0.57      9815
 weighted avg       0.60      0.58      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.42      0.51      3463
      neutral       0.52      0.58      0.55      3129
contradiction       0.58      0.75      0.65      3240

     accuracy                           0.58      9832
    macro avg       0.59      0.58      0

100%|██████████| 12272/12272 [00:48<00:00, 253.68it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 541.04it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 377.66it/s]


Epoch:  26  Train Loss:  0.8519464958910836 Train Accuracy:  0.6184027583256515
Epoch:  26  Val Loss:  0.9179715605434455  Val Accuracy:  0.6055017829852267
Epoch:  26 Test Loss:  0.9251190527499497  Test Accuracy:  0.6043531326281529

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.52      0.87      0.65      3479
      neutral       0.65      0.39      0.49      3123
contradiction       0.81      0.52      0.63      3213

     accuracy                           0.61      9815
    macro avg       0.66      0.60      0.59      9815
 weighted avg       0.65      0.61      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.52      0.88      0.65      3463
      neutral       0.64      0.38      0.48      3129
contradiction       0.82      0.53      0.64      3240

     accuracy                           0.60      9832
    macro avg       0.66      0.60     

100%|██████████| 12272/12272 [00:48<00:00, 254.04it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 562.81it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 536.93it/s]


Epoch:  27  Train Loss:  0.851361037874315 Train Accuracy:  0.6173077804543904
Epoch:  27  Val Loss:  0.9362718565844558  Val Accuracy:  0.6012226184411615
Epoch:  27 Test Loss:  0.9498773010133149  Test Accuracy:  0.596826688364524

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.79      0.66      3479
      neutral       0.53      0.56      0.55      3123
contradiction       0.85      0.44      0.58      3213

     accuracy                           0.60      9815
    macro avg       0.65      0.60      0.59      9815
 weighted avg       0.65      0.60      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.79      0.65      3463
      neutral       0.52      0.55      0.54      3129
contradiction       0.86      0.44      0.58      3240

     accuracy                           0.60      9832
    macro avg       0.65      0.59      0

100%|██████████| 12272/12272 [00:48<00:00, 251.76it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 563.04it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 554.93it/s]


Epoch:  28  Train Loss:  0.8495493013304907 Train Accuracy:  0.6193347627463064
Epoch:  28  Val Loss:  1.0464669720357715  Val Accuracy:  0.5224656138563424
Epoch:  28 Test Loss:  1.0456511322167013  Test Accuracy:  0.5159682668836453

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.21      0.32      3479
      neutral       0.41      0.86      0.56      3123
contradiction       0.74      0.53      0.62      3213

     accuracy                           0.52      9815
    macro avg       0.63      0.53      0.50      9815
 weighted avg       0.63      0.52      0.50      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.20      0.31      3463
      neutral       0.40      0.87      0.55      3129
contradiction       0.77      0.51      0.62      3240

     accuracy                           0.52      9832
    macro avg       0.63      0.53     

100%|██████████| 12272/12272 [00:48<00:00, 251.43it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 541.66it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 561.68it/s]


Epoch:  29  Train Loss:  0.8476797116524509 Train Accuracy:  0.6201954662823209
Epoch:  29  Val Loss:  0.855889134570131  Val Accuracy:  0.6156902699949057
Epoch:  29 Test Loss:  0.85329698239054  Test Accuracy:  0.6192026037428804

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.73      0.64      3479
      neutral       0.57      0.54      0.56      3123
contradiction       0.75      0.56      0.64      3213

     accuracy                           0.62      9815
    macro avg       0.63      0.61      0.61      9815
 weighted avg       0.63      0.62      0.61      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.75      0.65      3463
      neutral       0.56      0.54      0.55      3129
contradiction       0.78      0.56      0.65      3240

     accuracy                           0.62      9832
    macro avg       0.64      0.62      0.

100%|██████████| 12272/12272 [00:49<00:00, 248.73it/s]


Validation


100%|██████████| 307/307 [00:00<00:00, 549.33it/s]


Test


100%|██████████| 308/308 [00:00<00:00, 540.64it/s]


Epoch:  30  Train Loss:  0.8484391510350607 Train Accuracy:  0.6196097804442046
Epoch:  30  Val Loss:  0.8859996085058206  Val Accuracy:  0.6202750891492613
Epoch:  30 Test Loss:  0.8852921820112637  Test Accuracy:  0.6137103336045565

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.76      0.65      3479
      neutral       0.57      0.57      0.57      3123
contradiction       0.80      0.52      0.63      3213

     accuracy                           0.62      9815
    macro avg       0.65      0.62      0.62      9815
 weighted avg       0.65      0.62      0.62      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.76      0.65      3463
      neutral       0.55      0.56      0.56      3129
contradiction       0.83      0.51      0.63      3240

     accuracy                           0.61      9832
    macro avg       0.65      0.61     

100%|██████████| 308/308 [00:00<00:00, 472.54it/s]


Test Loss:  0.8868703764754456  Test Accuracy:  0.6141171684296176
               precision    recall  f1-score   support

   entailment       0.56      0.75      0.64      3463
      neutral       0.56      0.57      0.56      3129
contradiction       0.83      0.51      0.63      3240

     accuracy                           0.61      9832
    macro avg       0.65      0.61      0.61      9832
 weighted avg       0.65      0.61      0.61      9832

[[2612  729  122]
 [1130 1771  228]
 [ 897  688 1655]]


In [11]:
datasetName="snli"

strain_dataloader,sdev_dataloader,stest_dataloader,sembedding_matrix,sword2index = prepData(datasetName)
#intialize the model
INPUT_DIM = len(sword2index)

train data


100%|██████████| 550152/550152 [00:04<00:00, 123920.91it/s]


dev data


100%|██████████| 10000/10000 [00:00<00:00, 86322.09it/s]


test data


100%|██████████| 10000/10000 [00:00<00:00, 95045.36it/s]


preprocess


100%|██████████| 549367/549367 [00:22<00:00, 24501.34it/s]
100%|██████████| 549367/549367 [00:19<00:00, 28672.33it/s]
100%|██████████| 9842/9842 [00:00<00:00, 23686.03it/s]
100%|██████████| 9842/9842 [00:00<00:00, 31595.81it/s]
100%|██████████| 9824/9824 [00:00<00:00, 23695.50it/s]
100%|██████████| 9824/9824 [00:00<00:00, 31308.32it/s]


In [12]:
#initialize the model with above parameters
snli_model = NLI_LSTM(sembedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
snli_model.to(device)
#adam optimizer
optimizer = optim.Adam(snli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
#train
snli_model = train(snli_model,strain_dataloader,sdev_dataloader,stest_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(snli_model,stest_dataloader,criterion)



Epoch:  1
Training


100%|██████████| 17168/17168 [01:07<00:00, 253.00it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 526.12it/s]


Test


100%|██████████| 307/307 [00:01<00:00, 269.80it/s]


Epoch:  1  Train Loss:  1.1699728949858224 Train Accuracy:  0.3331598002792305
Epoch:  1  Val Loss:  1.1507832916913094  Val Accuracy:  0.33275756959967484
Epoch:  1 Test Loss:  1.1521396712682146  Test Accuracy:  0.33377442996742673

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.33      0.04      0.08      3329
      neutral       0.32      0.10      0.15      3235
contradiction       0.33      0.86      0.48      3278

     accuracy                           0.33      9842
    macro avg       0.33      0.33      0.24      9842
 weighted avg       0.33      0.33      0.24      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.37      0.05      0.08      3368
      neutral       0.34      0.10      0.16      3219
contradiction       0.33      0.86      0.48      3237

     accuracy                           0.33      9824
    macro avg       0.35      0.34      

100%|██████████| 17168/17168 [01:10<00:00, 243.29it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 534.34it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 536.09it/s]


Epoch:  2  Train Loss:  1.162398468175311 Train Accuracy:  0.33435572213110726
Epoch:  2  Val Loss:  1.1176165768078394  Val Accuracy:  0.33722820564925826
Epoch:  2 Test Loss:  1.1232075095176697  Test Accuracy:  0.3277687296416938

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.34      0.17      0.23      3329
      neutral       0.33      0.18      0.24      3235
contradiction       0.34      0.66      0.45      3278

     accuracy                           0.34      9842
    macro avg       0.34      0.34      0.30      9842
 weighted avg       0.34      0.34      0.30      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.32      0.17      0.22      3368
      neutral       0.34      0.19      0.24      3219
contradiction       0.33      0.63      0.43      3237

     accuracy                           0.33      9824
    macro avg       0.33      0.33      0

100%|██████████| 17168/17168 [01:09<00:00, 247.09it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 550.22it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 549.16it/s]


Epoch:  3  Train Loss:  1.1616942390577325 Train Accuracy:  0.3335129339767405
Epoch:  3  Val Loss:  1.1331661242943305  Val Accuracy:  0.3386506807559439
Epoch:  3 Test Loss:  1.1348194582842848  Test Accuracy:  0.33418159609120524

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.37      0.09      0.14      3329
      neutral       0.34      0.76      0.47      3235
contradiction       0.33      0.18      0.23      3278

     accuracy                           0.34      9842
    macro avg       0.35      0.34      0.28      9842
 weighted avg       0.35      0.34      0.28      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.37      0.08      0.14      3368
      neutral       0.33      0.75      0.46      3219
contradiction       0.34      0.18      0.24      3237

     accuracy                           0.33      9824
    macro avg       0.35      0.34      0

100%|██████████| 17168/17168 [01:08<00:00, 250.89it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 555.14it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 448.10it/s]


Epoch:  4  Train Loss:  1.1619245509329736 Train Accuracy:  0.33357118283406173
Epoch:  4  Val Loss:  1.152089176433427  Val Accuracy:  0.32625482625482627
Epoch:  4 Test Loss:  1.1459083522181557  Test Accuracy:  0.33112785016286644

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.32      0.02      0.04      3329
      neutral       0.32      0.25      0.28      3235
contradiction       0.33      0.71      0.45      3278

     accuracy                           0.33      9842
    macro avg       0.32      0.33      0.26      9842
 weighted avg       0.32      0.33      0.26      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.40      0.02      0.04      3368
      neutral       0.33      0.25      0.29      3219
contradiction       0.33      0.73      0.46      3237

     accuracy                           0.33      9824
    macro avg       0.35      0.34      

100%|██████████| 17168/17168 [01:08<00:00, 249.65it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 555.70it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 532.03it/s]


Epoch:  5  Train Loss:  1.161086866317902 Train Accuracy:  0.3342173810949693
Epoch:  5  Val Loss:  1.142228619812371  Val Accuracy:  0.4266409266409266
Epoch:  5 Test Loss:  1.1331287486545427  Test Accuracy:  0.43699104234527686

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.40      0.83      0.54      3329
      neutral       0.49      0.45      0.47      3235
contradiction       0.28      0.00      0.00      3278

     accuracy                           0.43      9842
    macro avg       0.39      0.42      0.34      9842
 weighted avg       0.39      0.43      0.34      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.41      0.84      0.55      3368
      neutral       0.51      0.45      0.48      3219
contradiction       0.38      0.00      0.00      3237

     accuracy                           0.44      9824
    macro avg       0.43      0.43      0.3

100%|██████████| 17168/17168 [01:09<00:00, 245.89it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 527.59it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 554.13it/s]


Epoch:  6  Train Loss:  0.997449380717998 Train Accuracy:  0.5208631024433575
Epoch:  6  Val Loss:  0.8899080757196848  Val Accuracy:  0.5948994106888844
Epoch:  6 Test Loss:  0.8754816894034221  Test Accuracy:  0.6039291530944625

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.61      0.61      3329
      neutral       0.58      0.63      0.60      3235
contradiction       0.60      0.54      0.57      3278

     accuracy                           0.59      9842
    macro avg       0.60      0.59      0.59      9842
 weighted avg       0.60      0.59      0.59      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.65      0.64      3368
      neutral       0.58      0.64      0.61      3219
contradiction       0.61      0.53      0.57      3237

     accuracy                           0.60      9824
    macro avg       0.60      0.60      0.6

100%|██████████| 17168/17168 [01:09<00:00, 248.08it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 480.97it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 525.48it/s]


Epoch:  7  Train Loss:  0.8970327290644712 Train Accuracy:  0.5943149115254466
Epoch:  7  Val Loss:  0.8775813736311802  Val Accuracy:  0.6008941272099166
Epoch:  7 Test Loss:  0.8687125907464602  Test Accuracy:  0.6050488599348535

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.80      0.66      3329
      neutral       0.75      0.40      0.52      3235
contradiction       0.58      0.60      0.59      3278

     accuracy                           0.60      9842
    macro avg       0.63      0.60      0.59      9842
 weighted avg       0.63      0.60      0.59      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.81      0.67      3368
      neutral       0.76      0.40      0.53      3219
contradiction       0.58      0.59      0.59      3237

     accuracy                           0.61      9824
    macro avg       0.64      0.60      0.

100%|██████████| 17168/17168 [01:08<00:00, 248.89it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 546.34it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 489.22it/s]


Epoch:  8  Train Loss:  0.8578639078732556 Train Accuracy:  0.6182206066254434
Epoch:  8  Val Loss:  0.8459704946000854  Val Accuracy:  0.6088193456614509
Epoch:  8 Test Loss:  0.8386265080215877  Test Accuracy:  0.620928338762215

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.57      0.62      3329
      neutral       0.68      0.50      0.58      3235
contradiction       0.53      0.76      0.63      3278

     accuracy                           0.61      9842
    macro avg       0.63      0.61      0.61      9842
 weighted avg       0.63      0.61      0.61      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.59      0.64      3368
      neutral       0.69      0.52      0.59      3219
contradiction       0.54      0.75      0.63      3237

     accuracy                           0.62      9824
    macro avg       0.64      0.62      0.6

100%|██████████| 17168/17168 [01:09<00:00, 248.37it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 541.73it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 536.70it/s]


Epoch:  9  Train Loss:  0.8351178014701283 Train Accuracy:  0.6332124062784987
Epoch:  9  Val Loss:  0.8113451876810619  Val Accuracy:  0.6510871774029668
Epoch:  9 Test Loss:  0.7982748262462864  Test Accuracy:  0.6553338762214984

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.78      0.69      3329
      neutral       0.74      0.51      0.60      3235
contradiction       0.63      0.66      0.65      3278

     accuracy                           0.65      9842
    macro avg       0.66      0.65      0.65      9842
 weighted avg       0.66      0.65      0.65      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.79      0.70      3368
      neutral       0.74      0.52      0.61      3219
contradiction       0.63      0.66      0.64      3237

     accuracy                           0.66      9824
    macro avg       0.67      0.65      0.

100%|██████████| 17168/17168 [01:11<00:00, 241.50it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 532.12it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 534.95it/s]


Epoch:  10  Train Loss:  0.8196105385652265 Train Accuracy:  0.642754297218435
Epoch:  10  Val Loss:  0.7743290106390978  Val Accuracy:  0.6611461085145296
Epoch:  10 Test Loss:  0.767527274569005  Test Accuracy:  0.662764657980456

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.70      0.68      3329
      neutral       0.67      0.63      0.65      3235
contradiction       0.65      0.65      0.65      3278

     accuracy                           0.66      9842
    macro avg       0.66      0.66      0.66      9842
 weighted avg       0.66      0.66      0.66      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.71      0.69      3368
      neutral       0.68      0.62      0.65      3219
contradiction       0.65      0.65      0.65      3237

     accuracy                           0.66      9824
    macro avg       0.66      0.66      0.

100%|██████████| 17168/17168 [01:10<00:00, 244.35it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 537.40it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 373.32it/s]


Epoch:  11  Train Loss:  0.8058844716457402 Train Accuracy:  0.651493446093413
Epoch:  11  Val Loss:  0.7687037342361042  Val Accuracy:  0.6705954074375127
Epoch:  11 Test Loss:  0.7706192893003407  Test Accuracy:  0.6725366449511401

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.76      0.70      3329
      neutral       0.62      0.73      0.67      3235
contradiction       0.78      0.53      0.63      3278

     accuracy                           0.67      9842
    macro avg       0.69      0.67      0.67      9842
 weighted avg       0.69      0.67      0.67      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.76      0.71      3368
      neutral       0.62      0.73      0.67      3219
contradiction       0.78      0.53      0.63      3237

     accuracy                           0.67      9824
    macro avg       0.69      0.67      

100%|██████████| 17168/17168 [01:09<00:00, 246.14it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 314.55it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 498.29it/s]


Epoch:  12  Train Loss:  0.7974897143374591 Train Accuracy:  0.6565428939124484
Epoch:  12  Val Loss:  0.7834705367877886  Val Accuracy:  0.6567770778297094
Epoch:  12 Test Loss:  0.7776324348069169  Test Accuracy:  0.6628664495114006

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.72      0.70      3329
      neutral       0.76      0.50      0.60      3235
contradiction       0.59      0.75      0.66      3278

     accuracy                           0.66      9842
    macro avg       0.68      0.66      0.65      9842
 weighted avg       0.68      0.66      0.65      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.73      0.70      3368
      neutral       0.77      0.50      0.61      3219
contradiction       0.59      0.76      0.67      3237

     accuracy                           0.66      9824
    macro avg       0.68      0.66     

100%|██████████| 17168/17168 [01:09<00:00, 248.24it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 391.43it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 387.49it/s]


Epoch:  13  Train Loss:  0.7877276213188679 Train Accuracy:  0.661878125187716
Epoch:  13  Val Loss:  0.7784232569785862  Val Accuracy:  0.6781142044299939
Epoch:  13 Test Loss:  0.7712444425211668  Test Accuracy:  0.6767100977198697

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.85      0.71      3329
      neutral       0.68      0.65      0.67      3235
contradiction       0.81      0.53      0.64      3278

     accuracy                           0.68      9842
    macro avg       0.70      0.68      0.67      9842
 weighted avg       0.70      0.68      0.67      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.84      0.71      3368
      neutral       0.68      0.65      0.66      3219
contradiction       0.82      0.53      0.65      3237

     accuracy                           0.68      9824
    macro avg       0.70      0.67      

100%|██████████| 17168/17168 [01:10<00:00, 242.72it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 503.17it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 511.03it/s]


Epoch:  14  Train Loss:  0.7742509291735578 Train Accuracy:  0.670957665822665
Epoch:  14  Val Loss:  0.7516938116062771  Val Accuracy:  0.6834992887624467
Epoch:  14 Test Loss:  0.7382401493939204  Test Accuracy:  0.6936074918566775

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.67      0.70      3329
      neutral       0.64      0.69      0.67      3235
contradiction       0.69      0.69      0.69      3278

     accuracy                           0.68      9842
    macro avg       0.69      0.68      0.68      9842
 weighted avg       0.69      0.68      0.68      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.68      0.71      3368
      neutral       0.65      0.71      0.68      3219
contradiction       0.70      0.69      0.70      3237

     accuracy                           0.69      9824
    macro avg       0.70      0.69      

100%|██████████| 17168/17168 [01:09<00:00, 247.28it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 544.73it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 537.63it/s]


Epoch:  15  Train Loss:  0.7585036697687246 Train Accuracy:  0.6795311695096357
Epoch:  15  Val Loss:  0.7734397313037475  Val Accuracy:  0.668360089412721
Epoch:  15 Test Loss:  0.7742933427083765  Test Accuracy:  0.6699918566775245

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.85      0.72      3329
      neutral       0.82      0.42      0.55      3235
contradiction       0.65      0.74      0.69      3278

     accuracy                           0.67      9842
    macro avg       0.70      0.67      0.65      9842
 weighted avg       0.70      0.67      0.66      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.84      0.72      3368
      neutral       0.82      0.42      0.55      3219
contradiction       0.66      0.74      0.70      3237

     accuracy                           0.67      9824
    macro avg       0.70      0.67      

100%|██████████| 17168/17168 [01:09<00:00, 245.94it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 530.12it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 547.51it/s]


Epoch:  16  Train Loss:  0.7480729942011255 Train Accuracy:  0.6844277140782027
Epoch:  16  Val Loss:  0.7523268656684207  Val Accuracy:  0.6934566145092461
Epoch:  16 Test Loss:  0.7390511681474381  Test Accuracy:  0.7029723127035831

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.81      0.73      3329
      neutral       0.76      0.54      0.63      3235
contradiction       0.68      0.73      0.70      3278

     accuracy                           0.69      9842
    macro avg       0.70      0.69      0.69      9842
 weighted avg       0.70      0.69      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.82      0.74      3368
      neutral       0.76      0.55      0.64      3219
contradiction       0.70      0.73      0.71      3237

     accuracy                           0.70      9824
    macro avg       0.71      0.70     

100%|██████████| 17168/17168 [01:09<00:00, 245.59it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 535.28it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 543.05it/s]


Epoch:  17  Train Loss:  0.740153727643296 Train Accuracy:  0.6890439360209113
Epoch:  17  Val Loss:  0.7652201892493607  Val Accuracy:  0.6890875838244259
Epoch:  17 Test Loss:  0.7665928703177636  Test Accuracy:  0.6867874592833876

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.79      0.73      3329
      neutral       0.79      0.49      0.61      3235
contradiction       0.65      0.78      0.71      3278

     accuracy                           0.69      9842
    macro avg       0.71      0.69      0.68      9842
 weighted avg       0.71      0.69      0.68      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.79      0.73      3368
      neutral       0.78      0.48      0.60      3219
contradiction       0.65      0.78      0.71      3237

     accuracy                           0.69      9824
    macro avg       0.70      0.68      

100%|██████████| 17168/17168 [01:09<00:00, 247.10it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 380.75it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 390.52it/s]


Epoch:  18  Train Loss:  0.7325785213392245 Train Accuracy:  0.6938385450891662
Epoch:  18  Val Loss:  0.70621854589357  Val Accuracy:  0.7054460475513107
Epoch:  18 Test Loss:  0.7044012447521819  Test Accuracy:  0.7060260586319218

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.78      0.73      3329
      neutral       0.66      0.69      0.67      3235
contradiction       0.79      0.64      0.71      3278

     accuracy                           0.71      9842
    macro avg       0.71      0.70      0.70      9842
 weighted avg       0.71      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.80      0.74      3368
      neutral       0.66      0.68      0.67      3219
contradiction       0.79      0.64      0.71      3237

     accuracy                           0.71      9824
    macro avg       0.71      0.70      0

100%|██████████| 17168/17168 [01:08<00:00, 248.84it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 537.65it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 534.97it/s]


Epoch:  19  Train Loss:  0.7269228081542648 Train Accuracy:  0.6967054810354463
Epoch:  19  Val Loss:  0.8948443304602202  Val Accuracy:  0.6657183499288762
Epoch:  19 Test Loss:  0.8952326709556269  Test Accuracy:  0.6632736156351792

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.94      0.70      3329
      neutral       0.78      0.47      0.59      3235
contradiction       0.84      0.58      0.68      3278

     accuracy                           0.67      9842
    macro avg       0.73      0.66      0.66      9842
 weighted avg       0.72      0.67      0.66      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.94      0.70      3368
      neutral       0.77      0.46      0.58      3219
contradiction       0.84      0.57      0.68      3237

     accuracy                           0.66      9824
    macro avg       0.72      0.66     

100%|██████████| 17168/17168 [01:08<00:00, 248.84it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 530.89it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 560.28it/s]


Epoch:  20  Train Loss:  0.7221880797913015 Train Accuracy:  0.6982927623974502
Epoch:  20  Val Loss:  0.7066784426569939  Val Accuracy:  0.7098150782361309
Epoch:  20 Test Loss:  0.7015439559271747  Test Accuracy:  0.7166123778501629

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.81      0.74      3329
      neutral       0.72      0.60      0.66      3235
contradiction       0.73      0.71      0.72      3278

     accuracy                           0.71      9842
    macro avg       0.71      0.71      0.71      9842
 weighted avg       0.71      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.81      0.74      3368
      neutral       0.73      0.62      0.67      3219
contradiction       0.75      0.72      0.73      3237

     accuracy                           0.72      9824
    macro avg       0.72      0.72     

100%|██████████| 17168/17168 [01:09<00:00, 246.49it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 536.69it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 367.65it/s]


Epoch:  21  Train Loss:  0.7185404002041541 Train Accuracy:  0.7013617490675632
Epoch:  21  Val Loss:  0.7901805541538572  Val Accuracy:  0.6594188173135541
Epoch:  21 Test Loss:  0.7866152877333886  Test Accuracy:  0.66439332247557

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.80      0.55      0.65      3329
      neutral       0.71      0.59      0.65      3235
contradiction       0.57      0.83      0.67      3278

     accuracy                           0.66      9842
    macro avg       0.69      0.66      0.66      9842
 weighted avg       0.69      0.66      0.66      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.80      0.56      0.66      3368
      neutral       0.72      0.59      0.65      3219
contradiction       0.57      0.85      0.68      3237

     accuracy                           0.66      9824
    macro avg       0.70      0.67      0

100%|██████████| 17168/17168 [01:10<00:00, 245.23it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 371.16it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 531.08it/s]


Epoch:  22  Train Loss:  0.7161783827104209 Train Accuracy:  0.7022700671864164
Epoch:  22  Val Loss:  0.7873486439128975  Val Accuracy:  0.6731355415565942
Epoch:  22 Test Loss:  0.7799667465570307  Test Accuracy:  0.683428338762215

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.81      0.54      0.65      3329
      neutral       0.62      0.69      0.65      3235
contradiction       0.64      0.79      0.71      3278

     accuracy                           0.67      9842
    macro avg       0.69      0.67      0.67      9842
 weighted avg       0.69      0.67      0.67      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.81      0.56      0.66      3368
      neutral       0.63      0.70      0.67      3219
contradiction       0.65      0.80      0.72      3237

     accuracy                           0.68      9824
    macro avg       0.70      0.69      

100%|██████████| 17168/17168 [01:09<00:00, 246.56it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 376.50it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 511.27it/s]


Epoch:  23  Train Loss:  0.7114411041718619 Train Accuracy:  0.7048093533102644
Epoch:  23  Val Loss:  0.8598735736949104  Val Accuracy:  0.6672424304003252
Epoch:  23 Test Loss:  0.8667517509250765  Test Accuracy:  0.6608306188925082

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.94      0.70      3329
      neutral       0.72      0.54      0.62      3235
contradiction       0.90      0.52      0.66      3278

     accuracy                           0.67      9842
    macro avg       0.73      0.67      0.66      9842
 weighted avg       0.73      0.67      0.66      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.94      0.70      3368
      neutral       0.72      0.53      0.61      3219
contradiction       0.90      0.51      0.65      3237

     accuracy                           0.66      9824
    macro avg       0.73      0.66     

100%|██████████| 17168/17168 [01:09<00:00, 247.80it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 552.30it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 474.97it/s]


Epoch:  24  Train Loss:  0.7083005223280897 Train Accuracy:  0.7064730862975024
Epoch:  24  Val Loss:  0.7122614349831234  Val Accuracy:  0.7092054460475513
Epoch:  24 Test Loss:  0.7063358816339449  Test Accuracy:  0.7085708469055375

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.70      0.80      0.75      3329
      neutral       0.64      0.75      0.69      3235
contradiction       0.84      0.58      0.69      3278

     accuracy                           0.71      9842
    macro avg       0.73      0.71      0.71      9842
 weighted avg       0.73      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.70      0.79      0.74      3368
      neutral       0.63      0.74      0.68      3219
contradiction       0.85      0.59      0.70      3237

     accuracy                           0.71      9824
    macro avg       0.73      0.71     

100%|██████████| 17168/17168 [01:08<00:00, 248.88it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 540.29it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 537.15it/s]


Epoch:  25  Train Loss:  0.7063158527042679 Train Accuracy:  0.7081349990079492
Epoch:  25  Val Loss:  0.7297546367173071  Val Accuracy:  0.6997561471245681
Epoch:  25 Test Loss:  0.7288241966927868  Test Accuracy:  0.7082654723127035

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.73      0.73      3329
      neutral       0.60      0.80      0.68      3235
contradiction       0.84      0.58      0.69      3278

     accuracy                           0.70      9842
    macro avg       0.73      0.70      0.70      9842
 weighted avg       0.73      0.70      0.70      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.74      0.74      3368
      neutral       0.61      0.79      0.69      3219
contradiction       0.85      0.60      0.70      3237

     accuracy                           0.71      9824
    macro avg       0.73      0.71     

100%|██████████| 17168/17168 [01:10<00:00, 243.76it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 525.98it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 528.89it/s]


Epoch:  26  Train Loss:  0.7031131324341793 Train Accuracy:  0.709696796494875
Epoch:  26  Val Loss:  0.7176796230789902  Val Accuracy:  0.7081893923999187
Epoch:  26 Test Loss:  0.7163637563926001  Test Accuracy:  0.7088762214983714

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.81      0.75      3329
      neutral       0.64      0.74      0.68      3235
contradiction       0.86      0.57      0.69      3278

     accuracy                           0.71      9842
    macro avg       0.73      0.71      0.71      9842
 weighted avg       0.73      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.81      0.75      3368
      neutral       0.64      0.74      0.68      3219
contradiction       0.87      0.58      0.69      3237

     accuracy                           0.71      9824
    macro avg       0.73      0.71      

100%|██████████| 17168/17168 [01:09<00:00, 245.81it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 552.67it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 524.48it/s]


Epoch:  27  Train Loss:  0.7014766431492575 Train Accuracy:  0.7110492621508027
Epoch:  27  Val Loss:  0.7509362906604619  Val Accuracy:  0.6732371469213575
Epoch:  27 Test Loss:  0.7551865891253132  Test Accuracy:  0.6735545602605864

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.81      0.52      0.63      3329
      neutral       0.56      0.80      0.66      3235
contradiction       0.76      0.70      0.73      3278

     accuracy                           0.67      9842
    macro avg       0.71      0.67      0.67      9842
 weighted avg       0.71      0.67      0.67      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.81      0.53      0.64      3368
      neutral       0.56      0.80      0.66      3219
contradiction       0.76      0.70      0.73      3237

     accuracy                           0.67      9824
    macro avg       0.71      0.68     

100%|██████████| 17168/17168 [01:09<00:00, 247.73it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 386.68it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 380.30it/s]


Epoch:  28  Train Loss:  0.6987412295322466 Train Accuracy:  0.7119885249751077
Epoch:  28  Val Loss:  0.7317728392489544  Val Accuracy:  0.6991465149359887
Epoch:  28 Test Loss:  0.7242776974792977  Test Accuracy:  0.7011400651465798

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.79      0.64      0.71      3329
      neutral       0.65      0.69      0.67      3235
contradiction       0.68      0.77      0.72      3278

     accuracy                           0.70      9842
    macro avg       0.71      0.70      0.70      9842
 weighted avg       0.71      0.70      0.70      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.78      0.64      0.70      3368
      neutral       0.64      0.69      0.67      3219
contradiction       0.69      0.78      0.73      3237

     accuracy                           0.70      9824
    macro avg       0.71      0.70     

100%|██████████| 17168/17168 [01:09<00:00, 248.26it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 543.06it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 516.48it/s]


Epoch:  29  Train Loss:  0.6968420819923601 Train Accuracy:  0.7134975344350862
Epoch:  29  Val Loss:  0.7024740637703375  Val Accuracy:  0.7207884576305629
Epoch:  29 Test Loss:  0.7012636713950564  Test Accuracy:  0.716307003257329

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.82      0.75      3329
      neutral       0.67      0.71      0.69      3235
contradiction       0.84      0.63      0.72      3278

     accuracy                           0.72      9842
    macro avg       0.73      0.72      0.72      9842
 weighted avg       0.73      0.72      0.72      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.82      0.74      3368
      neutral       0.67      0.70      0.68      3219
contradiction       0.84      0.62      0.72      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.71      

100%|██████████| 17168/17168 [01:09<00:00, 248.20it/s]


Validation


100%|██████████| 308/308 [00:00<00:00, 531.22it/s]


Test


100%|██████████| 307/307 [00:00<00:00, 541.85it/s]


Epoch:  30  Train Loss:  0.6943787157210936 Train Accuracy:  0.7137414515251189
Epoch:  30  Val Loss:  0.6856795443923442  Val Accuracy:  0.7249542775858565
Epoch:  30 Test Loss:  0.6898439505201209  Test Accuracy:  0.7182410423452769

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.79      0.75      3329
      neutral       0.66      0.73      0.69      3235
contradiction       0.82      0.65      0.73      3278

     accuracy                           0.72      9842
    macro avg       0.73      0.72      0.72      9842
 weighted avg       0.73      0.72      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.78      0.75      3368
      neutral       0.65      0.72      0.68      3219
contradiction       0.81      0.65      0.72      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.72     

100%|██████████| 307/307 [00:00<00:00, 457.23it/s]


Test Loss:  0.6841637003887748  Test Accuracy:  0.7250610749185668
               precision    recall  f1-score   support

   entailment       0.72      0.77      0.74      3368
      neutral       0.66      0.75      0.70      3219
contradiction       0.82      0.66      0.73      3237

     accuracy                           0.73      9824
    macro avg       0.73      0.72      0.73      9824
 weighted avg       0.73      0.73      0.73      9824

[[2600  584  184]
 [ 542 2400  277]
 [ 479  635 2123]]


In [13]:
INPUT_DIM = len(mword2index)
mnli_model=NLI_LSTM(membedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
mnli_model.to(device)
mnli_model.load_state_dict(torch.load("/content/gdrive/MyDrive/NLPProject/models/model_lstm_nli_mnli_ep_1.pt"))


#adam optimizer
optimizer = optim.Adam(mnli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
# #train
# model = train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(mnli_model,mtest_dataloader,criterion)

100%|██████████| 308/308 [00:00<00:00, 387.76it/s]


Test Loss:  0.9635263044725765  Test Accuracy:  0.504373474369406
               precision    recall  f1-score   support

   entailment       0.49      0.53      0.51      3463
      neutral       0.42      0.48      0.45      3129
contradiction       0.64      0.51      0.57      3240

     accuracy                           0.50      9832
    macro avg       0.52      0.50      0.51      9832
 weighted avg       0.52      0.50      0.51      9832

[[1821 1248  394]
 [1123 1490  516]
 [ 794  798 1648]]


In [14]:
INPUT_DIM = len(sword2index)
snli_model=NLI_LSTM(sembedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
snli_model.to(device)
snli_model.load_state_dict(torch.load("/content/gdrive/MyDrive/NLPProject/models/model_lstm_nli_snli_ep_10.pt"))


#adam optimizer
optimizer = optim.Adam(snli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
# #train
# model = train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(snli_model,stest_dataloader,criterion)

100%|██████████| 307/307 [00:00<00:00, 360.04it/s]


Test Loss:  0.7772045640293084  Test Accuracy:  0.6585912052117264
               precision    recall  f1-score   support

   entailment       0.67      0.70      0.68      3368
      neutral       0.66      0.63      0.65      3219
contradiction       0.64      0.64      0.64      3237

     accuracy                           0.66      9824
    macro avg       0.66      0.66      0.66      9824
 weighted avg       0.66      0.66      0.66      9824

[[2360  462  546]
 [ 567 2044  608]
 [ 603  568 2066]]
