In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import math
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
import warnings
warnings.filterwarnings("ignore")

from typing import List, Dict, Tuple, Optional
import copy
import math
from omegaconf import OmegaConf,DictConfig

from functools import partial
from tqdm import tqdm
import torch.nn.init as init

from typing import Dict, List, Union, Callable, Optional
from IPython.display import display
from matplotlib import rcParams
from matplotlib.ticker import FuncFormatter

import itertools
from scipy import stats

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)

## Data preparation

In [None]:
past_months = 75
future_months = 25

data_dir = "example_files"

# load static variables
static = pd.read_csv(data_dir+"static.csv", index_col=0).fillna(0)
static.sort_values(by='PatientEncounterCSNID', inplace=True)
static.set_index("PatientEncounterCSNID",inplace=True)
static_numeric = static["PatientBMI"].values.reshape(static.shape[0],1)
static_categoric = static.drop(["PatientBMI"],axis=1).astype(int).values

# load time series numeric conditionals, also known into the future
Age = pd.read_csv(data_dir+"age.csv")
Age.sort_values(by='PatientEncounterCSNID', inplace=True)
Age.set_index("PatientEncounterCSNID",inplace=True)

# load time series categorical conditionals, not known into the future
glasgow = pd.read_csv(data_dir+"glasgow.csv")
glasgow.sort_values(by='PatientEncounterCSNID', inplace=True)
glasgow.set_index("PatientEncounterCSNID",inplace=True)
glasgow[glasgow>15] = 15
glasgow = glasgow-3 ##help out of bound embedding error in the model, so now max is 12 and min is 0

In [None]:
# load features to predict, in this example there are 100 times point (75 past+25 future)
mean_BP = pd.read_csv(data_dir+"meanBP.csv") #.ffill(axis=1)
mean_BP.sort_values(by='PatientEncounterCSNID', inplace=True)
mean_BP.set_index("PatientEncounterCSNID",inplace=True)

pulse = pd.read_csv(data_dir+"pulse.csv") #.ffill(axis=1)
pulse.sort_values(by='PatientEncounterCSNID', inplace=True)
pulse.set_index("PatientEncounterCSNID",inplace=True)

SpO2 = pd.read_csv(data_dir+"SpO2.csv") #.ffill(axis=1)
SpO2.sort_values(by='PatientEncounterCSNID', inplace=True)
SpO2.set_index("PatientEncounterCSNID",inplace=True)

Resp = pd.read_csv(data_dir+"Resp.csv") #.ffill(axis=1)
Resp.sort_values(by='PatientEncounterCSNID', inplace=True)
Resp.set_index("PatientEncounterCSNID",inplace=True)

Temp = pd.read_csv(data_dir+"Temp.csv") #.ffill(axis=1)
Temp.sort_values(by='PatientEncounterCSNID', inplace=True)
Temp.set_index("PatientEncounterCSNID",inplace=True)

print(static_numeric.shape, static_categoric.shape, glasgow.shape, Age.shape)
print(mean_BP.shape, pulse.shape, SpO2.shape, Resp.shape, Temp.shape)

In [None]:
# load labs which are conditional time series numeric values; labs can fillna because we are not predicting it, that's fine
labs_count = 3
labs = np.zeros((static_numeric.shape[0],past_months,3))
count = 0
for measure in os.listdir(data_dir):
    if measure[:3]=="Lab":
        lab = pd.read_csv(data_dir+measure).ffill(axis=1).bfill(axis=1)
        lab.sort_values(by='PatientEncounterCSNID', inplace=True)
        labs[:,:,count] = lab.iloc[:,1:].values
        count += 1
print(labs.shape)

In [None]:
# get masks for real recorded values vs. imputed values
age = Age.values

mean_BP_mask = mean_BP.iloc[:,2:].notnull().astype('int').values
mean_BP_filled = mean_BP.iloc[:,2:].ffill(axis=1).values

