In [1]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import gzip
import json

import matplotlib
import pickle

import torch

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split, KFold
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, r2_score, mean_squared_error, precision_score, recall_score, roc_curve

import sys
from sepsis_utils.models.autoencoder import *
from sepsis_utils.dataset import *

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Update these
data_dir = "../data/mimiciv"
validation_paths = ['../data/eicu', '../data/ehrshot']
model_dir = "model_checkpoints"
if not os.path.exists(model_dir): os.mkdir(model_dir)

In [20]:
AS_IS_COLUMNS = []
BINARY_COLUMNS = []
LOG_NORM_COLUMNS = [
    "SpO2", "BUN", "Creatinine", "AST", "ALT", 
    "Total Bili", "Direct Bili", "INR",
    "Input Last 4 h", "Input Last 24 h", "Output Last 4 h", "Output Last 24 h",
    "Dopamine", "Epinephrine", "Norepinephrine", "Vasopressin", "Phenylephrine"
]

COMORBIDITIES = [
    'congestive_heart_failure', 'cardiac_arrhythmias', 'valvular_disease',
    'pulmonary_circulation', 'peripheral_vascular', 'hypertension',
    'paralysis', 'other_neurological', 'chronic_pulmonary',
    'diabetes_uncomplicated', 'diabetes_complicated', 'hypothyroidism',
    'renal_failure', 'liver_disease', 'peptic_ulcer',
    'aids', 'lymphoma', 'metastatic_cancer',
    'solid_tumor', 'rheumatoid_arthritis', 'coagulopathy',
    'obesity', 'weight_loss', 'fluid_electrolyte',
    'blood_loss_anemia', 'deficiency_anemias', 'alcohol_abuse',
    'drug_abuse', 'psychoses', 'depression'
]

EXCLUDE_COLUMNS = []

def format_states(df, scaler=None, all_columns=None):
    global BINARY_COLUMNS, LOG_NORM_COLUMNS, EXCLUDE_COLUMNS
    df = df.rename(columns={c: c.replace("Model:", "") for c in df.columns})
    df = df.assign(timestep=df.timestep.astype('datetime64[us]').astype(int) / 10**6)
    dummies = pd.get_dummies(df, 
                        columns=[c for c in df.columns 
                                if pd.api.types.is_object_dtype(df[c].dtype) 
                                or pd.api.types.is_string_dtype(df[c].dtype) 
                                or isinstance(df[c].dtype, pd.CategoricalDtype)])
    if all_columns is not None:
        dummies = dummies.assign(**{c: np.zeros(len(dummies), dtype=np.uint8)
                                    for c in all_columns if c not in dummies.columns})
    if scaler is None:
        EXCLUDE_COLUMNS = []
        # EXCLUDE_COLUMNS = [c for c in dummies.columns
        #                    if c.endswith("Present")]
        dummies = dummies.drop(columns=EXCLUDE_COLUMNS)
        AS_IS_COLUMNS = ["id", "timestep"]
        BINARY_COLUMNS = ["Invasive Ventilation", "Non-invasive Ventilation",
                 *[c for c in dummies.columns 
                   if "_" in c or 
                   " Present" in c or 
                   c.endswith("Current") or 
                   c.endswith("Ever") or 
                   (c.startswith("Positive") and c.endswith("Culture")) or 
                   c in COMORBIDITIES]]
        
        scaler = DataNormalization(dummies,
                                      as_is_columns=AS_IS_COLUMNS,
                                      binary_columns=BINARY_COLUMNS,
                                      log_norm_columns=LOG_NORM_COLUMNS,
                                      norm_columns=[c for c in dummies.columns if c not in BINARY_COLUMNS + AS_IS_COLUMNS + LOG_NORM_COLUMNS])
        return scaler.transform(dummies).fillna(0).assign(id=dummies['id'], timestep=dummies['timestep']), scaler, dummies.columns
    return scaler.transform(dummies.drop(columns=[c for c in EXCLUDE_COLUMNS if c in dummies.columns])).fillna(0).assign(id=dummies['id'], timestep=dummies['timestep'])
    
