# Import libraries

In [24]:
import random
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim

from tqdm import trange
from torch.autograd import Variable
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from nltk.tokenize import word_tokenize
import gensim.downloader

import warnings
warnings.filterwarnings('ignore')

In [25]:
word2vec_goog1e_news: gensim.models.keyedvectors.KeyedVectors = gensim.downloader.load('word2vec-google-news-300')
word2vec_goog1e_news.add_vector("<pad>", np.zeros(300))
pad_index = word2vec_goog1e_news.key_to_index["<pad>"]
embedding_weights = torch.FloatTensor(word2vec_goog1e_news.vectors)
vocab = word2vec_goog1e_news.key_to_index

In [26]:
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Using: cpu


# Import Dataset

In [27]:
def tokenize_pd_series_to_lsit(list_of_text):
    tokenized = []
    for sentence in list_of_text:
        tokenized.append(word_tokenize(sentence.lower()))
    return tokenized

def format_label(label):
    return torch.unsqueeze(torch.tensor(label.to_list()), axis=1).tolist()

def indexify(data):
    setences = []
    for sentence in data:
        s = [vocab[token] if token in vocab
            else vocab['UNK']
            for token in sentence]
        setences.append(s)
    return setences


In [28]:
training_data = pd.read_csv(filepath_or_buffer="TREC_dataset/modified_training_data.csv", sep=",") 
test_data = pd.read_csv(filepath_or_buffer="TREC_dataset/modified_test_data.csv", sep=",")

X = training_data["text"]
y = training_data["label-coarse"]
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=500)

X_test = test_data["text"]
y_test = test_data["label-coarse"]

X_train_lst = X_train.to_list()
X_val_lst = X_val.to_list()
X_test_lst = X_test.to_list()

X_train_tokenized = tokenize_pd_series_to_lsit(X_train_lst)
X_val_tokenized = tokenize_pd_series_to_lsit(X_val_lst)
X_test_tokenized = tokenize_pd_series_to_lsit(X_test_lst)

no_of_labels = 5

In [29]:
X_train_tokenized_indexified = indexify(X_train_tokenized)
X_val_tokenized_indexified = indexify(X_val_tokenized)
X_test_tokenized_indexified = indexify(X_test_tokenized)

y_train_formatted = format_label(y_train)
y_val_formatted = format_label(y_val)
y_test_formatted = format_label(y_test)

In [30]:
def data_iterator(sentences, labels, total_size: int, batch_size: int, shuffle: bool=False):
    # make a list that decides the order in which we go over the data- this avoids explicit shuffling of data
    order = list(range(total_size))
    if shuffle:
        random.seed(230)
        random.shuffle(order)

    # one pass over data
    for i in range((total_size+1)//batch_size):
        # fetch sentences and tags
        batch_sentences = [sentences[idx] for idx in order[i*batch_size:(i+1)*batch_size]]
        batch_tags = [labels[idx] for idx in order[i*batch_size:(i+1)*batch_size]]

        # compute length of longest sentence in batch
        batch_max_len = max([len(s) for s in batch_sentences])

        # prepare a numpy array with the data, initialising the data with pad_ind and all labels with -1
        # initialising labels to -1 differentiates tokens with tags from PADding tokens
        batch_data = vocab['<pad>']*np.ones((len(batch_sentences), batch_max_len))
        batch_labels = np.array(batch_tags).squeeze()

        # copy the data to the numpy array
        for j in range(len(batch_sentences)):
            cur_len = len(batch_sentences[j])
            batch_data[j][:cur_len] = batch_sentences[j]

        # since all data are indices, we convert them to torch LongTensors
        batch_data, batch_labels = torch.LongTensor(batch_data), torch.LongTensor(batch_labels)
        # convert them to Variables to record operations in the computational graph
        batch_data, batch_labels = Variable(batch_data), Variable(batch_labels)

        yield batch_data, batch_labels, batch_sentences

In [31]:
class Net(nn.Module):
    def __init__(self, embedding_weights, embedding_dim, lstm_hidden_dim, number_of_tags):
        super(Net, self).__init__()

        # the embedding takes as input the vocab_size and the embedding_dim
        self.embedding = nn.Embedding.from_pretrained(embedding_weights, padding_idx=pad_index)

        # the LSTM takes as input the size of its input (embedding_dim), its hidden size
        self.lstm = nn.LSTM(embedding_dim,
                            lstm_hidden_dim, batch_first=True)

        # the fully connected layer transforms the output to give the final output layer
        self.fc = nn.Linear(lstm_hidden_dim, number_of_tags)

    def forward(self, s, lengths):
        # apply the embedding layer that maps each token to its embedding
        s = self.embedding(s)

        # pack the sequences before feeding them to the LSTM
        packed_input = pack_padded_sequence(s, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)

        # unpack the sequences after passing through the LSTM
        padded_output, _ = pad_packed_sequence(packed_output, batch_first=True)
 
        s = torch.max(padded_output, dim=1)[0] # max pooling
        
        # apply the fully connected layer and obtain the output (before softmax) for each token
        s = self.fc(s)

        # apply log softmax on each token's output
        return F.log_softmax(s, dim=1)

In [32]:
def accuracy(outputs, labels):
    outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1)
    labels = labels.squeeze()
    # compare outputs with labels
    return np.sum([1 if first == second else 0 for first, second in zip(labels, outputs)]) / float(len(labels))