pulse_mask = pulse.notnull().astype('int').values
pulse_filled = pulse.ffill(axis=1).bfill(axis=1).values

SpO2_mask = SpO2.notnull().astype('int').values
SpO2_filled = SpO2.ffill(axis=1).bfill(axis=1).values

Resp_mask = Resp.notnull().astype('int').values
Resp_filled = Resp.ffill(axis=1).bfill(axis=1).values

Temp_mask = Temp.notnull().astype('int').values
Temp_filled = Temp.ffill(axis=1).bfill(axis=1).values

# glasgow_mask = glasgow.notnull().astype('int').values
glasgow_filled = glasgow.ffill(axis=1).bfill(axis=1)
glasgow_filled = glasgow_filled.fillna(12).values ##fillna with 15-3 because of cardinality

print(mean_BP_mask.shape, pulse_mask.shape, SpO2_mask.shape, Resp_mask.shape, Temp_mask.shape)

print(np.count_nonzero(np.isnan(mean_BP_filled)), np.count_nonzero(np.isnan(pulse_filled)), np.count_nonzero(np.isnan(SpO2_filled)),
      np.count_nonzero(np.isnan(Resp_filled)), np.count_nonzero(np.isnan(Temp_filled)), np.count_nonzero(np.isnan(glasgow_filled)))

In [None]:
targets = np.zeros((static_numeric.shape[0],past_months+future_months,5))
targets[...,0] = mean_BP_filled
targets[...,1] = pulse_filled
targets[...,2] = SpO2_filled
targets[...,3] = Resp_filled
targets[...,4] = Temp_filled

targets_masks = np.zeros((static_numeric.shape[0],future_months,5))
targets_masks[...,0] = mean_BP_mask[:,past_months:]
targets_masks[...,1] = pulse_mask[:,past_months:]
targets_masks[...,2] = SpO2_mask[:,past_months:]
targets_masks[...,3] = Resp_mask[:,past_months:]
targets_masks[...,4] = Temp_mask[:,past_months:]

targets.shape, np.count_nonzero(np.isnan(targets)), targets_masks.shape, targets_masks.sum().sum()

## Dataloader

In [None]:
class TimeSeriesDataset(Dataset):
    def __init__(self, static_numeric, static_categoric, labs, age, g_score, target_arr, target_mask):
        self.static_categorical = static_categoric
        self.static_numerical = static_numeric
        
        cohort = age.shape[0]
        self.historical_ts_numeric = np.concatenate((labs[:,:past_months,:],
                                                     target_arr[:,:past_months,:],
                                                     age[:,:past_months].reshape(cohort, past_months, 1)),
                                                     axis=-1)
        self.historical_ts_categorical = g_score[:,:past_months].reshape(cohort, past_months, 1)
        self.future_ts_numeric = age[:,past_months:].reshape(cohort, future_months, 1)
        
        self.target = target_arr[:,past_months:]
        self.target_mask = target_mask

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        static_cat = self.static_categorical[idx,...]
        static_num = self.static_numerical[idx,...]
        hist_ts_num = self.historical_ts_numeric[idx,...]
        hist_ts_cat = self.historical_ts_categorical[idx,...]
        future_ts_num = self.future_ts_numeric[idx,...]
        target_i = self.target[idx]
        target_mask_i = self.target_mask[idx]
        
        return {
            'static_feats_categorical': torch.tensor(static_cat, dtype=torch.int32),
            'static_feats_numeric': torch.tensor(static_num, dtype=torch.float32),
            'historical_ts_categorical': torch.tensor(hist_ts_cat, dtype=torch.int32),
            'historical_ts_numeric': torch.tensor(hist_ts_num, dtype=torch.float32),
            'future_ts_numeric': torch.tensor(future_ts_num, dtype=torch.float32),
            'target': torch.tensor(target_i, dtype=torch.float32),
            'target_mask': torch.tensor(target_mask_i, dtype=torch.int32),
        }

