In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import re
import string
import json
import sys
#import classificationreport
from sklearn.metrics import classification_report, confusion_matrix
from spacy.lang.en import English
eng = English()
tok = eng.tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
#NLI_LSTM with 3 class classification with dense layer
class NLI_LSTM(nn.Module):
    def __init__(self,embedding_matrix, vocab_size, embedding_dim, hidden_dim, output_dim,n_fc_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
        self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,batch_first=True)
        self.fc=nn.ModuleList([nn.Linear(hidden_dim * 2,hidden_dim * 2) for i in range(n_fc_layers)])
        self.linear=nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    
    # forward pass
    # NLI with 2 text inputs
    def forward(self, text1, text2):
        #text = [sent len, batch size]
        embedded1 = self.dropout(self.embedding(text1))
        embedded2 = self.dropout(self.embedding(text2))
        #embedded = [sent len, batch size, emb dim]
        output1, (hidden1, cell1) = self.lstm(embedded1)
        output2, (hidden2, cell2) = self.lstm(embedded2)

        hidden = torch.cat((hidden1,hidden2),dim=2)
        hidden = torch.squeeze(hidden)

        for i in range(len(self.fc)-1):
            hidden = self.fc[i](hidden)
            hidden = F.relu(hidden)
            hidden = self.dropout(hidden)
        
        hidden = self.fc[len(self.fc)-1](hidden)
        hidden = F.log_softmax(hidden,dim=1)
        hidden = self.dropout(hidden)

        hidden = self.linear(hidden)
        #hidden = [batch size, output dim]
        return hidden

In [3]:
#load glove vectors
embeddings_index = {}
f = open('../data/glove.6B.300d.txt')
lines = f.readlines()
for line in tqdm(lines):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

100%|████████████████████████████████| 400000/400000 [00:24<00:00, 16257.27it/s]


In [4]:
# tokenise input and remove punctuation and numbers using only regex
def tokenize(text):
    # text = re.sub(r'[^a-zA-Z ]', '', text)
    # #convert to lowercase
    # text = text.lower()
    # return text.split()
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    regex = re.compile('[' + re.escape(string.punctuation) + '0-9\\r\\t\\n]') # remove punctuation and numbers
    nopunctext = regex.sub(" ", text.lower())
    return [token.text for token in tok(nopunctext)]

UNK="<UNK>"
PAD="<PAD>"