def loss_fn(outputs, labels):
    loss = F.cross_entropy(outputs, labels.squeeze())
    return loss

In [33]:
class RunningAverage:
    def __init__(self):
        self.steps = 0
        self.total = 0

    def update(self, val):
        self.total += val
        self.steps += 1

    def __call__(self):
        return self.total / float(self.steps)

In [34]:
def train(model, optimizer, loss_fn, data_iterator, num_steps):
    # set model to training mode
    model.train()

    # summary for current training loop and a running average object for loss
    loss_avg = RunningAverage()

    # Use tqdm for progress bar
    t = trange(num_steps)
    for i in t:
        # fetch the next training batch
        train_batch, labels_batch, _ = next(data_iterator)
        train_batch = train_batch.to(device)
        labels_batch = labels_batch.to(device)
        
        # compute model output and loss
        seq_lengths = torch.LongTensor(list(map(len, train_batch)))
        output_batch = model(train_batch, seq_lengths)
        loss = loss_fn(output_batch, labels_batch)

        # clear previous gradients, compute gradients of all variables wrt loss
        optimizer.zero_grad()
        loss.backward()

        # performs updates using calculated gradients
        optimizer.step()

        # Evaluate summaries only once in a while
        if i % 10 == 0:
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch = output_batch.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

        # update the average loss
        loss_avg.update(loss.item())
        t.set_postfix(loss='{:05.3f}'.format(loss_avg()))

def evaluate(model, loss_fn, data_iterator, num_steps):
    # set model to evaluation mode
    model.eval()

    loss_avg = RunningAverage()
    accuracy_avg = RunningAverage()

    # compute metrics over the dataset
    for _ in range(num_steps):
        # fetch the next evaluation batch
        data_batch, labels_batch, _ = next(data_iterator)
        data_batch = data_batch.to(device)
        labels_batch = labels_batch.to(device)

        # compute model output
        seq_lengths = torch.LongTensor(list(map(len, data_batch)))
        output_batch = model(data_batch, seq_lengths)
        loss = loss_fn(output_batch, labels_batch)
        loss_avg.update(loss.item())
        accuracy_val = accuracy(output_batch, labels_batch)
        accuracy_avg.update(accuracy_val)
        
        # extract data from torch Variable, move to cpu, convert to numpy arrays
        output_batch = output_batch.data.cpu().numpy()
        labels_batch = labels_batch.data.cpu().numpy()

    print(f"{loss_avg()=}")
    print(f"{accuracy_avg()=}")
    
