In [1]:
import random
import numpy as np
import torch
from ignite.engine import Events
from ignite.metrics import Loss
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers.stores import EpochOutputStore
from typing import List
import torch
from ignite.metrics import Loss
from ignite.engine import create_supervised_trainer

from data.tabular_enn_fusion import MIMIC_Structured_Notes
from utils.loss import MultitaskLoss4OneStructured, MultitaskLoss4OneStructured_detached
from utils.logger import create_logger
from utils.metrics import AUPRC, NLL, AUROC, NPV, BalancedAccuracy, BrierScore, F1, Precision, Recall, Specificity
from utils.torch_utils import count_parameters
from utils.options import add_hparams2parser_jn
import torch.nn as nn
from rtdl_revisiting_models import MLP
from models.dst_pytorch import Dempster_Shafer_Module, Dempster_layer, DempsterNormalize_layer

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [2]:
# You can configure the hyperparameters over here
hparams_dict = {
# Common arguments
"seed": 10,
"logger": False,
"devices": "7," ,
"batch_size": 32,
"lr": 1.0e-4,
"max_epochs": 100,
"comments": "",


# Dataset arguments
"outcome":  "icu_death",
"n_class": 2,
"data_path": "./data/datasets/mimic/processed",
"structured_d_in_ls": None, # This will be set in the dataset class
"class_weight": None,   # This will be set in the dataset class
"cv_split": 0,


# Model arguments
"model": "mlp_enn_one_structured",
"pretrained_model": "emilyalsentzer/Bio_ClinicalBERT",
"prototype_dim": 20,
"n_blocks": 3,
"structured_d_hidden": 32,
"notes_d_hidden": 128,
"alpha1": 2,
"alpha2": 1,
"dropout":0.1
}
    


In [3]:
# Initilize the hparams objects by taking arguments from command line
hparams = add_hparams2parser_jn(hparams_dict)

# Initilize logger for saving the metric data
logger = create_logger(hparams)




In [4]:
# ********************** prepare for training
# set gpu and random seed
print("="*15, " Preparing for training ", "="*15)
device = torch.device(int(hparams.devices.split(",")[0]))
random.seed(hparams.seed)
np.random.seed(hparams.seed)
torch.manual_seed(hparams.seed)




<torch._C.Generator at 0x7f7bf7bfb830>

In [5]:
# ********************** dataset
print("="*15, " Building dataset module ", "="*15 )
print(f"* Data path: {hparams.data_path}")
print(f"* Outcome: {hparams.outcome}")
print(f"* CV split: {hparams.cv_split}")

# initilize dataset (read and preprocess the data)
train_dataset = MIMIC_Structured_Notes(hparams.data_path, "train", hparams)
val_dataset = MIMIC_Structured_Notes(hparams.data_path, "val", hparams)
test_dataset = MIMIC_Structured_Notes(hparams.data_path, "test", hparams)

# initilize the dataloader to wrap up the dataset
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams.batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size, shuffle=False)


* Data path: ./data/datasets/mimic/processed
* Outcome: icu_death
* CV split: 0
* Structured input dimensions: [22, 19]


In [6]:
def pignistic(mass, n_class):

    probs = mass[:, :n_class] + (1 / n_class) * mass[:, n_class].unsqueeze(1)
    uncertainty = mass[:, n_class]

    return probs, uncertainty

class MLPENNOneStructured(nn.Module):
    # combine twice
    def __init__(self, hparams) -> None:
        super().__init__()

        self.hparams = hparams

        # structured
        structured_d_in_ls = hparams.structured_d_in_ls
        self.backbone = MLP(
                        d_in=sum(structured_d_in_ls),
                        n_blocks=hparams.n_blocks,
                        d_block=hparams.structured_d_hidden,
                        d_out=hparams.structured_d_hidden,
                        dropout=hparams.dropout,
                        )
        self.structured_cls = nn.Sequential(nn.Dropout(hparams.dropout),
                                            nn.Linear(hparams.structured_d_hidden, hparams.n_class),
                                            )
        self.structured_dsm = Dempster_Shafer_Module(hparams.structured_d_hidden, hparams.n_class, hparams.prototype_dim)

        # notes
        
        self.notes_reducer = nn.Sequential(nn.Dropout(hparams.dropout),
                                            nn.Linear(768, hparams.notes_d_hidden),
                                            )
        self.notes_fcs4logits = nn.Sequential(nn.Dropout(hparams.dropout),
                                            nn.Linear(hparams.notes_d_hidden, hparams.n_class),
                                            )
        self.notes_dsm = Dempster_Shafer_Module(hparams.notes_d_hidden,
                                                                hparams.n_class, hparams.prototype_dim)
        
        # fusion
        self.ds_dempster = Dempster_layer(2, hparams.n_class)
        self.ds_normalize = DempsterNormalize_layer()

    def forward(self, inputs):
        """_summary_

        Args:
            inputs (_type_): a list

        Returns:
            _type_: _description_
        """
        cont_data, cat_data, notes_data = inputs

        structured_feats = self.backbone(torch.cat([cont_data, cat_data], dim=1))
        structured_logits = self.structured_cls(structured_feats)
        structured_mass = self.structured_dsm(structured_feats)

        note_reduced = self.notes_reducer(notes_data)
        notes_logits = self.notes_fcs4logits(note_reduced)
        notes_mass = self.notes_dsm(note_reduced)

        # combine all the mass functions
        mass_ls = [structured_mass, notes_mass]

        # mass_stack: [batch_size, 2, n_class+1]
        mass_stack = torch.stack(mass_ls, dim=1)
        mass_Dempster = self.ds_dempster(mass_stack)
        mass_Dempster_normalize = self.ds_normalize(mass_Dempster)

        probs, uncertainty = pignistic(mass_Dempster_normalize, self.hparams.n_class)

        return probs, structured_logits, notes_logits, uncertainty
    

