In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import sys
from tqdm import tqdm
import utils
import torch.nn.functional as F
import math
import random
import utils, models

In [4]:
# Generate training, validation and test sets
all_objects, labels = utils.get_sets(path_enter = "Data/")
X_training, Y_training, X_val, Y_val, X_test, Y_test = utils.split_sets(all_objects, labels)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7848/7848 [04:07<00:00, 31.77it/s]


In [91]:
model.labeling_order

(6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95)

In [99]:
unique_training, counts_training = np.unique( torch.tensor( Y_training ).view(1,-1).numpy(), return_counts = True )

inverse_freqs =  1 / counts_training
class_weights = inverse_freqs / np.sum(inverse_freqs)
weights_tensor = torch.tensor(class_weights, dtype = torch.float32)

weights_tensor

tensor([0.0774, 0.0254, 0.0132, 0.0102, 0.0677, 0.4384, 0.0268, 0.1083, 0.0128,
        0.0575, 0.0334, 0.0053, 0.0535, 0.0703])

In [5]:
# Define model
input_dim = 6
n_classes = 14
d_model = 64
nhead = 4
num_layers = 2

model = models.TransformerClassifier(input_dim, n_classes, d_model, nhead, num_layers)
model.load_state_dict( torch.load("Different_Tests/Test_1/Models_Data/Model_Saved/model_epoch_48.pt") )

<All keys matched successfully>

In [11]:
labeling_order = model.labeling_order
print(labeling_order)

(6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95)


In [18]:
Y_pred_labels = []
for i in tqdm( range( len(X_test) ) ):
    Y_pred = model(X_test[i])
    Y_pred_labels.append( Y_pred )
Y_pred_labels = torch.cat(Y_pred_labels, dim = 0)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1570/1570 [00:37<00:00, 41.82it/s]


In [22]:
max_indices = torch.argmax(Y_pred_labels, dim = 1).view(-1, 1)
max_indices.shape

torch.Size([1570, 1])

In [35]:
predicted_labeling = torch.tensor( [ labeling_order[ max_indices[i,:] ] for i in range(max_indices.shape[0]) ] ).view(-1,1)
predicted_labeling

tensor([[65],
        [90],
        [16],
        ...,
        [90],
        [16],
        [90]])

In [37]:
real_labeling = torch.tensor( Y_test ).view(-1,1)
real_labeling

tensor([[65],
        [90],
        [92],
        ...,
        [65],
        [16],
        [90]])

In [65]:
# Convert to numpy arrays
y_pred = predicted_labeling.view(1,-1).numpy()
y_true = real_labeling.view(1,-1).numpy()

In [70]:
unique_values, unique_counts = np.unique(y_true, return_counts = True)

unique_values, unique_counts

(array([ 6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95]),
 array([ 26, 111, 174, 220,  36,   9, 112,  11, 223,  34,  77, 452,  55,
         30]))

In [84]:
unique_values, unique_counts = np.unique(y_pred, return_counts = True)

unique_values, unique_counts

(array([16, 65, 88, 90]), array([ 275,  225,   61, 1009]))

In [83]:
y_pred[y_true == 90]

array([90, 65, 90, 90, 90, 90, 90, 65, 90, 90, 90, 90, 90, 90, 65, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 65, 90, 90, 90, 90, 90, 90, 90, 65, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 65, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 16, 90,
       90, 90, 90, 65, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 65, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 65, 90, 90, 90, 90,
       90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 90, 90, 65, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90,
       90, 16, 65, 90, 90, 65, 90, 90, 90, 90, 90, 90, 90, 90, 90, 65, 90,
       90, 90, 90, 90, 90