Skip to content

Example usage

Robin van de Water edited this page Jun 2, 2023 · 1 revision

Training a mortality prediction LSTM model

The fundamental experiment configuration of the benchmark contains three basic elements, (1) the dataset, (2) the prediction task, and (3) the model and (list of) hyperparameters. Each element can be combined in different ways. Additionally, we provide an interface for the extension of each element (Datasets & Preprocessing, Prediction Tasks, Models & Hyperparameters) in the process. Provided is an example of an experiment configuration: predicting Sepsis on MIMIC with an LSTM model.

We demonstrate the complete process of training an LSTM model to predict sepsis on the MIMIC-III demo dataset with YAIB.

import icu_benchmarks.data.preprocess
import icu_benchmarks.data.loader
import icu_benchmarks.models.wrappers
import icu_benchmarks.models.encoders

# CLASSIFICATION
NUM_CLASSES = 2
HORIZON = 24

train_common.weight = "balanced"

# DEEP LEARNING
DLWrapper.loss = @cross_entropy

# DATASET AND PREPROCESSING
preprocess.file_names = {
    "DYNAMIC": "dyn.parquet",
    "OUTCOME": "outc.parquet",
    "STATIC": "sta.parquet",
}

vars = {
    "GROUP": "stay_id",
    "LABEL": "label",
    "SEQUENCE": "time",
    "DYNAMIC": ["alb", "alp", "alt", "ast", "be", "bicar", "bili",
       "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl", "crea",
       "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact",
       "lymph", "map", "mch", "mchc", "mcv", "methb", "mg", "na", "neut",
       "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp",
       "temp", "tnt", "urine", "wbc"],
    "STATIC": ["age", "sex", "height", "weight"],
}

preprocess.vars = %vars
default_preprocessor.vars = %vars

Dataset.vars = %vars

# CROSS VALIDATION

execute_repeated_cv.cv_repetitions = 5
execute_repeated_cv.cv_folds = 5

In the above snippet, the basic task setup for mortality prediction after 24 hours is shown. We define the dataset files, by default split into 3 parquet files with the corresponding names. The listing describes three dataset components: dynamic data, outcome definitions, and static data. Below, one sees the variables and different "roles" assigned to concrete strings. In this listing, we also pass the vars to the preprocessing. Finally, we see the definition of the cross-validation folds and iterations.

import gin.torch.external_configurables
import icu_benchmarks.models.wrappers
import icu_benchmarks.models.encoders

default_preprocessor.generate_features = False

# Train params
train_common.model = @DLWrapper()

DLWrapper.encoder = @LSTMNet()
DLWrapper.optimizer_fn = @Adam
DLWrapper.train.epochs = 1000
DLWrapper.train.batch_size = 64
DLWrapper.train.patience = 10
DLWrapper.train.min_delta = 1e-4

# Optimizer params
optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)

# Encoder params
model/hyperparameter.class_to_tune = @LSTMNet
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2)
model/hyperparameter.layer_dim = (1, 3)

# Hyperparamter tuning
tune_hyperparameters.scopes = ["model", "optimizer"]
tune_hyperparameters.n_initial_points = 5
tune_hyperparameters.n_calls = 30
tune_hyperparameters.folds_to_tune_on = 2

Above we see the configuration for the LSTM. We first define the generating features from dynamic data (relevant for traditional ml). Then we bind the LSTM model with a gin flag. After this, the hyperparameters are specified. The optimizer and encoder parameters are then specified. Note that we can specify ranges of hyperparameters to be tuned by the hyperparameter optimizer. Settings for this can be found in the bottom cluster of code.

#!
(~\textcolor{blue}{icu-benchmarks} ~) train \
    -d demo_data/mortality24/mimic_demo \
    -n mimic_demo \
    -t BinaryClassification \
    -tn Mortality24 \
    -m LSTM \
    -gc \
    -lc \
    -s 2222 \
    -l ../yaib_logs/ \
    --tune
    

The above command shows training our LSTM model on the mimic_demo dataset (included in our repository).