In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
from src.models.model_m_e_theta import m_e_theta_daily
from src.data.mimic_iii.real_dataset import MIMIC3RealDatasetCollectionCausal,MIMIC3RealDatasetCollection
import yaml
import time
import pytorch_lightning as pl
from src.models.common import *
from torch.utils.data import DataLoader
import torch
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
device = "cuda"

In [2]:
with open('/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/config/dataset/mimic3_real.yaml', 'r') as file:
        config = yaml.safe_load(file)["dataset"]
batch_size=512

## Eval

In [3]:
def one_shot_pred(model,test_loader):
    li_outcome = []
    li_m = []
    li_e = []
    li_theta = []
    for idx,batch in enumerate(test_loader):
        predicted_outputs = []
        vitals = batch['vitals'].float()
        static_features = batch['static_features'].float()
        treatments = batch['current_treatments'].float()
        position = torch.arange(batch['vitals'].shape[1])
        position = position.repeat(batch['vitals'].shape[0],1,1)
        position = torch.permute(position,(0,2,1)).to(batch['vitals'].device)
        li_insample_y = batch["outputs"].clone().float()
        for i in range(batch['future_past_split'].shape[0]):
            vitals[i, int(batch['future_past_split'][i]):] = 0.0
            li_insample_y[i,int(batch['future_past_split'][i]):] = 0
        static_features = batch['static_features'].float()
        treatments = batch['current_treatments'].float()
        index = batch["future_past_split"]
        temporal = torch.concat([vitals,position,treatments],dim=-1)
        treatment = torch.zeros_like(li_insample_y)
        for k in range(model.treatment_max):
            treatment += temporal[:,:,-1-k].clone().unsqueeze(-1)*(2**k)
            for i,tau in enumerate(index):
                temporal[i,int(tau):,-1-k] = -1
        # Encapsulating inputs
        windows = {}
        windows["insample_y"] = li_insample_y.to(model.device)
        windows["multivariate_exog"] = temporal.to(model.device)
        windows["stat_exog"] = static_features.to(model.device)
        with torch.no_grad():
            m, e, theta = model(windows,index.int())
            m = m.cpu()
            e = e.cpu()
            theta = theta.cpu()
        T_reduced = treatment[:,:,0].long().cpu()
        T_reduced = F.one_hot(T_reduced, 2**model.treatment_max).float()
        if model.is_cdf:
            #T_reduced = 1 - torch.cumsum(T_reduced,dim = -1)
            T_reduced = torch.cumsum(T_reduced,dim = -1)
        shift = torch.matmul((T_reduced-e).unsqueeze(-2), theta.unsqueeze(-1)).squeeze(-1)
        outputs_scaled = m + shift
        for i in range(batch['vitals'].shape[0]):
            split = int(batch['future_past_split'][i])
            if split+model.proj_len<61:
                predicted_outputs.append(outputs_scaled[i, split :split+model.proj_len, :])
                li_m.append(m[i, split :split+model.proj_len, :])
                li_theta.append(theta[i, split :split+model.proj_len, :])
                li_e.append(e[i, split :split+model.proj_len, :])
        predicted_outputs = torch.stack(predicted_outputs)
        li_outcome.append(predicted_outputs)
    m = torch.stack(li_m)
    e = torch.stack(li_e)
    theta = torch.stack(li_theta)
    outcome = torch.concat(li_outcome)
    return outcome, m,e,theta
def compute_tau_step_error(model,dataloader):
    Y_true = torch.zeros((len(dataloader.dataset),model.proj_len,1))
    for i,tau in enumerate(dataloader.dataset.data["future_past_split"]):
        tau = int(tau)
        Y_true[i] = torch.tensor(dataloader.dataset.data["outputs"][i,tau:tau+model.proj_len])
    Y_hat = one_shot_pred(model,dataloader)[0]
    return Y_hat.numpy(), Y_true.numpy()

In [8]:
seeds = [10,101,1001,10010,10110]
losses_rmse = np.zeros((len(seeds),5))
losses_mae = np.zeros((len(seeds),5))
for i in tqdm(range(len(seeds))):
    dataset_collection = MIMIC3RealDatasetCollection("data/processed/all_hourly_data.h5",min_seq_length=30,max_seq_length=60,
                                                     seed=seeds[0],max_number=10000,split = {"val":0.15,"test":0.15}, projection_horizon=5,autoregressive=True,
                                                     outcome_list = config["outcome_list"],
                                                     vitals = config["vital_list"],
                                                     treatment_list = config["treatment_list"],
                                                     static_list = config["static_list"]
                                                     )
    dataset_collection.process_data_multi()
    
    dataset = dataset_collection.test_f_multi
    test_loader = DataLoader(dataset, batch_size=1024,shuffle=False)

    # model_path = "/home/thomas/fork_causal_transformer/Causal-forecasting-training-and-evaluation/TFT_drop/m_e_density_large_final_correc"
    # model_path = "/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/TFT/m_e_density_large_final_correc"
    # model_path = "/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/TFT_drop/theta_density_decay_-5"
    # model_path = "/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/TFT/m_e_cdf_night"
    # model_path = "/home/thomas/mimic/physionet.org/files/mimiciii/CausalTransformer/TFT_drop/theta_cdf_0"
    #model_path = "/home/thomas/fork_causal_transformer/Causal-forecasting-training-and-evaluation/TFT_thomas/theta_density_decay_-5_repro"
    # model_path = "/home/thomas/fork_causal_transformer/Causal-forecasting-training-and-evaluation/TFT_deterministic/m_e_density_multi_seed"
    # model_path = "/home/thomas/fork_causal_transformer/Causal-forecasting-training-and-evaluation/TFT_deterministic/theta_density_low_lr"
    model_path = "/home/thomas/fork_causal_transformer/Causal-forecasting-training-and-evaluation/TFT_deterministic/m_e_density_repro_pipeline"

    file = os.listdir(f"{model_path}_{i}/checkpoints")[0]
    path = os.path.join(f"{model_path}_{i}/checkpoints",file)
    #path = "TFT_drop/theta_cdf_0_0/checkpoints/epoch=0-val_loss=0.48.ckpt"
    
    model = m_e_theta_daily.load_from_checkpoint(path,map_location=device).to(device)
    model.training_m_e = False
    model.eval()

    pred,truth = compute_tau_step_error(model,test_loader)
    loss = np.sqrt(np.mean((pred-truth)**2,axis=0))*dataset_collection.test_f_multi.scaling_params['output_stds']
    losses_rmse[i] = loss.flatten()
    loss = np.mean(np.abs(pred-truth),axis=0)*dataset_collection.test_f_multi.scaling_params['output_stds']
    losses_mae[i] = loss.flatten()
    break

    

  0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
losses_rmse

In [6]:
np.mean(losses_rmse,axis=0)

array([ 8.90449958,  9.56974695,  9.87960325, 10.13996955, 10.36254403])

In [7]:
np.std(losses_rmse,axis=0)

array([0.01070116, 0.00456903, 0.00316494, 0.00854463, 0.0132813 ])

In [8]:
np.mean(losses_mae,axis=0)

array([6.14121738, 6.7888407 , 7.08935911, 7.33855885, 7.53462494])

In [9]:
np.std(losses_mae,axis=0)

array([0.01371462, 0.0124739 , 0.00643761, 0.02186196, 0.02822455])