## Test and compare trained models

In [None]:
import argparse
import os 
import json
import glob 
import copy 
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 
mse_loss = nn.MSELoss()
import utils
from utils import AverageMeterSet
import prepare_data
import models
from sklearn.model_selection import KFold
import sklearn.metrics as metrics
kf = KFold(n_splits=5, random_state=None, shuffle=False)
from datetime import date
today = date.today()
date = today.strftime("%m%d")
import matplotlib.pyplot as plt
import matplotlib 
import seaborn as sns 
matplotlib.rcParams["figure.dpi"] = 300
plt.style.use('bmh')
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
legend_properties = {'weight':'bold', 'size': 14}
legend_properties_s = {'weight':'bold', 'size': 10}
dir_data = {'satori': '/nobackup/users/weiliao', 'colab':'/content/drive/MyDrive/ColabNotebooks/MIMIC/Extract/MEEP/Extracted_sep_2022/0910'}
# load data
meep_mimic = np.load(dir_data['colab'] + '/MIMIC_compile_0911_2022.npy', \
                allow_pickle=True).item()
train_vital = meep_mimic ['train_head']
dev_vital = meep_mimic ['dev_head']
test_vital = meep_mimic ['test_head']
mimic_static = np.load(dir_data['colab'] + '/MIMIC_static_0922_2022.npy', \
                        allow_pickle=True).item()
mimic_target = np.load(dir_data['colab'] + '/MIMIC_target_0922_2022.npy', \
                        allow_pickle=True).item()
 
class Args:
    def __init__(self, d=None):
        if d is not None:
            for key, value in d.items():
                setattr(self, key, value)

base_dir = '/content/drive/My Drive/ColabNotebooks/MIMIC/TCN/VAE/checkpoints/'

workname_list = ['0505_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens',
                 '0506_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens_sens_1',
                 '0507_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens_21',
                 '0511_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens_sens0_bs64',
                 '0511_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens_sens1_bs64',
                 '0511_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens_sens21_bs64'
                 ]

