In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wandb
import seaborn as sns
import pickle
import itertools
import yaml
from torch.utils.data import DataLoader
from logger import Logger

from seqeval.metrics import classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import f1_score as f1

from warnings import simplefilter

from dataset_simpler import CustomSimplerDataset, NERSimplerDocuments
import baseline_lstm_model
import lstm_mha_attn_model
import bilstm_model
import bilstmcrf_model

with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)

config_settings = config['model_settings']

VALID_BATCH_SIZE = config_settings['batch_size']
test_sample_frac = config_settings['test_sample_frac']

learning_rate = config_settings['lr']
decay = config_settings['decay']

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0,
                'drop_last': True,
                'pin_memory': True
                }

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print(f'Device: {device}')

Device: cuda


## Loading Test Data

In [2]:
dataset = NERSimplerDocuments()
vocab = dataset.get_vocab()
labels_to_id = dataset.get_labels_to_id()
ids_to_labels = dict(map(reversed, labels_to_id.items()))

test_data = dataset.load_test_data()

testing_set = CustomSimplerDataset(test_data, labels_to_id, vocab, test_sample_frac)
testing_loader = DataLoader(testing_set, **test_params)

tokens_to_id = {ch:i for i,ch in enumerate(vocab)}
id_to_tokens = dict(map(reversed, tokens_to_id.items()))

## Test Function

In [3]:
def test(model, testing_loader, ids_to_labels):
    with torch.no_grad():

        accuracies = []
        sentences = []
        y_true = []
        y_pred = []
        f1_scores = []

        model.eval()

        print(f'Testing {int(len(testing_loader.dataset)/VALID_BATCH_SIZE)} batches')

        for idx, batch in enumerate(testing_loader):

            inputs = batch["tokens"]
            targets = batch["labels"]
            attention_mask = batch["attention_mask"]

            inputs = inputs.to(device)
            targets = targets.to(device)
    
            logits = model(inputs, attention_mask, targets)

            softmax = torch.nn.Softmax(dim=0)
            predictions = softmax(logits)

            test_predictions = torch.argmax(predictions, dim = 1)

            f1_score = f1(targets.flatten().cpu().detach().numpy(), test_predictions.flatten().cpu().detach().numpy(), average='weighted')

            sentences.append(inputs)
            y_true.append(targets)
            y_pred.append(test_predictions)
            f1_scores.append(f1_score)

            accuracy = torch.sum(torch.eq(test_predictions, targets)).item()/test_predictions.nelement()

            accuracies.append(accuracy)
    
    # if logger != '':
    #     logger.log({'test_accuracy': np.sum(accuracies) / len(accuracies),
    #                 'test_f1_score': np.sum(f1_scores) / len(f1_scores)})

    print(f'Avg f1 score: {np.sum(f1_scores) / len(f1_scores)}')

    y_true_labels = []
    y_pred_labels = []

    for labels in y_true:
        y_true_labels.append([[ids_to_labels.get(np.int64(label.cpu().item())) for label in tens_labels] for tens_labels in labels])

    for labels in y_pred:
        y_pred_labels.append([[ids_to_labels.get(np.int64(label.cpu().item())) for label in tens_labels] for tens_labels in labels])

    matrix = confusion_matrix(list(itertools.chain.from_iterable(y_true_labels[0])), 
                              list(itertools.chain.from_iterable(y_pred_labels[0])), 
                              labels=list(ids_to_labels.values()))

    cm = ConfusionMatrixDisplay(matrix/np.sum(matrix), display_labels=list(ids_to_labels.values()))
    print(f'Confusion Matrix:\n{cm}')

    fig, ax = plt.subplots(figsize=(10, 10))
    cm.plot(ax=ax, cmap=plt.cm.Blues)
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()
    plt.close()

    report = classification_report(y_true_labels[0], y_pred_labels[0], output_dict = True, zero_division = 0)
    print(f'classification_report:\n{classification_report(y_true_labels[0], y_pred_labels[0], output_dict = False, zero_division = 0)}')

    plt.figure(figsize = (30, 15))
    ax = sns.heatmap(pd.DataFrame(report).iloc[:-1, :].T, cmap = 'coolwarm', annot=True)
    plt.tight_layout()
    plt.show()
    plt.close()

    # if logger != '':
    #     logger.log({'classification_report': wandb.Image(ax.figure),
    #                 'confusion_matrix': wandb.Image(fig.figure)})
    
    return sentences, y_true, y_pred

## Baseline LSTM Model

In [4]:
with open(f'./trained_models/baseline_lstm.pkl', 'rb') as file:
    baseline_lstm_model = pickle.load(file)

baseline_lstm_model.eval()

loss_function = torch.nn.CrossEntropyLoss(ignore_index=labels_to_id['[PAD]'])
optimiser = torch.optim.SGD(baseline_lstm_model.parameters(), lr = learning_rate, momentum = 0.9, weight_decay = decay)

baseline_lstm_sentences, baseline_lstm_y_true, baseline_lstm_y_pred = test(baseline_lstm_model, testing_loader, ids_to_labels)

Testing 129 batches


  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
 

Avg f1 score: 0.273988587399842
Confusion Matrix:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay object at 0x0000021582FEF3D0>
classification_report:
              precision    recall  f1-score   support

    CARDINAL       0.00      0.29      0.01         7
        CLS]       0.00      0.00      0.00         0
        DATE       0.00      0.00      0.00        14
         FAC       0.00      0.00      0.00         1
         GPE       0.00      0.00      0.00        13
    LANGUAGE       0.00      0.00      0.00         0
         LAW       0.00      0.00      0.00         1
         LOC       0.00      0.00      0.00         1
       MONEY       0.00      0.00      0.00         5
        NORP       0.00      0.00      0.00         4
     ORDINAL       0.00      0.00      0.00         0
         ORG       0.00      0.00      0.00        14
        PAD]       0.00      0.00      0.00         0
     PERCENT       0.00      0.00      0.00         0
      PERSON       0.00



