In [1]:
%load_ext autoreload
%autoreload 2
from src.data.mimic_iii.real_dataset import MIMIC3RealDatasetCollection
from src.models.ct import CT
import torch
from pytorch_lightning import Trainer
from omegaconf import DictConfig, OmegaConf
import yaml
import os
from torch.utils.data import DataLoader
from hydra.utils import instantiate
import warnings
import matplotlib.pyplot as plt
import numpy as np
warnings.filterwarnings("ignore")

In [2]:
with open('config/dataset/mimic3_real.yaml', 'r') as file:
        config = yaml.safe_load(file)["dataset"]
with open('config/backbone/ct_hparams/mimic3_real/diastolic_blood_pressure.yaml', 'r') as file:
        config_MODEL = yaml.safe_load(file)

In [3]:
args = DictConfig({'model': {'dim_treatments': '???', 'dim_vitals': '???', 'dim_static_features': '???', 'dim_outcomes': '???', 'name': 'CT', 'multi': {'_target_': 'src.models.ct.CT', 'max_seq_length': '65', 'seq_hidden_units': 24, 'br_size': 22, 'fc_hidden_units': 22, 'dropout_rate': 0.2, 'num_layer': 2, 'num_heads': 3, 'max_grad_norm': None, 'batch_size': 64, 'attn_dropout': True, 'disable_cross_attention': False, 'isolate_subnetwork': '_', 'self_positional_encoding': {'absolute': False, 'trainable': True, 'max_relative_position': 30}, 'optimizer': {'optimizer_cls': 'adam', 'learning_rate': 0.0001, 'weight_decay': 0.0, 'lr_scheduler': False}, 'augment_with_masked_vitals': True, 'tune_hparams': False, 'tune_range': 50, 'hparams_grid': None, 'resources_per_trial': None}}, 'dataset': {'val_batch_size': 512, 'treatment_mode': 'multilabel', '_target_': 'src.data.MIMIC3RealDatasetCollection', 'seed': '${exp.seed}', 'name': 'mimic3_real', 'path': 'data/processed/all_hourly_data.h5', 'min_seq_length': 30, 'max_seq_length': 60, 'max_number': 5000, 'projection_horizon': 5, 'split': {'val': 0.15, 'test': 0.15}, 'autoregressive': True, 'treatment_list': ['vaso', 'vent'], 'outcome_list': ['diastolic blood pressure'], 'vital_list': ['heart rate', 'red blood cell count', 'sodium', 'mean blood pressure', 'systemic vascular resistance', 'glucose', 'chloride urine', 'glascow coma scale total', 'hematocrit', 'positive end-expiratory pressure set', 'respiratory rate', 'prothrombin time pt', 'cholesterol', 'hemoglobin', 'creatinine', 'blood urea nitrogen', 'bicarbonate', 'calcium ionized', 'partial pressure of carbon dioxide', 'magnesium', 'anion gap', 'phosphorous', 'venous pvo2', 'platelets', 'calcium urine'], 'static_list': ['gender', 'ethnicity', 'age'], 'drop_first': False}, 'exp': {'seed': 10, 'gpus': [0], 'max_epochs': 1, 'logging': False, 'mlflow_uri': 'http://127.0.0.1:5000', 'unscale_rmse': True, 'percentage_rmse': False, 'alpha': 0.01, 'update_alpha': True, 'alpha_rate': 'exp', 'balancing': 'domain_confusion', 'bce_weight': False, 'weights_ema': True, 'beta': 0.99}})

## Eval

In [4]:
def compute_tau_step_error(model, dataset):
    Y_true = np.zeros((len(dataset),model.hparams.dataset.projection_horizon,1))
    for i,tau in enumerate(dataset.data["future_past_split"]):
        tau = int(tau)
        Y_true[i] = dataset.data["outputs"][i,tau-1:tau+model.hparams.dataset.projection_horizon-1]
    pred_auto_reg = model.get_autoregressive_predictions(dataset)
    return pred_auto_reg, Y_true

In [None]:
seeds = [10,101,1010,10101,101010]
losses_rmse = np.zeros((len(seeds),5))
losses_mae = np.zeros((len(seeds),5))
for i in range(len(seeds)):
    dataset_collection = MIMIC3RealDatasetCollection("data/processed/all_hourly_data.h5",min_seq_length=30,max_seq_length=60,
                                                        seed=seeds[i],max_number=10000,split = {"val":0.15,"test":0.15}, projection_horizon=5,autoregressive=True,
                                                        outcome_list = config["outcome_list"],
                                                        vitals = config["vital_list"],
                                                        treatment_list = config["treatment_list"],
                                                        static_list = config["static_list"]
                                                        )
    dataset_collection.process_data_multi()
    args.model.dim_outcomes = dataset_collection.train_f.data['outputs'].shape[-1]
    args.model.dim_treatments = dataset_collection.train_f.data['current_treatments'].shape[-1]
    args.model.dim_vitals = dataset_collection.train_f.data['vitals'].shape[-1] if dataset_collection.has_vitals else 0
    args.model.dim_static_features = dataset_collection.train_f.data['static_features'].shape[-1]
    multimodel = instantiate(args.model.multi, args, dataset_collection, _recursive_=False)
    multimodel.hparams.exp.weights_ema = False
    file = os.listdir(f"/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/multirun/2024-09-10/02-17-13/{i}/checkpoints/")[0]
    path = os.path.join(f"/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/multirun/2024-09-10/02-17-13/{i}/checkpoints",file)
    state_dict = torch.load(path)["state_dict"]
    multimodel.load_state_dict(state_dict)
    multimodel_trainer = Trainer(gpus=eval(str(args.exp.gpus)), max_epochs=args.exp.max_epochs,
                                    #terminate_on_nan=True,
                                    gradient_clip_val=args.model.multi.max_grad_norm)
    multimodel.trainer = multimodel_trainer
    multimodel = multimodel.double()
    multimodel = multimodel.eval()
    
    pred,truth = compute_tau_step_error(multimodel,dataset_collection.test_f_multi)
    loss = np.sqrt(np.mean((pred-truth)**2,axis=0))*dataset_collection.test_f_multi.scaling_params['output_stds']
    losses_rmse[i] = loss.flatten()
    loss = np.mean(np.abs(pred-truth),axis=0)*dataset_collection.test_f_multi.scaling_params['output_stds']
    losses_mae[i] = loss.flatten()

In [None]:
np.mean(losses_rmse,axis=0)

In [None]:
np.std(losses_rmse,axis=0)

In [None]:
np.mean(losses_mae,axis=0)

In [None]:
np.std(losses_mae,axis=0)