In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from common import *

In [3]:
keys = ['roc_auc', 'pr_auc']

In [4]:
def filter_results(yr_df, pattern) :
    return yr_df.filter(regex=pattern, axis=0)

def get_yr_df_for_surgery(surgery_type) :
    yr_df = defaultdict(list)
    for yr in [1, 2, 3, 0.5, 0.25] :
        for model in ['baselines', 'Basic', 'Attention'] :
            name = os.path.join('outputs/' + surgery_type + '_' + str(yr), model)
            dirs = os.listdir(name)
            dfs = {}
            for d in sorted(dirs) :
                dirname = os.path.join(name, d)
                model_name = get_latest_model(dirname)
                if model_name is not None :
                    df = json.load(open(os.path.join(model_name, 'dev_evaluate.json')))
                    dfs[d] = {k:v for k, v in df.items() if k in keys}

            dfs = pd.DataFrame(dfs).transpose()
            yr_df[yr] += [dfs]

        yr_df[yr] = pd.concat(yr_df[yr], axis=0)
    yr_df = pd.concat(yr_df.values(), axis=1, keys=yr_df.keys()).round(3)
    yr_df.columns = yr_df.columns.swaplevel(0, 1)
    yr_df.sort_index(axis=1, level=0, inplace=True)
    display(HTML(yr_df.to_html()))
    return yr_df

Hip Surgery
===========

In [6]:
hip_yr_df.max()

pr_auc   0.25    0.579
         0.50    0.556
         1.00    0.495
         2.00    0.501
         3.00    0.536
roc_auc  0.25    0.883
         0.50    0.895
         1.00    0.868
         2.00    0.832
         3.00    0.854
dtype: float64

In [5]:
hip_yr_df = get_yr_df_for_surgery('HipSurgery')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
LR+BOW+norm=None,0.519,0.463,0.435,0.448,0.415,0.799,0.813,0.78,0.808,0.777
LR+BOW+norm=None+Structured,0.497,0.417,0.46,0.489,0.488,0.804,0.815,0.794,0.832,0.81
LR+BOW+norm=l1,0.1,0.1,0.1,0.1,0.1,0.5,0.5,0.5,0.5,0.5
LR+BOW+norm=l1+Structured,0.221,0.199,0.176,0.195,0.186,0.643,0.66,0.63,0.623,0.614
LR+BOW+norm=l2,0.481,0.491,0.479,0.453,0.453,0.852,0.856,0.83,0.808,0.829
LR+BOW+norm=l2+Structured,0.333,0.357,0.277,0.247,0.237,0.754,0.772,0.73,0.696,0.684
LR+BinaryBOW+norm=None,0.519,0.463,0.435,0.447,0.415,0.799,0.813,0.78,0.808,0.777
LR+BinaryBOW+norm=None+Structured,0.497,0.417,0.46,0.488,0.487,0.804,0.814,0.794,0.832,0.81
LR+BinaryBOW+norm=l1,0.1,0.1,0.1,0.1,0.1,0.5,0.5,0.5,0.5,0.5
LR+BinaryBOW+norm=l1+Structured,0.221,0.199,0.176,0.194,0.186,0.643,0.66,0.63,0.624,0.614