In [None]:
static_numeric = static_numeric.astype(np.float32)
static_categoric = static_categoric.astype(int)
labs = labs.astype(np.float32)
age = age.astype(np.float32)
g_score = glasgow_filled.astype(int)
targets = targets.astype(np.float32)
targets_masks = targets_masks.astype(int)

In [None]:
test_set = np.random.choice(age.shape[0], size=int(np.floor(age.shape[0]*0.05)), replace=False)
train_set = np.setdiff1d(np.arange(age.shape[0]), test_set)
print(len(train_set), len(test_set))

b_size = 800

train_dataset = TimeSeriesDataset(static_numeric[train_set, :], 
                                  static_categoric[train_set, :], 
                                  labs[train_set, :], 
                                  age[train_set, :], 
                                  g_score[train_set,:],
                                  targets[train_set, :],
                                 targets_masks[train_set,:])
train_dataloader = DataLoader(train_dataset, batch_size=b_size, shuffle=True)

test_dataset = TimeSeriesDataset(static_numeric[test_set, :],
                                  static_categoric[test_set, :],
                                  labs[test_set, :],
                                  age[test_set, :],
                                  g_score[test_set,:],
                                 targets[test_set, :],
                                targets_masks[test_set, :])
test_dataloader = DataLoader(test_dataset, batch_size=b_size, shuffle=True)

In [None]:
first=False
for data in train_dataloader:
    if first == False:
        print(data['static_feats_categorical'].shape, data['static_feats_numeric'].shape, 
              data['historical_ts_categorical'].shape, data['historical_ts_numeric'].shape, 
              data['future_ts_numeric'].shape, data['target'].shape, data['target_mask'].shape)
        first = True
    else:
        break

## TFT Model

In [None]:
import model as TFT_model

## Training

In [None]:
class QueueAggregator(object):
    def __init__(self, max_size):
        self._queued_list = []
        self.max_size = max_size

    def append(self, elem):
        self._queued_list.append(elem)
        if len(self._queued_list) > self.max_size:
            self._queued_list.pop(0)

    def get(self):
        return self._queued_list

In [None]:
static_card = (np.zeros(static_categoric.shape[1])+2).astype(int).tolist()

numeric_card = (np.zeros(1)+13).astype(int).tolist() #glasgow score

data_props = {'num_historical_numeric': 28,
              'num_historical_categorical': 1,
              'historical_categorical_cardinalities': numeric_card,
              'num_static_numeric': 1,
              'num_static_categorical': 74,
              'static_categorical_cardinalities': static_card,
              'num_future_numeric': 1,
              'num_feature_predicted': 5
              }

In [None]:
# ### hyperparameter tuning
# hyperparams = {
#     'batch_sizes': [400, 800, 1024, 2048],
#     'learning_rates': [1e-3, 1e-5],
#     'max_grad_norms': [100],
#     'dropout': [0.1, 0.3, 0.9],
#     'state_size': [120, 240],
#     'lstm_layers': [2,3,4]
#     'attension_heads': [2,4]
# }

# a = hyperparams.values()
# combinations = list(itertools.product(*a))
# hyperparam_tune_results = {}

# for c in combinations:
#     hyperparam_tune_results[c] = [0,0]
    
# hyperparam_tune_results
# # # np.save("test.npy", hyperparam_tune_results)

In [None]:
## Loss functions
def compute_quantile_loss_instance_wise(outputs: torch.Tensor,
                                        targets: torch.Tensor,
                                        masks: torch.Tensor,
                                        desired_quantiles: torch.Tensor) -> torch.Tensor:
    errors = targets.unsqueeze(-1) - outputs
    # errors: [num_samples x num_horizons x num_features x num_quantiles]
    
    # mask to account for losses only on real values
    for i in range(masks.shape[-1]):
        for j in range(len(desired_quantiles)):
            errors[...,i,j] = errors[...,i,j]*masks[...,i]

    # compute the loss separately for each sample,time-step,quantile
    losses_array = torch.max((desired_quantiles - 1) * errors, desired_quantiles * errors)
    # losses_array: [num_samples x num_horizons x num_features x num_quantiles]

    return losses_array


