## Test and compare trained models

In [None]:
# checkpoint file 
import glob 
base_dir = '/content/drive/My Drive/ColabNotebooks/MIMIC/TCN/VAE/checkpoints/'
workname = '0507_lr1e-4beta.001_res_regrtheta_5_mlp_regr_nonsens_21/'
all_path = glob.glob(base_dir+workname + '*.pt')
all_path = [p for p in all_path if "stage2" not in p]

In [None]:
import json 
with open(base_dir+workname+'params.json') as f: 
    params = json.load(f)
params['platform'] = 'colab'
class Args:
    def __init__(self, d=None):
        if d is not None:
            for key, value in d.items():
                setattr(self, key, value)
args = Args(params)

In [None]:
import argparse
import os 
import copy 
import pickle
import numpy as np
import pandas as pd
import torch
import utils
from utils import AverageMeterSet
import prepare_data
import models
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, random_state=None, shuffle=False)
from datetime import date
today = date.today()
date = today.strftime("%m%d")
import matplotlib.pyplot as plt
import matplotlib 
matplotlib.rcParams["figure.dpi"] = 300
plt.style.use('bmh')
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
legend_properties = {'weight':'bold', 'size': 14}
dir_data = {'satori': '/nobackup/users/weiliao', 'colab':'/content/drive/MyDrive/ColabNotebooks/MIMIC/Extract/MEEP/Extracted_sep_2022/0910'}


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Parser for time series VAE models")
    parser.add_argument("--device_id", type=int, default=0, help="GPU id")
    parser.add_argument("--platform", type=str, default='satori', choices=['satori', 'colab'], help='Platform to run the code')
    # data/loss parameters
    parser.add_argument("--use_sepsis3", action = 'store_false', default= True, help="Whethe only use sepsis3 subset")
    parser.add_argument("--bucket_size", type=int, default=300, help="bucket size to group different length of time-series data")
    parser.add_argument("--beta", type=float, default=0.001, help="coefficent for the elbo loss")
    parser.add_argument("--gamma", type=float, default=0.5, help="coefficent for the total_corr loss")
    parser.add_argument("--alpha", type=float, default=0.5, help="coefficent for the clf loss")
    parser.add_argument("--theta", type=float, default=5, help="coefficent for the sofa loss in stage 1")
    parser.add_argument("--zdim", type=int, default=20, help="dimension of the latent space")
    parser.add_argument("--sens_ind", type=int, default=21, help="index of the sensitive feature")
    parser.add_argument("--scale_elbo", action = 'store_true', default=False, help="Whether to scale the ELBO loss")
    # model parameters
    parser.add_argument("--kernel_size", type=int, default=3, help="kernel size")
    parser.add_argument("--drop_out", type=float, default=0.2, help="drop out rate")
    parser.add_argument("--enc_channels", default=[256, 128, 64, 40],  help="number of channels in the encoder")
    parser.add_argument("--dec_channels", default = [64, 128, 256, 200], help="number of channels in the decoder")
    parser.add_argument("--num_inputs", type=int, default=200, help="number of features in the inputs")
    # discriminator parameters
    parser.add_argument("--disc_channels",  type=int, default=200, help="number of channels in the discriminator")
    # regressor parameters
    parser.add_argument("--regr_model",  type=str, default='mlp', choices=['mlp', 'tcn'], help='Model choice in sofa prediction')
    parser.add_argument("--regr_channels",  type=int, default=200, help="number of channels in the regressor")
    parser.add_argument("--regr_tcn_channels",  nargs='+', type=int, help="number of channels in the regressor")
    parser.add_argument("--regr_only_nonsens", action = 'store_false', default=True, help="Whether only using nonsens latents to predict sofa")
    # training parameters
    parser.add_argument("--epochs", type=int, default=300, help="Number of training epochs")
    parser.add_argument("--data_batching", type=str, default='close', choices=['same', 'close', 'random'], help='How to batch data')
    parser.add_argument("--bs", type=int, default=16, help="batch size")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--patience", type=int, default=20, help="Patience epochs for early stopping.")
    parser.add_argument("--checkpoint", type=str, default='test', help=" name of checkpoint model")

    args = parser.parse_known_args()[0]
    device = torch.device("cuda:%d"%args.device_id if torch.cuda.is_available() else "cpu")
    arg_dict = vars(args)
    # workname = date + "_" +  args.checkpoint
    # utils.creat_checkpoint_folder('/home/weiliao/FR-TSVAE/checkpoints/' + workname, 'params.json', arg_dict)

    # load data
    meep_mimic = np.load(dir_data[args.platform] + '/MIMIC_compile_0911_2022.npy', \
                    allow_pickle=True).item()
    train_vital = meep_mimic ['train_head']
    dev_vital = meep_mimic ['dev_head']
    test_vital = meep_mimic ['test_head']
    mimic_static = np.load(dir_data[args.platform] + '/MIMIC_static_0922_2022.npy', \
                            allow_pickle=True).item()
    mimic_target = np.load(dir_data[args.platform] + '/MIMIC_target_0922_2022.npy', \
                            allow_pickle=True).item()

In [None]:
static_key = 'static_' + 'test'
idx = pd.IndexSlice
length = [i.shape[-1] for i in test_vital]
all_train_id = list(mimic_target['test'].keys())
stayids = [all_train_id[i] for i, m in enumerate(length) if m >24]
sofa_tail = [mimic_target['test'][j][24:]/15 for j in stayids]
# list of array [array([5]), array([1]), array([4]),
train_target = [np.nonzero(mimic_static[static_key].loc[idx[:, :, j]].iloc[:, 21:].values)[1] for j in stayids]
sub_ind = [i for i, m in enumerate(train_target) if m == 2 or m == 5]
race_dict = {2: 1, 5:0}
# a list of target class
train_targets = [race_dict[train_target[i][0]] for i in sub_ind]
# train_filters = [train_filter[i] for i in sub_ind]
# sofa_tails = [sofa_tail[i] for i in sub_ind]
# stayidss = [stayids[i] for i in sub_ind]

