In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix
import seaborn as sn
from pathlib import Path


import slayerSNN as snn

from models.snn import SlayerMLP
from utils.utils import letters
from utils.train_utils import get_datasets

In [2]:
fold_number = 1
which_model = 'bestLoss'
network_config = 'configs/network.yml'
device = torch.device('cuda:2')
data_dir = 'data/preprocessed/'
save_dir = 'analysis/imgs/'

params = snn.params(network_config)

In [3]:
# load model
input_size=160
output_size = len(letters)
net = SlayerMLP(params, input_size, output_size).to(device)
net.load_state_dict(torch.load(f'checkpoints/model_{fold_number}_{which_model}.pt'))
net.eval()

SlayerMLP(
  (slayer): spikeLayer()
  (fc): _denseLayer(160, 27, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
)

In [5]:
trial_number = 1

In [7]:
train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader = get_datasets(data_dir, fold_number, output_size, trial_number, batch_size=32, test=True)

train
validation
test


In [8]:
len(train_dataset) + len(test_dataset) + len(val_dataset)

1350

In [9]:
len(test_dataset)/27

10.0

In [10]:
# re-run the loss accuracy data
def calculate_loss_acc(_dataset, _loader):
    correct = 0
    true_labels = []
    pred_labels = []
    with torch.no_grad():
        for data, target, label in _loader:
            data = data.to(device)
            target = target.to(device)
            output = net.forward(data)
            pred_label = snn.predict.getClass(output)
            correct += torch.sum(
                pred_label == label
            ).data.item()
            true_labels.append(label.numpy())
            pred_labels.append(pred_label.cpu().numpy())            

        acc_value = correct / len(_dataset)
        
    true_labels = np.concatenate(true_labels)
    pred_labels = np.concatenate(pred_labels)
        
    return acc_value, true_labels, pred_labels

In [11]:
test_acc, test_true_labels, test_pred_labels = calculate_loss_acc(test_dataset, test_loader)
val_acc, val_true_labels, val_pred_labels = calculate_loss_acc(val_dataset, val_loader)
train_acc, train_true_labels, train_pred_labels = calculate_loss_acc(train_dataset, train_loader)

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/tasbolat/tas_python_env/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/tasbolat/tas_python_env/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tasbolat/tas_python_env/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tasbolat/some_python_examples/letter_spike/dataset.py", line 47, in __getitem__
    class_label = int(self.info[letters[index], 1])
KeyError: 0


In [None]:
print(f'Train accuracy: {train_acc}')
print(f'Validation accuracy: {val_acc}')
print(f'Test accuracy: {test_acc}')

In [12]:
# How does the confusion matrix look like:
test_cm = confusion_matrix(test_true_labels, test_pred_labels)
val_cm = confusion_matrix(val_true_labels, val_pred_labels)
train_cm = confusion_matrix(train_true_labels, train_pred_labels)

NameError: name 'test_true_labels' is not defined

In [None]:
plt.figure(figsize=(15,12))
sn.heatmap(test_cm, annot=True, xticklabels=letters.keys(), yticklabels=letters.keys()) # font size
plt.savefig(Path(save_dir)/f'model_{fold_number}_test_cm.png')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

In [None]:
plt.figure(figsize=(15,12))
sn.heatmap(train_cm, annot=True, xticklabels=letters.keys(), yticklabels=letters.keys()) # font size
plt.savefig(Path(save_dir)/f'model_{fold_number}_train_cm.png')
plt.ylabel('True')
plt.xlabel('Predicted')
plt.show()

In [None]:
plt.figure(figsize=(15,12))
sn.heatmap(val_cm, annot=True, xticklabels=letters.keys(), yticklabels=letters.keys()) # font size
plt.savefig(Path(save_dir)/f'model_{fold_number}_val_cm.png')
plt.ylabel('True')
plt.xlabel('Predicted')
plt.show()