In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

import random as rn
import seaborn as sns

In [None]:
directory = 'path'

testing_data = np.load(f'{directory}/data_test.npy')
testing_labels = np.load(f'{directory}/labels_test.npy')

In [None]:
def standardization(data):
    standardized = []
    for i in range(len(data)):
        positive = data[i] + (-1 * np.amin(data[i]))
        norm = positive / np.amax(positive)
        center = norm - np.mean(norm)
        standard = (center - np.mean(center)) / np.std(center, dtype=np.float32)
        standardized.append(standard)
    return np.asarray(standardized)

In [None]:
testing_data_std = standardization(testing_data)

In [None]:
path = 'to_model'
model = tf.keras.models.load_model(f'{path}/SRT-Ai/')

In [None]:
predictions_synthetic = model.predict(testing_data_std)

In [None]:
y_testing = keras.utils.to_categorical(testing_labels)

In [None]:
confusion_synthetic = confusion_matrix(np.argmax(y_testing, axis=1), np.argmax(predictions_synthetic, axis=1))

In [None]:
sns.set(rc={'figure.figsize':(10,10)})
ax = sns.heatmap(confusion_synthetic, annot=True, cmap='Blues', fmt='g', cbar=False)

ax.set_xlabel('\nPredicted Class', fontsize='24', weight='bold')
ax.set_ylabel('Actual Class\n', fontsize='24', weight='bold');

ax.xaxis.set_ticklabels(['No Termination', 'Contains Termination'], fontsize='18')
ax.yaxis.set_ticklabels(['No Termination', 'Contains Termination'], fontsize='18')

plt.show()

In [None]:
normalized_synthetic_cm = confusion_synthetic.astype('float')/confusion_synthetic.sum(axis=1)[:, np.newaxis]

sns.set(rc={'figure.figsize':(10,10)})
ax = sns.heatmap(normalized_synthetic_cm, annot=True, cmap='Blues', fmt='.2%', cbar=False)

ax.set_xlabel('\nPredicted Class', fontsize='24', weight='bold')
ax.set_ylabel('Actual Class\n', fontsize='24', weight='bold');

ax.xaxis.set_ticklabels(['No Termination', 'Contains Termination'], fontsize='18')
ax.yaxis.set_ticklabels(['No Termination', 'Contains Termination'], fontsize='18')

plt.show()