This notebook is to create visualisations for the final results.

In [2]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [17]:
%matplotlib qt

## Dataframe

In [30]:
data_five = {
    'Model': ['Raw', 'Raw + filtered', 'Raw + filtered + STFT'] * 9,
    'Metric': ['F1-score']*3 + ['Precision']*3 + ['Recall']*3 +
              ['F1-score']*3 + ['Precision']*3 + ['Recall']*3 +
              ['F1-score']*3 + ['Precision']*3 + ['Recall']*3,
    'Detection': ['Spindle']*9 + ['SO']*9 + ['SO-Spindle Coupling']*9,
    'Mean': [
        0.82, 0.89, 0.89,
        0.82, 0.88, 0.88,
        0.85, 0.91, 0.92,
        0.65, 0.74, 0.71,
        0.70, 0.79, 0.76,
        0.65, 0.72, 0.70,
        0.24, 0.43, 0.31,
        0.55, 0.51, 0.46,
        0.18, 0.38, 0.25
    ],
    'Std': [
        0.04, 0.05, 0.06,
        0.12, 0.11, 0.12,
        0.08, 0.06, 0.04,
        0.13, 0.07, 0.07,
        0.16, 0.12, 0.14,
        0.18, 0.13, 0.12,
        0.13, 0.07, 0.10,
        0.10, 0.04, 0.06,
        0.14, 0.11, 0.11
    ]
}
# these results come from the final tables
# you get at the end of the deep learning notebooks

df_five = pd.DataFrame(data_five)
df_five

Unnamed: 0,Model,Metric,Detection,Mean,Std
0,Raw,F1-score,Spindle,0.82,0.04
1,Raw + filtered,F1-score,Spindle,0.89,0.05
2,Raw + filtered + STFT,F1-score,Spindle,0.89,0.06
3,Raw,Precision,Spindle,0.82,0.12
4,Raw + filtered,Precision,Spindle,0.88,0.11
5,Raw + filtered + STFT,Precision,Spindle,0.88,0.12
6,Raw,Recall,Spindle,0.85,0.08
7,Raw + filtered,Recall,Spindle,0.91,0.06
8,Raw + filtered + STFT,Recall,Spindle,0.92,0.04
9,Raw,F1-score,SO,0.65,0.13


## Three grouped barplots

In [31]:
palette = {
    'F1-score': '#1f4e79',   
    'Precision': '#6497b1',  
    'Recall': '#a9cce3'      
}
# blue colours

models = ['Raw', 'Raw + filtered', 'Raw + filtered + STFT']
metrics = ['F1-score', 'Precision', 'Recall']

g = sns.catplot(
    data=df_five,
    kind="bar",
    x="Model",
    y="Mean",
    hue="Metric",
    col="Detection",
    height=4,
    aspect=1.2,
    errorbar=None,
    palette=palette,
    legend_out=True,
    order=models,
    hue_order=metrics
)

g._legend.remove()
# this is to avoid getting two legends

g.fig.subplots_adjust(right=0.85)
g.fig.legend(title="Metric", loc='center right', bbox_to_anchor=(0.92, 0.75))
# to get the legend above SO-Spindle Coupling (a lot of space)

g.set_titles("{col_name}")
g.set_axis_labels("Model", "Score")
g.set(ylim=(0, 1.1))
# need to go above one
# because SDs for spindles go above

for ax, detection in zip(g.axes[0], df_five['Detection'].unique()):
    detection_data = df_five[df_five['Detection'] == detection]

    bars = ax.patches

    # this code is to make sure that the errorbars go in the middle of the plots
    for i, bar in enumerate(bars):
        x = bar.get_x() + bar.get_width() / 2
        y = bar.get_height()

        model_idx = i // len(metrics)
        metric_idx = i % len(metrics)

        if model_idx >= len(models) or metric_idx >= len(metrics):
            # this is to avoid extra bars
            continue

        model = models[model_idx]
        metric = metrics[metric_idx]

        subset = detection_data[(detection_data['Model'] == model) & (detection_data['Metric'] == metric)]
        if not subset.empty:
            yerr = subset['Std'].values[0]
            ax.errorbar(
                x,
                y,
                yerr=yerr,
                fmt='none',
                c='black',
                capsize=5,
                capthick=1,
                elinewidth=1
            )

plt.tight_layout()
plt.show()


## Only F1-scores

In [48]:
# only keep the F1-score rows
df_f1 = df_five[df_five['Metric'] == 'F1-score']

# detection will be on the x-axis
# and the models are like the metrics
detections = ['Spindle', 'SO', 'SO-Spindle Coupling']
models = ['Raw', 'Raw + filtered', 'Raw + filtered + STFT']

palette_models = {
    'Raw': '#d35400',               
    'Raw + filtered': '#e67e22',  
    'Raw + filtered + STFT': '#f0b27a'  
}
# orange colours

plt.figure(figsize=(8,5))
ax = sns.barplot(
    data=df_f1,
    x='Detection',
    y='Mean',
    hue='Model',
    order=detections,
    hue_order=models,
    palette=palette_models,
    errorbar=None
)

# to convert hex to rgb
# because using slightly different colours
def hex_to_rgb(h):
    h = h.lstrip('#')
    return tuple(int(h[i:i+2], 16)/255 for i in (0, 2, 4))

palette_rgb = {k: hex_to_rgb(v) for k, v in palette_models.items()}

def color_distance(c1, c2):
    return sum((a-b)**2 for a,b in zip(c1[:3], c2))

for bar in ax.patches:
    y = bar.get_height()
    if y <= 0 or bar.get_width() < 0.2:
        continue

    x = bar.get_x() + bar.get_width() / 2

    # the x-ticks have the closest detection
    xvals = [tick.get_position()[0] for tick in ax.get_xticklabels()]
    detection_idx = min(range(len(xvals)), key=lambda i: abs(xvals[i] - x))
    detection = detections[detection_idx]

    # the model is found by matching the bar colour
    bar_color = bar.get_facecolor()
    model = min(palette_rgb.keys(), key=lambda k: color_distance(bar_color, palette_rgb[k]))

    subset = df_f1[(df_f1['Detection'] == detection) & (df_f1['Model'] == model)]
    if not subset.empty:
        yerr = subset['Std'].values[0]
        ax.errorbar(
            x,
            y,
            yerr=yerr,
            fmt='none',
            c='black',
            capsize=5,
            capthick=1,
            elinewidth=1
        )

ax.set_ylim(0, 1.1)
ax.set_title("F1-Scores by Event Detection and Model")
ax.set_ylabel("F1-score")
ax.set_xlabel("Event Detection")
plt.legend(title='Model', loc='upper right')
plt.tight_layout()
plt.show()
