In [29]:
import os
import pickle
import numpy as np
import pandas as po
import matplotlib.pyplot as plt

In [51]:
from sklearn.metrics import confusion_matrix, roc_curve

In [55]:
df_cm = po.DataFrame()
train_hist = {}
base_dir = 'results/final_plots/'
for model in os.listdir(base_dir):
    train_hist[model] = {}

    with open(base_dir+model + '/plots/predictions.npz', 'rb') as f:
        data = np.load(f)
    
        y_true = data['y_true']
        y_pred = data['y_pred']

    with open(base_dir+model + '/plots/history.pkl', 'rb') as f:
        train_hist[model]['history'] = pickle.load(f)

    train_hist[model]['y_true'] = np.argmax(y_true, axis = 1)
    train_hist[model]['y_pred'] = np.argmax(y_pred, axis = 1)
    train_hist[model]['y_prob_true'] = y_pred[:, 1]

    tn, fp, fn, tp = confusion_matrix(train_hist[model]['y_true'], train_hist[model]['y_pred']).ravel()
    row = {'model': model, 'TN': tn, 'FP': fp, 'FN': fn, 'TP': tp}
    df_cm = df_cm.append(row, ignore_index=True)

df_cm.to_csv('results/confusion_matrix.csv', index=False)

In [56]:
colors = ['orange', 'blue', 'green', 'red', 'grey']

In [57]:
plt.figure(figsize=(20, 10))
for i, model in enumerate(train_hist):
    plt.plot(train_hist[model]['history']['loss'][:15], color=colors[i], marker='.', label=model)
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Training Loss', fontsize=20)
plt.legend(fontsize=20)
plt.savefig('results/train_loss.png')
plt.close()

In [58]:
plt.figure(figsize=(20, 10))
for i, model in enumerate(train_hist):
    plt.plot(train_hist[model]['history']['val_loss'][:15], color=colors[i], marker='.', label=model)
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Val Loss', fontsize=20)
plt.legend(fontsize=20)
plt.savefig('results/val_loss.png')
plt.close()

In [62]:
plt.figure(figsize=(20, 10))
for i, model in enumerate(train_hist):
    lr_fpr, lr_tpr, _ = roc_curve(train_hist[model]['y_true'], train_hist[model]['y_prob_true'])
    plt.plot(lr_fpr, lr_tpr, color=colors[i], marker='.', label=model)

plt.legend(fontsize=20)
plt.xlabel('False Positive Rate', fontsize=20)
plt.ylabel('True Positive Rate', fontsize=20)
plt.savefig('results/roc_curve.png')
plt.close()