In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
from src.models.baseline import baseline
from src.data.mimic_iii.real_dataset import MIMIC3RealDataset,MIMIC3RealDatasetCollection
import yaml
import pytorch_lightning as pl
from src.models.common import *
from torch.utils.data import DataLoader
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
device = "cuda"

In [2]:
with open('/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/config/dataset/mimic3_real.yaml', 'r') as file:
        config = yaml.safe_load(file)["dataset"]
batch_size=512

## Eval

In [3]:
def one_shot_pred(model,test_loader):
    li_outcome = []
    li_tot = []
    for idx,batch in enumerate(test_loader):
        predicted_outputs = torch.zeros((batch['vitals'].shape[0], 5, 1))
        vitals = batch['vitals'].float()
        static_features = batch['static_features'].float()
        treatments = batch['current_treatments'].float()
        position = torch.arange(batch['vitals'].shape[1])
        position = position.repeat(batch['vitals'].shape[0],1,1)
        position = torch.permute(position,(0,2,1)).to(batch['vitals'].device)
        li_insample_y = batch["outputs"].clone().float()
        for i in range(batch['future_past_split'].shape[0]):
            vitals[i, int(batch['future_past_split'][i]):] = 0
            li_insample_y[i,int(batch['future_past_split'][i]):] = 0
        static_features = batch['static_features'].float()
        treatments = batch['current_treatments'].float()
        index = batch["future_past_split"]
        temporal = torch.concat([vitals,position,treatments],dim=-1)
        # Encapsulating inputs
        windows = {}
        windows["insample_y"] = li_insample_y.to(model.device)
        windows["multivariate_exog"] = temporal.to(model.device)
        windows["stat_exog"] = static_features.to(model.device)
        with torch.no_grad():
            outputs_scaled = model(windows,index.int()).cpu()
        for i in range(batch['vitals'].shape[0]):
            split = int(batch['future_past_split'][i])
            predicted_outputs[i, :, :] = outputs_scaled[i, split :split+model.proj_len, :]
        li_outcome.append(predicted_outputs)
        li_tot.append(outputs_scaled)
    outcome = torch.concat(li_outcome)
    outcome_tot = torch.concat(li_tot)
    return outcome, outcome_tot
def compute_tau_step_error(model,dataloader):
    Y_true = torch.zeros((len(dataloader.dataset),model.proj_len,1))
    for i,tau in enumerate(dataloader.dataset.data["future_past_split"]):
        tau = int(tau)
        Y_true[i] = torch.tensor(dataloader.dataset.data["outputs"][i,tau:tau+model.proj_len])
    Y_hat = one_shot_pred(model,dataloader)[0]
    return Y_hat.numpy(), Y_true.numpy()

In [4]:
seeds = [10,101,1001,10010,10110]
losses_rmse = np.zeros((len(seeds),5))
losses_mae = np.zeros((len(seeds),5))
for i in range(len(seeds)):
    dataset_collection = MIMIC3RealDatasetCollection("data/processed/all_hourly_data.h5",min_seq_length=30,max_seq_length=60,
                                                     seed=seeds[i],max_number=10000,split = {"val":0.15,"test":0.15}, projection_horizon=5,autoregressive=True,
                                                     outcome_list = config["outcome_list"],
                                                     vitals = config["vital_list"],
                                                     treatment_list = config["treatment_list"],
                                                     static_list = config["static_list"]
                                                     )
    dataset_collection.process_data_multi()
    
    dataset = dataset_collection.test_f_multi
    test_loader = DataLoader(dataset, batch_size=1024,shuffle=False)
    file = os.listdir(f"TFT/baseline_large_multi_{i}/checkpoints")[0]
    path = os.path.join(f"TFT/baseline_large_multi_{i}/checkpoints",file)
    model = baseline.load_from_checkpoint(path,map_location=device).to(device)
    model.eval()
    pred,truth = compute_tau_step_error(model,test_loader)
    loss = np.sqrt(np.mean((pred-truth)**2,axis=0))*dataset_collection.test_f_multi.scaling_params['output_stds']
    losses_rmse[i] = loss.flatten()
    loss = np.mean(np.abs(pred-truth),axis=0)*dataset_collection.test_f_multi.scaling_params['output_stds']
    losses_mae[i] = loss.flatten()

In [None]:
np.mean(losses_rmse,axis=0)

In [None]:
np.std(losses_rmse,axis=0)

In [None]:
np.mean(losses_mae,axis=0)

In [None]:
np.std(losses_mae,axis=0)