# Import libraries

In [51]:
import random
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim

from tqdm import trange
from torch.autograd import Variable
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
import gensim.downloader

import warnings
warnings.filterwarnings('ignore')

In [52]:
word2vec_goog1e_news: gensim.models.keyedvectors.KeyedVectors = gensim.downloader.load('word2vec-google-news-300')
word2vec_goog1e_news.add_vector("<pad>", np.zeros(300))
pad_index = word2vec_goog1e_news.key_to_index["<pad>"]
embedding_weights = torch.FloatTensor(word2vec_goog1e_news.vectors)
vocab = word2vec_goog1e_news.key_to_index

In [53]:
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Using: cpu


# Import Dataset

In [54]:
def tokenize_pd_series_to_lsit(list_of_text):
    tokenized = []
    for sentence in list_of_text:
        tokenized.append(word_tokenize(sentence.lower()))
    return tokenized

def format_label(label):
    return torch.unsqueeze(torch.tensor(label.to_list()), axis=1).tolist()

def indexify(data):
    setences = []
    for sentence in data:
        s = [vocab[token] if token in vocab
            else vocab['UNK']
            for token in sentence]
        setences.append(s)
    return setences


In [55]:
training_data = pd.read_csv(filepath_or_buffer="TREC_dataset/modified_training_data.csv", sep=",") 
test_data = pd.read_csv(filepath_or_buffer="TREC_dataset/modified_test_data.csv", sep=",")

X = training_data["text"]
y = training_data["label-coarse"]
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=500)

X_test = test_data["text"]
y_test = test_data["label-coarse"]

X_train_lst = X_train.to_list()
X_val_lst = X_val.to_list()
X_test_lst = X_test.to_list()

X_train_tokenized = tokenize_pd_series_to_lsit(X_train_lst)
X_val_tokenized = tokenize_pd_series_to_lsit(X_val_lst)
X_test_tokenized = tokenize_pd_series_to_lsit(X_test_lst)

no_of_labels = 5

In [56]:
X_train_tokenized_indexified = indexify(X_train_tokenized)
X_val_tokenized_indexified = indexify(X_val_tokenized)
X_test_tokenized_indexified = indexify(X_test_tokenized)

y_train_formatted = format_label(y_train)
y_val_formatted = format_label(y_val)
y_test_formatted = format_label(y_test)

In [57]:
def data_iterator(sentences, labels, total_size: int, batch_size: int, shuffle: bool=False):
    # make a list that decides the order in which we go over the data- this avoids explicit shuffling of data
    order = list(range(total_size))
    if shuffle:
        random.seed(230)
        random.shuffle(order)

    # one pass over data
    for i in range((total_size+1)//batch_size):
        # fetch sentences and tags
        batch_sentences = [sentences[idx] for idx in order[i*batch_size:(i+1)*batch_size]]
        batch_tags = [labels[idx] for idx in order[i*batch_size:(i+1)*batch_size]]

        # compute length of longest sentence in batch
        batch_max_len = max([len(s) for s in batch_sentences])

        # prepare a numpy array with the data, initialising the data with pad_ind and all labels with -1
        # initialising labels to -1 differentiates tokens with tags from PADding tokens
        batch_data = vocab['<pad>']*np.ones((len(batch_sentences), batch_max_len))
        batch_labels = np.array(batch_tags).squeeze()

        # copy the data to the numpy array
        for j in range(len(batch_sentences)):
            cur_len = len(batch_sentences[j])
            batch_data[j][:cur_len] = batch_sentences[j]

        # since all data are indices, we convert them to torch LongTensors
        batch_data, batch_labels = torch.LongTensor(batch_data), torch.LongTensor(batch_labels)
        # convert them to Variables to record operations in the computational graph
        batch_data, batch_labels = Variable(batch_data), Variable(batch_labels)

        yield batch_data, batch_labels, batch_sentences

In [58]:
class Net(nn.Module):
    def __init__(self, embedding_weights, embedding_dim, lstm_hidden_dim, number_of_tags):
        super(Net, self).__init__()

        # the embedding takes as input the vocab_size and the embedding_dim
        self.embedding = nn.Embedding.from_pretrained(embedding_weights, padding_idx=pad_index)

        # the LSTM takes as input the size of its input (embedding_dim), its hidden size
        self.lstm = nn.LSTM(embedding_dim,
                            lstm_hidden_dim, batch_first=True)

        # the fully connected layer transforms the output to give the final output layer
        self.fc = nn.Linear(lstm_hidden_dim, number_of_tags)

    def forward(self, s, lengths):
        # apply the embedding layer that maps each token to its embedding
        s = self.embedding(s)

        # pack the sequences before feeding them to the LSTM
        packed_input = pack_padded_sequence(s, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)

        # unpack the sequences after passing through the LSTM
        padded_output, _ = pad_packed_sequence(packed_output, batch_first=True)
 
        s = torch.max(padded_output, dim=1)[0] # max pooling
        
        # apply the fully connected layer and obtain the output (before softmax) for each token
        s = self.fc(s)

        # apply log softmax on each token's output
        return F.log_softmax(s, dim=1)

In [59]:
def accuracy(outputs, labels):
    outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1)
    labels = labels.squeeze()
    # compare outputs with labels
    return np.sum([1 if first == second else 0 for first, second in zip(labels, outputs)]) / float(len(labels))