In [None]:
import importlib 
importlib.reload(utils)
train_head, train_static, train_sofa, train_id =  utils.crop_data_target('mimic', train_vital, mimic_target, mimic_static, 'train', args.sens_ind)
dev_head, dev_static, dev_sofa, dev_id =  utils.crop_data_target('mimic', dev_vital , mimic_target, mimic_static, 'dev',  args.sens_ind)
test_head, test_static, test_sofa, test_id =  utils.crop_data_target('mimic', test_vital, mimic_target, mimic_static, 'test',  args.sens_ind)

In [None]:
if args.use_sepsis3 == True:
    train_head, train_static, train_sofa, train_id = utils.filter_sepsis('mimic', train_head, train_static, train_sofa, train_id, args.platform)
    dev_head, dev_static, dev_sofa, dev_id = utils.filter_sepsis('mimic', dev_head, dev_static, dev_sofa, dev_id, args.platform)
    test_head, test_static, test_sofa, test_id = utils.filter_sepsis('mimic', test_head, test_static, test_sofa, test_id, args.platform)

In [None]:
trainval_head = train_head + dev_head
trainval_static = train_static + dev_static
trainval_stail = train_sofa + dev_sofa
trainval_ids = train_id + dev_id

# prepare data
torch.autograd.set_detect_anomaly(True)
for c_fold, (train_index, test_index) in enumerate(kf.split(trainval_head)):
    # best_loss = 1e4
    # patience = 0
    # if c_fold >= 1:
    #     model.load_state_dict(torch.load('/home/weiliao/FR-TSVAE/start_weights.pt'))
    print('Starting Fold %d' % c_fold)
    print("TRAIN:", len(train_index), "TEST:", len(test_index))
    train_head, val_head = utils.slice_data(trainval_head, train_index), utils.slice_data(trainval_head, test_index)
    train_static, val_static = utils.slice_data(trainval_static, train_index), utils.slice_data(trainval_static, test_index)
    train_stail, val_stail = utils.slice_data(trainval_stail, train_index), utils.slice_data(trainval_stail, test_index)
    train_id, val_id = utils.slice_data(trainval_ids, train_index), utils.slice_data(trainval_ids, test_index)

In [None]:
# check if the return id and the sensitive labels are aligned
import random 
id_20 = random.choices(test_id, k=20)
id_20

In [None]:
importlib.reload(prepare_data)
train_dataloader, dev_dataloader, test_dataloader = prepare_data.get_data_loader(args, train_head, val_head,
                                                                                            test_head, 
                                                                                            train_stail, val_stail,
                                                                                            test_sofa,
                                                                                            train_static=train_static,
                                                                                            dev_static=val_static,
                                                                                            test_static=test_static,
                                                                                            train_id=train_id,
                                                                                            dev_id=val_id,
                                                                                            test_id=test_id)

In [None]:
v = []
s = []
t = []
i = []
k = []

for vitals, static, target, train_ids, key_mask in train_dataloader:
    v.append(vitals)
    s.append(static)
    t.append(target)
    i.append(train_ids)
    k.append(key_mask)

In [None]:
total_s = torch.cat(s, dim=0)
total_i = torch.cat(i, dim=0)
ind_20 = random.choices(range(len(total_i)), k=20)
for j in ind_20: 
    curr = total_i[j]
    print("current id is %d "%curr)
    if len(mimic_static['static_train'].loc[idx[:, :, curr.item()]]) == 1: 
        print(mimic_static['static_train'].loc[idx[:, :, curr.item()]].iloc[:, 0].values[0], mimic_static['static_train'].loc[idx[:, :, curr.item()]].iloc[:, 1].values[0], 
              mimic_static['static_train'].loc[idx[:, :, curr.item()]].iloc[:, 21:].values[0])
    else:
        print(mimic_static['static_dev'].loc[idx[:, :, curr.item()]].iloc[:, 0].values[0], mimic_static['static_dev'].loc[idx[:, :, curr.item()]].iloc[:, 1].values[0], 
              mimic_static['static_dev'].loc[idx[:, :, curr.item()]].iloc[:, 21:].values[0])   
    
    print(total_s[j])
    
#     np.nonzero(.values)[1] for j in stayids]

In [None]:
mimic_static['static_test'].loc[idx[:, :, curr.item()]].iloc[:, 21:].values[0]

In [None]:
len(mimic_static['static_train'].loc[idx[:, :, curr.item()]])

In [None]:
mimic_static['static_test'].loc[idx[:, :, 38775862]]

In [None]:
mimic_static['static_test'].loc[idx[:, :, curr.item()]]

In [None]:
test_id.shape

In [None]:
# deal with CM and AUC plots
import sklearn.metrics as metrics
import seaborn as sns 
import torch.nn as nn 

for p in all_path: 
    model.load_state_dict(torch.load(p))
    print(p)
    #  test on test loader 
    model.eval()
    logits = []
    stt = []
    with torch.no_grad():
        for vitals, static, target, train_ids, key_mask in test_dataloader:
            # (bs, feature_dim, T)
            vitals = vitals.to(device)
            # (bs)
            static = static.to(device)
            # (bs, T, 1)
            target = target.to(device)
            # (bs, T)
            key_mask = key_mask.to(device)

            # _mu shape [bs, zdim, T]
            _mu, _logvar = model.encoder(vitals)
            # b_logits [bs, 1]
            b_logits = _mu[:, model.sens_idx]
            logits.extend(torch.stack([b_logits[i].squeeze(0).mean() for i in range(len(b_logits))]))
            stt.extend(static)
    logits = torch.stack(logits)
    stt = torch.stack(stt) 
    metrics.RocCurveDisplay.from_predictions(stt.cpu(),  nn.Sigmoid()(logits).cpu())  
    plt.show()
    wname = p.split('/')[-1].split('.')[0]
    # plt.savefig('./checkpoints/' + workname + 'auc_curve/' + '%s_auc.eps'%wname, format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

    pred =  (nn.Sigmoid()(logits).cpu() > 0.5).float()
    cm = metrics.confusion_matrix(stt.cpu(), pred)
    cf_matrix = cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)
    group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()]
    # percentage based on true label 
    gr = (cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)).flatten()
    group_percentages = ['{0:.2%}'.format(value) for value in gr]

    labels = [f'{v1}\n{v2}' for v1, v2 in zip(group_percentages, group_counts)]

    labels = np.asarray(labels).reshape(2, 2)


    xlabel = ['Pred-%d'%i for i in range(2)]
    ylabel = ['%d'%i for i in range(2)]

    sns.set(font_scale = 1.5)

    hm = sns.heatmap(cf_matrix, annot=labels, fmt='', cmap = 'OrRd', \
    annot_kws={"fontsize": 16}, xticklabels=xlabel, yticklabels=ylabel, cbar=False)
    # hm.set(title=title)
    fig = plt.gcf()
    plt.show()  
    # plt.savefig('./checkpoints/' + workname + 'cm_matrix/' + '%s_cm.eps'%wname, format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

