 This notebook is part of an experiment where I compare the ECE between two models where one was trained for 15 epochs and the other for 1 epoch

Using a deterministic model, creates several outputs for each patient in the validation/test set. Saves into csv. The purpose is to do analyses later on 

# Imports

In [25]:
import os
cwd = os.getcwd()

# protection against running this cell multiple times
assert os.path.dirname(cwd).split('/')[-1] == 'master-thesis','Oops, directory already changed previously as indended. Ignoring...'

# change working directory (if assert passed)
new_cwd = os.path.dirname(cwd) # parent directory
os.chdir(new_cwd)

AssertionError: Oops, directory already changed previously as indended. Ignoring...

In [26]:
import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model, RNN

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence

import numpy as np
import pandas as pd

from config import Settings; settings = Settings()

# display all outputs
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Parameters

In [27]:
num_passes = 30 # number of (variational) forward passes for each input
model_name = 'pleasant-music-50_0epoch'
dataset_id = 'diag_only'
experiment_id = 'C'

sanity check

In [28]:
model_folder = os.path.join(settings.data_base,settings.models_folder,model_name)
os.path.exists(model_folder)

dataset_folder = os.path.join(settings.data_base,settings.model_ready_dataset_folder,dataset_id)
os.path.exists(dataset_folder)

deterministic_folder = os.path.join(settings.data_base,settings.deterministic_data_folder)

if not os.path.exists(deterministic_folder):
    os.mkdir(deterministic_folder)

experiment_folder = os.path.join(deterministic_folder,experiment_id)

if os.path.exists(experiment_folder):
    raise Exception(f"Experiment {experiment_folder} exists. If you want to overwrite it, manually delete the directory first")
else:
    os.mkdir(experiment_folder)

True

True

# Reproducibility

In [29]:
# Reproducibility
seed = settings.random_seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x10cd6c790>

In [30]:
batch_size = 64 # really doesn't matter for this notebook since we will only to inference
grouping = 'ccs'

In [31]:
dataset = DiagnosesDataset(os.path.join(dataset_folder,'dataset.json'),grouping)

train_dataset = DiagnosesDataset(os.path.join(dataset_folder,'train_subset.json'),grouping)
val_dataset = DiagnosesDataset(os.path.join(dataset_folder,'val_subset.json'),grouping)
test_dataset = DiagnosesDataset(os.path.join(dataset_folder,'test_subset.json'),grouping)


len(train_dataset)
len(val_dataset)
len(test_dataset)


train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

5249

1125

1125

In [32]:
# model hyperparameters path
hypp_save_path = os.path.join(model_folder, 'hyper_parameters.json')
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)

# weights path
weights_save_path = os.path.join(model_folder,"weights")

new_model = RNN(**params_loaded)
new_model.load_state_dict(torch.load(weights_save_path))

<All keys matched successfully>

In [33]:
def deterministic_forward(model, dataloader, dataset, name):
    """
    """
    model.eval()
    
    # eg:: ccs, icd9, etc..
    code_type = dataset.grouping
    
    int2code = dataset.grouping_data[code_type]['int2code']
    
    assert all(np.diff(list(int2code.keys())) == 1), 'Expecting sorted codes, if this fails then it might be time to update column-naming related code'
    col_names = ['diag_' + str(code) for code in int2code.keys()]
    
    total_loss = 0
    total_seq = 0 #total sequences
    
    sigmoid = nn.Sigmoid()
    
    full_df = None
    
    # deterministic
    
    full_df_det = None
    model.eval()
    with torch.no_grad():
         for i, batch in enumerate(iter(dataloader)):
            # get the inputs; data is a list of [inputs, labels]
            history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

            _,lengths = pad_packed_sequence(history_sequences['sequence'])
            
            # we want to ignore the padded sequences
            relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

            # # just flattens the list
            relevant_positions = [item for sublist in relevant_positions for item in sublist]
            
            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = (pd.DataFrame(relevant_outs,
                               columns=col_names)
                  .assign(pat_id=batch['target_pids'])
                 )
            
            if full_df_det is None:
                full_df_det = df
            else:
                full_df_det = pd.concat([full_df_det,df])
    
    full_df_det['adm_index'] = full_df_det.groupby(['pat_id']).cumcount()+1
    full_df_det = full_df_det.reset_index(drop=True)
    full_df_det[['pat_id','adm_index']] = full_df_det[['pat_id','adm_index']].astype(int)
    # reorder columns
    full_df_det = full_df_det.set_index(['pat_id','adm_index']).sort_index()
    
    
    
    # Now to store the true labels
    
    full_df_golden = None
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        relevant_labels = target_sequences['sequence'].view(1,-1,target_sequences['sequence'].size()[2])[:,relevant_positions,:].squeeze().detach().numpy()
        
        df = pd.DataFrame(relevant_labels,columns=col_names).assign(pat_id=batch['target_pids'])
        
        if full_df_golden is None:
            full_df_golden = df
        else:
            full_df_golden = pd.concat([full_df_golden,df])
            
    full_df_golden['adm_index'] = full_df_golden.groupby(['pat_id']).cumcount()+1
    full_df_golden = full_df_golden.reset_index(drop=True)
    full_df_golden[['pat_id','adm_index']] = full_df_golden[['pat_id','adm_index']].astype(int)
    
    # reorder columns
    full_df_golden = full_df_golden.set_index(['pat_id','adm_index']).sort_index()
            
    return full_df_det,full_df_golden

In [34]:
det,golden = deterministic_forward(new_model,val_dataloader,dataset,'validation')

# save into file

In [35]:
save_det_path = os.path.join(experiment_folder,'deterministic_forward.csv')
save_golden_path = os.path.join(experiment_folder,'golden.csv')

In [36]:
mapping = {save_det_path:det,
           save_golden_path:golden
          }
for save_path in mapping:
    mapping[save_path].to_csv(save_path)
    
# also save metadata
metadata = dict(model=model_name,
                num_passes=num_passes,
                dataset_id=dataset_id
               )

metadata_path = os.path.join(experiment_folder,'metadata.json')
with open(metadata_path,'w') as f:
    json.dump(metadata,f)
    
print('Saved!')

Saved!


# test it out

In [37]:
df_det = pd.read_csv(save_det_path,index_col=[0,1])

df_det.head(3)
df_det.shape

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,0.535628,0.525377,0.527379,0.497576,0.488512,0.485052,0.47947,0.513265,0.488949,0.526202,...,0.485728,0.464404,0.486183,0.505239,0.518107,0.519951,0.49605,0.49119,0.512884,0.487418
23,1,0.511217,0.520478,0.525117,0.499467,0.487114,0.491753,0.483988,0.516149,0.493881,0.527973,...,0.484553,0.467816,0.486824,0.508698,0.522956,0.517079,0.504047,0.505217,0.521186,0.495382
61,1,0.513139,0.526945,0.517097,0.497945,0.496405,0.482006,0.486107,0.516806,0.487517,0.511737,...,0.488322,0.465248,0.481619,0.505342,0.518992,0.524232,0.500959,0.493123,0.511251,0.483938


(1943, 272)

In [38]:
df_golden = pd.read_csv(save_golden_path,index_col=[0,1])

df_golden.head(3)
df_golden.shape

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
23,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
61,1,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


(1943, 272)