In [None]:
import numpy as np
from tensorflow import keras
import matplotlib.pyplot as plt
import seaborn as sb
import sklearn.metrics as metrics
import keras as ks

In [None]:
test_features = np.load('../../Arrays/test_features_hvg_subset.npy')
test_labels = np.load('../../Arrays/test_labels_hvg_subset.npy')

In [None]:
model = keras.models.load_model("../../Models/granulomas30_hvg_subset_jax_v1.keras", custom_objects={'LeakyReLU': ks.layers.LeakyReLU}) 

In [None]:
model.summary()

In [None]:
prediction = model.predict(test_features)

In [None]:
max_indices = np.argmax(prediction, axis=1)

In [None]:
def overall_metrics(y_true, y_pred, average='weighted'):

    results = {
        'accuracy': metrics.accuracy_score(y_true, y_pred),
        'precision': metrics.precision_score(y_true, y_pred, average=average, zero_division=0),
        'recall': metrics.recall_score(y_true, y_pred, average=average),
        'f1_score': metrics.f1_score(y_true, y_pred, average=average),
    }

    return results

In [None]:
def class_metrics(y_true, y_pred):
    return metrics.classification_report(y_true, y_pred, zero_division=0)

In [None]:
def create_confusion_matrix(y_true, y_pred):
    return metrics.confusion_matrix(y_true, y_pred)

In [None]:
def plot_confusion_matrix(y_true, y_pred):
    confusion_matrix = create_confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(8, 6))
    sb.heatmap(confusion_matrix, annot=False, cmap='Reds', cbar=True, xticklabels=np.unique(test_labels), yticklabels=np.unique(test_labels))
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

In [None]:
def roc_auc_ovr(y_true, y_score):
    return metrics.roc_auc_score(y_true, y_score, multi_class='ovr')

In [None]:
def roc_auc_ovo(y_true, y_score):
    return metrics.roc_auc_score(y_true, y_score, multi_class='ovo')

In [None]:
def average_precision(y_true, y_score): 
    return metrics.average_precision_score(y_true, y_score, average='weighted')

In [None]:
def balanced_accuracy(y_true, y_pred):
    return metrics.balanced_accuracy_score(y_true, y_pred)

In [None]:
overall_metrics(test_labels, max_indices)

In [None]:
print(class_metrics(test_labels, max_indices))

In [None]:
plot_confusion_matrix(test_labels, max_indices)

In [None]:
roc_auc_ovr(test_labels, prediction)

In [None]:
roc_auc_ovo(test_labels, prediction)

In [None]:
average_precision(test_labels, prediction)

In [None]:
balanced_accuracy(test_labels, max_indices)