In [None]:
# SOFA Test set results, including MSE, analyze on 20 paticular patients, MSE variation by ICU stay length 


In [None]:
# Compare the same with second stage SOFA training 

In [None]:
# check the 3 sense case
import argparse
import os 
import json
import glob 
import copy 
import pickle
import random 
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 
mse_loss = nn.MSELoss()
import utils
from utils import AverageMeterSet
import prepare_data
import models
from sklearn.model_selection import KFold
import sklearn.metrics as metrics
kf = KFold(n_splits=5, random_state=None, shuffle=False)
from datetime import date
today = date.today()
date = today.strftime("%m%d")
import matplotlib.pyplot as plt
import matplotlib 
import seaborn as sns 
matplotlib.rcParams["figure.dpi"] = 300
plt.style.use('bmh')
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
legend_properties = {'weight':'bold', 'size': 14}
legend_properties_s = {'weight':'bold', 'size': 10}
dir_data = {'satori': '/nobackup/users/weiliao', 'colab':'/content/drive/MyDrive/ColabNotebooks/MIMIC/Extract/MEEP/Extracted_sep_2022/0910'}
# load data
meep_mimic = np.load(dir_data['satori'] + '/MIMIC_compile_0911_2022.npy', \
                allow_pickle=True).item()
train_vital = meep_mimic ['train_head']
dev_vital = meep_mimic ['dev_head']
test_vital = meep_mimic ['test_head']
mimic_static = np.load(dir_data['satori'] + '/MIMIC_static_0922_2022.npy', \
                        allow_pickle=True).item()
mimic_target = np.load(dir_data['satori'] + '/MIMIC_target_0922_2022.npy', \
                        allow_pickle=True).item()
 
class Args:
    def __init__(self, d=None):
        if d is not None:
            for key, value in d.items():
                setattr(self, key, value)

base_dir = '/home/weiliao/FR-TSVAE/checkpoints/'
workname_list = [
    '0519_lr1e-4beta.001_res_regrtheta_1_mlp_regr_nonsens_sens0_mask_3sens'
]

with_mask = True
for wn in workname_list:

    all_path = glob.glob(base_dir + wn + '/*.pt')
    # all_path = [p for p in all_path if "stage2" not in p]
    with open(base_dir+wn+'/params.json') as f: 
        params = json.load(f)
    params['platform'] = 'satori'
    params['device_id'] = 0 
    args = Args(params)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    train_head, train_static, train_sofa, train_id =  utils.crop_data_target('mimic', train_vital, mimic_target, mimic_static, 'train', args.sens_ind)
    dev_head, dev_static, dev_sofa, dev_id =  utils.crop_data_target('mimic', dev_vital , mimic_target, mimic_static, 'dev',  args.sens_ind)
    test_head, test_static, test_sofa, test_id =  utils.crop_data_target('mimic', test_vital, mimic_target, mimic_static, 'test',  args.sens_ind)

    if args.use_sepsis3 == True:
        train_head, train_static, train_sofa, train_id = utils.filter_sepsis('mimic', train_head, train_static, train_sofa, train_id, args.platform)
        dev_head, dev_static, dev_sofa, dev_id = utils.filter_sepsis('mimic', dev_head, dev_static, dev_sofa, dev_id, args.platform)
        test_head, test_static, test_sofa, test_id = utils.filter_sepsis('mimic', test_head, test_static, test_sofa, test_id, args.platform)

    # build model
    model = models.Ffvae(args)
    # torch.save(model.state_dict(), '/home/weiliao/FR-TSVAE/start_weights.pt')

    # 10-fold cross validation
    trainval_head = train_head + dev_head
    trainval_static = train_static + dev_static
    trainval_stail = train_sofa + dev_sofa
    trainval_ids = train_id + dev_id

    # prepare data
    torch.autograd.set_detect_anomaly(True)
    for c_fold, (train_index, test_index) in enumerate(kf.split(trainval_head)):
        # best_loss = 1e4
        # patience = 0
        # if c_fold >= 1:
        #     model.load_state_dict(torch.load('/home/weiliao/FR-TSVAE/start_weights.pt'))
        print('Starting Fold %d' % c_fold)
        print("TRAIN:", len(train_index), "TEST:", len(test_index))
        train_head, val_head = utils.slice_data(trainval_head, train_index), utils.slice_data(trainval_head, test_index)
        train_static, val_static = utils.slice_data(trainval_static, train_index), utils.slice_data(trainval_static, test_index)
        train_stail, val_stail = utils.slice_data(trainval_stail, train_index), utils.slice_data(trainval_stail, test_index)
        train_id, val_id = utils.slice_data(trainval_ids, train_index), utils.slice_data(trainval_ids, test_index)

        train_dataloader, dev_dataloader, test_dataloader = prepare_data.get_data_loader(args, train_head, val_head,
                                                                                            test_head, 
                                                                                            train_stail, val_stail,
                                                                                            test_sofa,
                                                                                            train_static=train_static,
                                                                                            dev_static=val_static,
                                                                                            test_static=test_static,
                                                                                            train_id=train_id,
                                                                                            dev_id=val_id,
                                                                                            test_id=test_id)
    # prepare id_20 switch 34590507 to 32882608
    stay_ids = [33013986, 37338822, 31972580, 35364108, 34994922, 
                32087563, 36691371, 31438123, 33266445, 36424894, 
                32613134, 38247671, 33280018, 33676100, 30932864, 
                30525046, 33267162, 36431990, 31303162, 37216041]
    idmap = {}
    for i, ids in enumerate(test_id): 
        idmap[ids] = i 
    # convert id to index 
    id_20 = []
    for ids in stay_ids: 
        id_20.append(idmap[ids])
    # deal with CM and AUC plots, 3d 
    for p in all_path: 
        # creat a subfolder to save the test results
        curr_dir = base_dir + wn + '/'+ p.split('/')[-1].split('.')[0]
        if not os.path.exists(curr_dir):
            os.makedirs(curr_dir)
        model.load_state_dict(torch.load(p, map_location='cuda:0'))
        print(p)
        #  test on test loader 
        model.eval()
        logits = []
        stt = []
        sofa_loss = []
        with torch.no_grad():
            for vitals, static, target, train_ids, key_mask in test_dataloader:
                # (bs, feature_dim, T)
                vitals = vitals.to(device)
                # (bs, 3)
                static = static.to(device)
                # (bs, T, 1)
                target = target.to(device)
                # (bs, T)
                key_mask = key_mask.to(device)

                # _mu shape [bs, zdim, T]
                _mu, _logvar = model.encoder(vitals)
                # b_logits [bs, 3, T]
                b_logits = _mu[:, model.sens_idx, :]
                mu = _mu[:, model.nonsens_idx, :]
                # (bs, T, 1)
                sofa_p = model.regr(mu.transpose(1, 2), "classify")
                sofa_loss.extend([mse_loss(sofa_p[i][key_mask[i]==0], target[i][key_mask[i]==0]) for i in range(len(sofa_p))])
                # for static info prediction 
                if with_mask: 
                    logits.extend(torch.stack([b_logits[i][:, key_mask[i]==0].mean(dim=-1)  for i in range(len(b_logits))]))
                else:
                    logits.extend(torch.stack([b_logits[i].squeeze(0).mean() for i in range(len(b_logits))]))
                stt.extend(static)
        # save the sofa test result to file 
        test_loss = torch.mean(torch.stack(sofa_loss)).cpu().numpy()
        # compute SOFA loss CI 
        n_bootstraps = 1000
        size_bstr = 3000
        loss_np = torch.stack(sofa_loss).cpu().numpy()
        bootstrapped_losses = []
        for i in range(n_bootstraps):
            b = random.choices(loss_np, k=size_bstr)
            bootstrapped_losses.append(np.mean(b))
        bootstrapped_losses.sort()
        # a 95% confidence interval
        loss_ci_l = bootstrapped_losses[int(0.025 * len(bootstrapped_losses))]
        loss_ci_h = bootstrapped_losses[int(0.975 * len(bootstrapped_losses))]
        
        logits = torch.stack(logits)
        stt = torch.stack(stt) 
        # AUC CM section needs enumeration 
        for sens_i, sens_ind in enumerate([0, 1, 21]):
            # display AUC 
            metrics.RocCurveDisplay.from_predictions(stt[:, sens_i].cpu(),  nn.Sigmoid()(logits[:, sens_i]).cpu())  
