In [1]:
import pandas as pd
import os
import numpy as np
import tsai
import torch

from sklearn.model_selection import RepeatedStratifiedKFold, StratifiedKFold, StratifiedGroupKFold, StratifiedShuffleSplit, GroupShuffleSplit
from tsai.all import *
from fastai.metrics import RocAucBinary
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from fastai.callback.tracker import ReduceLROnPlateau, EarlyStoppingCallback

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from fastai.callback.core import Callback
from fastai.learner import Learner
from fastai.losses import CrossEntropyLossFlat
from fastai.metrics import BalancedAccuracy, RocAucBinary
from sklearn.calibration import CalibratedClassifierCV

In [2]:
torch.cuda.is_available()

True

In [3]:
data = np.load("../data/mimic_train_ids.npz", allow_pickle=True)
train = data["ids"]
train_case_ids = data["case_ids"]

data = np.load("../data/mimic_test_ids.npz", allow_pickle=True)
test = data["ids"]
test_case_ids = data["case_ids"]

In [4]:
data = np.load("../data/mimic_ecgs_and_labels_age_gender.npz")
X = data["X"][:]
y_cag = data["y_cag"][:]
y_revasc = data["y_revasc"][:]
subject_id = data["subject_id"][:]

In [5]:
groups = subject_id
y = y_revasc

In [6]:
class TSC(BaseEstimator, ClassifierMixin):
    def __init__(self, random_state=42, epochs=100):
        self.model = None
        self.classes_ = np.array([0, 1])
        self.random_state = random_state
        self.epochs = epochs
        self.batch_tfms = [TSStandardize(by_sample=False)]
        self.tfms = [None, TSClassification()]

    def fit(self, X, y, groups):
        for sp0, sp1 in StratifiedGroupKFold(n_splits=8, shuffle=True, random_state=42+self.random_state).split(X, y, groups=groups): # 8 splits
            break    

        patients = 5
        self.model = TSClassifier(X, y, 
                           arch="InceptionTime", tfms=self.tfms, #path=f"./models/{self.random_state}",
                           metrics=[RocAucBinary()], 
                           train_metrics=True,
                           splits=(sp0, sp1),
                           bs=64,
                           cbs=[
                               EarlyStoppingCallback(min_delta=0.0, patience=patients, comp=np.less), 
                               SaveModel(),
                               ReduceLROnPlateau(monitor='valid_loss', comp=np.less, patience=2, factor=10, min_lr=0,reset_on_fit=True)
                               ],
                           ### -> if you want a propper balanced accuracy
                           #loss_func=CrossEntropyLossFlat(reduction="mean", weight=torch.Tensor((dist[1]/dist.sum(), dist[0]/dist.sum())).to("cuda")),
                           batch_tfms=self.batch_tfms,
                           lr=5e-4,
                           verbose=True)
    
        self.model.fit(self.epochs)
        
        return self

    def fine_tune(self, X, y, groups, path):
        for sp0, sp1 in StratifiedGroupKFold(n_splits=8, shuffle=True, random_state=self.random_state).split(X, y, groups=groups): # 8 splits
            break
        patients = 5
        self.model = TSClassifier(X, y, 
                           arch="InceptionTime", tfms=self.tfms,
                           metrics=[RocAucBinary()], 
                           train_metrics=True,
                           splits=(sp0, sp1),
                           pretrained=True,
                           weights_path=path,
                           bs=64,
                           cbs=[
                               EarlyStoppingCallback(min_delta=0.0, patience=patients, comp=np.less), 
                               SaveModel(),
                               ],
                           ### -> if you want a propper balanced accuracy
                           #loss_func=CrossEntropyLossFlat(reduction="mean", weight=torch.Tensor((dist[1]/dist.sum(), dist[0]/dist.sum())).to("cuda")),
                           #loss_func=FocalLossFlat(),
                            batch_tfms=self.batch_tfms,
                           exclude_head=False,
                           lr=1e-5,
                           verbose=True)
        self.model.fit(self.epochs)
        
    
    def predict(self, X, bs=64):
        preds_int = self.model.get_X_preds(X, bs=bs)[2].astype(int).astype(int)
        return preds_int

    def predict_proba(self, X, bs=64):
        preds = self.model.get_X_preds(X, bs=bs)[0].numpy()[:, :]
        return preds

    def export(self, path):
        self.model.export(path)

    def load(self, path):
        self.model = load_learner(path, cpu=False)

In [7]:
pd.Series(y).value_counts(dropna=False)

0    179602
1      1084
Name: count, dtype: int64

In [8]:
scores = []
Xt = np.nan_to_num(X[:, :, :])
print(Xt.shape)


y_pred_proba = []
y_preds = []
ytests = []
tests = []

n_runs = 1  # 10 for the original ensemble model

tmpp = 0
for run in range(n_runs):
    X_train, X_test = Xt[train], Xt[test]
    y_train, y_test = y[train], y[test]
    
    groups_train = np.concatenate([groups[train]], axis=0)
        
    base_model = TSC(random_state=run, epochs=100)
    base_model.fit(X_train, y_train, groups=groups_train)
        
    y_pred = base_model.predict(X_test)
    y_pred_proba.append(base_model.predict_proba(X_test)[:, 1])
    y_preds.append(y_pred)
    
    ytests.append(y_test)
    tests.append(test)
        
    print(f"{run} - {n_runs}")
    scores.append(roc_auc_score(y_score=y_pred_proba[-1], y_true=y_test))
    print(scores)
    print(np.mean(scores))

    # to save each run and create the ensemble model
    #base_model.export(f"/media/seaweedmnt/projects/caidiology/er3.0/data_exploration/publication/revasc_model/exported_models/revasc_mimic_chest_pain/{run}.pkl")

(180686, 12, 1000)
arch: InceptionTime(c_in=12 c_out=2 seq_len=1000 arch_config={} kwargs={})


epoch,train_loss,train_roc_auc_score,valid_loss,valid_roc_auc_score,time
0,0.030359,0.78602,0.026074,0.876108,02:43
1,0.035314,0.854046,0.025561,0.876621,02:39
2,0.032045,0.876958,0.025318,0.887592,02:42
3,0.025449,0.882974,0.025152,0.883421,02:42
4,0.029554,0.889673,0.025998,0.889895,02:43
5,0.02522,0.900235,0.025669,0.871141,02:39
6,0.025827,0.918992,0.024751,0.897277,02:42
7,0.027923,0.928434,0.025148,0.899119,02:43
8,0.022111,0.929291,0.025207,0.895497,02:42
9,0.021129,0.933965,0.025127,0.895638,02:39


Epoch 5: reducing lr to 5e-05
Epoch 8: reducing lr to 5e-06
Epoch 10: reducing lr to 5.000000000000001e-07
No improvement since epoch 6: early stopping


0 - 1
[0.901873440655456]
0.901873440655456