def get_quantiles_loss_and_q_risk(outputs: torch.Tensor,
                                  targets: torch.Tensor,
                                  masks: torch.Tensor,
                                  desired_quantiles: torch.Tensor) -> Tuple[torch.Tensor, ...]:
    outputs = outputs.reshape((outputs.shape[0], future_months, 5, 3))
    losses_array = compute_quantile_loss_instance_wise(outputs=outputs,
                                                       targets=targets,
                                                       masks = masks,
                                                       desired_quantiles=desired_quantiles)
        
    # sum losses over quantiles and average across time and observations
    q_loss = (losses_array.sum(dim=-1)).sum(dim=-1).mean(dim=-1).mean()

    # compute q_risk for each quantile
    q_risk = 2 * (losses_array.sum(dim=1).sum(dim=0)) / (targets.abs().sum().unsqueeze(-1))
    q_risk = q_risk.sum(dim=0)

    return q_loss, q_risk, losses_array

def process_batch(batch: Dict[str,torch.tensor],
                  model: nn.Module,
                  quantiles_tensor: torch.tensor,
                  device:torch.device):
         
    if device.type=="cuda":
        for k in list(batch.keys()):
            batch[k] = batch[k].to(device)
    
    batch_outputs = model(batch)
    labels = batch['target'] # [batch, future_months, num_feat]
    target_masks = batch['target_mask']
    
    # [batch, future_months, num_feat*num_quantiles]
    predicted_quantiles = batch_outputs['predicted_quantiles']
    
    q_loss, q_risk, _ = get_quantiles_loss_and_q_risk(outputs=predicted_quantiles,
                                                      targets=labels,
                                                      masks=target_masks,
                                                      desired_quantiles=quantiles_tensor)
    return q_loss, q_risk

In [None]:
historical_steps = past_months
future_steps = future_months
num_epochs = 2000

# what is the running-window used by our QueueAggregator object for monitoring the training performance
ma_queue_size = 50
# how many evaluation rounds should we allow, without any improvement in the performance
patience_limit = 20

# a = hyperparams.values()
# combinations = list(itertools.product(*a))

# for c in combinations:
#     print(c)
configuration = {'optimization':
                 {
                     'batch_size': b_size,
                     'learning_rate': 1e-3,
                     'max_grad_norm': 100,
                 }
                 ,
                 'model':
                 {
                     'dropout': 0.3,
                     'state_size': 240,
                     'output_quantiles': [0.1, 0.5, 0.9],
                     'lstm_layers': 2,
                     'attention_heads': 2
                 },
                 # these arguments are related to possible extensions of the model class
                 'task_type':'regression',
                 'target_window_start': None,
                 'data_props': data_props
                }

In [None]:
tft_model = TFT_model.TemporalFusionTransformer(OmegaConf.create(configuration))
tft_model = nn.DataParallel(tft_model, device_ids=[0,1])
tft_model.to(device)
print(device)

train_dataloader = DataLoader(train_dataset, batch_size=b_size, shuffle=True,num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=b_size, shuffle=True, num_workers=4)

opt = optim.Adam(filter(lambda p: p.requires_grad, list(tft_model.parameters())),
                lr=configuration['optimization']['learning_rate'])

# initialize the loss aggregator for running window performance estimation
loss_aggregator = QueueAggregator(max_size=ma_queue_size)
quantiles_tensor = torch.tensor(configuration['model']['output_quantiles']).to(device)

In [None]:
loss_arr = []
loss_arr_test = []
patience = 0
min_loss = 9999
best_model = tft_model

