In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from hurry.filesize import size, si
import re

sns.set_style("whitegrid")

In [196]:
FIG_PARAMS = dict(
    fig_size=(14, 5),
    label_size=20,
    legend_size=14,
    legend_title_size=17,
    title_size=26,
    annot_size=16,
    tick_size=16,
    marker_size=2,
    fill_alpha=0.15,
)

FIG_PARAMS_HEATMAP = dict(
    fig_size=(12, 8),
    label_size=20,
    legend_size=14,
    legend_title_size=17,
    title_size=26,
    annot_size=16,
    tick_size=16,
    marker_size=2,
    fill_alpha=0.15,
)


X_NAMES = {
    "representation": "Represetation type",
    "model": "Model type"
}

Y_NAMES = {
    "fscore": "$F_1$",
    "mae": "$MAE$",
    "inference_mean_time": "Inference time\n[$ms$]",
    "memory_complexity": "Memory complexity\n[$bytes$]"
}

rep2name = {
    "whole_signal_waveforms": "Whole signal/nwaveforms",
    "whole_signal_features": "Whole signal/nfeatures",
    "agg_beat_waveforms": "Aggregated beat/nwaveforms",
    "agg_beat_features": "Aggregated beat/nfeatures",
}

def get_y_name(y):
    if "/" not in y:
        return Y_NAMES[y]
    split, metric = y.split("/")
    _metric = Y_NAMES[metric]
    _split = split.capitalize()
    return f"{_split} {_metric}"



# **Load results csv to df**

In [197]:
results = pd.read_csv("results_3.csv").drop(["Name", "State", "Created", "Runtime"], axis=1)
results['representation'] = results['representation'].apply(lambda rep: rep2name[rep])

# **Split results df into different datasets dfs**

In [179]:
ptbxl_results = results.query("dataset == 'ptbxl'")
sleep_edf_results = results.query("dataset == 'sleep_edf'")
mimic_results = results.query("dataset == 'mimic'")

# ptbxl_results['memory_complexity'].apply(lambda size_bytes: size(size_bytes, system=si))

In [175]:
def compare_results(data, x="representation", y="test/fscore", hue="model", ax=None):
    return_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=FIG_PARAMS['fig_size'])
        return_fig = True
        
    sns.barplot(data=data, x=x, y=y, hue=hue, ax=ax)
    ax.legend(
        bbox_to_anchor=(1.02, 1), loc='upper left', 
        borderaxespad=0, fontsize=FIG_PARAMS['legend_size'],
        title=X_NAMES[hue],
        title_fontsize=FIG_PARAMS['legend_title_size']
    )
    ax.set_xlabel(X_NAMES[x], fontsize=FIG_PARAMS["label_size"])
    ax.set_ylabel(get_y_name(y), fontsize=FIG_PARAMS["label_size"])
    # ax.tick_params(axis="x", labelrotation=15)
    ax.tick_params(axis="both", labelsize=FIG_PARAMS["tick_size"])
    plt.close()
    if return_fig:
        return fig

def get_comparison_figs(results_df, metric="fscore", split="test", prefix=""):
    if split is not None:
        y = f"{split}/{metric}"
    else:
        y = metric
    params = dict(data=results_df, y=y)
    models_fig = compare_results(x="representation", hue="model", **params)
    representation_fig = compare_results(x="model", hue="representation", **params)
    return {
        f"{prefix}models": models_fig,
        f"{prefix}reps": representation_fig
    }

