In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('../')

In [6]:
import pytorch_lightning as pl
from lib.lightning.airl import AIRLLightning, AIRL_NODEGAM_Lightning
from lib.lightning.bc import BC_MIMIC3_Lightning
import numpy as np
import pandas as pd
import torch
from lib.sepsis_simulator.sepsisSimDiabetes.State import State

from lib.vis_utils import vis_main_effects
from lib.lightning.utils import load_best_model_from_trained_dir
import pickle
import cvxpy as cvx
from lib.utils import Timer
from lib.mimic3.dataset import HypotensionDataset
from lib.nodegam.utils import bin_data
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [None]:
from lib.nodegam.utils import extract_GAM, process_in_chunks, check_numpy

In [9]:
device = 'cuda'

model = load_best_model_from_trained_dir('0909_bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5')
model.to(device)

BC_MIMIC3_Lightning(
  (gru): GRU(74, 64, batch_first=True)
  (out): Sequential(
    (0): Swapaxes()
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Swapaxes()
    (3): Linear(in_features=64, out_features=256, bias=True)
    (4): ELU(alpha=1.0)
    (5): Swapaxes()
    (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Swapaxes()
    (8): Dropout(p=0.5, inplace=False)
    (9): Linear(in_features=256, out_features=256, bias=True)
    (10): ELU(alpha=1.0)
    (11): Swapaxes()
    (12): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): Swapaxes()
    (14): Dropout(p=0.5, inplace=False)
    (15): Linear(in_features=256, out_features=256, bias=True)
    (16): ELU(alpha=1.0)
    (17): Swapaxes()
    (18): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): Swapaxes()
    (20): Dropout(p=0.5, inplace=False)
    (21): Linea

In [10]:
loader = HypotensionDataset.make_loader(
    split='all',
    batch_size=128,
    shuffle=False,
    num_workers=0,
)

In [22]:
results = []
for batch in loader:
    x_len = [v.size(0) for v in batch]
    x_pad = pad_sequence(batch, batch_first=True)
    x_pad = x_pad.cuda()
    
    x_packed = pack_padded_sequence(x_pad, x_len, enforce_sorted=False, batch_first=True)
    out, hiddens = model.gru(x_packed)
    out_padded, _ = pad_packed_sequence(out, batch_first=True)
    # ^-- [batch_size, max_len, hidden dim]
    pred = model.out(out_padded)
    
    prob = F.softmax(pred, dim=-1).cpu()
    results.extend([p[:the_len] for p, the_len in zip(prob, x_len)])

In [23]:
len(results)

9404

In [25]:
len(loader.dataset.icustay_ids)

9404

In [27]:
bc_prob = {k: v for k, v in zip(loader.dataset.icustay_ids, results)}

In [33]:
torch.save(bc_prob, 'data/model-data3/bc_probs.pkl')

In [5]:
bc_prob = torch.load('data/model-data3/bc_probs.pkl')

In [8]:
arr = []
for k, p in bc_prob.items():
    arr.append(p)
arr = torch.cat(arr)

In [9]:
arr.shape

torch.Size([262578, 16])

In [10]:
arr.mean(dim=0)

tensor([0.6563, 0.0229, 0.0271, 0.0225, 0.1012, 0.0029, 0.0084, 0.0113, 0.0684,
        0.0022, 0.0045, 0.0069, 0.0516, 0.0029, 0.0039, 0.0072])

Check if it works and accuracy

In [12]:
from lib.mimic3.dataset import HypotensionWithBCProbDataset

In [13]:
the_cls = HypotensionWithBCProbDataset

loader = the_cls.make_loader(
    data_kwargs=dict(
        fold=0,
        preprocess='quantile',
    ),
    split='test',
    batch_size=256,
    shuffle=False,
    num_workers=0,
)

Finish "Load cached normalized dataset: ./data/model-data3/normalized_states.pth" in 4.0s


In [35]:
from torch.nn.utils.rnn import pad_sequence, pack_sequence


bc_prob_total, total = 0, 0

for x in loader:
    x_list = x['x_list']
    bc_prob = x['bc_prob']
    
    x_len = [v.size(0) for v in x_list]
    x_pad = pad_sequence(x_list, batch_first=True)
    bc_prob_pad = pad_sequence(bc_prob, batch_first=True)

    states = HypotensionWithBCProbDataset.extract_cur_s(x_pad, state_type='all')
    actions = HypotensionWithBCProbDataset.extract_cur_a(x_pad, form='act_idx')

    # Construct the
    is_valid = states.new_zeros(*states.shape[:2]).bool()
    for idx, l in enumerate(x_len):
        is_valid[idx, :(l-1)] = True

    actions = actions[is_valid]
    bc_prob_pad = bc_prob_pad[:, :-1, :][is_valid]
    
    the_bc_prob = bc_prob_pad.gather(1, actions.unsqueeze(-1)).squeeze(-1)
    
    bc_prob_total += the_bc_prob.sum()
    total += the_bc_prob.shape[0]

In [36]:
bc_prob_total / total

tensor(0.7141)

## Other folds of models: what's their accuracy?

In [4]:
fold = 0

In [13]:
def bc_acc(fold=0, device='cuda'):
    with torch.no_grad():
        model = load_best_model_from_trained_dir(f'1018_bc_best_f{fold}__bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5')
        model.to(device)

        # the_cls = HypotensionWithBCProbDataset
        loader = HypotensionDataset.make_loader(
            data_kwargs=dict(
                fold=fold,
                preprocess='quantile',
            ),
            split='test',
            batch_size=256,
            shuffle=False,
            num_workers=0,
        )

        bc_prob_total, total = 0, 0
        for batch in loader:
            ## Cal model probability
            x_len = [v.size(0) for v in batch['x_list']]
            x_pad = pad_sequence(batch['x_list'], batch_first=True)
            x_pad = x_pad.to(device)

            x_packed = pack_padded_sequence(x_pad, x_len, enforce_sorted=False, batch_first=True)
            out, hiddens = model.gru(x_packed)
            out_padded, _ = pad_packed_sequence(out, batch_first=True)
            # ^-- [batch_size, max_len, hidden dim]
            pred = model.out(out_padded)

            bc_prob_pad = F.softmax(pred, dim=-1)

            ## Cal the expert actions
            states = HypotensionDataset.extract_cur_s(x_pad, state_type='all')
            actions = HypotensionDataset.extract_cur_a(x_pad, form='act_idx')

            # Construct the
            is_valid = states.new_zeros(*states.shape[:2]).bool()
            for idx, l in enumerate(x_len):
                is_valid[idx, :(l-1)] = True

            actions = actions[is_valid]
            bc_prob_pad = bc_prob_pad[:, :-1, :][is_valid]

            the_bc_prob = bc_prob_pad.gather(1, actions.unsqueeze(-1)).squeeze(-1)

            bc_prob_total += the_bc_prob.sum()
            total += the_bc_prob.shape[0]
        
        return bc_prob_total / total

In [17]:
accs = [bc_acc(fold) for fold in range(5)]
accs

rsync -avzL v:/h/kingsley/irl_nodegam/logs/1018_bc_best_f2__bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5 ./logs/
rsync -avzL v:/h/kingsley/irl_nodegam/logs/hparams/1018_bc_best_f3__bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5 ./logs/hparams/
rsync -avzL v:/h/kingsley/irl_nodegam/logs/1018_bc_best_f3__bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5 ./logs/
rsync -avzL v:/h/kingsley/irl_nodegam/logs/hparams/1018_bc_best_f4__bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5 ./logs/hparams/
rsync -avzL v:/h/kingsley/irl_nodegam/logs/1018_bc_best_f4__bc_s55_lr0.001_wd1e-05_bs128_nh64_nl1_dr0.0_fnh256_fnl4_fdr0.5 ./logs/


[tensor(0.7276, device='cuda:0'),
 tensor(0.7013, device='cuda:0'),
 tensor(0.7188, device='cuda:0'),
 tensor(0.7278, device='cuda:0'),
 tensor(0.7258, device='cuda:0')]

In [18]:
accs = [a.item() for a in accs]
np.mean(accs), np.std(accs)

(0.7202484011650085, 0.010020756865764154)