def loss_fn(outputs, labels):
    loss = F.cross_entropy(outputs, labels.squeeze())
    return loss

In [60]:
class RunningAverage:
    def __init__(self):
        self.steps = 0
        self.total = 0

    def update(self, val):
        self.total += val
        self.steps += 1

    def __call__(self):
        return self.total / float(self.steps)

In [61]:
def train(model, optimizer, loss_fn, data_iterator, num_steps):
    # set model to training mode
    model.train()

    # summary for current training loop and a running average object for loss
    loss_avg = RunningAverage()

    # Use tqdm for progress bar
    t = trange(num_steps)
    for i in t:
        # fetch the next training batch
        train_batch, labels_batch, _ = next(data_iterator)
        train_batch = train_batch.to(device)
        labels_batch = labels_batch.to(device)
        
        # compute model output and loss
        seq_lengths = torch.LongTensor(list(map(len, train_batch)))
        output_batch = model(train_batch, seq_lengths)
        loss = loss_fn(output_batch, labels_batch)

        # clear previous gradients, compute gradients of all variables wrt loss
        optimizer.zero_grad()
        loss.backward()

        # performs updates using calculated gradients
        optimizer.step()

        # update the average loss
        loss_avg.update(loss.item())
        t.set_postfix(loss='{:05.3f}'.format(loss_avg()))

def evaluate(model, loss_fn, data_iterator, num_steps):
    # set model to evaluation mode
    model.eval()

    loss_avg = RunningAverage()
    accuracy_avg = RunningAverage()

    # compute metrics over the dataset
    for _ in range(num_steps):
        # fetch the next evaluation batch
        data_batch, labels_batch, _ = next(data_iterator)
        data_batch = data_batch.to(device)
        labels_batch = labels_batch.to(device)

        # compute model output
        seq_lengths = torch.LongTensor(list(map(len, data_batch)))
        output_batch = model(data_batch, seq_lengths)
        loss = loss_fn(output_batch, labels_batch)
        loss_avg.update(loss.item())
        accuracy_val = accuracy(output_batch, labels_batch)
        accuracy_avg.update(accuracy_val)

    print(f"{loss_avg()=}")
    print(f"{accuracy_avg()=}")
    
