In [1]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

In [2]:
def plot_confusion_matrix(cm,
                          classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues
):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

In [3]:
df_path_list = [
    '/kaggle/input/hms-hbac-model-zoo/opt-adamw-bandhigh-10/resnet1d_gru_oof_df_ver-82_stage-2.csv',
    '/kaggle/input/hms-hbac-model-zoo/opt-adamw-bandhigh-20/resnet1d_gru_oof_df_ver-82_stage-2.csv',
    '/kaggle/input/hms-hbac-model-zoo/opt-adan-bandhigh-20/resnet1d_gru_oof_df_ver-82_stage-2.csv',
]

In [4]:
pred_col = sorted([
    'seizure_vote_pred',
    'lpd_vote_pred',
    'gpd_vote_pred',
    'lrda_vote_pred',
    'grda_vote_pred',
    'other_vote_pred'
])
other_thr = 0.6

In [5]:
cm_list = list()

for df_path in df_path_list:
    df = pd.read_csv(df_path, index_col=0)
    for f in df.fold.unique():
        df_fold = df.loc[df.fold==f].copy()
        label_col = sorted(df_fold['target'].unique().tolist())
        label_dict = dict(zip(pred_col, label_col))
        label = df_fold['target']
        pred = df_fold[pred_col]
        pred.loc[:, 'other_vote_pred'] = pred.loc[:, 'other_vote_pred'].map(lambda x: x if x > other_thr else 0)
        pred = pred.idxmax(axis=1).map(label_dict)
        ratio_df = pd.DataFrame(pd.concat([pd.Series(label.value_counts(True), name='true'), pd.Series(pred.value_counts(True), name='pred')], axis=1))
        cm = confusion_matrix(label, pred)
        np.set_printoptions(precision=2)
        class_names = sorted(np.unique(pred).tolist())
        plt.figure(figsize=(5,5))
        figname = f'{df_path.split("/")[-2]}fold={f}'
        plot_confusion_matrix(
            cm, 
            classes=class_names,
            normalize=True,
            title=f"Confusion matrix\n{figname}"
        )
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_list.append((figname, cm))
        plt.savefig(figname+".png")
        plt.close()

In [6]:
class_matrix = list()
names = list()
for name, cm in cm_list:
    class_matrix_elms = list()
    for i, class_name in enumerate(class_names):
        class_matrix_elms.append(cm[i, i])
    class_matrix.append(class_matrix_elms)
    names.append(name)
class_matrix = pd.DataFrame(class_matrix, columns=class_names, index=names)
class_matrix.to_csv("class_matrix.csv")
class_matrix

Unnamed: 0,GPD,GRDA,LPD,LRDA,Other,Seizure
opt-adamw-bandhigh-10fold=0,0.735043,0.561644,0.653125,0.530201,0.598628,0.438596
opt-adamw-bandhigh-10fold=1,0.697802,0.719101,0.78,0.395349,0.651403,0.413793
opt-adamw-bandhigh-10fold=2,0.834254,0.545455,0.775591,0.322581,0.584375,0.33871
opt-adamw-bandhigh-10fold=3,0.837838,0.792453,0.731788,0.282051,0.643082,0.462687
opt-adamw-bandhigh-10fold=4,0.829493,0.642857,0.732673,0.378378,0.475262,0.37037
opt-adamw-bandhigh-20fold=0,0.760684,0.547945,0.6875,0.436242,0.622642,0.473684
opt-adamw-bandhigh-20fold=1,0.71978,0.696629,0.812,0.348837,0.5613,0.37931
opt-adamw-bandhigh-20fold=2,0.767956,0.707071,0.748031,0.370968,0.617188,0.306452
opt-adamw-bandhigh-20fold=3,0.804054,0.660377,0.741722,0.384615,0.559748,0.507463
opt-adamw-bandhigh-20fold=4,0.83871,0.559524,0.712871,0.405405,0.587706,0.259259
