In [1]:
import os
os.chdir('/home/roobz/Jupyter/afib-detector/src/')

In [2]:
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import glob
from pathlib import Path
from IPython.display import display_html, Image, display, HTML

from evaluate import get_performance_tables

%load_ext autoreload
%autoreload 2

## ETL Pipeline

To split the data into 10-second labeled samples with 3-fold cross-validation, I extracted each unique occurance of an annotation and noted the record it came from, when the occurance began, and when it ended. Then I discarded any occurances less than 30s (3x the expected length) and split each occurance into 3 smaller, equally-sized signals and randomly one-to-one mapped each to a fold. From here, the subsamples were further split into 10 second slices with a 50% overlap between each, discarding any excess. This resulted in 54,989 samples (22,020 AFIB, 32,969 N) per fold. 

## Models and Training
When considering what type model to apply to this problem, I immediately jumped to Convolutional Neural Networks. CNNs have proven themselves as very capable signal classifiers in various other tasks, so I thought that they should be my go-to answer for this problem. However, the exact architecture of a CNN can vary widely so I've compared various designs in this project. The only common elements for the models is that they each take 10-second, 2-lead ECGs as their input and output a prediction of Normal Sinus Rhythym (0) or AFib (1).

To generate a baseline I used two models: First a 1-D variation of the Pytorch MobileNetV2 implementation and second the model described in Hsieh et. al, 2020. I also created a self-made CNN (though admittedly I have very little experience with them).
To train a model, I held out one fold for validation and trained on the remaining data and repeated this for each model and fold.

### Ensembles
Lastly, I grouped each fold by architecture into ensembles by averaging their outputs (without performing any further training). I then evaluated the ensembles on the entire dataset to determine if averaging outputs was an effective approach for merging the various models together.

In [3]:
model_overview, model_agg, merged_model_metrics, metrics_agg = get_performance_tables()

model_overview_styler = (
    model_overview.style.format("{:.4f}")
#     .set_table_attributes(
#         "style='display:inline; margin-right:20px; margin-left: 5%; vertical-align: middle;'"
#     )
    .set_caption('Individual Model Performance')
)


display(HTML("""
<style>
.level0 {
    text-align: center !important;
}

#ind_model_perf {
    margin: auto !important;
    display: flex;
    justify-content: center;
}
</style>
"""))

model_overview_html = f"<div id='ind_model_perf'>{model_overview_styler._repr_html_()}</div>"

display_html(model_overview_html, raw=True)

Model,Custom,Custom,Custom,Hsieh,Hsieh,Hsieh,MobileNetV2,MobileNetV2,MobileNetV2
Unnamed: 0_level_1,AUC,Accuracy,F1 Score,AUC,Accuracy,F1 Score,AUC,Accuracy,F1 Score
Fold,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
1,0.9997,0.9943,0.9929,0.9997,0.9949,0.9937,0.9997,0.9959,0.9949
2,0.9993,0.993,0.9913,0.9997,0.9952,0.994,0.9997,0.9951,0.9939
3,0.9989,0.9904,0.9879,0.9998,0.9943,0.993,0.9998,0.9927,0.9909


In [4]:
model_agg_styler = (
    model_agg.style.format("{:.4f}")
    .set_table_attributes(
        "style='margin:10px !important;'"
    )
    .set_caption('Averaged Performance')
)
merged_model_metrics_styler = (
    merged_model_metrics.style.format("{:.4f}")
    .set_table_attributes(
        "style='margin:10px; !important;'"
    )
    .set_caption('Averaged Ensemble Performance')
)

perf_html = model_agg_styler._repr_html_()+merged_model_metrics_styler._repr_html_()
perf_html = (
    '''
    <style>
    #perf_table_wrap {
        display: flex !important;
        justify-content: center;
        margin: auto !important;
        align-items: center !important;
        padding: 10px;
    }
    </style>

    <div id='perf_table_wrap'>
    ''' +
    perf_html +
    '</div>'
)

display_html(perf_html, raw=True)

Unnamed: 0_level_0,Accuracy,F1 Score,AUC
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Custom,0.9925,0.9907,0.9993
Hsieh,0.9948,0.9936,0.9997
MobileNetV2,0.9946,0.9932,0.9998

Unnamed: 0_level_0,Accuracy,F1 Score,AUC
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Custom,0.9938,0.9922,0.9994
Hsieh,0.9956,0.9945,0.9998
MobileNetV2,0.9963,0.9954,0.9999