def train_and_evaluate(
        model,
        train_sentences,
        train_labels,
        val_sentences,
        val_labels,
        num_epochs: int,
        batch_size: int,
        optimizer,
        loss_fn
):
    for epoch in range(num_epochs):
        # Run one epoch
        print("Epoch {}/{}".format(epoch + 1, num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        num_steps = (len(train_sentences) + 1) // batch_size
        train_data_iterator = data_iterator(train_sentences, train_labels, len(train_sentences), batch_size, shuffle=True)
        train(model, optimizer, loss_fn, train_data_iterator,num_steps)

        # Evaluate for one epoch on validation set
        num_steps = (len(val_sentences) + 1) // batch_size
        val_data_iterator = data_iterator(val_sentences, val_labels, len(val_sentences), batch_size, shuffle=False)
        evaluate(model, loss_fn, val_data_iterator, num_steps)

In [62]:
model = Net(embedding_weights, 300, 5, no_of_labels).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

if (os.path.isfile("model_weights_max_pooling.pth")):
    model.load_state_dict(torch.load('model_weights_max_pooling.pth'))
else:
    train_and_evaluate(model, X_train_tokenized_indexified , y_train_formatted , X_val_tokenized_indexified  , y_val_formatted, 100, 32, optimizer, loss_fn)
    torch.save(model.state_dict(), 'model_weights_max_pooling.pth')

Epoch 1/100


100%|██████████| 154/154 [00:02<00:00, 60.49it/s, loss=1.571]


loss_avg()=1.4334396680196126
accuracy_avg()=0.39375
Epoch 2/100


100%|██████████| 154/154 [00:02<00:00, 65.37it/s, loss=1.381]


loss_avg()=1.255910841623942
accuracy_avg()=0.4666666666666667
Epoch 3/100


100%|██████████| 154/154 [00:02<00:00, 63.69it/s, loss=1.141]


loss_avg()=0.9791324694951375
accuracy_avg()=0.6416666666666667
Epoch 4/100


100%|██████████| 154/154 [00:02<00:00, 56.84it/s, loss=0.932]


loss_avg()=0.853459898630778
accuracy_avg()=0.65
Epoch 5/100


100%|██████████| 154/154 [00:02<00:00, 61.29it/s, loss=0.825]


loss_avg()=0.7597200314203898
accuracy_avg()=0.675
Epoch 6/100


100%|██████████| 154/154 [00:02<00:00, 64.06it/s, loss=0.749]


loss_avg()=0.7052918950716655
accuracy_avg()=0.7020833333333333
Epoch 7/100


100%|██████████| 154/154 [00:02<00:00, 62.40it/s, loss=0.706]


loss_avg()=0.6629039764404296
accuracy_avg()=0.7520833333333333
Epoch 8/100


100%|██████████| 154/154 [00:02<00:00, 60.20it/s, loss=0.659]


loss_avg()=0.6412413636843364
accuracy_avg()=0.7583333333333333
Epoch 9/100


100%|██████████| 154/154 [00:02<00:00, 60.19it/s, loss=0.626]


loss_avg()=0.6160445610682169
accuracy_avg()=0.7604166666666666
Epoch 10/100


100%|██████████| 154/154 [00:02<00:00, 62.37it/s, loss=0.598]


loss_avg()=0.5988352596759796
accuracy_avg()=0.7604166666666666
Epoch 11/100


100%|██████████| 154/154 [00:02<00:00, 62.84it/s, loss=0.574]


loss_avg()=0.580945897102356
accuracy_avg()=0.76875
Epoch 12/100


100%|██████████| 154/154 [00:02<00:00, 62.96it/s, loss=0.550]


loss_avg()=0.5673103729883829
accuracy_avg()=0.7708333333333334
Epoch 13/100


100%|██████████| 154/154 [00:02<00:00, 64.03it/s, loss=0.531]


loss_avg()=0.5592840472857158
accuracy_avg()=0.7770833333333333
Epoch 14/100


100%|██████████| 154/154 [00:02<00:00, 60.36it/s, loss=0.513]


loss_avg()=0.5536128540833791
accuracy_avg()=0.7770833333333333
Epoch 15/100


100%|██████████| 154/154 [00:02<00:00, 64.21it/s, loss=0.496]


loss_avg()=0.540140547355016
accuracy_avg()=0.78125
Epoch 16/100


100%|██████████| 154/154 [00:02<00:00, 73.46it/s, loss=0.481]


loss_avg()=0.5360224664211273
accuracy_avg()=0.7791666666666667
Epoch 17/100


100%|██████████| 154/154 [00:02<00:00, 68.44it/s, loss=0.468]


loss_avg()=0.5213641623655955
accuracy_avg()=0.79375
Epoch 18/100


100%|██████████| 154/154 [00:02<00:00, 73.42it/s, loss=0.453]


loss_avg()=0.5165885229905446
accuracy_avg()=0.79375
Epoch 19/100


100%|██████████| 154/154 [00:02<00:00, 70.81it/s, loss=0.441]


loss_avg()=0.5056721429030101
accuracy_avg()=0.7875
Epoch 20/100


100%|██████████| 154/154 [00:02<00:00, 68.67it/s, loss=0.430]


loss_avg()=0.5003799597422282
accuracy_avg()=0.7958333333333333
Epoch 21/100


100%|██████████| 154/154 [00:02<00:00, 73.91it/s, loss=0.418]


loss_avg()=0.4924763103326162
accuracy_avg()=0.8020833333333334
Epoch 22/100


100%|██████████| 154/154 [00:02<00:00, 75.27it/s, loss=0.407]


loss_avg()=0.4891092419624329
accuracy_avg()=0.7958333333333333
Epoch 23/100


100%|██████████| 154/154 [00:02<00:00, 71.79it/s, loss=0.401]


loss_avg()=0.493764728307724
accuracy_avg()=0.8041666666666667
Epoch 24/100


100%|██████████| 154/154 [00:02<00:00, 74.71it/s, loss=0.388]


loss_avg()=0.48485292196273805
accuracy_avg()=0.8041666666666667
Epoch 25/100


100%|██████████| 154/154 [00:02<00:00, 73.93it/s, loss=0.383]


loss_avg()=0.475200488169988
accuracy_avg()=0.80625
Epoch 26/100


100%|██████████| 154/154 [00:02<00:00, 71.68it/s, loss=0.373]


loss_avg()=0.4711080809434255
accuracy_avg()=0.80625
Epoch 27/100


100%|██████████| 154/154 [00:02<00:00, 74.12it/s, loss=0.363]


loss_avg()=0.4674129903316498
accuracy_avg()=0.80625
Epoch 28/100


100%|██████████| 154/154 [00:02<00:00, 73.30it/s, loss=0.354]


loss_avg()=0.48639526069164274
accuracy_avg()=0.8041666666666667
Epoch 29/100


100%|██████████| 154/154 [00:02<00:00, 74.04it/s, loss=0.357]


loss_avg()=0.463215047121048
accuracy_avg()=0.8145833333333333
Epoch 30/100


100%|██████████| 154/154 [00:02<00:00, 74.12it/s, loss=0.343]


loss_avg()=0.4606945772965749
accuracy_avg()=0.8104166666666667
Epoch 31/100


100%|██████████| 154/154 [00:02<00:00, 72.79it/s, loss=0.335]


loss_avg()=0.46891224682331084
accuracy_avg()=0.8166666666666667
Epoch 32/100


100%|██████████| 154/154 [00:02<00:00, 69.25it/s, loss=0.335]


loss_avg()=0.45115579664707184
accuracy_avg()=0.8104166666666667
Epoch 33/100


100%|██████████| 154/154 [00:02<00:00, 72.80it/s, loss=0.323]


loss_avg()=0.44502198199431103
accuracy_avg()=0.8208333333333333
Epoch 34/100


100%|██████████| 154/154 [00:02<00:00, 72.48it/s, loss=0.317]


loss_avg()=0.43390544056892394
accuracy_avg()=0.8416666666666667
Epoch 35/100


100%|██████████| 154/154 [00:02<00:00, 74.71it/s, loss=0.312]


loss_avg()=0.43792554338773093
accuracy_avg()=0.84375
Epoch 36/100


100%|██████████| 154/154 [00:02<00:00, 75.27it/s, loss=0.305]


loss_avg()=0.4387392928202947
accuracy_avg()=0.8416666666666667
Epoch 37/100


100%|██████████| 154/154 [00:02<00:00, 71.58it/s, loss=0.305]


loss_avg()=0.44625896513462066
accuracy_avg()=0.8270833333333333
Epoch 38/100


100%|██████████| 154/154 [00:02<00:00, 73.43it/s, loss=0.306]


loss_avg()=0.4406376302242279
accuracy_avg()=0.8416666666666667
Epoch 39/100


100%|██████████| 154/154 [00:02<00:00, 71.99it/s, loss=0.294]


loss_avg()=0.4235330045223236
accuracy_avg()=0.8416666666666667
Epoch 40/100


100%|██████████| 154/154 [00:02<00:00, 73.17it/s, loss=0.291]


loss_avg()=0.4613529016574224
accuracy_avg()=0.8375
Epoch 41/100


100%|██████████| 154/154 [00:02<00:00, 73.76it/s, loss=0.297]


loss_avg()=0.433574178814888
accuracy_avg()=0.8416666666666667
Epoch 42/100


100%|██████████| 154/154 [00:02<00:00, 69.20it/s, loss=0.287]


loss_avg()=0.4308346221844355
accuracy_avg()=0.8458333333333333
Epoch 43/100


100%|██████████| 154/154 [00:02<00:00, 72.98it/s, loss=0.296]


loss_avg()=0.4442545572916667
accuracy_avg()=0.8416666666666667
Epoch 44/100


100%|██████████| 154/154 [00:02<00:00, 73.71it/s, loss=0.281]


loss_avg()=0.4343184341986974
accuracy_avg()=0.8375
Epoch 45/100


100%|██████████| 154/154 [00:02<00:00, 69.88it/s, loss=0.267]


loss_avg()=0.4450596198439598
accuracy_avg()=0.8354166666666667
Epoch 46/100


100%|██████████| 154/154 [00:02<00:00, 72.18it/s, loss=0.265]


loss_avg()=0.4349478170275688
accuracy_avg()=0.84375
Epoch 47/100


100%|██████████| 154/154 [00:02<00:00, 71.67it/s, loss=0.258]


loss_avg()=0.44789917320013045
accuracy_avg()=0.8395833333333333
Epoch 48/100


100%|██████████| 154/154 [00:02<00:00, 71.68it/s, loss=0.258]


loss_avg()=0.4268653243780136
accuracy_avg()=0.8541666666666666
Epoch 49/100


100%|██████████| 154/154 [00:02<00:00, 72.58it/s, loss=0.255]


loss_avg()=0.4507989843686422
accuracy_avg()=0.8375
Epoch 50/100


100%|██████████| 154/154 [00:02<00:00, 73.43it/s, loss=0.251]


loss_avg()=0.4486429909865061
accuracy_avg()=0.8520833333333333
Epoch 51/100


100%|██████████| 154/154 [00:02<00:00, 71.87it/s, loss=0.247]


loss_avg()=0.4329277008771896
accuracy_avg()=0.85625
Epoch 52/100


100%|██████████| 154/154 [00:02<00:00, 60.65it/s, loss=0.241]


loss_avg()=0.44834806124369303
accuracy_avg()=0.8479166666666667
Epoch 53/100


100%|██████████| 154/154 [00:02<00:00, 62.51it/s, loss=0.242]


loss_avg()=0.4465209235747655
accuracy_avg()=0.8479166666666667
Epoch 54/100


100%|██████████| 154/154 [00:02<00:00, 53.01it/s, loss=0.270]


loss_avg()=0.5390212466319402
accuracy_avg()=0.8375
Epoch 55/100


100%|██████████| 154/154 [00:02<00:00, 60.36it/s, loss=0.257]


loss_avg()=0.4082042306661606
accuracy_avg()=0.8541666666666666
Epoch 56/100


100%|██████████| 154/154 [00:02<00:00, 59.21it/s, loss=0.244]


loss_avg()=0.4161121234297752
accuracy_avg()=0.85625
Epoch 57/100


100%|██████████| 154/154 [00:02<00:00, 59.99it/s, loss=0.238]


loss_avg()=0.43374770184357964
accuracy_avg()=0.84375
Epoch 58/100


100%|██████████| 154/154 [00:02<00:00, 63.90it/s, loss=0.229]


loss_avg()=0.4349418779214223
accuracy_avg()=0.8479166666666667
Epoch 59/100


100%|██████████| 154/154 [00:02<00:00, 64.27it/s, loss=0.225]


loss_avg()=0.42901795109113056
accuracy_avg()=0.85
Epoch 60/100


100%|██████████| 154/154 [00:02<00:00, 67.40it/s, loss=0.221]


loss_avg()=0.4307725300391515
accuracy_avg()=0.85625
Epoch 61/100


100%|██████████| 154/154 [00:02<00:00, 65.32it/s, loss=0.218]


loss_avg()=0.4310761700073878
accuracy_avg()=0.8583333333333333
Epoch 62/100


100%|██████████| 154/154 [00:02<00:00, 65.26it/s, loss=0.213]


loss_avg()=0.434921932220459
accuracy_avg()=0.8583333333333333
Epoch 63/100


100%|██████████| 154/154 [00:02<00:00, 64.60it/s, loss=0.211]


loss_avg()=0.43792787194252014
accuracy_avg()=0.85
Epoch 64/100


100%|██████████| 154/154 [00:02<00:00, 61.53it/s, loss=0.209]


loss_avg()=0.4337731922666232
accuracy_avg()=0.85625
Epoch 65/100


100%|██████████| 154/154 [00:02<00:00, 63.65it/s, loss=0.206]


loss_avg()=0.43369911511739095
accuracy_avg()=0.8541666666666666
Epoch 66/100


100%|██████████| 154/154 [00:02<00:00, 64.49it/s, loss=0.202]


loss_avg()=0.4344914466142654
accuracy_avg()=0.85
Epoch 67/100


100%|██████████| 154/154 [00:02<00:00, 59.87it/s, loss=0.207]


loss_avg()=0.4467654978235563
accuracy_avg()=0.85
Epoch 68/100


100%|██████████| 154/154 [00:02<00:00, 59.28it/s, loss=0.222]


loss_avg()=0.44987585047880807
accuracy_avg()=0.8416666666666667
Epoch 69/100


100%|██████████| 154/154 [00:02<00:00, 58.42it/s, loss=0.205]


loss_avg()=0.45296814242998756
accuracy_avg()=0.8479166666666667
Epoch 70/100


100%|██████████| 154/154 [00:02<00:00, 59.60it/s, loss=0.206]


loss_avg()=0.4520209019382795
accuracy_avg()=0.85
Epoch 71/100


100%|██████████| 154/154 [00:02<00:00, 60.55it/s, loss=0.196]


loss_avg()=0.45228375097115836
accuracy_avg()=0.8479166666666667
Epoch 72/100


100%|██████████| 154/154 [00:02<00:00, 61.37it/s, loss=0.191]


loss_avg()=0.46350215474764506
accuracy_avg()=0.8541666666666666
Epoch 73/100


100%|██████████| 154/154 [00:02<00:00, 61.97it/s, loss=0.188]


loss_avg()=0.4664701193571091
accuracy_avg()=0.8541666666666666
Epoch 74/100


100%|██████████| 154/154 [00:02<00:00, 58.26it/s, loss=0.192]


loss_avg()=0.4696506440639496
accuracy_avg()=0.8458333333333333
Epoch 75/100


100%|██████████| 154/154 [00:02<00:00, 69.05it/s, loss=0.192]


loss_avg()=0.4546044677495956
accuracy_avg()=0.85625
Epoch 76/100


100%|██████████| 154/154 [00:02<00:00, 68.29it/s, loss=0.190]


loss_avg()=0.45694035092989604
accuracy_avg()=0.8541666666666666
Epoch 77/100


100%|██████████| 154/154 [00:02<00:00, 72.71it/s, loss=0.202]


loss_avg()=0.47419564227263133
accuracy_avg()=0.85
Epoch 78/100


100%|██████████| 154/154 [00:02<00:00, 75.16it/s, loss=0.193]


loss_avg()=0.4605573147535324
accuracy_avg()=0.8541666666666666
Epoch 79/100


100%|██████████| 154/154 [00:02<00:00, 72.30it/s, loss=0.187]


loss_avg()=0.465994065006574
accuracy_avg()=0.8520833333333333
Epoch 80/100


100%|██████████| 154/154 [00:02<00:00, 72.57it/s, loss=0.184]


loss_avg()=0.46257298489411675
accuracy_avg()=0.8541666666666666
Epoch 81/100


100%|██████████| 154/154 [00:02<00:00, 74.82it/s, loss=0.185]


loss_avg()=0.4644245515267054
accuracy_avg()=0.85
Epoch 82/100


100%|██████████| 154/154 [00:02<00:00, 74.07it/s, loss=0.176]


loss_avg()=0.4668502966562907
accuracy_avg()=0.85625
Epoch 83/100


100%|██████████| 154/154 [00:02<00:00, 67.33it/s, loss=0.182]


loss_avg()=0.4538414925336838
accuracy_avg()=0.8604166666666667
Epoch 84/100


100%|██████████| 154/154 [00:02<00:00, 72.09it/s, loss=0.171]


loss_avg()=0.4668213456869125
accuracy_avg()=0.8541666666666666
Epoch 85/100


100%|██████████| 154/154 [00:02<00:00, 73.50it/s, loss=0.167]


loss_avg()=0.4770985891421636
accuracy_avg()=0.8541666666666666
Epoch 86/100


100%|██████████| 154/154 [00:02<00:00, 74.72it/s, loss=0.163]


loss_avg()=0.476742230852445
accuracy_avg()=0.8541666666666666
Epoch 87/100


100%|██████████| 154/154 [00:02<00:00, 73.81it/s, loss=0.171]


loss_avg()=0.5039520134528478
accuracy_avg()=0.8520833333333333
Epoch 88/100


100%|██████████| 154/154 [00:02<00:00, 73.63it/s, loss=0.181]


loss_avg()=0.44912300010522205
accuracy_avg()=0.8708333333333333
Epoch 89/100


100%|██████████| 154/154 [00:02<00:00, 70.95it/s, loss=0.162]


loss_avg()=0.45888587335745495
accuracy_avg()=0.8583333333333333
Epoch 90/100


100%|██████████| 154/154 [00:02<00:00, 67.70it/s, loss=0.160]


loss_avg()=0.4849009205897649
accuracy_avg()=0.85625
Epoch 91/100


100%|██████████| 154/154 [00:02<00:00, 53.64it/s, loss=0.162]


loss_avg()=0.4697985122601191
accuracy_avg()=0.8604166666666667
Epoch 92/100


100%|██████████| 154/154 [00:02<00:00, 60.89it/s, loss=0.167]


loss_avg()=0.4509247789780299
accuracy_avg()=0.86875
Epoch 93/100


100%|██████████| 154/154 [00:02<00:00, 61.90it/s, loss=0.166]


loss_avg()=0.5290861825148264
accuracy_avg()=0.8479166666666667
Epoch 94/100


100%|██████████| 154/154 [00:02<00:00, 60.94it/s, loss=0.165]


loss_avg()=0.5192854672670364
accuracy_avg()=0.84375
Epoch 95/100


100%|██████████| 154/154 [00:02<00:00, 62.60it/s, loss=0.159]


loss_avg()=0.4970144838094711
accuracy_avg()=0.8541666666666666
Epoch 96/100


100%|██████████| 154/154 [00:02<00:00, 62.99it/s, loss=0.154]


loss_avg()=0.5066810766855876
accuracy_avg()=0.8520833333333333
Epoch 97/100


100%|██████████| 154/154 [00:02<00:00, 60.07it/s, loss=0.152]


loss_avg()=0.47695570488770805
accuracy_avg()=0.8625
Epoch 98/100


100%|██████████| 154/154 [00:02<00:00, 61.81it/s, loss=0.153]


loss_avg()=0.48437717159589133
accuracy_avg()=0.8645833333333334
Epoch 99/100


100%|██████████| 154/154 [00:02<00:00, 62.30it/s, loss=0.146]


loss_avg()=0.49904317359129585
accuracy_avg()=0.85625
Epoch 100/100


100%|██████████| 154/154 [00:02<00:00, 58.62it/s, loss=0.157]


loss_avg()=0.4494189163049062
accuracy_avg()=0.88125


## Final Test Accuracy

In [63]:
# Simple check with test dataset
model.eval()
test_data_iterator = data_iterator(X_test_tokenized_indexified, y_test_formatted, len(X_test_tokenized_indexified), len(X_test_tokenized_indexified), shuffle=False)
test_batch, labels_batch, test_sentences = next(test_data_iterator)

seq_lengths = torch.LongTensor(list(map(len, test_batch)))
output_batch = model(test_batch.to(device),seq_lengths)
final_test_accuracy = accuracy(output_batch, labels_batch.to(device))
print(f"{final_test_accuracy=}")

final_test_accuracy=0.866


In [64]:
def print_sentence_label(sentence: str) -> int:
    model.eval()
    sentence_tokenized = word_tokenize(sentence.lower())
    sentence_as_id = [
        vocab[token] if token in vocab
        else vocab['UNK']
        for token in sentence_tokenized
    ]
    seq_lengths = torch.LongTensor([len(sentence_as_id)])
    input = torch.tensor(sentence_as_id).unsqueeze(0).to(device)
    output = model(input, seq_lengths).to(device)
    label = np.argmax(output.detach().cpu().numpy())
    print(f"sentence = {sentence}, label = {label}")

# Checking results
print_sentence_label("What is a squirrel?")
print_sentence_label("Is Singapore located in Southeast Asia?")
print_sentence_label("Is Singapore in China?")
print_sentence_label("Name 11 famous martyrs .")
print_sentence_label("What ISPs exist in the Caribbean ?")
print_sentence_label("How many cars are manufactured every day?")

sentence = What is a squirrel?, label = 1
sentence = Is Singapore located in Southeast Asia?, label = 3
sentence = Is Singapore in China?, label = 1
sentence = Name 11 famous martyrs ., label = 4
sentence = What ISPs exist in the Caribbean ?, label = 1
sentence = How many cars are manufactured every day?, label = 4