for epoch in range(num_epochs):
    loss_e = 0
    loss_e_test = 0

    tft_model.train()
    for data in train_dataloader:
        opt.zero_grad()
                
        loss,_ = process_batch(batch=data,
                               model=tft_model,
                               quantiles_tensor=quantiles_tensor,
                               device=device)
        loss_e += loss.item()
        loss.backward()
        
        if configuration['optimization']['max_grad_norm'] > 0:
            nn.utils.clip_grad_norm_(tft_model.parameters(), configuration['optimization']['max_grad_norm'])

        opt.step()
        loss_aggregator.append(loss.item())
    loss_arr.append(loss_e)

    # early stopping on performance
    if len(loss_arr) > 1:
            if min_loss > loss_arr[-1]: ##greater than or equal to, or set a minimum loss to compare to
                min_loss = loss_arr[-1]
                best_model = tft_model
                patience = 0
            else:
                if patience > patience_limit:
                    torch.save(best_model, "TFT-multi.pt")
                    np.savetxt("TFT-multi_train_loss",loss_arr)
                    np.savetxt("TFT-multi_test_loss",loss_arr_test)

                    print("Patient max reached, exiting training.")
                    break
                else:
                    patience += 1
    
    if epoch%500 == 499:
        torch.save(tft_model, "TFT-multi_"+str(epoch)+".pt")
        np.savetxt("TFT-multi_train_loss",loss_arr)
        np.savetxt("TFT-multi_test_loss",loss_arr_test)

    ## evaluation round for test dataset
    tft_model.eval()
    with torch.no_grad():
        q_loss_eval, q_risk_eval = [], []
        
        for test_data in test_dataloader:
            batch_loss,batch_q_risk = process_batch(batch=test_data,
                                                    model=tft_model,
                                                    quantiles_tensor=quantiles_tensor,
                                                    device=device)
            loss_e_test += batch_loss.item()
            q_loss_eval.append(batch_loss)
            q_risk_eval.append(batch_q_risk)
        
        eval_loss = torch.stack(q_loss_eval).mean(axis=0)
        eval_q_risk = torch.stack(q_risk_eval,axis=0).mean(axis=0)
        loss_arr_test.append(loss_e_test)
    

    print(f"Epoch: {epoch}, Train Loss = {np.mean(loss_aggregator.get())}" + 
          f" Test q_loss = {eval_loss:.5f} , " + 
          " , ".join([f"q_risk_{q:.1} = {risk:.5f}" for q,risk in zip(quantiles_tensor,eval_q_risk)]))

# hyperparam_tune_results[c] = [loss_arr[-1], loss_arr_test[-1]]
# np.save("hyperparam_tune_results.npy", hyperparam_tune_results)

# Visualization

In [None]:
import visualization as vis

tft_model = TFT_model.TemporalFusionTransformer(OmegaConf.create(configuration))
tft_model = torch.load("TFT-multi.pt")

tft_model.to(device)
tft_model.eval()

In [None]:
rcParams.update({'figure.autolayout': True,
                 'figure.figsize': [10, 5],
                 'font.size': 17})

def process_test_batch(batch: Dict[str,torch.tensor],
                  model: nn.Module,
                  quantiles_tensor: torch.tensor,
                  device:torch.device):
         
    if device.type=="cuda":
        for k in list(batch.keys()):
            batch[k] = batch[k].to(device)
    
    batch_outputs = model(batch)
    labels = batch['target'] # [batch, future_months, num_feat]
    target_masks = batch['target_mask']
    
    # [batch, future_months, num_feat*num_quantiles]
    predicted_quantiles = batch_outputs['predicted_quantiles']
        
    q_loss, q_risk, _ = get_quantiles_loss_and_q_risk(outputs=predicted_quantiles,
                                                      targets=labels,
                                                      masks=target_masks,
                                                      desired_quantiles=quantiles_tensor)
    return q_loss, q_risk, predicted_quantiles.cpu().numpy(), labels.cpu().numpy()

