In [176]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import trange
from torch.autograd import Variable

In [177]:
words_path = "stanford_cs230_data/words.txt"
vocab = {}
with open(words_path) as f:
    for i, l in enumerate(f.read().splitlines()):
        vocab[l] = i
vocab

{'Thousands': 0,
 'of': 1,
 'demonstrators': 2,
 'have': 3,
 'marched': 4,
 'through': 5,
 'London': 6,
 'to': 7,
 'protest': 8,
 'the': 9,
 'war': 10,
 'in': 11,
 'Iraq': 12,
 'and': 13,
 'demand': 14,
 'withdrawal': 15,
 'British': 16,
 'troops': 17,
 'from': 18,
 'that': 19,
 'country': 20,
 '.': 21,
 'Families': 22,
 'soldiers': 23,
 'killed': 24,
 'conflict': 25,
 'joined': 26,
 'protesters': 27,
 'who': 28,
 'carried': 29,
 'banners': 30,
 'with': 31,
 'such': 32,
 'slogans': 33,
 'as': 34,
 '"': 35,
 'Bush': 36,
 'Number': 37,
 'One': 38,
 'Terrorist': 39,
 'Stop': 40,
 'Bombings': 41,
 'They': 42,
 'Houses': 43,
 'Parliament': 44,
 'a': 45,
 'rally': 46,
 'Hyde': 47,
 'Park': 48,
 'Police': 49,
 'put': 50,
 'number': 51,
 'marchers': 52,
 'at': 53,
 '10,000': 54,
 'while': 55,
 'organizers': 56,
 'claimed': 57,
 'it': 58,
 'was': 59,
 '1,00,000': 60,
 'The': 61,
 'comes': 62,
 'on': 63,
 'eve': 64,
 'annual': 65,
 'conference': 66,
 'Britain': 67,
 "'s": 68,
 'ruling': 69,
 'La

In [178]:
tags_path = "stanford_cs230_data/tags.txt"
tag_map = {}
with open(tags_path) as f:
    for i, l in enumerate(f.read().splitlines()):
        tag_map[l] = i
tag_map

{'O': 0, 'I-geo': 1, 'I-gpe': 2, 'I-per': 3, 'I-org': 4, 'I-tim': 5}

In [179]:
train_sentences_file = "stanford_cs230_data/train/sentences.txt"
train_labels_file = "stanford_cs230_data/train/labels.txt"

train_sentences = []
train_labels = []

with open(train_sentences_file) as f:
    for sentence in f.read().splitlines():
        #replace each token by its index if it is in vocab
        #else use index of UNK
        s = [vocab[token] if token in vocab
             else vocab['UNK']
             for token in sentence.split(' ')]
        train_sentences.append(s)

with open(train_labels_file) as f:
    for sentence in f.read().splitlines():
        #replace each label by its index
        l = [tag_map[label] for label in sentence.split(' ')]
        train_labels.append(l)
        
print(f"train_sentences: {train_sentences}")
print(f"train_labels: {train_labels}")

train_sentences: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 9, 15, 1, 16, 17, 18, 19, 20, 21], [22, 1, 23, 24, 11, 9, 25, 26, 9, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 35, 13, 35, 40, 9, 41, 21, 35], [42, 4, 18, 9, 43, 1, 44, 7, 45, 46, 11, 47, 48, 21], [49, 50, 9, 51, 1, 52, 53, 54, 55, 56, 57, 58, 59, 60, 21], [61, 8, 62, 63, 9, 64, 1, 9, 65, 66, 1, 67, 68, 69, 70, 71, 11, 9, 72, 73, 74, 75, 1, 76, 21], [61, 77, 78, 79, 80, 67, 68, 81, 11, 9, 12, 25, 13, 9, 82, 83, 1, 84, 16, 17, 11, 19, 20, 21], [61, 6, 85, 86, 87, 1, 88, 89, 90, 11, 91, 92, 93, 94, 95, 93, 96, 93, 13, 97, 21], [61, 98, 99, 100, 101, 78, 7, 102, 103, 104, 1, 105, 11, 106, 107, 63, 108, 7, 109, 7, 110, 68, 111, 1, 112, 113, 114, 21], [110, 115, 116, 117, 118, 1, 9, 114, 119, 53, 120, 121, 122, 123, 21], [124, 125, 126, 127, 128, 7, 129, 130, 7, 131, 132, 118, 1, 9, 123, 107, 93, 133, 134, 135, 136, 137, 138, 139, 21]]
train_labels: [[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0,

In [180]:
val_sentences_file = "stanford_cs230_data/val/sentences.txt"
val_labels_file = "stanford_cs230_data/val/labels.txt"

val_sentences = []
val_labels = []

with open(val_sentences_file) as f:
    for sentence in f.read().splitlines():
        #replace each token by its index if it is in vocab
        #else use index of UNK
        s = [vocab[token] if token in vocab
             else vocab['UNK']
             for token in sentence.split(' ')]
        val_sentences.append(s)

with open(val_labels_file) as f:
    for sentence in f.read().splitlines():
        #replace each label by its index
        l = [tag_map[label] for label in sentence.split(' ')]
        val_labels.append(l)

print(f"val_sentences: {val_sentences}")
print(f"val_labels: {val_labels}")

val_sentences: [[140, 141, 59, 142, 11, 143, 13, 144, 145, 146, 11, 45, 147, 148, 93, 149, 150, 151, 63, 152, 153, 116, 21], [154, 45, 155, 7, 156, 157, 158, 159, 93, 9, 160, 161, 162, 163, 164, 165, 58, 59, 166, 167, 168, 58, 169, 147, 68, 35, 170, 171, 35, 172, 173, 174, 21], [61, 175, 176, 177, 13, 9, 178, 179, 180, 181, 182, 21], [183, 184, 125, 126, 185, 1, 186, 187, 183, 23, 28, 188, 189, 190, 134, 191, 63, 45, 192, 193, 3, 194, 195, 11, 196, 197, 21], [198, 126, 9, 199, 200, 201, 17, 202, 18, 9, 203, 204, 205, 133, 45, 206, 207, 191, 208, 9, 209, 210, 211, 115, 116, 21], [212, 213, 214, 215, 216, 217, 218, 219, 220, 19, 209, 221, 222, 80, 9, 17, 7, 9, 183, 223, 11, 224, 13, 9, 23, 225, 226, 227, 228, 7, 229, 21], [206, 230, 165, 127, 231, 232, 23, 233, 9, 191, 63, 9, 201, 234, 93, 235, 125, 236, 237, 238, 9, 239, 21], [240, 107, 93, 183, 125, 165, 232, 201, 23, 13, 53, 241, 242, 230, 225, 24, 11, 243, 11, 9, 20, 68, 244, 204, 205, 21], [61, 183, 184, 245, 246, 247, 248, 249, 190

In [181]:
def data_iterator(sentences, labels, total_size: int, batch_size: int, shuffle: bool=False):
    # make a list that decides the order in which we go over the data- this avoids explicit shuffling of data
    order = list(range(total_size))
    if shuffle:
        random.seed(230)
        random.shuffle(order)

    # one pass over data
    for i in range((total_size+1)//batch_size):
        # fetch sentences and tags
        batch_sentences = [sentences[idx] for idx in order[i*batch_size:(i+1)*batch_size]]
        batch_tags = [labels[idx] for idx in order[i*batch_size:(i+1)*batch_size]]

        # compute length of longest sentence in batch
        batch_max_len = max([len(s) for s in batch_sentences])

        # prepare a numpy array with the data, initialising the data with pad_ind and all labels with -1
        # initialising labels to -1 differentiates tokens with tags from PADding tokens
        batch_data = vocab['<pad>']*np.ones((len(batch_sentences), batch_max_len))
        batch_labels = -1*np.ones((len(batch_sentences), batch_max_len))

        # copy the data to the numpy array
        for j in range(len(batch_sentences)):
            cur_len = len(batch_sentences[j])
            batch_data[j][:cur_len] = batch_sentences[j]
            batch_labels[j][:cur_len] = batch_tags[j]

        # since all data are indices, we convert them to torch LongTensors
        batch_data, batch_labels = torch.LongTensor(batch_data), torch.LongTensor(batch_labels)

        # shift tensors to GPU if available
        # if params.cuda:
        #     batch_data, batch_labels = batch_data.cuda(), batch_labels.cuda()

        # convert them to Variables to record operations in the computational graph
        batch_data, batch_labels = Variable(batch_data), Variable(batch_labels)

        yield batch_data, batch_labels

In [182]:
class Net(nn.Module):
    """
    This is the standard way to define your own network in PyTorch. You typically choose the components
    (e.g. LSTMs, linear layers etc.) of your network in the __init__ function. You then apply these layers
    on the input step-by-step in the forward function. You can use torch.nn.functional to apply functions
    such as F.relu, F.sigmoid, F.softmax. Be careful to ensure your dimensions are correct after each step.

    You are encouraged to have a look at the network in pytorch/vision/model/net.py to get a better sense of how
    you can go about defining your own network.

    The documentation for all the various components available to you is here: http://pytorch.org/docs/master/nn.html
    """

    def __init__(self, vocab_size, embedding_dim, lstm_hidden_dim, number_of_tags):
        """
        We define an recurrent network that predicts the NER tags for each token in the sentence. The components
        required are:

        - an embedding layer: this layer maps each index in range(params.vocab_size) to a params.embedding_dim vector
        - lstm: applying the LSTM on the sequential input returns an output for each token in the sentence
        - fc: a fully connected layer that converts the LSTM output for each token to a distribution over NER tags

        Args:
            params: (Params) contains vocab_size, embedding_dim, lstm_hidden_dim
        """
        super(Net, self).__init__()

        # the embedding takes as input the vocab_size and the embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # the LSTM takes as input the size of its input (embedding_dim), its hidden size
        # for more details on how to use it, check out the documentation
        self.lstm = nn.LSTM(embedding_dim,
                            lstm_hidden_dim, batch_first=True)

        # the fully connected layer transforms the output to give the final output layer
        self.fc = nn.Linear(lstm_hidden_dim, number_of_tags)

    def forward(self, s):
        """
        This function defines how we use the components of our network to operate on an input batch.

        Args:
            s: (Variable) contains a batch of sentences, of dimension batch_size x seq_len, where seq_len is
               the length of the longest sentence in the batch. For sentences shorter than seq_len, the remaining
               tokens are PADding tokens. Each row is a sentence with each element corresponding to the index of
               the token in the vocab.

        Returns:
            out: (Variable) dimension batch_size*seq_len x num_tags with the log probabilities of tokens for each token
                 of each sentence.

        Note: the dimensions after each step are provided
        """
        #                                -> batch_size x seq_len
        # apply the embedding layer that maps each token to its embedding
        # dim: batch_size x seq_len x embedding_dim
        s = self.embedding(s)

        # run the LSTM along the sentences of length seq_len
        # dim: batch_size x seq_len x lstm_hidden_dim
        s, _ = self.lstm(s)

        # make the Variable contiguous in memory (a PyTorch artefact)
        s = s.contiguous()

        # reshape the Variable so that each row contains one token
        # dim: batch_size*seq_len x lstm_hidden_dim
        s = s.view(-1, s.shape[2])

        # apply the fully connected layer and obtain the output (before softmax) for each token
        s = self.fc(s)                   # dim: batch_size*seq_len x num_tags

        # apply log softmax on each token's output (this is recommended over applying softmax
        # since it is numerically more stable)
        return F.log_softmax(s, dim=1)   # dim: batch_size*seq_len x num_tags

In [183]:
def loss_fn(outputs, labels):
    """
    Compute the cross entropy loss given outputs from the model and labels for all tokens. Exclude loss terms
    for PADding tokens.

    Args:
        outputs: (Variable) dimension batch_size*seq_len x num_tags - log softmax output of the model
        labels: (Variable) dimension batch_size x seq_len where each element is either a label in [0, 1, ... num_tag-1],
                or -1 in case it is a PADding token.

    Returns:
        loss: (Variable) cross entropy loss for all tokens in the batch

    Note: you may use a standard loss function from http://pytorch.org/docs/master/nn.html#loss-functions. This example
          demonstrates how you can easily define a custom loss function.
    """

    # reshape labels to give a flat vector of length batch_size*seq_len
    labels = labels.view(-1)

    # since PADding tokens have label -1, we can generate a mask to exclude the loss from those terms
    mask = (labels >= 0).float()

    # indexing with negative values is not supported. Since PADded tokens have label -1, we convert them to a positive
    # number. This does not affect training, since we ignore the PADded tokens with the mask.
    labels = labels % outputs.shape[1]

    num_tokens = int(torch.sum(mask))

    # compute cross entropy loss for all tokens (except PADding tokens), by multiplying with mask.
    return -torch.sum(outputs[range(outputs.shape[0]), labels]*mask)/num_tokens

In [184]:
class RunningAverage:
    """A simple class that maintains the running average of a quantity

    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """

    def __init__(self):
        self.steps = 0
        self.total = 0

    def update(self, val):
        self.total += val
        self.steps += 1

    def __call__(self):
        return self.total / float(self.steps)

In [185]:
def train(model, optimizer, loss_fn, data_iterator, metrics, num_steps):
    """Train the model on `num_steps` batches

    Args:
        model: (torch.nn.Module) the neural network
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        data_iterator: (generator) a generator that generates batches of data and labels
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """

    # set model to training mode
    model.train()

    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = RunningAverage()

    # Use tqdm for progress bar
    t = trange(num_steps)
    for i in t:
        # fetch the next training batch
        train_batch, labels_batch = next(data_iterator)

        # compute model output and loss
        output_batch = model(train_batch)
        loss = loss_fn(output_batch, labels_batch)

        # clear previous gradients, compute gradients of all variables wrt loss
        optimizer.zero_grad()
        loss.backward()

        # performs updates using calculated gradients
        optimizer.step()

        # Evaluate summaries only once in a while
        if i % 10 == 0:
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch = output_batch.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {metric: metrics[metric](output_batch, labels_batch)
                             for metric in metrics}
            summary_batch['loss'] = loss.item()
            summ.append(summary_batch)

        # update the average loss
        loss_avg.update(loss.item())
        t.set_postfix(loss='{:05.3f}'.format(loss_avg()))

    # compute mean of all metrics in summary
    metrics_mean = {metric: np.mean([x[metric]
                                     for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    print("- Train metrics: " + metrics_string)

In [186]:
def evaluate(model, loss_fn, data_iterator, metrics, num_steps):
    """Evaluate the model on `num_steps` batches.

    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        data_iterator: (generator) a generator that generates batches of data and labels
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """

    # set model to evaluation mode
    model.eval()

    # summary for current eval loop
    summ = []

    # compute metrics over the dataset
    for _ in range(num_steps):
        # fetch the next evaluation batch
        data_batch, labels_batch = next(data_iterator)

        # compute model output
        output_batch = model(data_batch)
        loss = loss_fn(output_batch, labels_batch)

        # extract data from torch Variable, move to cpu, convert to numpy arrays
        output_batch = output_batch.data.cpu().numpy()
        labels_batch = labels_batch.data.cpu().numpy()

        # compute all metrics on this batch
        summary_batch = {metric: metrics[metric](output_batch, labels_batch)
                         for metric in metrics}
        summary_batch['loss'] = loss.item()
        summ.append(summary_batch)

    # compute mean of all metrics in summary
    metrics_mean = {metric:np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    print("- Eval metrics : " + metrics_string)
    return metrics_mean

In [187]:
def train_and_evaluate(
        model,
        train_sentences,
        train_labels,
        val_sentences,
        val_labels,
        num_epochs: int,
        batch_size: int,
        optimizer,
        loss_fn,
        metrics
):
    for epoch in range(num_epochs):
        # Run one epoch
        print("Epoch {}/{}".format(epoch + 1, num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        num_steps = (len(train_sentences) + 1) // batch_size
        train_data_iterator = data_iterator(
            train_sentences, train_labels, len(train_sentences), batch_size, shuffle=True)
        train(model, optimizer, loss_fn, train_data_iterator,
              metrics, num_steps)

        # Evaluate for one epoch on validation set
        num_steps = (len(val_sentences) + 1) // batch_size
        val_data_iterator = data_iterator(
            val_sentences, val_labels, len(val_sentences), batch_size, shuffle=False)
        val_metrics = evaluate(
            model, loss_fn, val_data_iterator, metrics, num_steps)

In [188]:
def accuracy(outputs, labels):
    """
    Compute the accuracy, given the outputs and labels for all tokens. Exclude PADding terms.

    Args:
        outputs: (np.ndarray) dimension batch_size*seq_len x num_tags - log softmax output of the model
        labels: (np.ndarray) dimension batch_size x seq_len where each element is either a label in
                [0, 1, ... num_tag-1], or -1 in case it is a PADding token.

    Returns: (float) accuracy in [0,1]
    """

    # reshape labels to give a flat vector of length batch_size*seq_len
    labels = labels.ravel()

    # since PADding tokens have label -1, we can generate a mask to exclude the loss from those terms
    mask = (labels >= 0)

    # np.argmax gives us the class predicted for each token by the model
    outputs = np.argmax(outputs, axis=1)

    # compare outputs with labels and divide by number of tokens (excluding PADding tokens)
    return np.sum(outputs == labels)/float(np.sum(mask))


metrics = {
    'accuracy': accuracy,
    # could add more metrics such as accuracy for each token type
}

In [189]:
model = Net(368, 50, 50, 6)
optimizer = optim.Adam(model.parameters(), lr=0.01)
train_and_evaluate(model, train_sentences, train_labels, val_sentences, val_labels, 10, 5, optimizer, loss_fn, metrics)

Epoch 1/10


100%|██████████| 2/2 [00:00<00:00, 250.08it/s, loss=1.788]


- Train metrics: accuracy: 0.025 ; loss: 1.860
- Eval metrics : accuracy: 0.695 ; loss: 1.586
Epoch 2/10


100%|██████████| 2/2 [00:00<00:00, 333.32it/s, loss=1.380]


- Train metrics: accuracy: 0.849 ; loss: 1.481
- Eval metrics : accuracy: 0.802 ; loss: 1.143
Epoch 3/10


100%|██████████| 2/2 [00:00<00:00, 399.88it/s, loss=0.812]


- Train metrics: accuracy: 0.857 ; loss: 0.896
- Eval metrics : accuracy: 0.822 ; loss: 0.763
Epoch 4/10


100%|██████████| 2/2 [00:00<00:00, 333.33it/s, loss=0.625]


- Train metrics: accuracy: 0.857 ; loss: 0.581
- Eval metrics : accuracy: 0.822 ; loss: 0.798
Epoch 5/10


100%|██████████| 2/2 [00:00<00:00, 333.26it/s, loss=0.597]


- Train metrics: accuracy: 0.857 ; loss: 0.591
- Eval metrics : accuracy: 0.822 ; loss: 0.800
Epoch 6/10


100%|██████████| 2/2 [00:00<00:00, 333.36it/s, loss=0.515]


- Train metrics: accuracy: 0.857 ; loss: 0.539
- Eval metrics : accuracy: 0.822 ; loss: 0.774
Epoch 7/10


100%|██████████| 2/2 [00:00<00:00, 333.36it/s, loss=0.433]


- Train metrics: accuracy: 0.857 ; loss: 0.462
- Eval metrics : accuracy: 0.822 ; loss: 0.769
Epoch 8/10


100%|██████████| 2/2 [00:00<00:00, 333.29it/s, loss=0.389]


- Train metrics: accuracy: 0.857 ; loss: 0.411
- Eval metrics : accuracy: 0.818 ; loss: 0.787
Epoch 9/10


100%|██████████| 2/2 [00:00<00:00, 399.95it/s, loss=0.359]


- Train metrics: accuracy: 0.874 ; loss: 0.385
- Eval metrics : accuracy: 0.810 ; loss: 0.787
Epoch 10/10


100%|██████████| 2/2 [00:00<00:00, 400.09it/s, loss=0.309]

- Train metrics: accuracy: 0.891 ; loss: 0.340
- Eval metrics : accuracy: 0.814 ; loss: 0.784





In [190]:
test_sentences_file = "stanford_cs230_data/test/sentences.txt"
test_labels_file = "stanford_cs230_data/test/labels.txt"

test_sentences = []
test_labels = []

with open(test_sentences_file) as f:
    for sentence in f.read().splitlines():
        #replace each token by its index if it is in vocab
        #else use index of UNK
        s = [vocab[token] if token in vocab
             else vocab['UNK']
             for token in sentence.split(' ')]
        test_sentences.append(s)

with open(test_labels_file) as f:
    for sentence in f.read().splitlines():
        #replace each label by its index
        l = [tag_map[label] for label in sentence.split(' ')]
        test_labels.append(l)

print(f"test_sentences: {test_sentences}")
print(f"test_labels: {test_labels}")

test_sentences: [[266, 9, 267, 1, 268, 269, 11, 270, 93, 271, 272, 273, 274, 275, 165, 276, 78, 277, 167, 9, 6, 278, 21], [279, 280, 281, 282, 283, 237, 284, 285, 21], [286, 287, 278, 63, 9, 288, 289, 137, 11, 290, 24, 291, 292, 13, 293, 294, 21], [295, 296, 297, 298, 245, 169, 63, 299, 300, 7, 301, 302, 63, 303, 7, 304, 305, 306, 13, 307, 296, 297, 308, 21], [61, 309, 310, 311, 312, 303, 165, 125, 18, 9, 313, 1, 314, 299, 315, 316, 115, 116, 317, 318, 262, 319, 11, 320, 31, 303, 21], [321, 165, 300, 317, 318, 322, 45, 323, 324, 63, 303, 167, 9, 325, 315, 326, 327, 21], [61, 298, 180, 328, 310, 300, 7, 329, 9, 330, 192, 331, 332, 167, 303, 68, 184, 333, 21], [61, 297, 298, 334, 303, 68, 335, 1, 336, 11, 337, 338, 339, 13, 296, 297, 308, 93, 340, 341, 342, 343, 344, 11, 303, 21], [345, 125, 126, 346, 3, 24, 45, 347, 1, 9, 348, 349, 350, 167, 351, 345, 352, 353, 354, 355, 21], [198, 126, 356, 357, 59, 358, 359, 133, 346, 360, 150, 361, 362, 9, 363, 364, 1, 365, 21]]
test_labels: [[0, 0, 

In [191]:
print(len(test_sentences))
test_data_iterator = data_iterator(test_sentences, test_labels, len(test_sentences), 1, shuffle=True)
test_batch, labels_batch = next(test_data_iterator)
model_output = model(test_batch)
print(model_output)
print(model_output.size())

10
tensor([[-1.1183, -1.8462, -1.9015, -2.2424, -1.7853, -2.3856],
        [-0.0347, -4.0720, -4.6996, -6.4464, -5.3254, -6.4676],
        [-0.0620, -3.1846, -4.5866, -6.2881, -5.3532, -6.2509],
        [-0.0664, -3.1607, -4.5125, -6.1850, -5.0840, -5.9278],
        [-0.1254, -2.8633, -3.4542, -5.1714, -4.1124, -4.9431],
        [-0.3470, -1.8092, -2.9626, -4.4258, -3.2586, -3.5999],
        [-0.0321, -3.9309, -4.9840, -7.1018, -5.8339, -6.5657],
        [-0.0643, -3.3626, -4.2797, -6.2738, -4.7983, -5.6078],
        [-0.0248, -4.0293, -5.5215, -7.9063, -6.3390, -7.4565],
        [-0.0183, -4.4299, -5.4817, -8.1877, -6.6653, -7.6226],
        [-0.1449, -2.3843, -3.7928, -5.6601, -4.6798, -4.8985],
        [-0.1270, -2.8042, -3.2322, -5.5873, -4.5766, -5.2498],
        [-0.4357, -1.6742, -2.7314, -4.0512, -3.0342, -3.3497],
        [-0.0762, -3.2254, -3.9211, -5.7970, -4.7833, -6.0405],
        [-0.0490, -3.7194, -4.3494, -6.2713, -4.9353, -6.4255],
        [-0.2149, -2.4598, -2.8075, -

In [197]:
print(f"model_output \n {model_output}")
# predicted_labels = torch.argmax(model_output, dim=1)
# print(f"predicted_labels \n {predicted_labels}")
# predicted_labels = torch.argmax(torch.abs(model_output), dim=1)
# print(f"predicted_labels \n {predicted_labels}")
predicted_labels = np.argmax(model_output.detach().numpy(), axis=1)
print(f"predicted_labels \n {predicted_labels}")

model_output 
 tensor([[-1.1183, -1.8462, -1.9015, -2.2424, -1.7853, -2.3856],
        [-0.0347, -4.0720, -4.6996, -6.4464, -5.3254, -6.4676],
        [-0.0620, -3.1846, -4.5866, -6.2881, -5.3532, -6.2509],
        [-0.0664, -3.1607, -4.5125, -6.1850, -5.0840, -5.9278],
        [-0.1254, -2.8633, -3.4542, -5.1714, -4.1124, -4.9431],
        [-0.3470, -1.8092, -2.9626, -4.4258, -3.2586, -3.5999],
        [-0.0321, -3.9309, -4.9840, -7.1018, -5.8339, -6.5657],
        [-0.0643, -3.3626, -4.2797, -6.2738, -4.7983, -5.6078],
        [-0.0248, -4.0293, -5.5215, -7.9063, -6.3390, -7.4565],
        [-0.0183, -4.4299, -5.4817, -8.1877, -6.6653, -7.6226],
        [-0.1449, -2.3843, -3.7928, -5.6601, -4.6798, -4.8985],
        [-0.1270, -2.8042, -3.2322, -5.5873, -4.5766, -5.2498],
        [-0.4357, -1.6742, -2.7314, -4.0512, -3.0342, -3.3497],
        [-0.0762, -3.2254, -3.9211, -5.7970, -4.7833, -6.0405],
        [-0.0490, -3.7194, -4.3494, -6.2713, -4.9353, -6.4255],
        [-0.2149, -2.4598