def plot_results_heatmap(results_df, x="model", y="representation", metric="fscore", split="test", ax=None):
    if split is not None:
        z = f"{split}/{metric}"
    else:
        z = metric
    
    df = results_df.pivot(y, x, z)
    fmt = "" if metric == "memory_complexity" else ".3g"
    return_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 2*len(df)))
        return_fig = True
    axx = sns.heatmap(
        df, 
        cmap='coolwarm', 
        annot=True, fmt=fmt, annot_kws={"fontsize": FIG_PARAMS_HEATMAP['tick_size']},
        cbar=True,
        ax=ax, 
        square=False
    )
    cb = axx.collections[0].colorbar
    cb.ax.tick_params(labelsize=FIG_PARAMS_HEATMAP['legend_size'])
    cb.ax.set_title(get_y_name(z), fontsize=FIG_PARAMS_HEATMAP['legend_title_size'])

    ax.set_xlabel(X_NAMES[x], fontsize=FIG_PARAMS_HEATMAP["label_size"])
    ax.set_ylabel(X_NAMES[y], fontsize=FIG_PARAMS_HEATMAP["label_size"])
    ax.tick_params(axis="x", labelrotation=15)
    ax.tick_params(axis="y", labelrotation=0)
    ax.tick_params(axis="both", labelsize=FIG_PARAMS_HEATMAP["tick_size"])
    plt.close()
    if return_fig:
        return fig
    
def get_all_heatmaps(results_df, x="model", y="representation", split="test", pred_qual_metric="fscore"):
    pred_qual_fig = plot_results_heatmap(results_df, split=split, metric=pred_qual_metric)
    inference_time_fig = plot_results_heatmap(results_df, split=None, metric="inference_mean_time")
    memory_complexity_fig = plot_results_heatmap(results_df, split=None, metric="memory_complexity")
    return {
        f"pred_qual_heatmap": pred_qual_fig,
        f"infer_time_heatmap": inference_time_fig,
        f"mem_comp_heatmap": memory_complexity_fig
    }


def get_all_comparison_figs(results_df, split="test", pred_qual_metric="fscore", prefix=""):
    pred_qual_figs = get_comparison_figs(results_df, split=split, metric=pred_qual_metric, prefix="pred_qual_")
    inference_time_figs = get_comparison_figs(results_df, split=None, metric="inference_mean_time", prefix="infer_time_")
    memory_complexity_figs = get_comparison_figs(results_df, split=None, metric="memory_complexity", prefix="mem_comp_")
    heatmaps = get_all_heatmaps(results_df, x="model", y="representation", split=split, pred_qual_metric=pred_qual_metric)
    figs = {**pred_qual_figs, **inference_time_figs, **memory_complexity_figs, **heatmaps}
    figs = {f"{prefix}{name}": fig for name, fig in figs.items()}
    return figs

In [176]:
results = pd.read_csv("results_3.csv").drop(["Name", "State", "Created", "Runtime"], axis=1)
results['representation'] = results['representation'].apply(lambda rep: rep2name[rep])

info = [
    ("ptbxl", "fscore"),
    ("sleep_edf", "fscore"),
    ("mimic", "mae"),
]

all_results = {name: {"data": results.query(f"dataset == '{name}'"), "metric": metric} for name, metric in info}

for ds_name, ds_results in all_results.items():
    data = ds_results['data']
    metric = ds_results['metric']
    figs = get_all_comparison_figs(data, split="test", pred_qual_metric=metric, prefix=f"{ds_name}_")
    for fig_name, fig in figs.items():
        fig.savefig(f"plots/{fig_name}.pdf", bbox_inches="tight")

  df = results_df.pivot(y, x, z)
  df = results_df.pivot(y, x, z)
  df = results_df.pivot(y, x, z)


In [None]:
ptbxl_figs = get_all_comparison_figs(ptbxl_results, split="test", pred_qual_metric='fscore')
sleep_edf_figs = get_all_comparison_figs(sleep_edf_results, split="test", pred_qual_metric='fscore')
mimic_figs = get_all_comparison_figs(mimic_results, split="test", pred_qual_metric='mae')

In [270]:
rep2name_for_table = {
    "whole_signal_waveforms": "WSW",
    "whole_signal_features": "WSF",
    "agg_beat_waveforms": "ABW",
    "agg_beat_features": "ABF",
}

In [314]:
results = pd.read_csv("results_3.csv").drop(["Name", "State", "Created", "Runtime"], axis=1)
results['representation'] = results['representation'] .apply(lambda rep: rep2name_for_table[rep])
results['model'] = results['model'].apply(lambda model: "DT" if model == "Decision Tree" else model)
results['memory_complexity'] = results['memory_complexity'].apply(lambda byte_size: size(byte_size, system=si))
df = results.drop(["test/auroc", "val/fscore", "val/auroc", "fit_time", "inference_std_time", "val/mae"], axis=1).round(3)