In [None]:
## calculate metric

total_counter = 0

individual_percent_bounded = np.zeros((5,len(test_set)))-10
MAE = np.zeros((5,3))
MAPE = np.zeros((5,3))

with torch.no_grad():
    for data in test_dataloader:
        mask = data['target_mask'] #shape 800, 25, 5
        batch_loss,batch_q_risk,prediction,true = process_test_batch(batch=data,
                                                                     model=tft_model,
                                                                     quantiles_tensor=quantiles_tensor,
                                                                     device=device)
        #prediction.shape = 800, 25, 18, where first 3 is one measure, etc
        #true.shape = 800, 25, 5
        
        for meas in range(5):
            prediction_this_measure = prediction[...,3*meas:3*(meas+1)]
            true_this_measure = true[...,meas]
            mask_this_measure = mask[...,meas]
            counter = 0
        
            for i in range(prediction.shape[0]):
                x1 = prediction_this_measure[i,:,0]
                x5 = prediction_this_measure[i,:,1]
                x9 = prediction_this_measure[i,:,2]
                y = true_this_measure[i,:]
                ind_mask = mask_this_measure[i,:].cpu().numpy()

                within_bound_counter = 0
                valid_time_points = 0
                MAE_1 = 0
                MAE_5 = 0
                MAE_9 = 0

                MAPE_1 = 0
                MAPE_5 = 0
                MAPE_9 = 0
                
                for j in range(25):
                    if ind_mask[j]==1 and y[j]>0:
                        valid_time_points += 1

                        MAE_1 += np.abs(x1[j]-y[j])
                        MAE_5 += np.abs(x5[j]-y[j])
                        MAE_9 += np.abs(x9[j]-y[j])

                        MAPE_1 += np.abs((x1[j]-y[j])/y[j])
                        MAPE_5 += np.abs((x5[j]-y[j])/y[j])
                        MAPE_9 += np.abs((x9[j]-y[j])/y[j])

                        if x1[j] <= y[j] and x9[j] >= y[j]:
                            within_bound_counter += 1
                
                if valid_time_points > 0:
                    MAE[meas, 0] += MAE_1*1.0/valid_time_points
                    MAE[meas, 1] += MAE_5*1.0/valid_time_points
                    MAE[meas, 2] += MAE_9*1.0/valid_time_points

                    MAPE[meas, 0] += MAPE_1*1.0/valid_time_points
                    MAPE[meas, 1] += MAPE_5*1.0/valid_time_points
                    MAPE[meas, 2] += MAPE_9*1.0/valid_time_points
                    
                    individual_percent_bounded[meas, total_counter+counter] = within_bound_counter * 100.0 / valid_time_points
                else:
                    individual_percent_bounded[meas, total_counter+counter] = -10 #i will drop these later
                
                counter += 1
        total_counter += mask.shape[0]


tmp = pd.DataFrame(np.transpose(individual_percent_bounded), columns=['sys_BP','dias_BP','pulse','SpO2','resp','temp'])
tmp = tmp[~tmp.isin([-10]).any(axis=1)]

for c in range(tmp.shape[1]):
    measure = tmp.iloc[:,c]
    measure = measure.loc[measure>=0]
    
    print("for measure "+ tmp.columns.values[c]+" with total count "+str(measure.shape[0]))
    print("MAE for percentiles: ", MAE[c,:]*1.0/measure.shape[0])
    print("MAPE for percentiles: ", MAPE[c,:]*1.0/measure.shape[0])
    
    for i in range(50,110,10):
        num = measure.loc[measure>=i].shape[0]
        print("num patients with correct bound "+str(i)+"% of trajectory: " + str(num)
             + " (" + str(num*100.0/measure.shape[0]) +"%)")
    
    print("===============================")

In [None]:
## sample visualization on test set
percentile_10 = []
percentile_50 = []
percentile_90 = []
num_features_predicted = 5

