Using a variational model, creates several outputs for each patient in the validation/test set. Saves into csv. The purpose is to do analyses later on 

# Imports

In [1]:
import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model

from mourga_variational.variational_rnn import VariationalRNN

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence

import numpy as np
import pandas as pd

from config import Settings; settings = Settings()

# display all outputs
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Parameters

In [2]:
num_passes = 30 # number of (variational) forward passes for each input
model_name = 'golden-oath-84'
dataset_id = 'diag_only'
experiment_id = 'A'

sanity check

In [3]:
model_folder = os.path.join(settings.data_base,settings.models_folder,model_name)
os.path.exists(model_folder)

dataset_folder = os.path.join(settings.data_base,settings.model_ready_dataset_folder,dataset_id)
os.path.exists(dataset_folder)

experiment_folder = os.path.join(settings.data_base,settings.variational_data_folder,experiment_id)
if os.path.exists(experiment_folder):
    raise Exception(f"Experiment {experiment_folder} exists. If you want to overwrite it, manually delete the directory first")
else:
    os.mkdir(experiment_folder)

True

True

# Reproducibility

In [4]:
# Reproducibility
seed = settings.random_seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x7fcae2329db0>

In [5]:
batch_size = 64 # really doesn't matter for this notebook since we will only to inference
grouping = 'ccs'

In [6]:
dataset = DiagnosesDataset(os.path.join(dataset_folder,'dataset.json'),grouping)

train_dataset = DiagnosesDataset(os.path.join(dataset_folder,'train_subset.json'),grouping)
val_dataset = DiagnosesDataset(os.path.join(dataset_folder,'val_subset.json'),grouping)
test_dataset = DiagnosesDataset(os.path.join(dataset_folder,'test_subset.json'),grouping)


len(train_dataset)
len(val_dataset)
len(test_dataset)


train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

5249

1125

1125

In [7]:
# model hyperparameters path
hypp_save_path = os.path.join(model_folder, 'hyper_parameters.json')
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)

# weights path
weights_save_path = os.path.join(model_folder,"weights")

new_model = VariationalRNN(**params_loaded)
new_model.load_state_dict(torch.load(weights_save_path))

<All keys matched successfully>

In [9]:
def variational_forward(model, dataloader, dataset, name, num_passes=2):
    """
    """
    # to make sure the dropout mask is turned on. But we won't train here
    model.train()
    
    # eg:: ccs, icd9, etc..
    code_type = dataset.grouping
    
    int2code = dataset.grouping_data[code_type]['int2code']
    
    assert all(np.diff(list(int2code.keys())) == 1), 'Expecting sorted codes, if this fails then it might be time to update column-naming related code'
    col_names = ['diag_' + str(code) for code in int2code.keys()]
    
    total_loss = 0
    total_seq = 0 #total sequences
    
    sigmoid = nn.Sigmoid()
    
    full_df = None
    
    # variational
    
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']
        
        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]
        
        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        for pass_ in range(1,num_passes+1):

            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = (pd.DataFrame(relevant_outs,
                               columns=col_names)
                  .assign(pat_id=batch['target_pids'],
                          npass=pass_)
                 )
            
            if full_df is None:
                full_df = df
            else:
                full_df = pd.concat([full_df,df])
    
    full_df['adm_index'] = full_df.groupby(['pat_id','npass']).cumcount()+1
    full_df = full_df.reset_index(drop=True)
    full_df[['pat_id','adm_index','npass']] = full_df[['pat_id','adm_index','npass']].astype(int)
    # reorder columns
    full_df = full_df.set_index(['pat_id','adm_index','npass']).sort_index()
    
    # deterministic
    
    full_df_det = None
    model.eval()
    with torch.no_grad():
         for i, batch in enumerate(iter(dataloader)):
            # get the inputs; data is a list of [inputs, labels]
            history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

            _,lengths = pad_packed_sequence(history_sequences['sequence'])
            relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

            # just flatten the list
            relevant_positions = [item for sublist in relevant_positions for item in sublist]
            
            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = (pd.DataFrame(relevant_outs,
                               columns=col_names)
                  .assign(pat_id=batch['target_pids'])
                 )
            
            if full_df_det is None:
                full_df_det = df
            else:
                full_df_det = pd.concat([full_df_det,df])
    
    full_df_det['adm_index'] = full_df_det.groupby(['pat_id']).cumcount()+1
    full_df_det = full_df_det.reset_index(drop=True)
    full_df_det[['pat_id','adm_index']] = full_df_det[['pat_id','adm_index']].astype(int)
    # reorder columns
    full_df_det = full_df_det.set_index(['pat_id','adm_index']).sort_index()
    
    
    
    # Now to store the true labels
    
    full_df_golden = None
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        relevant_labels = target_sequences['sequence'].view(1,-1,target_sequences['sequence'].size()[2])[:,relevant_positions,:].squeeze().detach().numpy()
        
        df = pd.DataFrame(relevant_labels,columns=col_names).assign(pat_id=batch['target_pids'])
        
        if full_df_golden is None:
            full_df_golden = df
        else:
            full_df_golden = pd.concat([full_df_golden,df])
            
    full_df_golden['adm_index'] = full_df_golden.groupby(['pat_id']).cumcount()+1
    full_df_golden = full_df_golden.reset_index(drop=True)
    full_df_golden[['pat_id','adm_index']] = full_df_golden[['pat_id','adm_index']].astype(int)
    # reorder columns
    full_df_golden = full_df_golden.set_index(['pat_id','adm_index']).sort_index()
            
    return full_df,full_df_det,full_df_golden