inputs_df, normalizer, all_cols = format_states(pd.read_csv(os.path.join(data_dir, "extracted_model_features.csv")))
inputs_df = inputs_df.clip(-10, 10).assign(id=inputs_df['id'], timestep=inputs_df['timestep'])
target_df = pd.read_csv(os.path.join(data_dir, "predictive_targets.csv")).fillna(0)
target_df = target_df.assign(timestep=target_df.timestep.astype('datetime64[us]').astype(int) / 10**6)
# target_scaler = DataNormalization(target_df, as_is_columns=['id', 'timestep', 'VasopressorNeed', 'Mortality'], norm_columns=['DischargeTime', 'SOFA'])
# target_df = target_scaler.transform(target_df.fillna(0))
print("Excluding columns:", EXCLUDE_COLUMNS)

# filter to a maximum of 14 days per patient (84 timesteps), minimum of 12 hours (3 timesteps)
max_steps = 14 * 6
print('max steps:', max_steps)
min_days = 0.5
time_res = 4
inputs_mask = ((inputs_df['timestep'].groupby(inputs_df['id']).transform(lambda g: np.argsort(g)) < max_steps)
              & (pd.Series(np.ones(len(inputs_df)), index=inputs_df.index).groupby(inputs_df['id']).transform("sum") >= min_days * 24 / time_res))

print("Before:", len(inputs_df))
inputs_df = inputs_df[inputs_mask].reset_index(drop=True)
target_df = target_df[inputs_mask].reset_index(drop=True)
print("After:", len(inputs_df))

KeyboardInterrupt: 

In [None]:
# Validation datasets
validation_datasets = []
for path in validation_paths:
    print(path)
    val_X = format_states(pd.read_csv(os.path.join(path, "extracted_model_features.csv")), normalizer, all_cols)
    val_y = pd.read_csv(os.path.join(path, "predictive_targets.csv")).fillna(0)
    print(len(val_X), len(val_y))
    
    val_X = val_X.clip(-10, 10).assign(id=val_X['id'], timestep=val_X['timestep'])
    val_y = val_y.assign(timestep=val_y.timestep.astype('datetime64[us]').astype(int) / 10**6)
    # val_y = target_scaler.transform(val_y.fillna(0))

    # filter to a maximum of 14 days per patient (84 timesteps), minimum of 12 hours (3 timesteps)
    val_mask = ((val_X['timestep'].groupby(val_X['id']).transform(lambda g: np.argsort(g)) < max_steps)
                & (pd.Series(np.ones(len(val_X)), index=val_X.index).groupby(val_X['id']).transform("sum") >= min_days * 24 / time_res))

    print("Before:", len(val_X))
    val_X = val_X[val_mask].reset_index(drop=True)
    val_y = val_y[val_mask].reset_index(drop=True)
    print("After:", len(val_X))
    validation_datasets.append((os.path.basename(path), val_X, val_y))

../data/eicu


  val_X = format_states(pd.read_csv(os.path.join(path, "extracted_model_features.csv")), normalizer, all_cols)


334268 334268
Before: 334268
After: 333766
../data/ehrshot
38712 38712
Before: 38712
After: 38686


In [None]:
weights = []
for feature in inputs_df.columns:
    if feature in ('id', 'timestep'): continue
    uniques, counts = np.unique(inputs_df[feature], return_counts=True)
    if len(uniques) < 10:
        weights.append((
            uniques[1:],
            len(inputs_df) / (len(uniques) * counts)
        ))
    else:
        counts, bin_cutoffs = np.histogram(inputs_df[feature], bins=10)
        weights.append((
            bin_cutoffs[1:-1],
            (len(inputs_df) + len(counts)) / (len(counts) * (counts + 1))
        ))
    print(feature, weights[-1])