def train_and_evaluate(
        model,
        train_sentences,
        train_labels,
        val_sentences,
        val_labels,
        num_epochs: int,
        batch_size: int,
        optimizer,
        loss_fn
):
    for epoch in range(num_epochs):
        # Run one epoch
        print("Epoch {}/{}".format(epoch + 1, num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        num_steps = (len(train_sentences) + 1) // batch_size
        train_data_iterator = data_iterator(train_sentences, train_labels, len(train_sentences), batch_size, shuffle=True)
        train(model, optimizer, loss_fn, train_data_iterator,num_steps)

        # Evaluate for one epoch on validation set
        num_steps = (len(val_sentences) + 1) // batch_size
        val_data_iterator = data_iterator(val_sentences, val_labels, len(val_sentences), batch_size, shuffle=False)
        evaluate(model, loss_fn, val_data_iterator, num_steps)

In [35]:
model = Net(embedding_weights, 300, 5, no_of_labels).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

if (os.path.isfile("model_weights_max_pooling.pth")):
    model.load_state_dict(torch.load('model_weights_max_pooling.pth'))
else:
    train_and_evaluate(model, X_train_tokenized_indexified , y_train_formatted , X_val_tokenized_indexified  , y_val_formatted, 100, 32, optimizer, loss_fn)
    torch.save(model.state_dict(), 'model_weights_max_pooling.pth')

Epoch 1/100


100%|██████████| 154/154 [00:02<00:00, 57.08it/s, loss=1.490]


loss_avg()=1.3798008839289346
accuracy_avg()=0.41041666666666665
Epoch 2/100


100%|██████████| 154/154 [00:02<00:00, 62.81it/s, loss=1.388]


loss_avg()=1.3442778587341309
accuracy_avg()=0.41041666666666665
Epoch 3/100


100%|██████████| 154/154 [00:02<00:00, 55.12it/s, loss=1.340]


loss_avg()=1.2637571652730306
accuracy_avg()=0.42916666666666664
Epoch 4/100


100%|██████████| 154/154 [00:02<00:00, 62.07it/s, loss=1.164]


loss_avg()=1.0765936374664307
accuracy_avg()=0.5583333333333333
Epoch 5/100


100%|██████████| 154/154 [00:02<00:00, 57.14it/s, loss=1.009]


loss_avg()=0.9581860860188802
accuracy_avg()=0.6291666666666667
Epoch 6/100


100%|██████████| 154/154 [00:02<00:00, 52.61it/s, loss=0.904]


loss_avg()=0.8764602899551391
accuracy_avg()=0.6520833333333333
Epoch 7/100


100%|██████████| 154/154 [00:02<00:00, 60.13it/s, loss=0.825]


loss_avg()=0.8112151940663656
accuracy_avg()=0.7
Epoch 8/100


100%|██████████| 154/154 [00:02<00:00, 65.64it/s, loss=0.759]


loss_avg()=0.7622329314549764
accuracy_avg()=0.725
Epoch 9/100


100%|██████████| 154/154 [00:02<00:00, 68.81it/s, loss=0.707]


loss_avg()=0.7246369481086731
accuracy_avg()=0.74375
Epoch 10/100


100%|██████████| 154/154 [00:02<00:00, 59.20it/s, loss=0.672]


loss_avg()=0.703320582707723
accuracy_avg()=0.74375
Epoch 11/100


100%|██████████| 154/154 [00:02<00:00, 73.79it/s, loss=0.636]


loss_avg()=0.6869298617045084
accuracy_avg()=0.75
Epoch 12/100


100%|██████████| 154/154 [00:02<00:00, 71.99it/s, loss=0.618]


loss_avg()=0.6443212270736695
accuracy_avg()=0.7729166666666667
Epoch 13/100


100%|██████████| 154/154 [00:02<00:00, 72.43it/s, loss=0.587]


loss_avg()=0.6317625085512797
accuracy_avg()=0.7770833333333333
Epoch 14/100


100%|██████████| 154/154 [00:02<00:00, 67.17it/s, loss=0.568]


loss_avg()=0.6205099205176036
accuracy_avg()=0.7791666666666667
Epoch 15/100


100%|██████████| 154/154 [00:02<00:00, 73.09it/s, loss=0.551]


loss_avg()=0.6066420674324036
accuracy_avg()=0.7833333333333333
Epoch 16/100


100%|██████████| 154/154 [00:02<00:00, 71.90it/s, loss=0.532]


loss_avg()=0.5954086641470592
accuracy_avg()=0.7895833333333333
Epoch 17/100


100%|██████████| 154/154 [00:02<00:00, 74.77it/s, loss=0.516]


loss_avg()=0.5948229213555654
accuracy_avg()=0.7833333333333333
Epoch 18/100


100%|██████████| 154/154 [00:02<00:00, 63.01it/s, loss=0.500]


loss_avg()=0.5687028884887695
accuracy_avg()=0.7979166666666667
Epoch 19/100


100%|██████████| 154/154 [00:02<00:00, 66.14it/s, loss=0.486]


loss_avg()=0.5627172986666361
accuracy_avg()=0.8041666666666667
Epoch 20/100


100%|██████████| 154/154 [00:02<00:00, 66.88it/s, loss=0.474]


loss_avg()=0.5569022536277771
accuracy_avg()=0.8125
Epoch 21/100


100%|██████████| 154/154 [00:02<00:00, 69.07it/s, loss=0.464]


loss_avg()=0.554031103849411
accuracy_avg()=0.80625
Epoch 22/100


100%|██████████| 154/154 [00:02<00:00, 67.93it/s, loss=0.465]


loss_avg()=0.5910986423492431
accuracy_avg()=0.7791666666666667
Epoch 23/100


100%|██████████| 154/154 [00:02<00:00, 68.34it/s, loss=0.443]


loss_avg()=0.5355316738287608
accuracy_avg()=0.8208333333333333
Epoch 24/100


100%|██████████| 154/154 [00:02<00:00, 65.88it/s, loss=0.430]


loss_avg()=0.528932144244512
accuracy_avg()=0.825
Epoch 25/100


100%|██████████| 154/154 [00:02<00:00, 75.82it/s, loss=0.421]


loss_avg()=0.5176396052042643
accuracy_avg()=0.8291666666666667
Epoch 26/100


100%|██████████| 154/154 [00:02<00:00, 65.06it/s, loss=0.409]


loss_avg()=0.5161626716454824
accuracy_avg()=0.8270833333333333
Epoch 27/100


100%|██████████| 154/154 [00:02<00:00, 66.22it/s, loss=0.402]


loss_avg()=0.5052206377188365
accuracy_avg()=0.8354166666666667
Epoch 28/100


100%|██████████| 154/154 [00:02<00:00, 73.95it/s, loss=0.401]


loss_avg()=0.5015108048915863
accuracy_avg()=0.8333333333333334
Epoch 29/100


100%|██████████| 154/154 [00:02<00:00, 71.71it/s, loss=0.393]


loss_avg()=0.4948892295360565
accuracy_avg()=0.8354166666666667
Epoch 30/100


100%|██████████| 154/154 [00:02<00:00, 68.52it/s, loss=0.381]


loss_avg()=0.4853109876314799
accuracy_avg()=0.83125
Epoch 31/100


100%|██████████| 154/154 [00:02<00:00, 70.16it/s, loss=0.379]


loss_avg()=0.4854000429312388
accuracy_avg()=0.83125
Epoch 32/100


100%|██████████| 154/154 [00:02<00:00, 69.17it/s, loss=0.366]


loss_avg()=0.5084718565146128
accuracy_avg()=0.8270833333333333
Epoch 33/100


100%|██████████| 154/154 [00:02<00:00, 67.13it/s, loss=0.384]


loss_avg()=0.49086901744206746
accuracy_avg()=0.8270833333333333
Epoch 34/100


100%|██████████| 154/154 [00:02<00:00, 60.78it/s, loss=0.353]


loss_avg()=0.48065735697746276
accuracy_avg()=0.83125
Epoch 35/100


100%|██████████| 154/154 [00:02<00:00, 65.90it/s, loss=0.347]


loss_avg()=0.48389073411623634
accuracy_avg()=0.8291666666666667
Epoch 36/100


100%|██████████| 154/154 [00:02<00:00, 71.84it/s, loss=0.340]


loss_avg()=0.47183237671852113
accuracy_avg()=0.8354166666666667
Epoch 37/100


100%|██████████| 154/154 [00:02<00:00, 64.12it/s, loss=0.341]


loss_avg()=0.48389280438423155
accuracy_avg()=0.8291666666666667
Epoch 38/100


100%|██████████| 154/154 [00:02<00:00, 76.88it/s, loss=0.336]


loss_avg()=0.48860645095507305
accuracy_avg()=0.825
Epoch 39/100


100%|██████████| 154/154 [00:02<00:00, 74.27it/s, loss=0.325]


loss_avg()=0.4965680023034414
accuracy_avg()=0.8354166666666667
Epoch 40/100


100%|██████████| 154/154 [00:02<00:00, 71.48it/s, loss=0.320]


loss_avg()=0.4946302056312561
accuracy_avg()=0.825
Epoch 41/100


100%|██████████| 154/154 [00:02<00:00, 73.09it/s, loss=0.315]


loss_avg()=0.504873804251353
accuracy_avg()=0.8270833333333333
Epoch 42/100


100%|██████████| 154/154 [00:02<00:00, 73.26it/s, loss=0.319]


loss_avg()=0.5023643513520558
accuracy_avg()=0.8270833333333333
Epoch 43/100


100%|██████████| 154/154 [00:02<00:00, 61.75it/s, loss=0.312]


loss_avg()=0.4957024296124776
accuracy_avg()=0.825
Epoch 44/100


100%|██████████| 154/154 [00:02<00:00, 68.75it/s, loss=0.350]


loss_avg()=0.501507314046224
accuracy_avg()=0.8270833333333333
Epoch 45/100


100%|██████████| 154/154 [00:02<00:00, 70.62it/s, loss=0.303]


loss_avg()=0.5004149456818898
accuracy_avg()=0.8375
Epoch 46/100


100%|██████████| 154/154 [00:02<00:00, 60.16it/s, loss=0.297]


loss_avg()=0.4990192363659541
accuracy_avg()=0.8458333333333333
Epoch 47/100


100%|██████████| 154/154 [00:02<00:00, 59.57it/s, loss=0.295]


loss_avg()=0.5156989594300588
accuracy_avg()=0.8291666666666667
Epoch 48/100


100%|██████████| 154/154 [00:02<00:00, 60.54it/s, loss=0.288]


loss_avg()=0.49862363934516907
accuracy_avg()=0.8375
Epoch 49/100


100%|██████████| 154/154 [00:02<00:00, 71.49it/s, loss=0.284]


loss_avg()=0.4967386782169342
accuracy_avg()=0.8333333333333334
Epoch 50/100


100%|██████████| 154/154 [00:02<00:00, 60.98it/s, loss=0.273]


loss_avg()=0.491346408923467
accuracy_avg()=0.8354166666666667
Epoch 51/100


100%|██████████| 154/154 [00:02<00:00, 67.66it/s, loss=0.279]


loss_avg()=0.5075637936592102
accuracy_avg()=0.8270833333333333
Epoch 52/100


100%|██████████| 154/154 [00:02<00:00, 68.71it/s, loss=0.271]


loss_avg()=0.5015876253445943
accuracy_avg()=0.8354166666666667
Epoch 53/100


100%|██████████| 154/154 [00:02<00:00, 72.86it/s, loss=0.384]


loss_avg()=0.518190469344457
accuracy_avg()=0.84375
Epoch 54/100


100%|██████████| 154/154 [00:02<00:00, 68.39it/s, loss=0.298]


loss_avg()=0.5136682987213135
accuracy_avg()=0.8375
Epoch 55/100


100%|██████████| 154/154 [00:02<00:00, 74.81it/s, loss=0.285]


loss_avg()=0.4931497673193614
accuracy_avg()=0.8479166666666667
Epoch 56/100


100%|██████████| 154/154 [00:02<00:00, 61.16it/s, loss=0.271]


loss_avg()=0.4995055158933004
accuracy_avg()=0.8479166666666667
Epoch 57/100


100%|██████████| 154/154 [00:02<00:00, 67.18it/s, loss=0.268]


loss_avg()=0.5006751596927643
accuracy_avg()=0.84375
Epoch 58/100


100%|██████████| 154/154 [00:02<00:00, 64.27it/s, loss=0.259]


loss_avg()=0.4952219307422638
accuracy_avg()=0.85
Epoch 59/100


100%|██████████| 154/154 [00:02<00:00, 63.58it/s, loss=0.257]


loss_avg()=0.48640621701876324
accuracy_avg()=0.8520833333333333
Epoch 60/100


100%|██████████| 154/154 [00:02<00:00, 68.09it/s, loss=0.253]


loss_avg()=0.48890081445376077
accuracy_avg()=0.8458333333333333
Epoch 61/100


100%|██████████| 154/154 [00:02<00:00, 72.71it/s, loss=0.246]


loss_avg()=0.49238970279693606
accuracy_avg()=0.8520833333333333
Epoch 62/100


100%|██████████| 154/154 [00:02<00:00, 70.95it/s, loss=0.243]


loss_avg()=0.49493619402249656
accuracy_avg()=0.8520833333333333
Epoch 63/100


100%|██████████| 154/154 [00:02<00:00, 72.33it/s, loss=0.241]


loss_avg()=0.4885060985883077
accuracy_avg()=0.8520833333333333
Epoch 64/100


100%|██████████| 154/154 [00:02<00:00, 73.27it/s, loss=0.236]


loss_avg()=0.49188764492670695
accuracy_avg()=0.85
Epoch 65/100


100%|██████████| 154/154 [00:02<00:00, 66.56it/s, loss=0.233]


loss_avg()=0.48771565755208335
accuracy_avg()=0.85625
Epoch 66/100


100%|██████████| 154/154 [00:02<00:00, 53.75it/s, loss=0.230]


loss_avg()=0.4970814327398936
accuracy_avg()=0.8520833333333333
Epoch 67/100


100%|██████████| 154/154 [00:02<00:00, 53.60it/s, loss=0.234]


loss_avg()=0.5010046839714051
accuracy_avg()=0.85
Epoch 68/100


100%|██████████| 154/154 [00:02<00:00, 58.82it/s, loss=0.227]


loss_avg()=0.5001096347967784
accuracy_avg()=0.85
Epoch 69/100


100%|██████████| 154/154 [00:02<00:00, 60.34it/s, loss=0.223]


loss_avg()=0.512130731344223
accuracy_avg()=0.8520833333333333
Epoch 70/100


100%|██████████| 154/154 [00:02<00:00, 61.05it/s, loss=0.223]


loss_avg()=0.5050235529740651
accuracy_avg()=0.8541666666666666
Epoch 71/100


100%|██████████| 154/154 [00:02<00:00, 59.51it/s, loss=0.223]


loss_avg()=0.5333589911460876
accuracy_avg()=0.8333333333333334
Epoch 72/100


100%|██████████| 154/154 [00:02<00:00, 60.46it/s, loss=0.227]


loss_avg()=0.5063299556573232
accuracy_avg()=0.8520833333333333
Epoch 73/100


100%|██████████| 154/154 [00:02<00:00, 68.98it/s, loss=0.216]


loss_avg()=0.5251919011274974
accuracy_avg()=0.8416666666666667
Epoch 74/100


100%|██████████| 154/154 [00:02<00:00, 69.06it/s, loss=0.213]


loss_avg()=0.5200437764326732
accuracy_avg()=0.8479166666666667
Epoch 75/100


100%|██████████| 154/154 [00:02<00:00, 62.55it/s, loss=0.208]


loss_avg()=0.5207803944746653
accuracy_avg()=0.85
Epoch 76/100


100%|██████████| 154/154 [00:02<00:00, 74.48it/s, loss=0.212]


loss_avg()=0.5692525704701742
accuracy_avg()=0.8291666666666667
Epoch 77/100


100%|██████████| 154/154 [00:02<00:00, 74.83it/s, loss=0.226]


loss_avg()=0.5149901727835338
accuracy_avg()=0.84375
Epoch 78/100


100%|██████████| 154/154 [00:02<00:00, 73.61it/s, loss=0.214]


loss_avg()=0.518990836540858
accuracy_avg()=0.85
Epoch 79/100


100%|██████████| 154/154 [00:02<00:00, 69.18it/s, loss=0.205]


loss_avg()=0.5257384498914083
accuracy_avg()=0.8354166666666667
Epoch 80/100


100%|██████████| 154/154 [00:02<00:00, 60.03it/s, loss=0.200]


loss_avg()=0.5169296979904174
accuracy_avg()=0.84375
Epoch 81/100


100%|██████████| 154/154 [00:02<00:00, 65.66it/s, loss=0.210]


loss_avg()=0.5118446369965871
accuracy_avg()=0.85
Epoch 82/100


100%|██████████| 154/154 [00:02<00:00, 53.69it/s, loss=0.201]


loss_avg()=0.5054792096217473
accuracy_avg()=0.84375
Epoch 83/100


100%|██████████| 154/154 [00:02<00:00, 58.11it/s, loss=0.198]


loss_avg()=0.5220555027325948
accuracy_avg()=0.85
Epoch 84/100


100%|██████████| 154/154 [00:02<00:00, 64.61it/s, loss=0.195]


loss_avg()=0.5235499630371729
accuracy_avg()=0.8458333333333333
Epoch 85/100


100%|██████████| 154/154 [00:02<00:00, 64.28it/s, loss=0.190]


loss_avg()=0.5232146312793096
accuracy_avg()=0.8354166666666667
Epoch 86/100


100%|██████████| 154/154 [00:02<00:00, 65.95it/s, loss=0.205]


loss_avg()=0.5533677736918131
accuracy_avg()=0.8375
Epoch 87/100


100%|██████████| 154/154 [00:02<00:00, 65.35it/s, loss=0.191]


loss_avg()=0.5429877241452535
accuracy_avg()=0.8354166666666667
Epoch 88/100


100%|██████████| 154/154 [00:02<00:00, 62.70it/s, loss=0.189]


loss_avg()=0.5108259409666062
accuracy_avg()=0.84375
Epoch 89/100


100%|██████████| 154/154 [00:02<00:00, 59.99it/s, loss=0.195]


loss_avg()=0.555565865834554
accuracy_avg()=0.8375
Epoch 90/100


100%|██████████| 154/154 [00:02<00:00, 64.17it/s, loss=0.181]


loss_avg()=0.544446180264155
accuracy_avg()=0.8416666666666667
Epoch 91/100


100%|██████████| 154/154 [00:02<00:00, 62.45it/s, loss=0.177]


loss_avg()=0.5341030915578207
accuracy_avg()=0.8395833333333333
Epoch 92/100


100%|██████████| 154/154 [00:02<00:00, 59.94it/s, loss=0.176]


loss_avg()=0.542215887705485
accuracy_avg()=0.8416666666666667
Epoch 93/100


100%|██████████| 154/154 [00:02<00:00, 64.03it/s, loss=0.189]


loss_avg()=0.5338416794935862
accuracy_avg()=0.8416666666666667
Epoch 94/100


100%|██████████| 154/154 [00:02<00:00, 57.87it/s, loss=0.180]


loss_avg()=0.5515288929144542
accuracy_avg()=0.8458333333333333
Epoch 95/100


100%|██████████| 154/154 [00:02<00:00, 52.22it/s, loss=0.181]


loss_avg()=0.5273735413948695
accuracy_avg()=0.8479166666666667
Epoch 96/100


100%|██████████| 154/154 [00:02<00:00, 58.21it/s, loss=0.190]


loss_avg()=0.5567000766595205
accuracy_avg()=0.8395833333333333
Epoch 97/100


100%|██████████| 154/154 [00:02<00:00, 63.78it/s, loss=0.221]


loss_avg()=0.5522946745157242
accuracy_avg()=0.8395833333333333
Epoch 98/100


100%|██████████| 154/154 [00:02<00:00, 72.26it/s, loss=0.181]


loss_avg()=0.5277779658635458
accuracy_avg()=0.8458333333333333
Epoch 99/100


100%|██████████| 154/154 [00:02<00:00, 67.39it/s, loss=0.172]


loss_avg()=0.5273652702569962
accuracy_avg()=0.84375
Epoch 100/100


100%|██████████| 154/154 [00:02<00:00, 74.04it/s, loss=0.170]


loss_avg()=0.5296146074930826
accuracy_avg()=0.8458333333333333


## Final Test Accuracy

In [36]:
# Simple check with test dataset
model.eval()
test_data_iterator = data_iterator(X_test_tokenized_indexified, y_test_formatted, len(X_test_tokenized_indexified), len(X_test_tokenized_indexified), shuffle=False)
test_batch, labels_batch, test_sentences = next(test_data_iterator)

seq_lengths = torch.LongTensor(list(map(len, test_batch)))
output_batch = model(test_batch.to(device),seq_lengths)
final_test_accuracy = accuracy(output_batch, labels_batch.to(device))
print(f"{final_test_accuracy=}")

final_test_accuracy=0.866


In [37]:
def print_sentence_label(sentence: str) -> int:
    model.eval()
    sentence_tokenized = word_tokenize(sentence.lower())
    sentence_as_id = [
        vocab[token] if token in vocab
        else vocab['UNK']
        for token in sentence_tokenized
    ]
    seq_lengths = torch.LongTensor([len(sentence_as_id)])
    input = torch.tensor(sentence_as_id).unsqueeze(0).to(device)
    output = model(input, seq_lengths).to(device)
    label = np.argmax(output.detach().cpu().numpy())
    print(f"sentence = {sentence}, label = {label}")

# Checking results
print_sentence_label("What is a squirrel?")
print_sentence_label("Is Singapore located in Southeast Asia?")
print_sentence_label("Is Singapore in China?")
print_sentence_label("Name 11 famous martyrs .")
print_sentence_label("What ISPs exist in the Caribbean ?")
print_sentence_label("How many cars are manufactured every day?")

sentence = What is a squirrel?, label = 0
sentence = Is Singapore located in Southeast Asia?, label = 3
sentence = Is Singapore in China?, label = 1
sentence = Name 11 famous martyrs ., label = 4
sentence = What ISPs exist in the Caribbean ?, label = 0
sentence = How many cars are manufactured every day?, label = 4