#             fig = plt.gcf()
#             plt.show()
#             fig.savefig(curr_dir + '/auc.eps', format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

            # calculate optimal threshold
            fpr, tpr, thresholds = metrics.roc_curve(stt[:, sens_i].cpu(),  nn.Sigmoid()(logits[:, sens_i]).cpu())
            gmeans = np.sqrt(tpr * (1-fpr))
            opt_th = thresholds[np.argmax(gmeans)]

            # AUC CI 
            bootstrapped_aucs = []
            for i in range(n_bootstraps):
                # bootstrap by sampling with replacement on the prediction indices
                indices = random.choices(range(len(stt)), k=1000)
                score = metrics.roc_auc_score(stt[:, sens_i].cpu()[indices], nn.Sigmoid()(logits[:, sens_i]).cpu()[indices])
                bootstrapped_aucs.append(score)
            bootstrapped_aucs.sort()
            # a 95% confidence interval
            auc_ci_l = bootstrapped_aucs[int(0.025 * len(bootstrapped_aucs))]
            auc_ci_h = bootstrapped_aucs[int(0.975 * len(bootstrapped_aucs))]
            msg = 'test loss %.5f, sofa loss ci (%.5f - %.5f), auc ci (%.5f - %.5f)'%(test_loss, loss_ci_l, loss_ci_h, auc_ci_l, auc_ci_h)
            print(msg)

#             with open(curr_dir + '/sofa_test.json', 'w') as f:
#                 msg = 'test loss %.5f, sofa loss ci (%.5f - %.5f), auc ci (%.5f - %.5f)'%(test_loss, loss_ci_l, loss_ci_h, auc_ci_l, auc_ci_h)
#                 json.dump(msg, f)

            for th in [0.5, opt_th]:

                pred =  (nn.Sigmoid()(logits[:, sens_i]).cpu() > th).float()
                cm = metrics.confusion_matrix(stt[:, sens_i].cpu(), pred)
                cf_matrix = cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)
                group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()]
                # percentage based on true label 
                gr = (cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)).flatten()
                group_percentages = ['{0:.2%}'.format(value) for value in gr]

                labels = [f'{v1}\n{v2}' for v1, v2 in zip(group_percentages, group_counts)]

                labels = np.asarray(labels).reshape(2, 2)

                if sens_ind == 0:
                    xlabel = ['Pred-%s'%l for l in ['F', 'M']]
                    ylabel = ['%s'%l for l in ['F', 'M']]   
                elif sens_ind == 1: 
                    xlabel = ['Pred-%s'%l for l in ['Y', 'E']]
                    ylabel = ['%s'%l for l in ['Y', 'E']]   
                elif sens_ind == 21: 
                    xlabel = ['Pred-%s'%l for l in ['W', 'B']]
                    ylabel = ['%s'%l for l in ['W', 'B']]   

                sns.set(font_scale = 1.5)

                hm = sns.heatmap(cf_matrix, annot=labels, fmt='', cmap = 'OrRd', \
                annot_kws={"fontsize": 16}, xticklabels=xlabel, yticklabels=ylabel, cbar=False)
                # hm.set(title=title)
