In [None]:
import os

import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F, torch.utils.data as data
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import optuna
import toml

import models
from datasets.loader.datamodule import EhrDataModule
from datasets.loader.load_los_info import get_los_info
from datasets.loader.unpad import unpad_y
from losses import get_simple_loss
from metrics import get_all_metrics

In [None]:
class Pipeline(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters()
        self.hidden_dim = config["hidden_dim"]
        self.input_dim = config["input_dim"]
        self.output_dim = config["output_dim"]
        model_class = getattr(models, config['model_name'])
        self.ehr_encoder = model_class(**config)
        if config["task"] == "outcome":
            self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.0), nn.Sigmoid())
        elif config["task"] == "los":
            self.head = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim), nn.Dropout(0.0))
        elif config["task"] == "multitask":
            self.head = models.heads.MultitaskHead(self.hidden_dim, self.output_dim, drop=0.0)

        self.validation_step_outputs = []

    def forward(self, x):
        embedding = self.ehr_encoder(x)
        y_hat = self.head(embedding)
        return y_hat, embedding

    def training_step(self, batch, batch_idx):
        x, y, lens, pid = batch
        y_hat, embedding = self(x)
        y_hat, y = unpad_y(y_hat, y, lens)
        loss = get_simple_loss(y_hat, y, self.config["task"])
        self.log("train_loss_step", loss, on_step=True, on_epoch=True)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y, lens, pid = batch
        y_hat, embedding = self(x)
        y_hat, y = unpad_y(y_hat, y, lens)
        loss = get_simple_loss(y_hat, y, self.config["task"])
        self.log("val_loss_step", loss, on_step=True, on_epoch=True)
        outs = {'y_pred': y_hat, 'y_true': y, 'val_loss_step': loss}
        self.validation_step_outputs.append(outs)
        return loss
    def on_validation_epoch_end(self):
        y_pred = torch.cat([x['y_pred'] for x in self.validation_step_outputs])
        y_true = torch.cat([x['y_true'] for x in self.validation_step_outputs])
        loss = torch.stack([x['val_loss_step'] for x in self.validation_step_outputs]).mean()
        self.log("val_loss_epoch", loss, on_step=False, on_epoch=True)
        metrics = get_all_metrics(y_pred, y_true, self.config["task"], self.config["los_info"])
        for k, v in metrics.items(): self.log(k, v, on_step=False, on_epoch=True)
        main_metric = metrics[self.config["main_metric"]]
        self.validation_step_outputs.clear()
        return main_metric

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [None]:
model_name = "GRU"
stage = "tune"
dataset = "tjh"
task = "outcome" # ["outcome", "los", "multitask"]
fold = 0
tjh_config = {"demo_dim": 2, "lab_dim": 73, "input_dim": 75,}
cdsl_config = {"demo_dim": 2, "lab_dim": 97, "input_dim": 99,}
dataset_config = {}
if dataset == "tjh": dataset_config = tjh_config
elif dataset == "cdsl": dataset_config = cdsl_config
output_dim = 1
main_metric = "mae" if task == "los" else "auroc"
epochs = 100
patience = 10

config = {"stage": stage, "task": task, "dataset": dataset, "output_dim": output_dim, "fold": fold, "epochs": epochs, "patience": patience, "model_name": model_name, "main_metric": main_metric}
config = config | dataset_config

In [None]:
"""
- tune: hyperparameter search (Only the first fold)
- train: train model with the best hyperparameters (K-fold / repeat with random seeds)
- test: test model on the test set with the saved checkpoints (on best epoch)
"""

def objective(trial: optuna.trial.Trial):
    global config
    # config
    trial_config = {
        "hidden_dim": trial.suggest_int("hidden_dim", 16, 1024),
        "batch_size": trial.suggest_int("batch_size", 1, 16),
    }
    config = config | trial_config
    los_config = get_los_info(f'datasets/{config["dataset"]}/processed/fold_{config["fold"]}')
    config["los_info"] = los_config

    # data
    dm = EhrDataModule(f'datasets/{config["dataset"]}/processed/fold_{config["fold"]}', batch_size=config["batch_size"])
    
    # callbacks
    checkpoint_filename = f'{config["model_name"]}-fold{config["fold"]}'
    if config["task"] in ["outcome", "multitask"]:
        early_stopping_callback = EarlyStopping(monitor="auprc", patience=config["patience"], mode="max",)
        checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="auprc", mode="max", dirpath=f'./checkpoints/{config["stage"]}/{config["dataset"]}/{config["task"]}', filename=checkpoint_filename,)
    elif config["task"] == "los":
        early_stopping_callback = EarlyStopping(monitor="mae", patience=config["patience"], mode="min",)
        checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="mae", mode="min", dirpath=f'./checkpoints/{config["stage"]}/{config["dataset"]}/{config["task"]}', filename=checkpoint_filename,)
    
    # logger
    logger = CSVLogger(save_dir="logs", name=f'{config["stage"]}/{config["dataset"]}/{config["task"]}', version=checkpoint_filename, flush_logs_every_n_steps=4)
    
    # train/val/test
    pipeline = Pipeline(config)
    trainer = L.Trainer(max_epochs=config["epochs"], logger=logger, callbacks=[early_stopping_callback, checkpoint_callback])
    trainer.fit(pipeline, dm)

    # return best metric score
    best_metric_score = checkpoint_callback.best_model_score
    return best_metric_score

direction = "minimize" if config["task"] == "los" else "maximize"
search_space = {"hidden_dim": [64,128], "batch_size": [32, 64]}
study = optuna.create_study(direction=direction, sampler=optuna.samplers.GridSampler(search_space))
study.optimize(objective, n_trials=100)


In [None]:
trial = study.best_trial
config = config | trial.params

# save the config dict to toml file, with name of {model}-{task}-{score}.toml
with open(f'./checkpoints/{config["stage"]}/{config["dataset"]}/{config["task"]}/{config["model_name"]}_best.toml', 'w') as f:
    toml.dump(config, f)
