In [None]:
import numpy as np
from tqdm import tqdm
import seaborn as sn
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from sklearn.metrics import (precision_recall_curve,
                             PrecisionRecallDisplay,
                             confusion_matrix,
                             roc_curve,
                             auc,
                             roc_auc_score)


In [None]:
#function for creating a loss/accuracy plot
def loss_acc_plot(counter, loss_history ,valloss_history ,acc_history ,valacc_history,results_dir):
    # Plotting loss and accuracy per iteration
    # Create Plot

    fig, ax1 = plt.subplots()

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss', color='black')
    plot_1 = ax1.plot(counter, loss_history, color='blue', label='Loss')
    ax1.tick_params(axis='y', labelcolor='black')

    plot_3 = ax1.plot(counter, valloss_history, color='cyan', label='Validation Loss')
    ax1.tick_params(axis='y', labelcolor='black')

    # Adding Twin Axes

    ax2 = ax1.twinx()

    ax2.set_ylabel('Accuracy', color='black')
    plot_2 = ax2.plot(counter, acc_history, color='orange', label='Accuracy')
    ax2.tick_params(axis='y', labelcolor='black')

    plot_4 = ax2.plot(counter, valacc_history, color='gold', label='Validation Accuracy')
    ax2.tick_params(axis='y', labelcolor='black')

    # Add legends

    lns = plot_1 + plot_2 + plot_3 + plot_4
    labels = [l.get_label() for l in lns]
    plt.legend(lns, labels, loc=7)

    # Show plot
    plt.savefig(results_dir + 'accuracy_plot.png')
    plt.show()


In [None]:
#function for creating a color coded precision-recall curve
def plot_pr_color_coded(y_true, y_pred, ax, label):

    viridis = ['#fde725', '#5ec962', '#21918c', '#3b528b', '#440154']

    cm = LinearSegmentedColormap.from_list(
        colors=viridis, name='viridis')
    p, r, thrs = precision_recall_curve(y_true, y_pred)
    for idx in range(len(p) - 1):
        ax.plot(r[idx:idx + 2], p[idx:idx + 2],
                lw=4, color=cm(thrs[idx]))

    ax.set_title(label)

In [None]:
#function for creating a confusion matrix
def cfm(y_true ,y_pred,results_dir, class_mapping):
    # confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)

    # Normalise
    cmn = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis]
    fig, ax = plt.subplots(figsize=(10, 10))
    sn.heatmap(cmn, annot=True, fmt='.2f', xticklabels=class_mapping, yticklabels=class_mapping, cmap='Blues')
    plt.xlabel('Prediction')
    plt.ylabel('Label')

    plt.savefig(results_dir + "confusion_mtx_calls.png")
    # plt.show()

In [None]:
#function for plotting RPC and ROC curve
def plot_pr_roc_curve(target_list, prediction_list,results_dir,class_mapping):
    prediction_list_rearr = [[], [], [], [], [], []]
    target_list_rearr = [[], [], [], [], [], []]

    for P, T in zip(prediction_list, target_list):
        for n in range(6):
            prediction_list_rearr[n].append(P[n])
            target_list_rearr[n].append(T[n])
            
    fig, axs = plt.subplots(3, 2, figsize=(10, 10))

    # Croak
    plot_pr_color_coded(target_list_rearr[0], prediction_list_rearr [0], axs[0][0], class_mapping[0])
    # Groan
    plot_pr_color_coded(target_list_rearr[1], prediction_list_rearr [1], axs[0][1], class_mapping[1])
    # Growl
    plot_pr_color_coded(target_list_rearr[2], prediction_list_rearr [2], axs[1][0], class_mapping[2])
    # Moan
    plot_pr_color_coded(target_list_rearr[3], prediction_list_rearr [3], axs[1][1], class_mapping[3])
    # Rumble
    plot_pr_color_coded(target_list_rearr[4], prediction_list_rearr [4], axs[2][0], class_mapping[4])
    # Whoops
    plot_pr_color_coded(target_list_rearr[5], prediction_list_rearr [5], axs[2][1], class_mapping[5])
    fig.colorbar(mappable=None, ax=axs.ravel().tolist())

    plt.savefig(results_dir + 'Color-coded_Precision-recall.png')

    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(6):
        fpr[i], tpr[i], _ = roc_curve(target_list_rearr[i], prediction_list_rearr[i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    plt.figure()
    for i,class_ in zip(range(6),class_mapping[:-1]):
        plt.plot(fpr[i], tpr[i], label='Class {0} (area = {1:0.2f})'''.format(class_, roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC CURVE')
    plt.legend(loc="lower right")
    plt.savefig(results_dir  + "roc_curve.png", dpi=300)