#                 fig = plt.gcf()
#                 plt.show()  
#                 fig.savefig(curr_dir + '/cm_%f.eps'%th, format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)
        
        for v in ['raw', 'smooth']: 
            fig, ax = plt.subplots(4, 5, figsize=(25, 12))
            axes = ax.flatten()
            for i in range(20):
                id = test_id[id_20[i]]
                sofa = mimic_target['test'][id]
                _mu, _logvar = model.encoder(torch.FloatTensor(test_head[id_20[i]]).unsqueeze(0).to(device))
                mu = _mu[:, model.nonsens_idx, :]
                sofa_p = model.regr(mu.transpose(1, 2), "classify").squeeze(0).cpu().detach().numpy()*15
                if v == 'smooth': 
                    sofa_p = [np.round(i) for i in sofa_p]
                n = len(sofa)
                axes[i].plot(range(len(sofa)), sofa, label='Current SOFA')
                axes[i].plot(range(24, n), sofa_p, c="tab:green", label ='Predicted SOFA')
                axes[i].set_xlim((0, len(sofa)))
                axes[i].tick_params(axis='both', labelsize=8)
                if max(sofa) <= 11 and int(max(sofa_p)) <=11:
                    axes[i].set_ylim((0, 13))
                else: 
                    axes[i].set_ylim((0, max(max(sofa), int(max(sofa_p)))+2))
                if i == 0: 
                    axes[i].set_ylabel('SOFA score', size=14,  fontweight='bold')
                if i == 19:
                    axes[i].set_xlabel('ICU_in Hours', size=14,  fontweight='bold')
                if i == 4:
                    axes[i].legend(loc='upper right',  prop=legend_properties)
                # save each small figure 
                fig_s, ax_s = plt.subplots(1, 1, figsize=(5, 3))
                ax_s.plot(range(len(sofa)), sofa, label='Current SOFA')
                ax_s.plot(range(24, n), sofa_p, c="tab:green", label ='Predicted SOFA')
                ax_s.set_xlim((0, len(sofa)))
                ax_s.tick_params(axis='both', labelsize=6)
                if max(sofa) <= 11 and int(max(sofa_p)) <=11:
                    ax_s.set_ylim((0, 13))
                else: 
                    ax_s.set_ylim((0, max(max(sofa), int(max(sofa_p)))+2))
                ax_s.set_ylabel('SOFA score', size=12,  fontweight='bold')
                ax_s.set_xlabel('ICU_in Hours', size=12,  fontweight='bold')
                ax_s.legend(loc='upper right',  prop=legend_properties_s)
#                 fig_s.savefig(curr_dir + '/indiv_sofa_%s_%d.eps'%(v, i), format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

#             # save the big figure 
#             fig.savefig(curr_dir + '/sofa_%.5f_%s.eps'%(test_loss, v), format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)


In [3]:
# 3 sens version 2 with weights 
# check the 3 sense case
import argparse
import os 
import json
import glob 
import copy 
import pickle
import random 
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 
mse_loss = nn.MSELoss()
import utils
from utils import AverageMeterSet
import prepare_data
import models
from sklearn.model_selection import KFold
import sklearn.metrics as metrics
kf = KFold(n_splits=5, random_state=None, shuffle=False)
from datetime import date
today = date.today()
date = today.strftime("%m%d")
import matplotlib.pyplot as plt
import matplotlib 
import seaborn as sns 
matplotlib.rcParams["figure.dpi"] = 300
plt.style.use('bmh')
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
legend_properties = {'weight':'bold', 'size': 14}
legend_properties_s = {'weight':'bold', 'size': 10}
dir_data = {'satori': '/nobackup/users/weiliao', 'colab':'/content/drive/MyDrive/ColabNotebooks/MIMIC/Extract/MEEP/Extracted_sep_2022/0910'}
# load data
meep_mimic = np.load(dir_data['satori'] + '/MIMIC_compile_0911_2022.npy', \
                allow_pickle=True).item()
train_vital = meep_mimic ['train_head']
dev_vital = meep_mimic ['dev_head']
test_vital = meep_mimic ['test_head']
mimic_static = np.load(dir_data['satori'] + '/MIMIC_static_0922_2022.npy', \
                        allow_pickle=True).item()
mimic_target = np.load(dir_data['satori'] + '/MIMIC_target_0922_2022.npy', \
                        allow_pickle=True).item()
 
class Args:
    def __init__(self, d=None):
        if d is not None:
            for key, value in d.items():
                setattr(self, key, value)

base_dir = '/home/weiliao/FR-TSVAE/checkpoints/'
workname_list = [
    '0519_lr1e-4beta.001_res_regrtheta_1_mlp_regr_nonsens_sens0_mask_3sens'
]