In [6]:
filter_results(hip_yr_df, r'LR+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
LR+BOW+norm=None,0.519,0.463,0.435,0.448,0.415,0.799,0.813,0.78,0.808,0.777
LR+BOW+norm=None+Structured,0.497,0.417,0.46,0.489,0.488,0.804,0.815,0.794,0.832,0.81
LR+BOW+norm=l1,0.1,0.1,0.1,0.1,0.1,0.5,0.5,0.5,0.5,0.5
LR+BOW+norm=l1+Structured,0.221,0.199,0.176,0.195,0.186,0.643,0.66,0.63,0.623,0.614
LR+BOW+norm=l2,0.481,0.491,0.479,0.453,0.453,0.852,0.856,0.83,0.808,0.829
LR+BOW+norm=l2+Structured,0.333,0.357,0.277,0.247,0.237,0.754,0.772,0.73,0.696,0.684
LR+BinaryBOW+norm=None,0.519,0.463,0.435,0.447,0.415,0.799,0.813,0.78,0.808,0.777
LR+BinaryBOW+norm=None+Structured,0.497,0.417,0.46,0.488,0.487,0.804,0.814,0.794,0.832,0.81
LR+BinaryBOW+norm=l1,0.1,0.1,0.1,0.1,0.1,0.5,0.5,0.5,0.5,0.5
LR+BinaryBOW+norm=l1+Structured,0.221,0.199,0.176,0.194,0.186,0.643,0.66,0.63,0.624,0.614


In [7]:
filter_results(hip_yr_df, r'LSTM+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
LSTM(hs=128),0.203,0.17,0.186,0.243,0.152,0.702,0.681,0.678,0.658,0.61
LSTM(hs=128)+Structured,0.343,0.307,0.28,0.269,0.249,0.736,0.733,0.718,0.717,0.701
LSTM(hs=128)+Attention(additive)(hs=128),0.429,0.499,0.461,0.359,0.391,0.874,0.877,0.868,0.789,0.825
LSTM(hs=128)+Attention(additive)(hs=128)+Structured,0.461,0.355,0.371,0.381,0.338,0.822,0.809,0.795,0.798,0.819


In [8]:
filter_results(hip_yr_df, r'Average+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
Average(hs=256),0.505,0.45,0.451,0.41,0.469,0.852,0.835,0.838,0.819,0.82
Average(hs=256)+Structured,0.382,0.335,0.358,0.357,0.304,0.786,0.76,0.756,0.758,0.76
Average(hs=256)+Attention(additive)(hs=128),0.441,0.505,0.456,0.397,0.408,0.855,0.895,0.844,0.791,0.846
Average(hs=256)+Attention(additive)(hs=128)+Structured,0.471,0.43,0.458,0.342,0.324,0.819,0.824,0.811,0.8,0.822


In [9]:
filter_results(hip_yr_df, r'CNN+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
"CNN(hs=64)(kernels=3,5,7,9)",0.539,0.486,0.395,0.45,0.428,0.876,0.878,0.845,0.814,0.846
"CNN(hs=64)(kernels=3,5,7,9)+Structured",0.538,0.501,0.429,0.467,0.401,0.883,0.887,0.854,0.83,0.854
"CNN(hs=64)(kernels=3,5,7,9)+Attention(additive)(hs=128)",0.468,0.504,0.495,0.425,0.434,0.852,0.875,0.839,0.819,0.847
"CNN(hs=64)(kernels=3,5,7,9)+Attention(additive)(hs=128)+Structured",0.487,0.468,0.481,0.421,0.365,0.847,0.84,0.841,0.809,0.829


Knee Surgery
============

In [9]:
knee_yr_df.max()

pr_auc   0.25    0.753
         0.50    0.759
         1.00    0.685
         2.00    0.637
         3.00    0.603
roc_auc  0.25    0.933
         0.50    0.947
         1.00    0.922
         2.00    0.915
         3.00    0.876
dtype: float64

In [7]:
knee_yr_df = get_yr_df_for_surgery('KneeSurgery')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
LR+BOW+norm=None,0.598,0.665,0.638,0.533,0.441,0.885,0.896,0.889,0.852,0.827
LR+BOW+norm=None+Structured,0.61,0.609,0.623,0.573,0.482,0.893,0.883,0.873,0.863,0.833
LR+BOW+norm=l1,0.136,0.121,0.121,0.121,0.121,0.523,0.5,0.5,0.5,0.5
LR+BOW+norm=l1+Structured,0.257,0.338,0.346,0.328,0.321,0.658,0.689,0.678,0.668,0.662
LR+BOW+norm=l2,0.753,0.726,0.679,0.637,0.581,0.933,0.925,0.911,0.912,0.866
LR+BOW+norm=l2+Structured,0.591,0.641,0.586,0.531,0.496,0.874,0.878,0.847,0.812,0.786
LR+BinaryBOW+norm=None,0.598,0.665,0.638,0.533,0.441,0.885,0.896,0.889,0.852,0.827
LR+BinaryBOW+norm=None+Structured,0.61,0.609,0.622,0.574,0.482,0.893,0.883,0.873,0.863,0.833
LR+BinaryBOW+norm=l1,0.136,0.121,0.121,0.121,0.121,0.523,0.5,0.5,0.5,0.5
LR+BinaryBOW+norm=l1+Structured,0.257,0.338,0.346,0.328,0.321,0.658,0.689,0.678,0.668,0.662


In [11]:
filter_results(knee_yr_df, r'LSTM+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
LSTM(hs=128),0.216,0.246,0.23,0.208,0.154,0.662,0.657,0.654,0.642,0.572
LSTM(hs=128)+Structured,0.379,0.378,0.391,0.338,0.328,0.722,0.718,0.714,0.687,0.672
LSTM(hs=128)+Attention(additive)(hs=128),0.619,0.687,0.634,0.558,0.545,0.905,0.92,0.889,0.891,0.859
LSTM(hs=128)+Attention(additive)(hs=128)+Structured,0.659,0.718,0.632,0.612,0.56,0.885,0.92,0.888,0.868,0.848


In [12]:
filter_results(knee_yr_df, r'Average+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
Average(hs=256),0.653,0.7,0.684,0.619,0.603,0.927,0.947,0.922,0.893,0.876
Average(hs=256)+Structured,0.492,0.563,0.503,0.499,0.398,0.811,0.818,0.782,0.767,0.74
Average(hs=256)+Attention(additive)(hs=128),0.62,0.673,0.649,0.582,0.507,0.895,0.907,0.898,0.911,0.876
Average(hs=256)+Attention(additive)(hs=128)+Structured,0.59,0.621,0.657,0.613,0.553,0.892,0.908,0.896,0.855,0.846


In [13]:
filter_results(knee_yr_df, r'CNN+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
"CNN(hs=64)(kernels=3,5,7,9)",0.704,0.683,0.583,0.596,0.525,0.917,0.908,0.893,0.906,0.876
"CNN(hs=64)(kernels=3,5,7,9)+Structured",0.666,0.678,0.671,0.615,0.583,0.927,0.913,0.915,0.886,0.866
"CNN(hs=64)(kernels=3,5,7,9)+Attention(additive)(hs=128)",0.638,0.668,0.563,0.587,0.526,0.916,0.903,0.91,0.897,0.874
"CNN(hs=64)(kernels=3,5,7,9)+Attention(additive)(hs=128)+Structured",0.694,0.686,0.621,0.616,0.597,0.898,0.914,0.882,0.878,0.855


In [14]:
filter_results(knee_yr_df, r'LR+')

Unnamed: 0_level_0,pr_auc,pr_auc,pr_auc,pr_auc,pr_auc,roc_auc,roc_auc,roc_auc,roc_auc,roc_auc
Unnamed: 0_level_1,0.25,0.50,1.00,2.00,3.00,0.25,0.50,1.00,2.00,3.00
LR+BOW+norm=None,0.598,0.665,0.638,0.533,0.441,0.885,0.896,0.889,0.852,0.827
LR+BOW+norm=None+Structured,0.61,0.609,0.623,0.573,0.482,0.893,0.883,0.873,0.863,0.833
LR+BOW+norm=l1,0.136,0.121,0.121,0.121,0.121,0.523,0.5,0.5,0.5,0.5
LR+BOW+norm=l1+Structured,0.257,0.338,0.346,0.328,0.321,0.658,0.689,0.678,0.668,0.662
LR+BOW+norm=l2,0.753,0.726,0.679,0.637,0.581,0.933,0.925,0.911,0.912,0.866
LR+BOW+norm=l2+Structured,0.591,0.641,0.586,0.531,0.496,0.874,0.878,0.847,0.812,0.786
LR+BinaryBOW+norm=None,0.598,0.665,0.638,0.533,0.441,0.885,0.896,0.889,0.852,0.827
LR+BinaryBOW+norm=None+Structured,0.61,0.609,0.622,0.574,0.482,0.893,0.883,0.873,0.863,0.833
LR+BinaryBOW+norm=l1,0.136,0.121,0.121,0.121,0.121,0.523,0.5,0.5,0.5,0.5
LR+BinaryBOW+norm=l1+Structured,0.257,0.338,0.346,0.328,0.321,0.658,0.689,0.678,0.668,0.662