In [7]:
# ********************** model
print("="*15, " Building model ", "="*15)
print(f"* Model: {hparams.model}")
model = MLPENNOneStructured(hparams)

model.to(device)
print("* Trainable model parameters:")
count_parameters(model)

* Model: mlp_enn_one_structured
* Trainable model parameters:
+----------------------------------------+------------+
|                Modules                 | Parameters |
+----------------------------------------+------------+
|    backbone.blocks.0.linear.weight     |    1312    |
|     backbone.blocks.0.linear.bias      |     32     |
|    backbone.blocks.1.linear.weight     |    1024    |
|     backbone.blocks.1.linear.bias      |     32     |
|    backbone.blocks.2.linear.weight     |    1024    |
|     backbone.blocks.2.linear.bias      |     32     |
|         backbone.output.weight         |    1024    |
|          backbone.output.bias          |     32     |
|        structured_cls.1.weight         |     64     |
|         structured_cls.1.bias          |     2      |
|          structured_dsm.ds1.w          |    640     |
| structured_dsm.ds1_activate.eta.weight |     20     |
| structured_dsm.ds1_activate.xi.weight  |     20     |
|        structured_dsm.ds2.beta         |

106628

In [8]:
# ********************** loss and optimizer
# initialize customized loss function and optimizer
print("="*15, " Building loss, optimizer, and metrics ", "="*15)
criterion = MultitaskLoss4OneStructured(hparams, weight=torch.tensor(hparams.class_weight).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=hparams.lr)
# scheduler = MultiStepLR(optimizer, milestones=[100], gamma=0.2, verbose=True)



In [9]:
def transform_probs2labels(output: List[torch.Tensor]):
    return torch.argmax(output[0][0], dim=1), output[1]

def transform_probs2probs(output: List[torch.Tensor]):
    return output[0][0][:,1], output[1]

# metrics should be a dictionary
metrics = {
        "precision": Precision(output_transform=transform_probs2labels),
        "recall": Recall(output_transform=transform_probs2labels),
        "specificity": Specificity(output_transform=transform_probs2labels),
        "npv": NPV(output_transform=transform_probs2labels),

        "bacc": BalancedAccuracy(output_transform=transform_probs2labels), 
        "f1": F1(output_transform=transform_probs2labels),
        "aucroc": AUROC(output_transform=transform_probs2probs),
        "auprc": AUPRC(output_transform=transform_probs2probs),

        "brier": BrierScore(output_transform=transform_probs2probs),
        "nll": NLL(output_transform=lambda x: (x[0][0], x[1]))
        }

# add loss into metrics
criterion_detached = MultitaskLoss4OneStructured_detached(hparams, weight=torch.tensor(hparams.class_weight).to(device))
metrics["loss"] = Loss(criterion_detached, output_transform=lambda x: (x[0][:-1], x[1]))

In [10]:
# ********************** trainer and evaluator
print("="*15, " Building trainer and evaluator ", "="*15)
trainer = create_supervised_trainer(model, 
                                    optimizer, 
                                    criterion, 
                                    device=device, 
                                    output_transform=lambda x, y, y_pred, loss: criterion_detached(y_pred, y), 
                                    model_transform=lambda x: x[:-1]
                                    )


evaluator = create_supervised_evaluator(model, 
                                        metrics=metrics, 
                                        device=device)
eos = EpochOutputStore()
eos.attach(evaluator, 'output')

best_metric = -1
save_path = "./best_model.pth"
# ********************** add event handlers
# Log training loss in each iteration
def log_training_results(engine):
    batch_idx = engine.state.iteration % engine.state.epoch_length if engine.state.iteration  % engine.state.epoch_length != 0 else engine.state.epoch_length
    print(f"Epoch: {engine.state.epoch} | Batch: {batch_idx}/{engine.state.epoch_length} - Train loss: {engine.state.output:.4f}")
    logger.log_metrics({"train_loss": engine.state.output})