with_mask = True
for wn in workname_list:

    all_path = glob.glob(base_dir + wn + '/*.pt')
    # all_path = [p for p in all_path if "stage2" not in p]
    with open(base_dir+wn+'/params.json') as f: 
        params = json.load(f)
    params['platform'] = 'satori'
    params['device_id'] = 0 
    args = Args(params)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    train_head, train_static, train_sofa, train_id =  utils.crop_data_target('mimic', train_vital, mimic_target, mimic_static, 'train', args.sens_ind)
    dev_head, dev_static, dev_sofa, dev_id =  utils.crop_data_target('mimic', dev_vital , mimic_target, mimic_static, 'dev',  args.sens_ind)
    test_head, test_static, test_sofa, test_id =  utils.crop_data_target('mimic', test_vital, mimic_target, mimic_static, 'test',  args.sens_ind)

    if args.use_sepsis3 == True:
        train_head, train_static, train_sofa, train_id = utils.filter_sepsis('mimic', train_head, train_static, train_sofa, train_id, args.platform)
        dev_head, dev_static, dev_sofa, dev_id = utils.filter_sepsis('mimic', dev_head, dev_static, dev_sofa, dev_id, args.platform)
        test_head, test_static, test_sofa, test_id = utils.filter_sepsis('mimic', test_head, test_static, test_sofa, test_id, args.platform)

    # build model
   
    # torch.save(model.state_dict(), '/home/weiliao/FR-TSVAE/start_weights.pt')

    # 10-fold cross validation
    trainval_head = train_head + dev_head
    trainval_static = train_static + dev_static
    trainval_stail = train_sofa + dev_sofa
    trainval_ids = train_id + dev_id

    # prepare data
    torch.autograd.set_detect_anomaly(True)
    for c_fold, (train_index, test_index) in enumerate(kf.split(trainval_head)):
        # best_loss = 1e4
        # patience = 0
        # if c_fold >= 1:
        #     model.load_state_dict(torch.load('/home/weiliao/FR-TSVAE/start_weights.pt'))
        print('Starting Fold %d' % c_fold)
        print("TRAIN:", len(train_index), "TEST:", len(test_index))
        train_head, val_head = utils.slice_data(trainval_head, train_index), utils.slice_data(trainval_head, test_index)
        train_static, val_static = utils.slice_data(trainval_static, train_index), utils.slice_data(trainval_static, test_index)
        train_stail, val_stail = utils.slice_data(trainval_stail, train_index), utils.slice_data(trainval_stail, test_index)
        train_id, val_id = utils.slice_data(trainval_ids, train_index), utils.slice_data(trainval_ids, test_index)
        weights_per_class = []
        for i in range(3):
            ctype, count= np.unique(np.asarray(val_static)[:, i], return_counts=True)
            total_dev_samples = len(val_static)
            curr = torch.FloatTensor([ total_dev_samples / k / len(ctype) for k in count]).to(device)
            weights_per_class.append(curr)

        train_dataloader, dev_dataloader, test_dataloader = prepare_data.get_data_loader(args, train_head, val_head,
                                                                                            test_head, 
                                                                                            train_stail, val_stail,
                                                                                            test_sofa,
                                                                                            train_static=train_static,
                                                                                            dev_static=val_static,
                                                                                            test_static=test_static,
                                                                                            train_id=train_id,
                                                                                            dev_id=val_id,
                                                                                            test_id=test_id)
        model = models.Ffvae(args, weights_per_class)
    # prepare id_20 switch 34590507 to 32882608
    stay_ids = [33013986, 37338822, 31972580, 35364108, 34994922, 
                32087563, 36691371, 31438123, 33266445, 36424894, 
                32613134, 38247671, 33280018, 33676100, 30932864, 
                30525046, 33267162, 36431990, 31303162, 37216041]
    idmap = {}
    for i, ids in enumerate(test_id): 
        idmap[ids] = i 
    # convert id to index 
    id_20 = []
    for ids in stay_ids: 
        id_20.append(idmap[ids])
    # deal with CM and AUC plots, 3d 
    for p in all_path: 
        # creat a subfolder to save the test results
        curr_dir = base_dir + wn + '/'+ p.split('/')[-1].split('.')[0]
        if not os.path.exists(curr_dir):
            os.makedirs(curr_dir)
        model.load_state_dict(torch.load(p, map_location='cuda:0'))
        print(p)
        #  test on test loader 
        model.eval()
        logits = []
        stt = []
        sofa_loss = []
        with torch.no_grad():
            for vitals, static, target, train_ids, key_mask in test_dataloader:
                # (bs, feature_dim, T)
                vitals = vitals.to(device)
                # (bs, 3)
                static = static.to(device)
                # (bs, T, 1)
                target = target.to(device)
                # (bs, T)
                key_mask = key_mask.to(device)

                # _mu shape [bs, zdim, T]
                _mu, _logvar = model.encoder(vitals)
                # b_logits [bs, 3, T]
                b_logits = _mu[:, model.sens_idx, :]
                mu = _mu[:, model.nonsens_idx, :]
                # (bs, T, 1)
                sofa_p = model.regr(mu.transpose(1, 2), "classify")
                sofa_loss.extend([mse_loss(sofa_p[i][key_mask[i]==0], target[i][key_mask[i]==0]) for i in range(len(sofa_p))])
                # for static info prediction 
                if with_mask: 
                    logits.extend(torch.stack([b_logits[i][:, key_mask[i]==0].mean(dim=-1)  for i in range(len(b_logits))]))
                else:
                    logits.extend(torch.stack([b_logits[i].squeeze(0).mean() for i in range(len(b_logits))]))
                stt.extend(static)
        # save the sofa test result to file 
        test_loss = torch.mean(torch.stack(sofa_loss)).cpu().numpy()
        # compute SOFA loss CI 
        n_bootstraps = 1000
        size_bstr = 3000
        loss_np = torch.stack(sofa_loss).cpu().numpy()
        bootstrapped_losses = []
        for i in range(n_bootstraps):
            b = random.choices(loss_np, k=size_bstr)
            bootstrapped_losses.append(np.mean(b))
        bootstrapped_losses.sort()
        # a 95% confidence interval
        loss_ci_l = bootstrapped_losses[int(0.025 * len(bootstrapped_losses))]
        loss_ci_h = bootstrapped_losses[int(0.975 * len(bootstrapped_losses))]
        
        logits = torch.stack(logits)
        stt = torch.stack(stt) 
        # AUC CM section needs enumeration 
        for sens_i, sens_ind in enumerate([0, 1, 21]):
            # display AUC 
            metrics.RocCurveDisplay.from_predictions(stt[:, sens_i].cpu(),  nn.Sigmoid()(logits[:, sens_i]).cpu())  
#             fig = plt.gcf()
#             plt.show()
#             fig.savefig(curr_dir + '/auc.eps', format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

            # calculate optimal threshold
            fpr, tpr, thresholds = metrics.roc_curve(stt[:, sens_i].cpu(),  nn.Sigmoid()(logits[:, sens_i]).cpu())
            gmeans = np.sqrt(tpr * (1-fpr))
            opt_th = thresholds[np.argmax(gmeans)]
            all_auc = metrics.auc(fpr, tpr)

            # AUC CI 
            bootstrapped_aucs = []
            for i in range(n_bootstraps):
                # bootstrap by sampling with replacement on the prediction indices
                indices = random.choices(range(len(stt)), k=1000)
                score = metrics.roc_auc_score(stt[:, sens_i].cpu()[indices], nn.Sigmoid()(logits[:, sens_i]).cpu()[indices])
                bootstrapped_aucs.append(score)
            bootstrapped_aucs.sort()
            # a 95% confidence interval
            auc_ci_l = bootstrapped_aucs[int(0.025 * len(bootstrapped_aucs))]
            auc_ci_h = bootstrapped_aucs[int(0.975 * len(bootstrapped_aucs))]
            msg = 'test loss %.5f, auc %.5f, sofa loss ci (%.5f - %.5f), auc ci (%.5f - %.5f)'%(test_loss, all_auc, loss_ci_l, loss_ci_h, auc_ci_l, auc_ci_h)
            print(msg)