signal_history_arr = np.zeros((len(test_set), past_months))
true_y = np.zeros((len(test_set), future_months, 5))
predict_y = np.zeros((len(test_set), future_months, 5, 3))
counter = 0

with torch.no_grad():
    for data in test_dataloader:
        mask = data['target_mask']
        
        batch_loss,batch_q_risk,prediction,true = process_test_batch(batch=data,
                                                                     model=tft_model,
                                                                     quantiles_tensor=quantiles_tensor,
                                                                     device=device)

        
        for meas in range(5):
            signal_history_arr = data['historical_ts_numeric'][...,-7+meas].cpu().numpy()
            print("=====================================")
            prediction_this_measure = prediction[...,3*meas:3*(meas+1)]
            true_this_measure = true[...,meas]
            mask_this_measure = mask[...,meas]
            counter = 0
        
            for i in range(3):
                vis.display_target_trajectory(signal_history=signal_history_arr,
                                              signal_future=true_this_measure,
                                              model_preds=prediction_this_measure,
                                              observation_index=i,
                                              model_quantiles=configuration['model']['output_quantiles'],
                                              unit='15min')

In [None]:
# plot bland_altman
def bland_altman_plot(ground_truth, predicted, predicted_upper, predicted_lower):
    # Calculate the means and differences
    means = (ground_truth + predicted) / 2
    differences = predicted - ground_truth
    
    # Calculate the mean and standard deviation of the differences
    mean_diff = np.mean(differences)
    std_diff = np.std(differences, ddof=1)
    
    # Plot the points
    plt.figure(figsize=(10, 5))
    plt.errorbar(means, differences, yerr=[predicted - predicted_lower, predicted_upper - predicted], 
                 fmt='o', marker='.', color='blue', ecolor='green', alpha=0.5)
    
    # Plot the mean difference and limits of agreement
    plt.axhline(mean_diff, color='black', linestyle='-')
    plt.axhline(mean_diff + 1.96 * std_diff, color='black', linestyle='--')
    plt.axhline(mean_diff - 1.96 * std_diff, color='black', linestyle='--')
    
    # Labels and title
    plt.title(f'Bland-Altman Mean BP: {mean_diff:.2f} +/- {std_diff:.2f}')
    plt.xlabel('Ground Truth')
    plt.ylabel('Prediction - Ground Truth')
    
    # Adding the legend
    plt.legend([f'N={len(ground_truth)}'], loc='upper left')
    
    plt.xlim(50,118)
    plt.ylim(-70,40)
    
    plt.grid(True)
    plt.show()

In [None]:
GT1 = []
P1 = []
U1 = []
L1 = []

with torch.no_grad():
    for data in test_dataloader:        
        meas=1
        true_this_measure = data['target'][...,meas].cpu().numpy()

        batch_loss,batch_q_risk,prediction,true = process_test_batch(batch=data,
                                                                     model=tft_model,
                                                                     quantiles_tensor=quantiles_tensor,
                                                                     device=device)
        prediction_this_measure = prediction[...,3*meas:3*(meas+1)]
        prediction = prediction_this_measure[...,1]
        upper = prediction_this_measure[...,2]
        lower = prediction_this_measure[...,0]
        
        for i in range(mask_this_measure.shape[0]):
            for j in range(mask_this_measure.shape[1]):
                if mask_this_measure[i,j] == 1:
                    GT1 = np.append(GT1, true_this_measure[i,j])
                    P1 = np.append(P1, prediction[i,j])
                    U1 = np.append(U1, upper[i,j])
                    L1 = np.append(L1, lower[i,j])

GT1.shape, P1.shape, U1.shape, L1.shape

MGT = GT1+(GT1-GT1)/3.0
MP = P1+(P1-P1)/3.0
MU = U1+(U1-U1)/3.0
ML = L1+(L1-L1)/3.0

bland_altman_plot(MGT, MP, MU, ML)