In [315]:
df.rename(
    columns={
        "inference_mean_time": "Inference time [$ms$]",
        "memory_complexity": "Memory complexity [$bytes$]",
        "test/fscore": "F1", 
        "test/mae": "MAE"
    },
    inplace=True
)

In [316]:
df = df.set_index(["dataset", "model", "representation"])[["F1", "MAE", "Inference time [$ms$]", "Memory complexity [$bytes$]"]].sort_index()

In [317]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,F1,MAE,Inference time [$ms$],Memory complexity [$bytes$]
dataset,model,representation,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
mimic,CNN,ABF,,17.166,0.826,3M
mimic,CNN,ABW,,16.963,0.811,3M
mimic,CNN,WSF,,16.005,1.913,3M
mimic,CNN,WSW,,16.147,12.243,3M
mimic,DT,ABF,,23.63,0.064,1M
mimic,DT,ABW,,23.008,0.074,1M
mimic,DT,WSF,,23.035,0.091,1M
mimic,DT,WSW,,24.357,0.073,1M
mimic,LGBM,ABF,,16.743,0.129,290K
mimic,LGBM,ABW,,16.755,0.13,299K


In [318]:
ptbxl_df = df.loc['ptbxl'].drop("MAE", axis=1)
sleep_edf_df = df.loc['sleep_edf'].drop("MAE", axis=1)
mimic_df = df.loc['mimic'].drop("F1", axis=1)

In [319]:
def get_latex_table(df):
    latex_table = df.to_latex().replace("NaN", "-").replace('\$', '$')
    latex_table =  re.sub(' +', ' ', latex_table)
    return latex_table

In [320]:
ptbxl_table = get_latex_table(ptbxl_df)
sleep_edf_table = get_latex_table(sleep_edf_df)
mimic_table = get_latex_table(mimic_df)

  latex_table = df.to_latex().replace("NaN", "-").replace('\$', '$')


In [321]:
ptbxl_df

Unnamed: 0_level_0,Unnamed: 1_level_0,F1,Inference time [$ms$],Memory complexity [$bytes$]
model,representation,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
CNN,ABF,0.698,2.378,3M
CNN,ABW,0.743,0.754,3M
CNN,WSF,0.592,3.834,3M
CNN,WSW,0.73,2.056,3M
DT,ABF,0.6,0.037,436K
DT,ABW,0.59,0.039,432K
DT,WSF,0.592,0.054,399K
DT,WSW,0.432,0.049,527K
LGBM,ABF,0.765,0.202,1M
LGBM,ABW,0.744,0.218,1M


In [308]:
print(ptbxl_table)

\begin{tabular}{llrrl}
\toprule
 & & F1 & Inference time [$ms$] & Memory complexity [$bytes$] \\
model & representation & & & \\
\midrule
DT & WSW & 0.432 & 0.049 & 527K \\
CNN & ABF & 0.698 & 2.378 & 3M \\
 & ABW & 0.743 & 0.754 & 3M \\
 & WSF & 0.592 & 3.834 & 3M \\
 & WSW & 0.730 & 2.056 & 3M \\
MLP & ABF & 0.715 & 0.275 & 3M \\
 & ABW & 0.701 & 0.247 & 3M \\
 & WSF & 0.673 & 0.207 & 6M \\
 & WSW & 0.436 & 0.536 & 14M \\
Regression & ABF & 0.690 & 0.037 & 38K \\
 & ABW & 0.663 & 0.041 & 48K \\
 & WSF & 0.615 & 0.044 & 146K \\
 & WSW & 0.407 & 0.055 & 480K \\
DT & ABF & 0.600 & 0.037 & 436K \\
 & ABW & 0.590 & 0.039 & 432K \\
 & WSF & 0.592 & 0.054 & 399K \\
LGBM & ABF & 0.765 & 0.202 & 1M \\
 & ABW & 0.744 & 0.218 & 1M \\
 & WSF & 0.766 & 0.239 & 1M \\
 & WSW & 0.558 & 0.277 & 2M \\
\bottomrule
\end{tabular}