# Log validation and test results in each epoch
def log_validation_results(engine):
    print("Validating...")
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print("val completed")

    # save the best model with best metric
    if metrics["auprc"] > best_metric:
        torch.save(model.state_dict(), save_path)
        print("Best model saved!")

    logger.log_metrics(metrics, prefix="val")


def log_test_results(engine):
    print("Testing...")
    evaluator.run(test_loader)
    metrics = evaluator.state.metrics
    print("test completed")
    logger.log_metrics(metrics, prefix="test")

    # scheduler.step()


trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: logger.set_epoch(engine.state.epoch))
trainer.add_event_handler(Events.ITERATION_STARTED, lambda engine: logger.new_step())
trainer.add_event_handler(Events.ITERATION_COMPLETED, log_training_results)

trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_test_results)
# trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: scheduler.step())





<ignite.engine.events.RemovableEventHandle at 0x7f7a810242b0>

In [11]:
# ********************** training start
print("="*15, " Start training ", "="*15)
trainer.run(train_loader, max_epochs=hparams.max_epochs)

Epoch: 1 | Batch: 1/722 - Train loss: 6.0477
Epoch: 1 | Batch: 2/722 - Train loss: 4.3211
Epoch: 1 | Batch: 3/722 - Train loss: 3.5381
Epoch: 1 | Batch: 4/722 - Train loss: 2.9692
Epoch: 1 | Batch: 5/722 - Train loss: 3.1851
Epoch: 1 | Batch: 6/722 - Train loss: 2.9523
Epoch: 1 | Batch: 7/722 - Train loss: 2.9848
Epoch: 1 | Batch: 8/722 - Train loss: 2.7853
Epoch: 1 | Batch: 9/722 - Train loss: 2.8166
Epoch: 1 | Batch: 10/722 - Train loss: 2.7618
Epoch: 1 | Batch: 11/722 - Train loss: 2.8602
Epoch: 1 | Batch: 12/722 - Train loss: 2.7928
Epoch: 1 | Batch: 13/722 - Train loss: 2.6888
Epoch: 1 | Batch: 14/722 - Train loss: 2.6940
Epoch: 1 | Batch: 15/722 - Train loss: 2.7057
Epoch: 1 | Batch: 16/722 - Train loss: 2.8146
Epoch: 1 | Batch: 17/722 - Train loss: 2.6352
Epoch: 1 | Batch: 18/722 - Train loss: 2.6686
Epoch: 1 | Batch: 19/722 - Train loss: 2.7496
Epoch: 1 | Batch: 20/722 - Train loss: 3.0270
Epoch: 1 | Batch: 21/722 - Train loss: 2.6332
Epoch: 1 | Batch: 22/722 - Train loss: 3.12

State:
	iteration: 72200
	epoch: 100
	epoch_length: 722
	max_epochs: 100
	output: <class 'torch.Tensor'>
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [12]:
# load the best the model
model = MLPENNOneStructured(hparams)
model.load_state_dict(torch.load(save_path))
model.to(device)

  model.load_state_dict(torch.load(save_path))


MLPENNOneStructured(
  (backbone): MLP(
    (blocks): ModuleList(
      (0): Sequential(
        (linear): Linear(in_features=41, out_features=32, bias=True)
        (activation): ReLU()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1-2): 2 x Sequential(
        (linear): Linear(in_features=32, out_features=32, bias=True)
        (activation): ReLU()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (output): Linear(in_features=32, out_features=32, bias=True)
  )
  (structured_cls): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=2, bias=True)
  )
  (structured_dsm): Dempster_Shafer_Module(
    (ds1): Distance_layer()
    (ds1_activate): DistanceActivation_layer(
      (eta): Linear(in_features=20, out_features=1, bias=False)
      (xi): Linear(in_features=20, out_features=1, bias=False)
    )
    (ds2): Belief_layer()
    (ds2_omega): Omega_layer()
    (ds3_dempster): Dempster_layer()
    (ds3_normalize

In [13]:
# configure the evaluator for testing seperately
evaluator = create_supervised_evaluator(model, 
                                        metrics=metrics, 
                                        device=device)
print("Testing with the best model")
evaluator.run(test_loader)
res = evaluator.state.metrics

Testing with the best model


In [14]:
# print out the res
for k, v in res.items():
    print(k, ": ", v, sep="")

precision: 0.33953997809419495
recall: 0.6828193832599119
specificity: 0.8222811671087533
npv: 0.950920245398773
bacc: 0.7525502751843326
f1: 0.4535479151426481
aucroc: 0.8521551237258362
auprc: 0.48234734347281333
brier: 0.12682986326526213
nll: 0.38304057717323303
loss: 2.136039192471406