In [10]:
var,det,golden = variational_forward(new_model,val_dataloader,dataset,'validation',num_passes=num_passes)

# save into file

In [11]:
save_var_path = os.path.join(experiment_folder,'variational_forward.csv')
save_det_path = os.path.join(experiment_folder,'deterministic_forward.csv')
save_golden_path = os.path.join(experiment_folder,'golden.csv')

In [12]:
mapping = {save_var_path:var,
       save_det_path:det,
       save_golden_path:golden
      }
for save_path in mapping:
    mapping[save_path].to_csv(save_path)
    
# also save metadata
metadata = dict(model=model_name,
                num_passes=num_passes,
                dataset_id=dataset_id
               )

metadata_path = os.path.join(experiment_folder,'metadata.json')
with open(metadata_path,'w') as f:
    json.dump(metadata,f)
    
print('Saved!')

Saved!


# test it out

In [13]:
df_var = pd.read_csv(save_var_path,index_col=[0,1,2])

df_var.head(3)
df_var.shape

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,npass,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
21,1,1,0.001646,0.2196,0.087425,0.084163,0.026606,0.029044,0.006477,0.005765,0.000793,0.014757,...,0.000132,0.003995,0.000508,0.002002,0.046862,0.086731,0.001672,0.000341,0.001392,0.023608
21,1,2,0.001627,0.241448,0.127487,0.038899,0.002566,0.016236,0.004588,0.016988,0.000358,0.010856,...,5.5e-05,0.002041,0.000502,0.004578,0.077088,0.094556,0.00143,0.000183,0.001322,0.017402
21,1,3,0.004938,0.295236,0.131692,0.051219,0.020193,0.026516,0.006584,0.005307,0.000923,0.008159,...,8.3e-05,0.002266,0.000214,0.004752,0.185865,0.095267,0.002175,0.000482,0.001545,0.019843


(58290, 272)

In [14]:
df_det = pd.read_csv(save_det_path,index_col=[0,1])

df_det.head(3)
df_det.shape

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,0.005465,0.281523,0.145583,0.050645,0.005677,0.007363,0.008749,0.007137,0.000349,0.012388,...,0.000104,0.003191,0.000439,0.010623,0.082972,0.117599,0.001562,0.000729,0.002455,0.027174
23,1,0.002017,0.127589,0.118386,0.019256,0.000759,0.001934,0.023269,0.003224,0.000435,0.013632,...,0.000307,0.000991,0.000544,0.005259,0.105743,0.056455,0.001037,0.000765,0.006583,0.019564
61,1,0.055304,0.231112,0.080319,0.055422,0.02971,0.058429,0.056654,0.011114,0.000952,0.027585,...,0.001459,0.002902,0.00134,0.009572,0.127062,0.118838,0.006779,0.002197,0.002455,0.027425


(1943, 272)

In [15]:
df_golden = pd.read_csv(save_golden_path,index_col=[0,1])

df_golden.head(3)
df_golden.shape

Unnamed: 0_level_0,Unnamed: 1_level_0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_262,diag_263,diag_264,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271
pat_id,adm_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
21,1,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
23,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
61,1,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


(1943, 272)