#             with open(curr_dir + '/sofa_test.json', 'w') as f:
#                 msg = 'test loss %.5f, sofa loss ci (%.5f - %.5f), auc ci (%.5f - %.5f)'%(test_loss, loss_ci_l, loss_ci_h, auc_ci_l, auc_ci_h)
#                 json.dump(msg, f)

            for th in [0.5, opt_th]:

                pred =  (nn.Sigmoid()(logits[:, sens_i]).cpu() > th).float()
                cm = metrics.confusion_matrix(stt[:, sens_i].cpu(), pred)
                cf_matrix = cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)
                group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()]
                # percentage based on true label 
                gr = (cm/np.repeat(np.expand_dims(np.sum(cm, axis=1), axis=-1), 2, axis=1)).flatten()
                group_percentages = ['{0:.2%}'.format(value) for value in gr]

                labels = [f'{v1}\n{v2}' for v1, v2 in zip(group_percentages, group_counts)]

                labels = np.asarray(labels).reshape(2, 2)

                if sens_ind == 0:
                    xlabel = ['Pred-%s'%l for l in ['F', 'M']]
                    ylabel = ['%s'%l for l in ['F', 'M']]   
                elif sens_ind == 1: 
                    xlabel = ['Pred-%s'%l for l in ['Y', 'E']]
                    ylabel = ['%s'%l for l in ['Y', 'E']]   
                elif sens_ind == 21: 
                    xlabel = ['Pred-%s'%l for l in ['W', 'B']]
                    ylabel = ['%s'%l for l in ['W', 'B']]   

                sns.set(font_scale = 1.5)

                hm = sns.heatmap(cf_matrix, annot=labels, fmt='', cmap = 'OrRd', \
                annot_kws={"fontsize": 16}, xticklabels=xlabel, yticklabels=ylabel, cbar=False)
                # hm.set(title=title)
#                 fig = plt.gcf()
#                 plt.show()  
#                 fig.savefig(curr_dir + '/cm_%f.eps'%th, format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)
        
        for v in ['raw', 'smooth']: 
            fig, ax = plt.subplots(4, 5, figsize=(25, 12))
            axes = ax.flatten()
            for i in range(20):
                id = test_id[id_20[i]]
                sofa = mimic_target['test'][id]
                _mu, _logvar = model.encoder(torch.FloatTensor(test_head[id_20[i]]).unsqueeze(0).to(device))
                mu = _mu[:, model.nonsens_idx, :]
                sofa_p = model.regr(mu.transpose(1, 2), "classify").squeeze(0).cpu().detach().numpy()*15
                if v == 'smooth': 
                    sofa_p = [np.round(i) for i in sofa_p]
                n = len(sofa)
                axes[i].plot(range(len(sofa)), sofa, label='Current SOFA')
                axes[i].plot(range(24, n), sofa_p, c="tab:green", label ='Predicted SOFA')
                axes[i].set_xlim((0, len(sofa)))
                axes[i].tick_params(axis='both', labelsize=8)
                if max(sofa) <= 11 and int(max(sofa_p)) <=11:
                    axes[i].set_ylim((0, 13))
                else: 
                    axes[i].set_ylim((0, max(max(sofa), int(max(sofa_p)))+2))
                if i == 0: 
                    axes[i].set_ylabel('SOFA score', size=14,  fontweight='bold')
                if i == 19:
                    axes[i].set_xlabel('ICU_in Hours', size=14,  fontweight='bold')
                if i == 4:
                    axes[i].legend(loc='upper right',  prop=legend_properties)
                # save each small figure 
                fig_s, ax_s = plt.subplots(1, 1, figsize=(5, 3))
                ax_s.plot(range(len(sofa)), sofa, label='Current SOFA')
                ax_s.plot(range(24, n), sofa_p, c="tab:green", label ='Predicted SOFA')
                ax_s.set_xlim((0, len(sofa)))
                ax_s.tick_params(axis='both', labelsize=6)
                if max(sofa) <= 11 and int(max(sofa_p)) <=11:
                    ax_s.set_ylim((0, 13))
                else: 
                    ax_s.set_ylim((0, max(max(sofa), int(max(sofa_p)))+2))
                ax_s.set_ylabel('SOFA score', size=12,  fontweight='bold')
                ax_s.set_xlabel('ICU_in Hours', size=12,  fontweight='bold')
                ax_s.legend(loc='upper right',  prop=legend_properties_s)
#                 fig_s.savefig(curr_dir + '/indiv_sofa_%s_%d.eps'%(v, i), format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)

#             # save the big figure 
#             fig.savefig(curr_dir + '/sofa_%.5f_%s.eps'%(test_loss, v), format='eps', bbox_inches = 'tight', pad_inches = 0.1, dpi=1200)


Starting Fold 0
TRAIN: 8964 TEST: 2242
8964
Starting Fold 1
TRAIN: 8965 TEST: 2241
8965
Starting Fold 2
TRAIN: 8965 TEST: 2241
8965
Starting Fold 3
TRAIN: 8965 TEST: 2241
8965
Starting Fold 4
TRAIN: 8965 TEST: 2241
8965
/home/weiliao/FR-TSVAE/checkpoints/0519_lr1e-4beta.001_res_regrtheta_1_mlp_regr_nonsens_sens0_mask_3sens/stage1_clfw_fold_0_epoch149.pt
test loss 0.01601, sofa loss ci (0.01514 - 0.01689), auc ci (0.71922 - 0.78086)
test loss 0.01601, sofa loss ci (0.01514 - 0.01689), auc ci (0.73456 - 0.79252)
test loss 0.01601, sofa loss ci (0.01514 - 0.01689), auc ci (0.70116 - 0.79493)