#import data
def getDataset(dataset_name="mnli"):
    if dataset_name=="mnli":
        filepath_train="../data/multinli_1.0/multinli_1.0/multinli_1.0_train.jsonl"
        filepath_dev="../data/multinli_1.0/multinli_1.0/multinli_1.0_dev_matched.jsonl"
        filepath_test="../data/multinli_1.0/multinli_1.0/multinli_1.0_dev_mismatched.jsonl"
    elif dataset_name=="snli":
        filepath_train="../data/snli_1.0/snli_1.0/snli_1.0_train.jsonl"
        filepath_dev="../data/snli_1.0/snli_1.0/snli_1.0_dev.jsonl"
        filepath_test="../data/snli_1.0/snli_1.0/snli_1.0_test.jsonl"
    else:
        print("Invalid dataset name")
        return None
    
    #read train,dev and test data
    labels = ["contradiction", "entailment", "neutral"]
    f= open(filepath_train, "r")
    data = list(f)
    train_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("train data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        train_dataset["premise"].append(line['sentence1'])
        train_dataset["hypothesis"].append(line['sentence2'])
        train_dataset["label"].append(line['gold_label'])
    f.close()

    f= open(filepath_dev, "r")
    data = list(f)
    dev_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("dev data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        dev_dataset["premise"].append(line['sentence1'])
        dev_dataset["hypothesis"].append(line['sentence2'])
        dev_dataset["label"].append(line['gold_label'])
    f.close()

    f= open(filepath_test, "r")
    data = list(f)
    test_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("test data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        test_dataset["premise"].append(line['sentence1'])
        test_dataset["hypothesis"].append(line['sentence2'])
        test_dataset["label"].append(line['gold_label'])
    f.close()
    
    return train_dataset,dev_dataset,test_dataset

def getWord2index(dataset):
    word2index = {"":0,UNK:1,PAD:2}
    for sentence in dataset["premise"]:
        for word in tokenize(sentence):
            if word not in word2index:
                word2index[word] = len(word2index)
    for sentence in dataset["hypothesis"]:
        for word in tokenize(sentence):
            if word not in word2index:
                word2index[word] = len(word2index)
    return word2index

def getEmbeddingMatrix(word2index,emb_size=300):
    embedding_matrix = np.zeros((len(word2index),emb_size),dtype=np.float32)
    for word, i in word2index.items():
        if i==0:
            embedding_matrix[i] = np.zeros(emb_size)
        elif word in embeddings_index:
            embedding_matrix[i] = embeddings_index[word]
        else:
            embedding_matrix[i] = np.random.uniform(-0.25,0.25,emb_size)
    return embedding_matrix

def getLabel2index(dataset):
    label2index = {"entailment":0,"neutral":1,"contradiction":2}
    return label2index

def getSentence2vector(sentence,word2index,padLength=32):
    sentence = tokenize(sentence)
    vector = []
    for word in sentence:
        if word in word2index:
            vector.append(word2index[word])
        else:
            vector.append(word2index[UNK])
    
    if len(vector)>padLength:
        vector=vector[:padLength]
    else:
        for i in range(padLength-len(vector)):
            vector.append(word2index[PAD])
    
    if(len(vector)!=padLength):
        print("Error in vector length")
    return np.array(vector)   

In [5]:
def preprocess(dataset,word2index,label2index):
    dataset["premise"] = [getSentence2vector(sentence,word2index) for sentence in tqdm(dataset["premise"])]
    dataset["hypothesis"] = [getSentence2vector(sentence,word2index) for sentence in tqdm(dataset["hypothesis"])]
    dataset["label"] = [label2index[label] for label in dataset["label"]]
    return dataset

def getDataloader(dataset,batch_size=32):
    # print(len(dataset["premise"]))
    premise = torch.tensor(dataset["premise"],dtype=torch.long)
    hypothesis = torch.tensor(dataset["hypothesis"],dtype=torch.long)
    labels = torch.tensor(dataset["label"],dtype=torch.long)
    dataset = torch.utils.data.TensorDataset(premise,hypothesis,labels)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def prepData(dataset_name="mnli"):
    train_dataset,dev_dataset,test_dataset = getDataset(dataset_name)
    word2index = getWord2index(train_dataset)
    label2index = getLabel2index(train_dataset)
    embedding_matrix = getEmbeddingMatrix(word2index)
    #preprocess datasets
    print("preprocess")
    train_dataset = preprocess(train_dataset,word2index,label2index)
    dev_dataset = preprocess(dev_dataset,word2index,label2index)
    test_dataset = preprocess(test_dataset,word2index,label2index)
    #dataloader
    train_dataloader = getDataloader(train_dataset)
    dev_dataloader = getDataloader(dev_dataset)
    test_dataloader = getDataloader(test_dataset)

    return train_dataloader,dev_dataloader,test_dataloader,embedding_matrix,word2index


In [6]:
def train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,epochs=5):
    
    
    f=open(f"../reports_and_results/report_lstm_nli_{datasetName}_foreachepoch.txt",'w')
    
    total_loss=[]
    total_acc=[]
    for epoch in range(1,1+epochs):
        ep_loss=0
        ep_acc=0
        train_total=0
        val_total=0
        test_total=0
        print("\n\nEpoch: ",epoch)
        #training
        print("Training")
        for batch in tqdm(train_dataloader):
            prem, hyp, label = batch
            # print(len(prem),len(hyp),len(label))
            optimizer.zero_grad()
            prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
            output = model(prem, hyp)
            # print(len(output))
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            ep_loss+=loss.item()
            total_loss.append(loss.item())
            acc = (output.argmax(1) == label).sum().item()
            train_total+=label.size(0)
            ep_acc+=acc
        total_acc.append((ep_acc/train_total))
        
        print("Validation")
        #validation
        val_y_true = []
        val_y_pred = []
        with torch.no_grad():
            ep_val_loss=0
            ep_val_acc=0
            for batch in tqdm(dev_dataloader):
                prem, hyp, label = batch
                prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
                output = model(prem, hyp)
                loss = criterion(output, label)
                val_y_true.extend(label.cpu())
                val_y_pred.extend(output.cpu().argmax(1))
                ep_val_loss+=loss.item()
                acc = (output.argmax(1) == label).sum().item()
                val_total+=label.size(0)
                ep_val_acc+=acc
        
        print("Test")
        y_true = []
        y_pred = []
        with torch.no_grad():
            ep_test_loss=0
            ep_test_acc=0
            for batch in tqdm(test_dataloader):
                prem, hyp, label = batch
                prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
                output = model(prem, hyp)
                y_true.extend(label.cpu())
                y_pred.extend(output.cpu().argmax(1))
                loss = criterion(output, label)
                ep_test_loss+=loss.item()
                acc = (output.argmax(1) == label).sum().item()
                test_total+=label.size(0)
                ep_test_acc+=acc
                
        print("Epoch: ",epoch," Train Loss: ",ep_loss/len(train_dataloader),"Train Accuracy: ",ep_acc/train_total)

        print("Epoch: ",epoch," Val Loss: ",ep_val_loss/len(dev_dataloader)," Val Accuracy: ",ep_val_acc/val_total)

        print("Epoch: ",epoch,"Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)

        print()
        print("Validation Classification report")

        print(classification_report(val_y_true, val_y_pred, target_names=["entailment","neutral","contradiction"]))
        print("Test Classification report")
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
                
        orginal_stdout=sys.stdout
        sys.stdout=f 
        print("\n\nEpoch: ",epoch)

        
        print("Epoch: ",epoch," Train Loss: ",ep_loss/len(train_dataloader),"Train Accuracy: ",ep_acc/train_total)

        print("Epoch: ",epoch," Val Loss: ",ep_val_loss/len(dev_dataloader)," Val Accuracy: ",ep_val_acc/val_total)

        print("Epoch: ",epoch,"Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)

        print()
        print("Validation Classification report")

        print(classification_report(val_y_true, val_y_pred, target_names=["entailment","neutral","contradiction"]))
        print("Test Classification report")
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
        #save model
        torch.save(model.state_dict(), f"../models/model_lstm_nli_{datasetName}_ep_{epoch}.pt")
        sys.stdout=orginal_stdout
        
    sys.stdout=f
    print("Total Train Loss",sum(total_loss)/len(total_loss)," Total Train Accuracy: ",sum(total_acc)/len(total_acc))
              
    sys.stdout=orginal_stdout
    print("Total Train Loss",sum(total_loss)/len(total_loss)," Total Train Accuracy: ",sum(total_acc)/len(total_acc))
    f.close()
    return model
    

In [7]:
def test(model,test_dataloader,criterion):
#test
    y_true = []
    y_pred = []
    test_total=0
    with torch.no_grad():
        ep_test_loss=0
        ep_test_acc=0
        for batch in tqdm(test_dataloader):
            prem, hyp, label = batch
            prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
            output = model(prem, hyp)
            y_true.extend(label.cpu())
            y_pred.extend(output.cpu().argmax(1))
            loss = criterion(output, label)
            ep_test_loss+=loss.item()
            acc = (output.argmax(1) == label).sum().item()
            ep_test_acc+=acc
            test_total+=label.size(0)
        print("Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
        print(confusion_matrix(y_true, y_pred))
        


In [8]:
EMBEDDING_DIM = 300
HIDDEN_DIM = 100
OUTPUT_DIM = 3
N_FC_LAYERS = 2
DROPOUT = 0.3
EPOCHS = 30
lr=0.001

In [9]:
datasetName="mnli"

mtrain_dataloader,mdev_dataloader,mtest_dataloader,membedding_matrix,mword2index = prepData(datasetName)

INPUT_DIM = len(mword2index)

train data


100%|███████████████████████████████| 392702/392702 [00:02<00:00, 142146.51it/s]


dev data


100%|█████████████████████████████████| 10000/10000 [00:00<00:00, 126808.85it/s]


test data


100%|█████████████████████████████████| 10000/10000 [00:00<00:00, 124214.95it/s]


preprocess


100%|████████████████████████████████| 392702/392702 [00:30<00:00, 13087.71it/s]
100%|████████████████████████████████| 392702/392702 [00:19<00:00, 20141.12it/s]
100%|████████████████████████████████████| 9815/9815 [00:00<00:00, 11985.46it/s]
100%|████████████████████████████████████| 9815/9815 [00:00<00:00, 19911.14it/s]
100%|████████████████████████████████████| 9832/9832 [00:00<00:00, 11025.80it/s]
100%|████████████████████████████████████| 9832/9832 [00:00<00:00, 18068.43it/s]
  premise = torch.tensor(dataset["premise"],dtype=torch.long)


In [10]:


#initialize the model with above parameters
mnli_model = NLI_LSTM(membedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
#adam optimizer
mnli_model.to(device)
optimizer = optim.Adam(mnli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
#train
mnli_model = train(mnli_model,mtrain_dataloader,mdev_dataloader,mtest_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(mnli_model,mtest_dataloader,criterion)




Epoch:  1
Training


100%|████████████████████████████████████| 12272/12272 [00:57<00:00, 214.83it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 518.62it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 499.40it/s]


Epoch:  1  Train Loss:  1.1012832468722837 Train Accuracy:  0.41839868398938634
Epoch:  1  Val Loss:  1.0391180895827103  Val Accuracy:  0.46673458991339783
Epoch:  1 Test Loss:  1.0156671304207343  Test Accuracy:  0.47030105777054515

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.23      0.32      3479
      neutral       0.47      0.41      0.44      3123
contradiction       0.45      0.78      0.57      3213

     accuracy                           0.47      9815
    macro avg       0.49      0.47      0.44      9815
 weighted avg       0.49      0.47      0.44      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.24      0.33      3463
      neutral       0.44      0.43      0.43      3129
contradiction       0.46      0.76      0.57      3240

     accuracy                           0.47      9832
    macro avg       0.49      0.48     

100%|████████████████████████████████████| 12272/12272 [00:54<00:00, 225.76it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 427.44it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 384.14it/s]


Epoch:  2  Train Loss:  0.9873785929805374 Train Accuracy:  0.5185458693869652
Epoch:  2  Val Loss:  0.995045176740578  Val Accuracy:  0.5073866530820174
Epoch:  2 Test Loss:  0.9725893983593235  Test Accuracy:  0.5102725793327909

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.30      0.40      3479
      neutral       0.49      0.47      0.48      3123
contradiction       0.48      0.77      0.59      3213

     accuracy                           0.51      9815
    macro avg       0.53      0.51      0.49      9815
 weighted avg       0.53      0.51      0.49      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.30      0.40      3463
      neutral       0.48      0.48      0.48      3129
contradiction       0.50      0.76      0.61      3240

     accuracy                           0.51      9832
    macro avg       0.52      0.51      0.4

100%|████████████████████████████████████| 12272/12272 [00:59<00:00, 206.97it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 413.91it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 296.00it/s]


Epoch:  3  Train Loss:  0.9539278550887046 Train Accuracy:  0.5461698692647351
Epoch:  3  Val Loss:  0.9502856319812688  Val Accuracy:  0.552827305145186
Epoch:  3 Test Loss:  0.9466895292718689  Test Accuracy:  0.5502441008950366

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.45      0.51      3479
      neutral       0.45      0.71      0.55      3123
contradiction       0.73      0.52      0.60      3213

     accuracy                           0.55      9815
    macro avg       0.59      0.56      0.55      9815
 weighted avg       0.59      0.55      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.58      0.44      0.50      3463
      neutral       0.44      0.72      0.55      3129
contradiction       0.77      0.51      0.61      3240

     accuracy                           0.55      9832
    macro avg       0.60      0.55      0.5

100%|████████████████████████████████████| 12272/12272 [01:16<00:00, 161.30it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 337.60it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 305.53it/s]


Epoch:  4  Train Loss:  0.9360061957509375 Train Accuracy:  0.5588283227485472
Epoch:  4  Val Loss:  0.9714779436394135  Val Accuracy:  0.5522159959246052
Epoch:  4 Test Loss:  0.9636900688146616  Test Accuracy:  0.559499593165175

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.48      0.91      0.62      3479
      neutral       0.64      0.21      0.32      3123
contradiction       0.75      0.50      0.60      3213

     accuracy                           0.55      9815
    macro avg       0.62      0.54      0.51      9815
 weighted avg       0.62      0.55      0.52      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.48      0.92      0.63      3463
      neutral       0.64      0.21      0.32      3129
contradiction       0.79      0.51      0.62      3240

     accuracy                           0.56      9832
    macro avg       0.64      0.55      0.5

100%|████████████████████████████████████| 12272/12272 [01:14<00:00, 165.59it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 392.93it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 428.72it/s]


Epoch:  5  Train Loss:  0.9247146720792783 Train Accuracy:  0.5675881457186365
Epoch:  5  Val Loss:  0.9398071198976001  Val Accuracy:  0.5574121242995416
Epoch:  5 Test Loss:  0.9164400723847476  Test Accuracy:  0.5666192026037429

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.48      0.53      3479
      neutral       0.55      0.46      0.50      3123
contradiction       0.53      0.75      0.62      3213

     accuracy                           0.56      9815
    macro avg       0.56      0.56      0.55      9815
 weighted avg       0.56      0.56      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.49      0.54      3463
      neutral       0.54      0.47      0.50      3129
contradiction       0.56      0.75      0.64      3240

     accuracy                           0.57      9832
    macro avg       0.57      0.57      0.

100%|████████████████████████████████████| 12272/12272 [01:03<00:00, 193.88it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 393.83it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 289.98it/s]


Epoch:  6  Train Loss:  0.9173656905111232 Train Accuracy:  0.5734017142769836
Epoch:  6  Val Loss:  0.955836085038387  Val Accuracy:  0.5752419765664799
Epoch:  6 Test Loss:  0.93343186378479  Test Accuracy:  0.584519934906428

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.51      0.85      0.64      3479
      neutral       0.68      0.24      0.35      3123
contradiction       0.67      0.61      0.64      3213

     accuracy                           0.58      9815
    macro avg       0.62      0.56      0.54      9815
 weighted avg       0.62      0.58      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.51      0.87      0.64      3463
      neutral       0.69      0.24      0.35      3129
contradiction       0.70      0.62      0.66      3240

     accuracy                           0.58      9832
    macro avg       0.63      0.57      0.55  

100%|████████████████████████████████████| 12272/12272 [01:33<00:00, 130.93it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 300.23it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 336.22it/s]


Epoch:  7  Train Loss:  0.9054850880129474 Train Accuracy:  0.5814357960998416
Epoch:  7  Val Loss:  0.8984677983805877  Val Accuracy:  0.5916454406520631
Epoch:  7 Test Loss:  0.8875788572159681  Test Accuracy:  0.5965215622457283

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.74      0.63      3479
      neutral       0.60      0.38      0.47      3123
contradiction       0.65      0.64      0.64      3213

     accuracy                           0.59      9815
    macro avg       0.60      0.59      0.58      9815
 weighted avg       0.60      0.59      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.76      0.63      3463
      neutral       0.59      0.38      0.46      3129
contradiction       0.68      0.63      0.65      3240

     accuracy                           0.60      9832
    macro avg       0.61      0.59      0.

100%|████████████████████████████████████| 12272/12272 [01:21<00:00, 150.73it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 376.29it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 363.48it/s]


Epoch:  8  Train Loss:  0.8996846384017921 Train Accuracy:  0.5847181832534594
Epoch:  8  Val Loss:  1.0146143401484535  Val Accuracy:  0.4935303107488538
Epoch:  8 Test Loss:  1.0058879409130517  Test Accuracy:  0.4968470301057771

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.13      0.22      3479
      neutral       0.39      0.84      0.54      3123
contradiction       0.72      0.54      0.62      3213

     accuracy                           0.49      9815
    macro avg       0.60      0.51      0.46      9815
 weighted avg       0.60      0.49      0.45      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.12      0.21      3463
      neutral       0.39      0.87      0.54      3129
contradiction       0.76      0.54      0.63      3240

     accuracy                           0.50      9832
    macro avg       0.62      0.51      0.

100%|████████████████████████████████████| 12272/12272 [01:09<00:00, 176.63it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 293.42it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 438.81it/s]


Epoch:  9  Train Loss:  0.8954710398801885 Train Accuracy:  0.5884232827945872
Epoch:  9  Val Loss:  0.9203405160857334  Val Accuracy:  0.589302088639837
Epoch:  9 Test Loss:  0.9057335756815873  Test Accuracy:  0.5938771358828315

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.52      0.83      0.64      3479
      neutral       0.66      0.32      0.43      3123
contradiction       0.70      0.59      0.64      3213

     accuracy                           0.59      9815
    macro avg       0.62      0.58      0.57      9815
 weighted avg       0.62      0.59      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.52      0.84      0.64      3463
      neutral       0.65      0.32      0.43      3129
contradiction       0.72      0.59      0.65      3240

     accuracy                           0.59      9832
    macro avg       0.63      0.59      0.5

100%|████████████████████████████████████| 12272/12272 [01:08<00:00, 178.42it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 268.43it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 298.37it/s]


Epoch:  10  Train Loss:  0.8893078930428279 Train Accuracy:  0.5926656854306828
Epoch:  10  Val Loss:  0.9396341497424371  Val Accuracy:  0.5375445746306673
Epoch:  10 Test Loss:  0.9330330063383301  Test Accuracy:  0.5360048820179008

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.29      0.40      3479
      neutral       0.43      0.79      0.55      3123
contradiction       0.72      0.56      0.63      3213

     accuracy                           0.54      9815
    macro avg       0.60      0.55      0.53      9815
 weighted avg       0.60      0.54      0.53      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.28      0.39      3463
      neutral       0.42      0.80      0.55      3129
contradiction       0.74      0.55      0.63      3240

     accuracy                           0.54      9832
    macro avg       0.61      0.55     

100%|████████████████████████████████████| 12272/12272 [01:14<00:00, 165.60it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 466.56it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 458.37it/s]


Epoch:  11  Train Loss:  0.8873752898802542 Train Accuracy:  0.5944533004670208
Epoch:  11  Val Loss:  0.9317555231457813  Val Accuracy:  0.5699439633214468
Epoch:  11 Test Loss:  0.9140590849247846  Test Accuracy:  0.575772986167616

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.45      0.53      3479
      neutral       0.53      0.55      0.54      3123
contradiction       0.57      0.71      0.63      3213

     accuracy                           0.57      9815
    macro avg       0.58      0.57      0.57      9815
 weighted avg       0.58      0.57      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.64      0.46      0.53      3463
      neutral       0.51      0.56      0.53      3129
contradiction       0.59      0.72      0.65      3240

     accuracy                           0.58      9832
    macro avg       0.58      0.58      

100%|████████████████████████████████████| 12272/12272 [01:18<00:00, 156.27it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 480.26it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 449.15it/s]


Epoch:  12  Train Loss:  0.8824080383184097 Train Accuracy:  0.597216209746831
Epoch:  12  Val Loss:  0.9406909967866705  Val Accuracy:  0.599490575649516
Epoch:  12 Test Loss:  0.9162251749983081  Test Accuracy:  0.6005899104963385

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.79      0.65      3479
      neutral       0.69      0.31      0.43      3123
contradiction       0.64      0.67      0.66      3213

     accuracy                           0.60      9815
    macro avg       0.63      0.59      0.58      9815
 weighted avg       0.62      0.60      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.81      0.65      3463
      neutral       0.68      0.31      0.42      3129
contradiction       0.67      0.66      0.66      3240

     accuracy                           0.60      9832
    macro avg       0.63      0.59      0

100%|████████████████████████████████████| 12272/12272 [01:10<00:00, 174.09it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 294.09it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 424.58it/s]


Epoch:  13  Train Loss:  0.8796656166447382 Train Accuracy:  0.5990343823051576
Epoch:  13  Val Loss:  0.9189552534674977  Val Accuracy:  0.5865511971472236
Epoch:  13 Test Loss:  0.8977151189918642  Test Accuracy:  0.5975386493083807

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.61      0.60      3479
      neutral       0.59      0.41      0.48      3123
contradiction       0.57      0.74      0.64      3213

     accuracy                           0.59      9815
    macro avg       0.59      0.58      0.58      9815
 weighted avg       0.59      0.59      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.64      0.62      3463
      neutral       0.59      0.40      0.48      3129
contradiction       0.60      0.74      0.66      3240

     accuracy                           0.60      9832
    macro avg       0.60      0.59     

100%|████████████████████████████████████| 12272/12272 [01:18<00:00, 155.85it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 327.50it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 367.97it/s]


Epoch:  14  Train Loss:  0.8754686147673966 Train Accuracy:  0.6021105061853518
Epoch:  14  Val Loss:  0.9147572426143609  Val Accuracy:  0.5967396841569027
Epoch:  14 Test Loss:  0.8943829583850774  Test Accuracy:  0.6079129373474369

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.77      0.64      3479
      neutral       0.67      0.32      0.43      3123
contradiction       0.62      0.68      0.65      3213

     accuracy                           0.60      9815
    macro avg       0.61      0.59      0.58      9815
 weighted avg       0.61      0.60      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.79      0.66      3463
      neutral       0.67      0.33      0.44      3129
contradiction       0.65      0.68      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.60     

100%|████████████████████████████████████| 12272/12272 [01:18<00:00, 156.29it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 462.10it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 459.93it/s]


Epoch:  15  Train Loss:  0.8761703544059547 Train Accuracy:  0.6023320482197697
Epoch:  15  Val Loss:  0.8832220095376627  Val Accuracy:  0.6050942435048395
Epoch:  15 Test Loss:  0.8764262739327047  Test Accuracy:  0.6070992676973149

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.76      0.65      3479
      neutral       0.64      0.38      0.47      3123
contradiction       0.65      0.66      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.62      0.60      0.59      9815
 weighted avg       0.62      0.61      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.77      0.65      3463
      neutral       0.63      0.38      0.48      3129
contradiction       0.68      0.65      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.62      0.60     

100%|████████████████████████████████████| 12272/12272 [01:12<00:00, 169.60it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 351.69it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 347.57it/s]


Epoch:  16  Train Loss:  0.8723976217875696 Train Accuracy:  0.6040127119291473
Epoch:  16  Val Loss:  0.8776708390891358  Val Accuracy:  0.6057055527254203
Epoch:  16 Test Loss:  0.8719387418263919  Test Accuracy:  0.6075061025223759

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.79      0.65      3479
      neutral       0.61      0.41      0.49      3123
contradiction       0.71      0.59      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.62      0.60      0.60      9815
 weighted avg       0.62      0.61      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.80      0.65      3463
      neutral       0.61      0.40      0.48      3129
contradiction       0.74      0.60      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.60     

100%|████████████████████████████████████| 12272/12272 [01:12<00:00, 168.78it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 407.00it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 429.87it/s]


Epoch:  17  Train Loss:  0.8713184230301573 Train Accuracy:  0.6040763734332903
Epoch:  17  Val Loss:  0.8806752997423228  Val Accuracy:  0.6107997962302598
Epoch:  17 Test Loss:  0.8674070304864413  Test Accuracy:  0.6113710333604556

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.78      0.65      3479
      neutral       0.63      0.41      0.50      3123
contradiction       0.68      0.62      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.62      0.60      0.60      9815
 weighted avg       0.62      0.61      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.79      0.65      3463
      neutral       0.63      0.40      0.49      3129
contradiction       0.71      0.63      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.60     

100%|████████████████████████████████████| 12272/12272 [01:10<00:00, 174.49it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 277.16it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 495.39it/s]


Epoch:  18  Train Loss:  0.8712509759304732 Train Accuracy:  0.6055940636920617
Epoch:  18  Val Loss:  0.9082037796803328  Val Accuracy:  0.5753438614365767
Epoch:  18 Test Loss:  0.8911697581990973  Test Accuracy:  0.5795362082994304

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.64      0.47      0.54      3479
      neutral       0.51      0.57      0.54      3123
contradiction       0.60      0.69      0.64      3213

     accuracy                           0.58      9815
    macro avg       0.58      0.58      0.57      9815
 weighted avg       0.58      0.58      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.46      0.53      3463
      neutral       0.50      0.59      0.54      3129
contradiction       0.63      0.70      0.66      3240

     accuracy                           0.58      9832
    macro avg       0.58      0.58     

100%|████████████████████████████████████| 12272/12272 [01:07<00:00, 182.03it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 157.81it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 449.24it/s]


Epoch:  19  Train Loss:  0.868676036246597 Train Accuracy:  0.6070633712076842
Epoch:  19  Val Loss:  0.9429933147244034  Val Accuracy:  0.5622007131940907
Epoch:  19 Test Loss:  0.9353368978995782  Test Accuracy:  0.5547192839707079

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.40      0.50      3479
      neutral       0.45      0.76      0.56      3123
contradiction       0.74      0.55      0.63      3213

     accuracy                           0.56      9815
    macro avg       0.61      0.57      0.56      9815
 weighted avg       0.62      0.56      0.56      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.64      0.37      0.47      3463
      neutral       0.44      0.77      0.56      3129
contradiction       0.76      0.54      0.63      3240

     accuracy                           0.55      9832
    macro avg       0.61      0.56      

100%|████████████████████████████████████| 12272/12272 [01:11<00:00, 172.14it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 276.30it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 301.73it/s]


Epoch:  20  Train Loss:  0.8663002040062338 Train Accuracy:  0.6082703933262372
Epoch:  20  Val Loss:  0.8783093275387046  Val Accuracy:  0.6052980132450331
Epoch:  20 Test Loss:  0.8581834278710476  Test Accuracy:  0.6079129373474369

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.62      0.61      3479
      neutral       0.57      0.54      0.55      3123
contradiction       0.64      0.66      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.60      0.60      0.60      9815
 weighted avg       0.60      0.61      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.64      0.62      3463
      neutral       0.55      0.53      0.54      3129
contradiction       0.67      0.65      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.61      0.61     

100%|████████████████████████████████████| 12272/12272 [01:06<00:00, 184.88it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 466.13it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 459.25it/s]


Epoch:  21  Train Loss:  0.8647528190719452 Train Accuracy:  0.6094264862414758
Epoch:  21  Val Loss:  1.010086209940988  Val Accuracy:  0.5620988283239939
Epoch:  21 Test Loss:  1.0332269434417998  Test Accuracy:  0.5574654190398698

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.49      0.90      0.64      3479
      neutral       0.57      0.40      0.47      3123
contradiction       0.87      0.36      0.50      3213

     accuracy                           0.56      9815
    macro avg       0.65      0.55      0.54      9815
 weighted avg       0.64      0.56      0.54      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.49      0.90      0.64      3463
      neutral       0.55      0.40      0.46      3129
contradiction       0.90      0.35      0.50      3240

     accuracy                           0.56      9832
    macro avg       0.65      0.55      

100%|████████████████████████████████████| 12272/12272 [01:06<00:00, 185.63it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 484.56it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 452.00it/s]


Epoch:  22  Train Loss:  0.8642010579742928 Train Accuracy:  0.6098542915493173
Epoch:  22  Val Loss:  0.8889391292189931  Val Accuracy:  0.601935812531839
Epoch:  22 Test Loss:  0.8725722530832538  Test Accuracy:  0.6022172497965825

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.57      0.59      3479
      neutral       0.52      0.61      0.56      3123
contradiction       0.67      0.64      0.66      3213

     accuracy                           0.60      9815
    macro avg       0.61      0.60      0.60      9815
 weighted avg       0.61      0.60      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.57      0.59      3463
      neutral       0.51      0.61      0.56      3129
contradiction       0.70      0.63      0.66      3240

     accuracy                           0.60      9832
    macro avg       0.61      0.60      

100%|████████████████████████████████████| 12272/12272 [01:09<00:00, 177.43it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 459.60it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 457.84it/s]


Epoch:  23  Train Loss:  0.861685373974573 Train Accuracy:  0.6116393601254896
Epoch:  23  Val Loss:  0.8856899965469535  Val Accuracy:  0.6066225165562914
Epoch:  23 Test Loss:  0.8746532581069253  Test Accuracy:  0.6068958502847844

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.77      0.65      3479
      neutral       0.61      0.41      0.49      3123
contradiction       0.69      0.62      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.62      0.60      0.60      9815
 weighted avg       0.62      0.61      0.60      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.80      0.65      3463
      neutral       0.60      0.38      0.47      3129
contradiction       0.71      0.62      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.62      0.60      

100%|████████████████████████████████████| 12272/12272 [01:11<00:00, 171.61it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 278.99it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 292.50it/s]


Epoch:  24  Train Loss:  0.8621818065385891 Train Accuracy:  0.610170052609867
Epoch:  24  Val Loss:  0.9633699321591505  Val Accuracy:  0.5660723382577687
Epoch:  24 Test Loss:  0.942217127262772  Test Accuracy:  0.5687550854353133

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.57      0.59      3479
      neutral       0.65      0.32      0.43      3123
contradiction       0.51      0.80      0.62      3213

     accuracy                           0.57      9815
    macro avg       0.59      0.56      0.55      9815
 weighted avg       0.59      0.57      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.58      0.60      3463
      neutral       0.64      0.30      0.41      3129
contradiction       0.52      0.81      0.63      3240

     accuracy                           0.57      9832
    macro avg       0.59      0.57      0

100%|████████████████████████████████████| 12272/12272 [01:10<00:00, 173.59it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 280.04it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 384.59it/s]


Epoch:  25  Train Loss:  0.8608515489278203 Train Accuracy:  0.6103661300426277
Epoch:  25  Val Loss:  0.9471624961505107  Val Accuracy:  0.574223127865512
Epoch:  25 Test Loss:  0.9267593887332198  Test Accuracy:  0.5768917819365338

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.66      0.62      3479
      neutral       0.67      0.28      0.39      3123
contradiction       0.54      0.77      0.63      3213

     accuracy                           0.57      9815
    macro avg       0.60      0.57      0.55      9815
 weighted avg       0.60      0.57      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.58      0.68      0.63      3463
      neutral       0.65      0.27      0.38      3129
contradiction       0.55      0.76      0.64      3240

     accuracy                           0.58      9832
    macro avg       0.59      0.57      

100%|████████████████████████████████████| 12272/12272 [01:08<00:00, 179.85it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 325.34it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 336.05it/s]


Epoch:  26  Train Loss:  0.8615216623185472 Train Accuracy:  0.611598616762838
Epoch:  26  Val Loss:  0.9548149590383523  Val Accuracy:  0.5468160978094753
Epoch:  26 Test Loss:  0.9365113932978023  Test Accuracy:  0.5458706265256306

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.28      0.40      3479
      neutral       0.45      0.72      0.56      3123
contradiction       0.63      0.66      0.65      3213

     accuracy                           0.55      9815
    macro avg       0.58      0.56      0.53      9815
 weighted avg       0.58      0.55      0.53      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.27      0.39      3463
      neutral       0.44      0.73      0.55      3129
contradiction       0.65      0.66      0.66      3240

     accuracy                           0.55      9832
    macro avg       0.59      0.55      

100%|████████████████████████████████████| 12272/12272 [01:08<00:00, 179.93it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 489.32it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 458.91it/s]


Epoch:  27  Train Loss:  0.8578566845598712 Train Accuracy:  0.6134804508253078
Epoch:  27  Val Loss:  0.9790344044129312  Val Accuracy:  0.5323484462557311
Epoch:  27 Test Loss:  0.972311552662354  Test Accuracy:  0.5270545158665582

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.71      0.25      0.37      3479
      neutral       0.43      0.80      0.56      3123
contradiction       0.68      0.58      0.63      3213

     accuracy                           0.53      9815
    macro avg       0.61      0.54      0.52      9815
 weighted avg       0.61      0.53      0.51      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.24      0.36      3463
      neutral       0.42      0.81      0.55      3129
contradiction       0.71      0.56      0.63      3240

     accuracy                           0.53      9832
    macro avg       0.60      0.54      

100%|████████████████████████████████████| 12272/12272 [01:07<00:00, 181.97it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 381.59it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 406.40it/s]


Epoch:  28  Train Loss:  0.858389700084167 Train Accuracy:  0.6132538158705584
Epoch:  28  Val Loss:  0.9650114953323762  Val Accuracy:  0.5681100356597045
Epoch:  28 Test Loss:  0.9503468155086815  Test Accuracy:  0.5751627339300244

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.56      0.58      3479
      neutral       0.69      0.33      0.45      3123
contradiction       0.51      0.81      0.62      3213

     accuracy                           0.57      9815
    macro avg       0.60      0.57      0.55      9815
 weighted avg       0.60      0.57      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.57      0.59      3463
      neutral       0.67      0.34      0.45      3129
contradiction       0.52      0.80      0.63      3240

     accuracy                           0.58      9832
    macro avg       0.60      0.57      

100%|████████████████████████████████████| 12272/12272 [01:10<00:00, 174.22it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 358.07it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 463.02it/s]


Epoch:  29  Train Loss:  0.8557937037772857 Train Accuracy:  0.6152451477201543
Epoch:  29  Val Loss:  0.9315669295842174  Val Accuracy:  0.5732042791645441
Epoch:  29 Test Loss:  0.9240148348080648  Test Accuracy:  0.5763832384052074

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.41      0.51      3479
      neutral       0.48      0.71      0.57      3123
contradiction       0.65      0.62      0.63      3213

     accuracy                           0.57      9815
    macro avg       0.60      0.58      0.57      9815
 weighted avg       0.60      0.57      0.57      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.42      0.51      3463
      neutral       0.47      0.71      0.57      3129
contradiction       0.68      0.61      0.64      3240

     accuracy                           0.58      9832
    macro avg       0.60      0.58     

100%|████████████████████████████████████| 12272/12272 [01:11<00:00, 171.82it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 281.46it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 267.31it/s]


Epoch:  30  Train Loss:  0.8556207113157389 Train Accuracy:  0.6148631786952957
Epoch:  30  Val Loss:  0.8984044167040225  Val Accuracy:  0.6069281711665817
Epoch:  30 Test Loss:  0.8809550278759622  Test Accuracy:  0.6084214808787632

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.81      0.65      3479
      neutral       0.66      0.37      0.47      3123
contradiction       0.69      0.62      0.65      3213

     accuracy                           0.61      9815
    macro avg       0.63      0.60      0.59      9815
 weighted avg       0.63      0.61      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.83      0.65      3463
      neutral       0.64      0.36      0.46      3129
contradiction       0.72      0.62      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.60     

100%|████████████████████████████████████████| 308/308 [00:01<00:00, 294.98it/s]


Test Loss:  0.8827209006269257  Test Accuracy:  0.60913344182262
               precision    recall  f1-score   support

   entailment       0.54      0.83      0.65      3463
      neutral       0.64      0.36      0.46      3129
contradiction       0.72      0.62      0.66      3240

     accuracy                           0.61      9832
    macro avg       0.63      0.60      0.59      9832
 weighted avg       0.63      0.61      0.60      9832

[[2861  354  248]
 [1466 1130  533]
 [ 956  286 1998]]


In [11]:
datasetName="snli"

strain_dataloader,sdev_dataloader,stest_dataloader,sembedding_matrix,sword2index = prepData(datasetName)
#intialize the model
INPUT_DIM = len(sword2index)

train data


100%|████████████████████████████████| 550152/550152 [00:07<00:00, 75061.04it/s]


dev data


100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 67383.68it/s]


test data


100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 67923.51it/s]


preprocess


100%|████████████████████████████████| 549367/549367 [00:29<00:00, 18633.43it/s]
100%|████████████████████████████████| 549367/549367 [00:21<00:00, 25397.48it/s]
100%|████████████████████████████████████| 9842/9842 [00:00<00:00, 19120.38it/s]
100%|████████████████████████████████████| 9842/9842 [00:00<00:00, 25347.01it/s]
100%|████████████████████████████████████| 9824/9824 [00:00<00:00, 19469.94it/s]
100%|████████████████████████████████████| 9824/9824 [00:00<00:00, 25865.49it/s]


In [12]:
#initialize the model with above parameters
snli_model = NLI_LSTM(sembedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
snli_model.to(device)
#adam optimizer
optimizer = optim.Adam(snli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
#train
snli_model = train(snli_model,strain_dataloader,sdev_dataloader,stest_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(snli_model,stest_dataloader,criterion)



Epoch:  1
Training


100%|████████████████████████████████████| 17168/17168 [01:16<00:00, 223.71it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 455.40it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 465.71it/s]


Epoch:  1  Train Loss:  1.1413172060295675 Train Accuracy:  0.37204819364832614
Epoch:  1  Val Loss:  0.9950530130367774  Val Accuracy:  0.5439951229424914
Epoch:  1 Test Loss:  0.9936889105200379  Test Accuracy:  0.5457043973941368

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.34      0.46      3329
      neutral       0.57      0.57      0.57      3235
contradiction       0.47      0.72      0.57      3278

     accuracy                           0.54      9842
    macro avg       0.59      0.55      0.54      9842
 weighted avg       0.59      0.54      0.53      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.36      0.48      3368
      neutral       0.57      0.57      0.57      3219
contradiction       0.47      0.71      0.57      3237

     accuracy                           0.55      9824
    macro avg       0.59      0.55      0

100%|████████████████████████████████████| 17168/17168 [01:27<00:00, 195.48it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 346.57it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 302.18it/s]


Epoch:  2  Train Loss:  0.8821629602965471 Train Accuracy:  0.6044265491010563
Epoch:  2  Val Loss:  0.8375505971444117  Val Accuracy:  0.6234505181873603
Epoch:  2 Test Loss:  0.8294789785282619  Test Accuracy:  0.6244910423452769

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.63      0.65      3329
      neutral       0.54      0.77      0.63      3235
contradiction       0.75      0.47      0.58      3278

     accuracy                           0.62      9842
    macro avg       0.65      0.62      0.62      9842
 weighted avg       0.65      0.62      0.62      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.65      0.66      3368
      neutral       0.54      0.76      0.63      3219
contradiction       0.75      0.47      0.58      3237

     accuracy                           0.62      9824
    macro avg       0.65      0.62      0.

100%|████████████████████████████████████| 17168/17168 [01:30<00:00, 190.06it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 444.12it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 440.54it/s]


Epoch:  3  Train Loss:  0.8267890822059039 Train Accuracy:  0.6384893886964452
Epoch:  3  Val Loss:  0.7800269186883778  Val Accuracy:  0.6657183499288762
Epoch:  3 Test Loss:  0.777587072177508  Test Accuracy:  0.6677524429967426

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.66      0.67      3329
      neutral       0.61      0.72      0.66      3235
contradiction       0.71      0.62      0.66      3278

     accuracy                           0.67      9842
    macro avg       0.67      0.67      0.67      9842
 weighted avg       0.67      0.67      0.67      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.67      0.68      3368
      neutral       0.62      0.72      0.66      3219
contradiction       0.70      0.62      0.66      3237

     accuracy                           0.67      9824
    macro avg       0.67      0.67      0.6

100%|████████████████████████████████████| 17168/17168 [01:33<00:00, 183.78it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 385.99it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 459.53it/s]


Epoch:  4  Train Loss:  0.8014741822181466 Train Accuracy:  0.653037040812426
Epoch:  4  Val Loss:  0.7653290364262345  Val Accuracy:  0.666023166023166
Epoch:  4 Test Loss:  0.7610187255792586  Test Accuracy:  0.6650040716612378

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.82      0.71      3329
      neutral       0.75      0.50      0.60      3235
contradiction       0.67      0.67      0.67      3278

     accuracy                           0.67      9842
    macro avg       0.68      0.66      0.66      9842
 weighted avg       0.68      0.67      0.66      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.83      0.71      3368
      neutral       0.75      0.50      0.60      3219
contradiction       0.67      0.66      0.67      3237

     accuracy                           0.67      9824
    macro avg       0.68      0.66      0.66

100%|████████████████████████████████████| 17168/17168 [01:36<00:00, 177.26it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 282.11it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 448.97it/s]


Epoch:  5  Train Loss:  0.7862899141526761 Train Accuracy:  0.6625061206807107
Epoch:  5  Val Loss:  0.8973945745012977  Val Accuracy:  0.6275147327778907
Epoch:  5 Test Loss:  0.8911722575220301  Test Accuracy:  0.6260179153094463

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.71      0.67      0.69      3329
      neutral       0.51      0.85      0.63      3235
contradiction       0.92      0.37      0.53      3278

     accuracy                           0.63      9842
    macro avg       0.71      0.63      0.62      9842
 weighted avg       0.71      0.63      0.62      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.71      0.67      0.69      3368
      neutral       0.51      0.84      0.64      3219
contradiction       0.91      0.36      0.52      3237

     accuracy                           0.63      9824
    macro avg       0.71      0.63      0.

100%|████████████████████████████████████| 17168/17168 [01:35<00:00, 179.89it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 389.49it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 407.24it/s]


Epoch:  6  Train Loss:  0.7738367887222506 Train Accuracy:  0.6688133797625266
Epoch:  6  Val Loss:  0.7360244641056308  Val Accuracy:  0.6929485876854298
Epoch:  6 Test Loss:  0.7202084535302091  Test Accuracy:  0.6941164495114006

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.78      0.72      3329
      neutral       0.70      0.62      0.66      3235
contradiction       0.71      0.67      0.69      3278

     accuracy                           0.69      9842
    macro avg       0.70      0.69      0.69      9842
 weighted avg       0.70      0.69      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.79      0.72      3368
      neutral       0.70      0.61      0.66      3219
contradiction       0.72      0.67      0.70      3237

     accuracy                           0.69      9824
    macro avg       0.70      0.69      0.

100%|████████████████████████████████████| 17168/17168 [01:41<00:00, 169.33it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 318.65it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 346.43it/s]


Epoch:  7  Train Loss:  0.7591521930893864 Train Accuracy:  0.6774360309228621
Epoch:  7  Val Loss:  0.7130398000409077  Val Accuracy:  0.6999593578540947
Epoch:  7 Test Loss:  0.707038195203104  Test Accuracy:  0.7061278501628665

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.78      0.73      3329
      neutral       0.65      0.68      0.67      3235
contradiction       0.77      0.64      0.70      3278

     accuracy                           0.70      9842
    macro avg       0.71      0.70      0.70      9842
 weighted avg       0.71      0.70      0.70      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.79      0.74      3368
      neutral       0.67      0.68      0.67      3219
contradiction       0.77      0.65      0.70      3237

     accuracy                           0.71      9824
    macro avg       0.71      0.70      0.7

100%|████████████████████████████████████| 17168/17168 [01:34<00:00, 180.75it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 311.87it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 327.22it/s]


Epoch:  8  Train Loss:  0.7402065372544115 Train Accuracy:  0.6880446040624938
Epoch:  8  Val Loss:  0.7425250995468784  Val Accuracy:  0.6878683194472668
Epoch:  8 Test Loss:  0.7381296843193254  Test Accuracy:  0.6886197068403909

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.76      0.66      0.70      3329
      neutral       0.59      0.79      0.67      3235
contradiction       0.77      0.62      0.69      3278

     accuracy                           0.69      9842
    macro avg       0.71      0.69      0.69      9842
 weighted avg       0.71      0.69      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.75      0.66      0.70      3368
      neutral       0.59      0.79      0.67      3219
contradiction       0.78      0.62      0.69      3237

     accuracy                           0.69      9824
    macro avg       0.71      0.69      0.

100%|████████████████████████████████████| 17168/17168 [01:37<00:00, 176.97it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 356.59it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 470.89it/s]


Epoch:  9  Train Loss:  0.7293607881032312 Train Accuracy:  0.6936492363028722
Epoch:  9  Val Loss:  0.7373156888144357  Val Accuracy:  0.6924405608616135
Epoch:  9 Test Loss:  0.7298975079183858  Test Accuracy:  0.6968648208469055

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.79      0.60      0.68      3329
      neutral       0.61      0.77      0.68      3235
contradiction       0.72      0.71      0.72      3278

     accuracy                           0.69      9842
    macro avg       0.71      0.69      0.69      9842
 weighted avg       0.71      0.69      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.79      0.61      0.69      3368
      neutral       0.62      0.77      0.69      3219
contradiction       0.72      0.71      0.72      3237

     accuracy                           0.70      9824
    macro avg       0.71      0.70      0.

100%|████████████████████████████████████| 17168/17168 [01:37<00:00, 176.55it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 412.43it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 294.49it/s]


Epoch:  10  Train Loss:  0.7211840810065001 Train Accuracy:  0.6977393982528983
Epoch:  10  Val Loss:  0.7143503901633349  Val Accuracy:  0.7003657793131477
Epoch:  10 Test Loss:  0.7104439419915699  Test Accuracy:  0.6978827361563518

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.81      0.74      3329
      neutral       0.76      0.52      0.62      3235
contradiction       0.69      0.76      0.73      3278

     accuracy                           0.70      9842
    macro avg       0.71      0.70      0.69      9842
 weighted avg       0.71      0.70      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.81      0.73      3368
      neutral       0.76      0.52      0.62      3219
contradiction       0.69      0.76      0.73      3237

     accuracy                           0.70      9824
    macro avg       0.71      0.70     

100%|████████████████████████████████████| 17168/17168 [01:31<00:00, 186.95it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 436.57it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 440.12it/s]


Epoch:  11  Train Loss:  0.7163228721836129 Train Accuracy:  0.700522601466779
Epoch:  11  Val Loss:  0.7623795837938011  Val Accuracy:  0.6900020321072953
Epoch:  11 Test Loss:  0.7593511230394195  Test Accuracy:  0.6893322475570033

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.71      0.73      3329
      neutral       0.74      0.53      0.62      3235
contradiction       0.62      0.83      0.71      3278

     accuracy                           0.69      9842
    macro avg       0.70      0.69      0.69      9842
 weighted avg       0.70      0.69      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.72      0.73      3368
      neutral       0.74      0.52      0.61      3219
contradiction       0.62      0.83      0.71      3237

     accuracy                           0.69      9824
    macro avg       0.70      0.69      

100%|████████████████████████████████████| 17168/17168 [01:36<00:00, 177.46it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 328.92it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 319.98it/s]


Epoch:  12  Train Loss:  0.7099008666297624 Train Accuracy:  0.7035679245386054
Epoch:  12  Val Loss:  0.7008231390606273  Val Accuracy:  0.7135744767323715
Epoch:  12 Test Loss:  0.6942111272182837  Test Accuracy:  0.7156962540716613

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.80      0.74      3329
      neutral       0.66      0.71      0.69      3235
contradiction       0.84      0.62      0.71      3278

     accuracy                           0.71      9842
    macro avg       0.73      0.71      0.71      9842
 weighted avg       0.73      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.80      0.74      3368
      neutral       0.67      0.72      0.69      3219
contradiction       0.84      0.63      0.72      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.71     

100%|████████████████████████████████████| 17168/17168 [01:35<00:00, 180.45it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 332.12it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 334.82it/s]


Epoch:  13  Train Loss:  0.7051800196627563 Train Accuracy:  0.7059397451976548
Epoch:  13  Val Loss:  0.7248974884678792  Val Accuracy:  0.6982320666531193
Epoch:  13 Test Loss:  0.731532781733752  Test Accuracy:  0.6929967426710097

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.77      0.75      3329
      neutral       0.78      0.50      0.61      3235
contradiction       0.63      0.82      0.71      3278

     accuracy                           0.70      9842
    macro avg       0.71      0.70      0.69      9842
 weighted avg       0.71      0.70      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.78      0.75      3368
      neutral       0.78      0.49      0.60      3219
contradiction       0.63      0.81      0.71      3237

     accuracy                           0.69      9824
    macro avg       0.71      0.69      

100%|████████████████████████████████████| 17168/17168 [01:34<00:00, 182.61it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 375.49it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 384.41it/s]


Epoch:  14  Train Loss:  0.7012605330241082 Train Accuracy:  0.7086301142951797
Epoch:  14  Val Loss:  0.771540142595768  Val Accuracy:  0.6811623653728917
Epoch:  14 Test Loss:  0.7724230441868499  Test Accuracy:  0.6792548859934854

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.82      0.55      0.66      3329
      neutral       0.56      0.84      0.67      3235
contradiction       0.77      0.66      0.71      3278

     accuracy                           0.68      9842
    macro avg       0.72      0.68      0.68      9842
 weighted avg       0.72      0.68      0.68      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.82      0.55      0.66      3368
      neutral       0.56      0.83      0.67      3219
contradiction       0.77      0.66      0.71      3237

     accuracy                           0.68      9824
    macro avg       0.72      0.68      

100%|████████████████████████████████████| 17168/17168 [01:29<00:00, 191.43it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 300.61it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 322.53it/s]


Epoch:  15  Train Loss:  0.6975934533645353 Train Accuracy:  0.7104176261042254
Epoch:  15  Val Loss:  0.7198925635644368  Val Accuracy:  0.6895956106482423
Epoch:  15 Test Loss:  0.7255540187273429  Test Accuracy:  0.6860749185667753

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.81      0.61      0.69      3329
      neutral       0.56      0.83      0.67      3235
contradiction       0.82      0.64      0.72      3278

     accuracy                           0.69      9842
    macro avg       0.73      0.69      0.69      9842
 weighted avg       0.73      0.69      0.69      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.79      0.60      0.69      3368
      neutral       0.56      0.82      0.66      3219
contradiction       0.82      0.64      0.72      3237

     accuracy                           0.69      9824
    macro avg       0.72      0.69     

100%|████████████████████████████████████| 17168/17168 [01:37<00:00, 175.18it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 309.56it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 330.02it/s]


Epoch:  16  Train Loss:  0.694764021360502 Train Accuracy:  0.7118155986799353
Epoch:  16  Val Loss:  0.703687916134859  Val Accuracy:  0.7187563503352977
Epoch:  16 Test Loss:  0.6933654731182012  Test Accuracy:  0.7151872964169381

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.86      0.75      3329
      neutral       0.77      0.58      0.66      3235
contradiction       0.76      0.72      0.74      3278

     accuracy                           0.72      9842
    macro avg       0.73      0.72      0.71      9842
 weighted avg       0.73      0.72      0.72      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.66      0.86      0.75      3368
      neutral       0.76      0.57      0.65      3219
contradiction       0.75      0.72      0.74      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.71      0

100%|████████████████████████████████████| 17168/17168 [01:30<00:00, 190.43it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 374.28it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 397.91it/s]


Epoch:  17  Train Loss:  0.6907658303609571 Train Accuracy:  0.7142420276427234
Epoch:  17  Val Loss:  0.7305901751115724  Val Accuracy:  0.7132696606380817
Epoch:  17 Test Loss:  0.7133450198445336  Test Accuracy:  0.7169177524429967

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.84      0.75      3329
      neutral       0.80      0.53      0.63      3235
contradiction       0.71      0.77      0.74      3278

     accuracy                           0.71      9842
    macro avg       0.73      0.71      0.71      9842
 weighted avg       0.73      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.84      0.75      3368
      neutral       0.80      0.52      0.63      3219
contradiction       0.72      0.78      0.75      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.71     

100%|████████████████████████████████████| 17168/17168 [01:34<00:00, 182.43it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 302.04it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 324.27it/s]


Epoch:  18  Train Loss:  0.6884842875768593 Train Accuracy:  0.7149919816807344
Epoch:  18  Val Loss:  0.6739162677874813  Val Accuracy:  0.7273928063401748
Epoch:  18 Test Loss:  0.6695142878577452  Test Accuracy:  0.7316775244299675

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.70      0.80      0.75      3329
      neutral       0.68      0.71      0.70      3235
contradiction       0.82      0.67      0.74      3278

     accuracy                           0.73      9842
    macro avg       0.74      0.73      0.73      9842
 weighted avg       0.74      0.73      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.71      0.80      0.75      3368
      neutral       0.69      0.72      0.70      3219
contradiction       0.82      0.68      0.74      3237

     accuracy                           0.73      9824
    macro avg       0.74      0.73     

100%|████████████████████████████████████| 17168/17168 [01:32<00:00, 186.39it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 452.42it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 443.70it/s]


Epoch:  19  Train Loss:  0.6870004307706072 Train Accuracy:  0.7159712905944478
Epoch:  19  Val Loss:  0.7036692807813744  Val Accuracy:  0.7152001625685837
Epoch:  19 Test Loss:  0.7060522214791674  Test Accuracy:  0.7192589576547231

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.79      0.66      0.72      3329
      neutral       0.66      0.71      0.68      3235
contradiction       0.71      0.78      0.75      3278

     accuracy                           0.72      9842
    macro avg       0.72      0.72      0.71      9842
 weighted avg       0.72      0.72      0.72      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.79      0.66      0.72      3368
      neutral       0.66      0.72      0.69      3219
contradiction       0.72      0.78      0.75      3237

     accuracy                           0.72      9824
    macro avg       0.72      0.72     

100%|████████████████████████████████████| 17168/17168 [01:34<00:00, 180.72it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 375.96it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 379.80it/s]


Epoch:  20  Train Loss:  0.6856059378702176 Train Accuracy:  0.7164682261584696
Epoch:  20  Val Loss:  0.6978121589530598  Val Accuracy:  0.7190611664295875
Epoch:  20 Test Loss:  0.6942784250365018  Test Accuracy:  0.7184446254071661

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.78      0.68      0.73      3329
      neutral       0.63      0.77      0.70      3235
contradiction       0.78      0.70      0.74      3278

     accuracy                           0.72      9842
    macro avg       0.73      0.72      0.72      9842
 weighted avg       0.73      0.72      0.72      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.77      0.68      0.72      3368
      neutral       0.64      0.76      0.69      3219
contradiction       0.77      0.71      0.74      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.72     

100%|████████████████████████████████████| 17168/17168 [01:32<00:00, 186.57it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 482.08it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 455.97it/s]


Epoch:  21  Train Loss:  0.6846769093470744 Train Accuracy:  0.7169451386777873
Epoch:  21  Val Loss:  0.6708640684555103  Val Accuracy:  0.7259703312334891
Epoch:  21 Test Loss:  0.6619591953700062  Test Accuracy:  0.7288273615635179

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.76      0.72      0.74      3329
      neutral       0.68      0.70      0.69      3235
contradiction       0.74      0.75      0.75      3278

     accuracy                           0.73      9842
    macro avg       0.73      0.73      0.73      9842
 weighted avg       0.73      0.73      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.76      0.73      0.74      3368
      neutral       0.69      0.71      0.70      3219
contradiction       0.75      0.75      0.75      3237

     accuracy                           0.73      9824
    macro avg       0.73      0.73     

100%|████████████████████████████████████| 17168/17168 [01:40<00:00, 171.00it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 327.78it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 335.82it/s]


Epoch:  22  Train Loss:  0.683601212383136 Train Accuracy:  0.7181774660654899
Epoch:  22  Val Loss:  0.7440542742222934  Val Accuracy:  0.7125584230847388
Epoch:  22 Test Loss:  0.7451346915591424  Test Accuracy:  0.7164087947882736

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.87      0.74      3329
      neutral       0.69      0.68      0.68      3235
contradiction       0.89      0.58      0.70      3278

     accuracy                           0.71      9842
    macro avg       0.74      0.71      0.71      9842
 weighted avg       0.74      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.88      0.74      3368
      neutral       0.70      0.69      0.69      3219
contradiction       0.90      0.58      0.70      3237

     accuracy                           0.72      9824
    macro avg       0.75      0.71      

100%|████████████████████████████████████| 17168/17168 [01:35<00:00, 180.70it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 370.26it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 394.70it/s]


Epoch:  23  Train Loss:  0.6818603198013623 Train Accuracy:  0.7190311758806044
Epoch:  23  Val Loss:  0.6829626421843257  Val Accuracy:  0.7223125381020118
Epoch:  23 Test Loss:  0.6783871051251694  Test Accuracy:  0.7216001628664495

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.85      0.76      3329
      neutral       0.78      0.56      0.65      3235
contradiction       0.73      0.75      0.74      3278

     accuracy                           0.72      9842
    macro avg       0.73      0.72      0.72      9842
 weighted avg       0.73      0.72      0.72      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.86      0.75      3368
      neutral       0.78      0.55      0.65      3219
contradiction       0.74      0.75      0.75      3237

     accuracy                           0.72      9824
    macro avg       0.73      0.72     

100%|████████████████████████████████████| 17168/17168 [01:34<00:00, 182.61it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 366.43it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 382.36it/s]


Epoch:  24  Train Loss:  0.6789260437279828 Train Accuracy:  0.7207713604930766
Epoch:  24  Val Loss:  0.7203870573407644  Val Accuracy:  0.711440764072343
Epoch:  24 Test Loss:  0.7184824338759196  Test Accuracy:  0.711421009771987

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.75      0.75      3329
      neutral       0.61      0.81      0.69      3235
contradiction       0.88      0.57      0.69      3278

     accuracy                           0.71      9842
    macro avg       0.74      0.71      0.71      9842
 weighted avg       0.74      0.71      0.71      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.74      0.75      0.74      3368
      neutral       0.60      0.80      0.69      3219
contradiction       0.88      0.58      0.70      3237

     accuracy                           0.71      9824
    macro avg       0.74      0.71      0

100%|████████████████████████████████████| 17168/17168 [01:35<00:00, 180.14it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 461.43it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 459.70it/s]


Epoch:  25  Train Loss:  0.6789243235448089 Train Accuracy:  0.720440070117062
Epoch:  25  Val Loss:  0.6707301367219393  Val Accuracy:  0.7351148140621825
Epoch:  25 Test Loss:  0.6750525909836983  Test Accuracy:  0.7303542345276873

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.86      0.76      3329
      neutral       0.72      0.68      0.70      3235
contradiction       0.83      0.67      0.74      3278

     accuracy                           0.74      9842
    macro avg       0.74      0.73      0.73      9842
 weighted avg       0.74      0.74      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.84      0.75      3368
      neutral       0.71      0.68      0.69      3219
contradiction       0.83      0.67      0.74      3237

     accuracy                           0.73      9824
    macro avg       0.74      0.73      

100%|████████████████████████████████████| 17168/17168 [01:32<00:00, 185.27it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 358.17it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 383.91it/s]


Epoch:  26  Train Loss:  0.6762467468294548 Train Accuracy:  0.7213738721109932
Epoch:  26  Val Loss:  0.7724729432881653  Val Accuracy:  0.6937614306035359
Epoch:  26 Test Loss:  0.772441395233819  Test Accuracy:  0.6928949511400652

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.82      0.74      3329
      neutral       0.84      0.42      0.56      3235
contradiction       0.65      0.83      0.73      3278

     accuracy                           0.69      9842
    macro avg       0.72      0.69      0.68      9842
 weighted avg       0.72      0.69      0.68      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.68      0.82      0.75      3368
      neutral       0.83      0.42      0.56      3219
contradiction       0.65      0.83      0.73      3237

     accuracy                           0.69      9824
    macro avg       0.72      0.69      

100%|████████████████████████████████████| 17168/17168 [01:32<00:00, 185.68it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 441.96it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 431.85it/s]


Epoch:  27  Train Loss:  0.6760631127322904 Train Accuracy:  0.7225024437215923
Epoch:  27  Val Loss:  0.6531030680839117  Val Accuracy:  0.732168258484048
Epoch:  27 Test Loss:  0.6477794894760517  Test Accuracy:  0.7335097719869706

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.80      0.76      3329
      neutral       0.69      0.70      0.69      3235
contradiction       0.78      0.70      0.74      3278

     accuracy                           0.73      9842
    macro avg       0.73      0.73      0.73      9842
 weighted avg       0.73      0.73      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.80      0.76      3368
      neutral       0.69      0.70      0.70      3219
contradiction       0.79      0.70      0.74      3237

     accuracy                           0.73      9824
    macro avg       0.74      0.73      

100%|████████████████████████████████████| 17168/17168 [01:40<00:00, 171.24it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 393.68it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 392.65it/s]


Epoch:  28  Train Loss:  0.673331680830884 Train Accuracy:  0.7231704853039953
Epoch:  28  Val Loss:  0.6860545108264143  Val Accuracy:  0.7248526722210933
Epoch:  28 Test Loss:  0.6783484349615799  Test Accuracy:  0.7263843648208469

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.79      0.75      3329
      neutral       0.75      0.60      0.67      3235
contradiction       0.71      0.78      0.74      3278

     accuracy                           0.72      9842
    macro avg       0.73      0.72      0.72      9842
 weighted avg       0.73      0.72      0.72      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.72      0.80      0.76      3368
      neutral       0.75      0.60      0.67      3219
contradiction       0.71      0.78      0.74      3237

     accuracy                           0.73      9824
    macro avg       0.73      0.73      

100%|████████████████████████████████████| 17168/17168 [01:30<00:00, 188.75it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 431.45it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 444.38it/s]


Epoch:  29  Train Loss:  0.6734519656374356 Train Accuracy:  0.7228373746511895
Epoch:  29  Val Loss:  0.66365189192357  Val Accuracy:  0.7318634423897582
Epoch:  29 Test Loss:  0.6645120550250386  Test Accuracy:  0.7315757328990228

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.76      0.74      0.75      3329
      neutral       0.68      0.69      0.69      3235
contradiction       0.76      0.76      0.76      3278

     accuracy                           0.73      9842
    macro avg       0.73      0.73      0.73      9842
 weighted avg       0.73      0.73      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.75      0.75      0.75      3368
      neutral       0.69      0.69      0.69      3219
contradiction       0.75      0.75      0.75      3237

     accuracy                           0.73      9824
    macro avg       0.73      0.73      0

100%|████████████████████████████████████| 17168/17168 [01:33<00:00, 183.24it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 438.90it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 450.84it/s]


Epoch:  30  Train Loss:  0.6722821749352745 Train Accuracy:  0.7239167987884237
Epoch:  30  Val Loss:  0.669052271390116  Val Accuracy:  0.7260719365982524
Epoch:  30 Test Loss:  0.6685523192734982  Test Accuracy:  0.7290309446254072

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.75      0.74      0.74      3329
      neutral       0.68      0.68      0.68      3235
contradiction       0.75      0.76      0.75      3278

     accuracy                           0.73      9842
    macro avg       0.73      0.73      0.73      9842
 weighted avg       0.73      0.73      0.73      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.75      0.75      0.75      3368
      neutral       0.68      0.68      0.68      3219
contradiction       0.75      0.75      0.75      3237

     accuracy                           0.73      9824
    macro avg       0.73      0.73      

100%|████████████████████████████████████████| 307/307 [00:00<00:00, 448.05it/s]


Test Loss:  0.6656459579638627  Test Accuracy:  0.7292345276872965
               precision    recall  f1-score   support

   entailment       0.76      0.74      0.75      3368
      neutral       0.69      0.69      0.69      3219
contradiction       0.75      0.76      0.75      3237

     accuracy                           0.73      9824
    macro avg       0.73      0.73      0.73      9824
 weighted avg       0.73      0.73      0.73      9824

[[2481  568  319]
 [ 468 2231  520]
 [ 335  450 2452]]


In [13]:
INPUT_DIM = len(mword2index)
mnli_model=NLI_LSTM(membedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
mnli_model.to(device)
mnli_model.load_state_dict(torch.load("../models/model_lstm_nli_mnli_ep_1.pt"))


#adam optimizer
optimizer = optim.Adam(mnli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
# #train
# model = train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(mnli_model,mtest_dataloader,criterion)

100%|████████████████████████████████████████| 308/308 [00:01<00:00, 225.52it/s]


Test Loss:  1.0207105209301044  Test Accuracy:  0.47447111472742065
               precision    recall  f1-score   support

   entailment       0.54      0.24      0.33      3463
      neutral       0.46      0.44      0.45      3129
contradiction       0.46      0.76      0.58      3240

     accuracy                           0.47      9832
    macro avg       0.49      0.48      0.45      9832
 weighted avg       0.49      0.47      0.45      9832

[[ 835 1120 1508]
 [ 433 1374 1322]
 [ 265  519 2456]]


In [14]:
INPUT_DIM = len(sword2index)
snli_model=NLI_LSTM(sembedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
snli_model.to(device)
snli_model.load_state_dict(torch.load("../models/model_lstm_nli_snli_ep_10.pt"))


#adam optimizer
optimizer = optim.Adam(snli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
# #train
# model = train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(snli_model,stest_dataloader,criterion)

100%|████████████████████████████████████████| 307/307 [00:00<00:00, 375.48it/s]


Test Loss:  0.7142397785613902  Test Accuracy:  0.7032776872964169
               precision    recall  f1-score   support

   entailment       0.68      0.81      0.74      3368
      neutral       0.77      0.53      0.62      3219
contradiction       0.69      0.77      0.73      3237

     accuracy                           0.70      9824
    macro avg       0.71      0.70      0.70      9824
 weighted avg       0.71      0.70      0.70      9824

[[2725  255  388]
 [ 796 1693  730]
 [ 484  262 2491]]