In [11]:
print(f'Original Sentence:{baseline_lstm_sentences[0][0].cpu().detach().numpy()} \nLabels{baseline_lstm_y_true[0][0].cpu().detach().numpy()} \nPredictions{baseline_lstm_y_pred[0][0].cpu().detach().numpy()}')

Original Sentence:[34093   631    31 29771    31  3081  3082  4943   478  3962   158    21
     8 34094 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095] 
Labels[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0] 
Predictions[ 8  8  8  8  8  8  8  8  8  8  7 39  9 39 39 39  9 39 39 39 10 10  9 10
 10 10  9 10 10 10  9  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  8
  8  8]


## LSTM with Multi Head Attention

In [12]:
with open(f'./trained_models/LSTM_MHA.pkl', 'rb') as file:
    lstm_attn_model = pickle.load(file)

lstm_attn_model.eval()

loss_function = torch.nn.CrossEntropyLoss(ignore_index=labels_to_id['[PAD]'])
optimiser = torch.optim.SGD(lstm_attn_model.parameters(), lr = learning_rate, momentum = 0.9, weight_decay = decay)

lstm_attn_sentences, lstm_attn_y_true, lstm_attn_y_pred = test(lstm_attn_model, testing_loader, ids_to_labels)

Testing 129 batches


  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
 

Avg f1 score: 0.9263999438023569
Confusion Matrix:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay object at 0x000002158D585F90>
classification_report:
              precision    recall  f1-score   support

    CARDINAL       0.00      0.00      0.00         8
        DATE       0.00      0.00      0.00        14
         FAC       0.00      0.00      0.00         1
         GPE       0.00      0.00      0.00        24
         LOC       0.00      0.00      0.00         4
       MONEY       0.00      0.00      0.00         2
        NORP       0.00      0.00      0.00        12
         ORG       0.00      0.00      0.00        13
     PERCENT       0.00      0.00      0.00         5
      PERSON       0.00      0.00      0.00        11
    QUANTITY       0.00      0.00      0.00         6

   micro avg       0.00      0.00      0.00       100
   macro avg       0.00      0.00      0.00       100
weighted avg       0.00      0.00      0.00       100



In [13]:
print(f'Original Sentence:{lstm_attn_sentences[0][0].cpu().detach().numpy()} \nLabels{lstm_attn_y_true[0][0].cpu().detach().numpy()} \nPredictions{lstm_attn_y_pred[0][0].cpu().detach().numpy()}')

Original Sentence:[34093   784 23055   143   250   451   515   216     8 34094 34095 34095
 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095] 
Labels[0 0 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0] 
Predictions[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]


## Bi-directional LSTM model

In [14]:
with open(f'./trained_models/bilstm.pkl', 'rb') as file:
    bilstm_model = pickle.load(file)

bilstm_model.eval()

loss_function = torch.nn.CrossEntropyLoss(ignore_index=labels_to_id['[PAD]'])
optimiser = torch.optim.SGD(bilstm_model.parameters(), lr = learning_rate, momentum = 0.9, weight_decay = decay)

bilstm_sentences, bilstm_y_true, bilstm_y_pred = test(bilstm_model, testing_loader, ids_to_labels)

Testing 129 batches


  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  label_ids = [self.labels_to_id[label] for label in labels_as_lst[index]]
 

Avg f1 score: 0.5754528211991176
Confusion Matrix:
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay object at 0x0000021596187A10>
classification_report:
              precision    recall  f1-score   support

    CARDINAL       0.00      0.00      0.00         8
        CLS]       0.00      0.00      0.00         0
        DATE       0.00      0.00      0.00        11
       EVENT       0.00      0.00      0.00         0
         FAC       0.00      0.00      0.00         6
         GPE       0.00      0.00      0.00        18
    LANGUAGE       0.00      0.00      0.00         0
         LAW       0.00      0.00      0.00         1
         LOC       0.00      0.00      0.00         0
       MONEY       0.00      0.00      0.00         1
        NORP       0.00      0.00      0.00         7
     ORDINAL       0.00      0.00      0.00         1
         ORG       0.00      0.00      0.00        17
        PAD]       0.00      0.00      0.00         0
     PERCENT       0.0



In [15]:
print(f'Original Sentence:{bilstm_sentences[0][0].cpu().detach().numpy()} \nLabels{bilstm_y_true[0][0].cpu().detach().numpy()} \nPredictions{bilstm_y_pred[0][0].cpu().detach().numpy()}')

Original Sentence:[34093  1398  4377  4378  1935  6332  6333    79  3137    31  7527     8
    82   849   172    88  7054    31    95  1493   180   279    26  6466
   180   505   231    26  5231  1398    31   180   338   451    10    97
   439     8   106 34094 34095 34095 34095 34095 34095 34095 34095 34095
 34095 34095] 
Labels[ 0 11 12 12  0  4  5  0  6  0  7  8  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0 11 12 12  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0] 
Predictions[15 15 15 15 15 15 15 15 15 15 26 29 26 26 26 29 26 26 26 26 29 28 29 29
 29 28 29 29 29 29 28 23 28 28 28 28 28 28 28 28 23 14 14 14 14 14 14 14
 14 14]