/home/weiliao/FR-TSVAE/checkpoints/0519_lr1e-4beta.001_res_regrtheta_1_mlp_regr_nonsens_sens0_mask_3sens/stage1_corr_fold_0_epoch149.pt
test loss 0.01410, sofa loss ci (0.01335 - 0.01499), auc ci (0.75009 - 0.80580)
test loss 0.01410, sofa loss ci (0.01335 - 0.01499), auc ci (0.75196 - 0.81071)
test loss 0.01410, sofa loss ci (0.01335 - 0.01499), auc ci (0.49296 - 0.61692)
/home/weiliao/FR-TSVAE/checkpoints/0519_lr1e-4beta.001_res_regrtheta_1_mlp_regr_nonsens_sens0_mask_3sens/stage1_fold_0_epoch149.pt
test loss 0.01388, sofa loss ci (0.01319 - 0.01467), auc ci (0.71852 - 0.77926)
test loss 0.01388, sofa loss ci (0.01319 - 0.01467), auc ci (0.74256 - 0.80050)
test loss 0.01388, sofa loss ci (0.01319 - 0.01467), auc ci (0.47569 - 0.60137)
/home/weiliao/FR-TSVAE/checkpoints/0519_lr1e-4beta.001_res_regrtheta_1_mlp_regr_nonsens_sens0_mask_3sens/stage1_clf_fold_0_epoch149.pt
test loss 0.01497, sofa loss ci (0.01425 - 0.01574), auc ci (0.71177 - 0.76787)
test loss 0.01497, sofa loss ci (0.014

In [None]:
for c_fold, (train_index, test_index) in enumerate(kf.split(trainval_head)):
    # best_loss = 1e4
    # patience = 0
    # if c_fold >= 1:
    #     model.load_state_dict(torch.load('/home/weiliao/FR-TSVAE/start_weights.pt'))
    print('Starting Fold %d' % c_fold)
    print("TRAIN:", len(train_index), "TEST:", len(test_index))
    train_head, val_head = utils.slice_data(trainval_head, train_index), utils.slice_data(trainval_head, test_index)
    train_static, val_static = utils.slice_data(trainval_static, train_index), utils.slice_data(trainval_static, test_index)
    train_stail, val_stail = utils.slice_data(trainval_stail, train_index), utils.slice_data(trainval_stail, test_index)
    train_id, val_id = utils.slice_data(trainval_ids, train_index), utils.slice_data(trainval_ids, test_index)

    print(plt.hist(np.asarray(train_static)))
    print(plt.hist(np.asarray(val_static)))
#     train_dataloader, dev_dataloader, test_dataloader = prepare_data.get_data_loader(args, train_head, val_head,
#                                                                                         test_head, 
#                                                                                         train_stail, val_stail,
#                                                                                         test_sofa,
#                                                                                         train_static=train_static,
#                                                                                         dev_static=val_static,
#                                                                                         test_static=test_static,
#                                                                                         train_id=train_id,
#                                                                                         dev_id=val_id,
#                                                                                         test_id=test_id)

In [None]:
weights_per_class = []
for i in range(3):
    ctype, count= np.unique(np.asarray(val_static)[:, i], return_counts=True)
    total_dev_samples = len(val_static)
    curr = torch.FloatTensor([ total_dev_samples / k / len(ctype) for k in count]).to(device)
    weights_per_class.append(curr)
#     ce_val_loss = nn.CrossEntropyLoss(weight = weights_per_class)

In [6]:
all_auc = metrics.auc(fpr, tpr)
msg = 'test loss %.5f, auc %.5f, sofa loss ci (%.5f - %.5f), auc ci (%.5f - %.5f)'%(test_loss, all_auc, loss_ci_l, loss_ci_h, auc_ci_l, auc_ci_h)
print(msg)

test loss 0.01381, auc 0.55354, sofa loss ci (0.01306 - 0.01456), auc ci (0.49240 - 0.62051)


In [None]:
# use b_logits and static 
b_squeeze.shape

In [None]:
static.shape

In [82]:
b_squeeze = torch.stack([b_logits[i][:, key_mask[i]==0].mean(dim=-1)  for i in range(len(b_logits))])
clf_losses = [
            nn.BCEWithLogitsLoss(pos_weight = weights_per_class[k][1]/weights_per_class[k][0])(_b_logit.to(device), _a_sens.to(device))
            for k, (_b_logit, _a_sens) in enumerate(zip(
            b_squeeze.t(), static.type(torch.FloatTensor).t()))]

In [83]:
clf_losses

[tensor(0.6690, device='cuda:0'),
 tensor(0.3711, device='cuda:0'),
 tensor(0.6934, device='cuda:0')]

In [None]:
weights_per_class

In [79]:
 for k, (_b_logit, _a_sens) in enumerate(zip(b_squeeze.t(), static.type(torch.FloatTensor).t())):
        print(_b_logit, _a_sens)
        c = nn.BCEWithLogitsLoss(pos_weight = weights_per_class[k][1]/weights_per_class[k][0])
        print(c(_b_logit.to(device), _a_sens.to(device)))
        

tensor([0.3125, 0.3751, 0.1742, 0.2965], device='cuda:0') tensor([1., 0., 1., 0.])
tensor(0.6690, device='cuda:0')
tensor([0.0493, 0.0386, 2.7230, 4.4408], device='cuda:0') tensor([0., 0., 1., 1.])
tensor(0.3711, device='cuda:0')
tensor([6.1603e-06, 1.7574e-03, 0.0000e+00, 0.0000e+00], device='cuda:0') tensor([0., 0., 0., 0.])
tensor(0.6934, device='cuda:0')


In [None]:
b_squeeze

In [None]:
_b_logit.size()

In [None]:
_a_sens.size()

In [None]:
pos_weight = torch.ones([64]) 

In [None]:
pos_weight

tensor(0.7986, device='cuda:0')

In [None]:
c