Invasive Ventilation (array([0.5]), array([0.77526925, 1.40820169]))
Non-invasive Ventilation (array([0.5]), array([ 0.50509609, 49.55726416]))
ALT Present (array([0.5]), array([0.54931096, 5.56986639]))
AST Present (array([0.5]), array([0.54927838, 5.57321869]))
Arterial BE Present (array([0.5]), array([0.64419483, 2.23376529]))
Arterial pH Present (array([0.5]), array([0.67832359, 1.90194578]))
BUN Present (array([0.5]), array([0.69906248, 1.75588713]))
Bicarbonate Present (array([0.5]), array([0.70027981, 1.74825362]))
CO2 Calc Present (array([0.5]), array([0.64418468, 2.23388744]))
CVP Present (array([0.5]), array([0.58544817, 3.42575027]))
Calcium Present (array([0.5]), array([0.67588764, 1.92136198]))
Cardioversion/Defibrillation Current (array([0.5]), array([5.00274811e-01, 9.10216907e+02]))
Cardioversion/Defibrillation Ever (array([0.5]), array([ 0.50712906, 35.56773021]))
Chloride Present (array([0.5]), array([0.71180372, 1.6803381 ]))
Creatinine Present (array([0.5]), array([

In [None]:
import itertools

parameter_sets = [
    *(dict(zip(['architecture', 'nencoder', 'nhid', 'nembed', 'lr', 'dropout', 'lr_decay', 
                'batch_size', 'mask_prob', 'noise_factor', 'loss_weighting', 'past_value_reconstruction', 'first_value_reconstruction'],
               x))
      for x in itertools.product(
          ['dense'],
          [4],
          [128],
          [32],
          [1e-3],
          [0.1],
          [0.98],
          [32],
          [0, 0.1, 0.2],
          [0, 0.1, 0.2],
          ['none', 'unusualness'],
          [False, True],
          [False, True]
      )),
    *(dict(zip(['architecture', 'nencoder', 'nhid', 'nhead', 'nembed', 'lr', 'dropout', 'lr_decay', 
                'batch_size', 'mask_prob', 'noise_factor', 'loss_weighting', 'past_value_reconstruction', 'first_value_reconstruction'],
               x))
      for x in itertools.product(
          ['transformer'],
          [4],
          [128],
          [8],
          [32],
          [1e-3],
          [0.1],
          [0.98],
          [32],
          [0, 0.1, 0.2],
          [0, 0.1, 0.2],
          ['none', 'unusualness'],
          [False, True],
          [False, True]
      )),
]
len(parameter_sets), parameter_sets[:5]

(144,
 [{'architecture': 'dense',
   'nencoder': 4,
   'nhid': 128,
   'nembed': 32,
   'lr': 0.001,
   'dropout': 0.1,
   'lr_decay': 0.98,
   'batch_size': 32,
   'mask_prob': 0,
   'noise_factor': 0,
   'loss_weighting': 'none',
   'past_value_reconstruction': False,
   'first_value_reconstruction': False},
  {'architecture': 'dense',
   'nencoder': 4,
   'nhid': 128,
   'nembed': 32,
   'lr': 0.001,
   'dropout': 0.1,
   'lr_decay': 0.98,
   'batch_size': 32,
   'mask_prob': 0,
   'noise_factor': 0,
   'loss_weighting': 'none',
   'past_value_reconstruction': False,
   'first_value_reconstruction': True},
  {'architecture': 'dense',
   'nencoder': 4,
   'nhid': 128,
   'nembed': 32,
   'lr': 0.001,
   'dropout': 0.1,
   'lr_decay': 0.98,
   'batch_size': 32,
   'mask_prob': 0,
   'noise_factor': 0,
   'loss_weighting': 'none',
   'past_value_reconstruction': True,
   'first_value_reconstruction': False},
  {'architecture': 'dense',
   'nencoder': 4,
   'nhid': 128,
   'nembed': 32,

In [None]:
task_types = {
    "Mortality": "classification",
    "DischargeTime": "regression",
    "VasopressorNeed": "classification",
    "SOFA": "regression"
}

loss_histories = []
downstream_task_metrics = []
trajectory_ids = inputs_df['id'].unique()
seed = 0
kfold = KFold(n_splits=5, shuffle=True, random_state=seed)
splits = list(kfold.split(trajectory_ids))
    
for i, param_set in enumerate(reversed(parameter_sets)):
    for trial, (train_idx, test_idx) in enumerate(splits):
        print(trajectory_ids[train_idx], trajectory_ids[test_idx])
        
        train_df = inputs_df[inputs_df['id'].isin(trajectory_ids[train_idx])]
        val_df = inputs_df[inputs_df['id'].isin(trajectory_ids[test_idx])]
        model_params = {k: v for k, v in param_set.items() if k not in ('batch_size', 'loss_weighting')}
        if param_set['loss_weighting'] == 'unusualness':
            model_params['train_weights'] = weights
            model_params['val_weights'] = weights
            model_params['test_weights'] = weights
        model_trainer = TimeSeriesAutoencoderTrainer(
            train_df,
            val_df,
            val_df,
            time_col='timestep',
            device=device,
            checkpoint_path=os.path.join(model_dir, f"autoencoder_{i}_trial_{trial}.pt"),
            **model_params
        )
        print(i, sum(p.numel() for p in model_trainer.model.parameters()))
        losses = []
        model_trainer.fit(epochs=25, patience=10, batch_size=param_set['batch_size'], loss_callback=lambda t, v: losses.append((t, v)))
        loss_histories.append(np.array(losses))
        with open(os.path.join(model_dir, "losses.pkl"), "wb") as file:
            pickle.dump(loss_histories, file)
        
        task_inputs = model_trainer.encode(model_trainer.val_dataset)
        # Now evaluate on the validation datasets
        for val_name, val_X, val_y in validation_datasets:
            true = val_y[task]
            val_inputs = model_trainer.encode(model_trainer.make_dataset(val_X, id_col='id', time_col='timestep'))

            for task, task_type in task_types.items():
                print(task)
                task_outputs = target_df.loc[target_df['id'].isin(trajectory_ids[test_idx]), task]
                clf = LogisticRegression() if task_type == 'classification' else LinearRegression()
                clf.fit(task_inputs, task_outputs)
                
                if task_type == "classification":
                    # choose threshold
                    fpr, tpr, thresholds = roc_curve(task_outputs, clf.predict_proba(task_inputs)[:,1])
                    opt_threshold = thresholds[np.argmax(tpr - fpr)]
            
                if task_type == "classification":
                    preds_prob = clf.predict_proba(val_inputs)[:,1]
                    preds = preds_prob >= opt_threshold
                    metrics = {
                        "accuracy": accuracy_score(true, preds),
                        "f1": f1_score(true, preds),
                        "precision": precision_score(true, preds),
                        "recall": recall_score(true, preds),
                        "auroc": roc_auc_score(true, preds_prob),
                    }
                else:
                    preds = clf.predict(val_inputs)
                    # inverse transform
                    metrics = {
                        "r2": r2_score(true, preds),
                        "mse": mean_squared_error(true, preds)
                    }
                downstream_task_metrics.append({
                    **param_set,
                    "dataset": val_name,
                    "task": task,
                    "trial": trial,
                    **metrics
                })
        with open(os.path.join(model_dir, "task_metrics.pkl"), "wb") as file:
            pickle.dump(downstream_task_metrics, file)

[30000646 30000831 30001396 ... 39999172 39999230 39999552] [30000484 30001148 30003202 ... 39998012 39998871 39999301]
0 912960
Epoch 0


7.203375:  29%|██▉       | 407/1399 [01:14<03:02,  5.43it/s]


KeyboardInterrupt: 