for wn in workname_list:

    all_path = glob.glob(base_dir + wn + '/*.pt')
    all_path = [p for p in all_path if "stage2" not in p]
    with open(base_dir+wn+'/params.json') as f: 
        params = json.load(f)
    params['platform'] = 'colab'
    args = Args(params)
    device = torch.device("cuda:%d"%args.device_id if torch.cuda.is_available() else "cpu")
    
    train_head, train_static, train_sofa, train_id =  utils.crop_data_target('mimic', train_vital, mimic_target, mimic_static, 'train', args.sens_ind)
    dev_head, dev_static, dev_sofa, dev_id =  utils.crop_data_target('mimic', dev_vital , mimic_target, mimic_static, 'dev',  args.sens_ind)
    test_head, test_static, test_sofa, test_id =  utils.crop_data_target('mimic', test_vital, mimic_target, mimic_static, 'test',  args.sens_ind)

    if args.use_sepsis3 == True:
        train_head, train_static, train_sofa, train_id = utils.filter_sepsis('mimic', train_head, train_static, train_sofa, train_id, args.platform)
        dev_head, dev_static, dev_sofa, dev_id = utils.filter_sepsis('mimic', dev_head, dev_static, dev_sofa, dev_id, args.platform)
        test_head, test_static, test_sofa, test_id = utils.filter_sepsis('mimic', test_head, test_static, test_sofa, test_id, args.platform)

    # build model
    model = models.Ffvae(args)
    # torch.save(model.state_dict(), '/home/weiliao/FR-TSVAE/start_weights.pt')

    # 10-fold cross validation
    trainval_head = train_head + dev_head
    trainval_static = train_static + dev_static
    trainval_stail = train_sofa + dev_sofa
    trainval_ids = train_id + dev_id

    # prepare data
    torch.autograd.set_detect_anomaly(True)
    for c_fold, (train_index, test_index) in enumerate(kf.split(trainval_head)):
        # best_loss = 1e4
        # patience = 0
        # if c_fold >= 1:
        #     model.load_state_dict(torch.load('/home/weiliao/FR-TSVAE/start_weights.pt'))
        print('Starting Fold %d' % c_fold)
        print("TRAIN:", len(train_index), "TEST:", len(test_index))
        train_head, val_head = utils.slice_data(trainval_head, train_index), utils.slice_data(trainval_head, test_index)
        train_static, val_static = utils.slice_data(trainval_static, train_index), utils.slice_data(trainval_static, test_index)
        train_stail, val_stail = utils.slice_data(trainval_stail, train_index), utils.slice_data(trainval_stail, test_index)
        train_id, val_id = utils.slice_data(trainval_ids, train_index), utils.slice_data(trainval_ids, test_index)

        train_dataloader, dev_dataloader, test_dataloader = prepare_data.get_data_loader(args, train_head, val_head,
                                                                                            test_head, 
                                                                                            train_stail, val_stail,
                                                                                            test_sofa,
                                                                                            train_static=train_static,
                                                                                            dev_static=val_static,
                                                                                            test_static=test_static,
                                                                                            train_id=train_id,
                                                                                            dev_id=val_id,
                                                                                            test_id=test_id)
    # prepare id_20 
    stay_ids = [30534026, 32238007, 30134473, 37221316, 36824639, 30397073, 39298593, 34590507, 36437173, 
                38656645, 36688715, 37461900, 31176570, 38392389, 37603454, 37741158, 37605174, 33792670, 36939508, 37602211]
    idmap = {}
    for i, ids in enumerate(test_id): 
        idmap[ids] = i 
    # convert id to index 
    id_20 = []
    for ids in stay_ids: 
        id_20.append(idmap[ids])
    # deal with CM and AUC plots, 3d 
    for p in all_path: 
        # creat a subfolder to save the test results
        curr_dir = base_dir + wn + '/'+ p.split('/')[-1].split('.')[0]
        if not os.path.exists(curr_dir):
            os.makedirs(curr_dir)
        model.load_state_dict(torch.load(p))
        print(p)
        #  test on test loader 
        model.eval()
        logits = []
        stt = []
        sofa_loss = []
        with torch.no_grad():
            for vitals, static, target, train_ids, key_mask in test_dataloader:
                # (bs, feature_dim, T)
                vitals = vitals.to(device)
                # (bs)
                static = static.to(device)
                # (bs, T, 1)
                target = target.to(device)
                # (bs, T)
                key_mask = key_mask.to(device)

                # _mu shape [bs, zdim, T]
                _mu, _logvar = model.encoder(vitals)
                # b_logits [bs, 1]
                b_logits = _mu[:, model.sens_idx]
                mu = _mu[:, model.nonsens_idx, :]
                # (bs, T, 1)
                sofa_p = model.regr(mu.transpose(1, 2), "classify")
                sofa_loss.extend([mse_loss(sofa_p[i][key_mask[i]==0], target[i][key_mask[i]==0]) for i in range(len(sofa_p))])
                # for static info prediction 
                logits.extend(torch.stack([b_logits[i].squeeze(0).mean() for i in range(len(b_logits))]))
                stt.extend(static)
        # save the sofa test result to file 
        test_loss = torch.mean(torch.stack(sofa_loss)).cpu().numpy()
        with open(curr_dir + '/sofa_test.json', 'w') as f:
            json.dump(str(test_loss), f)
        print(test_loss)
        
        # display AUC 
        logits = torch.stack(logits)
        stt = torch.stack(stt) 
        metrics.RocCurveDisplay.from_predictions(stt.cpu(),  nn.Sigmoid()(logits).cpu())  
        fig = plt.gcf()
        plt.show()
        fig.savefig(curr_dir + '/auc.eps', format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

        # calculate optimal threshold
        fpr, tpr, thresholds = metrics.roc_curve(stt.cpu(),  nn.Sigmoid()(logits).cpu())
        gmeans = np.sqrt(tpr * (1-fpr))
        opt_th = thresholds[np.argmax(gmeans)]

        for th in [0.5, opt_th]:

            pred =  (nn.Sigmoid()(logits).cpu() > th).float()
            cm = metrics.confusion_matrix(stt.cpu(), pred)
            cf_matrix = cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)
            group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()]
            # percentage based on true label 
            gr = (cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)).flatten()
            group_percentages = ['{0:.2%}'.format(value) for value in gr]

            labels = [f'{v1}\n{v2}' for v1, v2 in zip(group_percentages, group_counts)]

            labels = np.asarray(labels).reshape(2, 2)
            
            if args.sens_ind == 0:
                xlabel = ['Pred-%s'%l for l in ['F', 'M']]
                ylabel = ['%s'%l for l in ['F', 'M']]   
            elif args.sens_ind == 1: 
                xlabel = ['Pred-%s'%l for l in ['Y', 'E']]
                ylabel = ['%s'%l for l in ['Y', 'E']]   
            elif args.sens_ind == 21: 
                xlabel = ['Pred-%s'%l for l in ['W', 'B']]
                ylabel = ['%s'%l for l in ['W', 'B']]   

            sns.set(font_scale = 1.5)

            hm = sns.heatmap(cf_matrix, annot=labels, fmt='', cmap = 'OrRd', \
            annot_kws={"fontsize": 16}, xticklabels=xlabel, yticklabels=ylabel, cbar=False)
            # hm.set(title=title)
            fig = plt.gcf()
            plt.show()  
            fig.savefig(curr_dir + '/cm_%f.eps'%th, format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)
        
        for v in ['raw', 'smooth']: 
            fig, ax = plt.subplots(4, 5, figsize=(25, 12))
            axes = ax.flatten()
            for i in range(20):
                id = test_id[id_20[i]]
                sofa = mimic_target['test'][id]
                _mu, _logvar = model.encoder(torch.FloatTensor(test_head[id_20[i]]).unsqueeze(0).to(device))
                mu = _mu[:, model.nonsens_idx, :]
                sofa_p = model.regr(mu.transpose(1, 2), "classify").squeeze(0).cpu().detach().numpy()*15
                if v == 'smooth': 
                    sofa_p = [np.round(i) for i in sofa_p]
                n = len(sofa)
                axes[i].plot(range(len(sofa)), sofa, label='Current SOFA')
                axes[i].plot(range(24, n), sofa_p, c="tab:green", label ='Predicted SOFA')
                axes[i].set_xlim((0, len(sofa)))
                axes[i].tick_params(axis='both', labelsize=8)
                if max(sofa) <= 11 and int(max(sofa_p)) <=11:
                    axes[i].set_ylim((0, 12))
                else: 
                    axes[i].set_ylim((0, max(max(sofa), int(max(sofa_p)))+1))
                if i == 0: 
                    axes[i].set_ylabel('SOFA score', size=14,  fontweight='bold')
                if i == 19:
                    axes[i].set_xlabel('ICU_in Hours', size=14,  fontweight='bold')
                if i == 4:
                    axes[i].legend(loc='upper right',  prop=legend_properties)
                # save each small figure 
                fig_s, ax_s = plt.subplots(1, 1, figsize=(5, 3))
                ax_s.plot(range(len(sofa)), sofa, label='Current SOFA')
                ax_s.plot(range(24, n), sofa_p, c="tab:green", label ='Predicted SOFA')
                ax_s.set_xlim((0, len(sofa)))
                ax_s.tick_params(axis='both', labelsize=6)
                if max(sofa) <= 11 and int(max(sofa_p)) <=11:
                    ax_s.set_ylim((0, 12))
                else: 
                    ax_s.set_ylim((0, max(max(sofa), int(max(sofa_p)))+1))
                ax_s.set_ylabel('SOFA score', size=12,  fontweight='bold')
                ax_s.set_xlabel('ICU_in Hours', size=12,  fontweight='bold')
                ax_s.legend(loc='upper right',  prop=legend_properties_s)
                fig_s.savefig(curr_dir + '/indiv_sofa_%s_%d.eps'%(v, i), format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

            # save the big figure 
            fig.savefig(curr_dir + '/sofa_%s.eps'%v, format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

In [None]:
id_20