Skip to content

Adding a new model

Robin van de Water edited this page Nov 14, 2023 · 2 revisions

We allow prediction models to be easily added and integrated into a Pytorch Lightning module. This incorporates advanced logging and debugging capabilities, as well as built-in parallelism. Our interface derives from the BaseModule.

Adding a model consists of three steps:

  1. Add a model through the existing MLPredictionWrapper or DLPredictionWrapper.
  2. Add a GIN config file to bind hyperparameters.
  3. Execute YAIB using a simple command.

An example for an RNN model is provided in the repository folder: docs/adding_model. Putting the RNN.gin file in configs/prediction_models and the rnn.py file into icu_benchmarks/models allows you to run the model fully. We detail particular steps below and describe specifics for each case.

Regular ML

For standard Scikit-Learn type models (e.g., LGBM), one can simply wrap MLPredictionWrapper the function with minimal code overhead. Many ML (and some DL) models can be incorporated this way, requiring minimal code additions. See below.

@gin.configurable
class RFClassifier(MLWrapper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = self.model_args()

    @gin.configurable(module="RFClassifier")
    def model_args(self, *args, **kwargs):
        return RandomForestClassifier(*args, **kwargs)

Adding DL models

It is relatively straightforward to add new Pytorch models to YAIB. We first provide a standard RNN-model which needs no extra components. Then, we show the implementation of the Temporal Fusion Transformer model.

Standard RNN-model

The definition of dl models can be done by creating a subclass from the DLPredictionWrapper, inherits the standard methods needed for training dl learning models. Pytorch Lightning significantly reduces the code overhead.

@gin.configurable
class RNNet(DLPredictionWrapper):
    """Torch standard RNN model"""

    def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs):
        super().__init__(
            input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs
        )
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True)
        self.logit = nn.Linear(hidden_dim, num_classes)

    def init_hidden(self, x):
        h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim)
        return h0

    def forward(self, x):
        h0 = self.init_hidden(x)
        out, hn = self.rnn(x, h0)
        pred = self.logit(out)
        return pred

Adding a SOTA model: Temporal Fusion Transformer

There are two main questions when you want to add a more complex model:

  • Do you want to manually define the model or use an existing library? This might require adapting the DLPredictionWrapper.
  • Does the model expect the data to be in a certain format? This might require adapting the PredictionDataset.

By adapting, we mean creating a new subclass that inherits most functionality to avoid code duplication, is future-proof, and follows good coding practices.

First, you can add modules to models/layers.py to use them for your model.

class StaticCovariateEncoder(nn.Module):
    """
    Network to produce 4 context vectors to enrich static variables
    Variable selection Network --> GRNs
    """

    def __init__(self, num_static_vars, hidden, dropout):
        super().__init__()
        self.vsn = VariableSelectionNetwork(hidden, dropout, num_static_vars)
        self.context_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(4)])

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        variable_ctx, sparse_weights = self.vsn(x)

        # Context vectors:
        # variable selection context
        # enrichment context
        # state_c context
        # state_h context
        cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns]

        return cs, ce, ch, cc

Note that we can create modules out of modules as well.

Adapting the DLPredictionWrapper

The next step is to use the building blocks defined in layers.py or modules from an existing library to add to the model in models/dl_models.py. In this In this case, we use the Pytorch-forecasting library (https://github.com/jdb78/pytorch-forecasting):

class TFTpytorch(DLPredictionWrapper):

    supported_run_modes = [RunMode.classification, RunMode.regression]

    def __init__(self, dataset, hidden, dropout, n_heads, dropout_att, lr, optimizer, num_classes, *args, **kwargs):
        super().__init__(lr=lr, optimizer=optimizer, *args, **kwargs)
        self.model = TemporalFusionTransformer.from_dataset(
            dataset=dataset)
        self.logit = nn.Linear(7, num_classes)

   
    def forward(self, x):
        out = self.model(x)
        pred = self.logit(out["prediction"])
        return pred

Adapting the PredictionDataset

Some models require an adjusted dataloader to facilitate, for example, explainability methods. In this case, changes need to be made to the data/loader.py file to ensure the data loader returns the data in the correct format. This can be done by creating a class that inherits from PredictionDataset and editing the get_item method.

@gin.configurable("PredictionDatasetTFT")
class PredictionDatasetTFT(PredictionDataset):
 def __init__(self, *args, ram_cache: bool = True, **kwargs):
        super().__init__(*args, ram_cache=True, **kwargs)

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
        """Function to sample from the data split of choice. Used for TFT.
        The data needs to be given to the model in the following order 
        [static categorical, static continuous,known categorical,known continuous, observed categorical, observed continuous,target,id]

Then, you must check models/wrapper.py, particularly the step_fn method, to ensure the data is correctly transferred to the device.

Adding the model config GIN file

To define hyperparameters for each model in a standardized manner, we use GIN-config. We need to specify a GIN file to bind the parameters to train and optimize this model from a choice of hyperparameters. Note that we can use modifiers for the optimizer (e.g, Adam optimizer) and ranges that we can specify in rounded brackets "()". Square brackets, "[]", result in a random choice where the variable is uniformly sampled.

# Hyperparameters for TFT model.

# Common settings for DL models
include "configs/prediction_models/common/DLCommon.gin"

# Optimizer params
train_common.model = @TFT

optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)

# Encoder params
model/hyperparameter.class_to_tune = @TFT
model/hyperparameter.encoder_length = 24
model/hyperparameter.hidden = 256
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.dropout = (0.0, 0.4)
model/hyperparameter.dropout_att = (0.0, 0.4)
model/hyperparameter.n_heads =4
model/hyperparameter.example_length=25

Training the model

After these steps, your model should be trainable with the following command:

icu-benchmarks train \
    -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here
    -n mimic_demo \
    -t BinaryClassification \ # Insert task name here
    -tn Mortality24 \
    --log-dir ../yaib_logs/ \
    -m TFT \ # Insert model here
    -s 2222 \
    -l ../yaib_logs/ \
    --tune