Using a variational model, creates several outputs for each patient in the validation/test set. Saves into csv. The purpose is to do analyses later on 

# Imports

In [1]:
import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model

from mourga_variational.variational_rnn import VariationalRNN

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence

import numpy as np
import pandas as pd

from config import Settings; settings = Settings()

# display all outputs
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Parameters

In [2]:
num_passes = 30 # number of (variational) forward passes for each input
model_name = 'golden-oath-84'
dataset_id = 'diag_only'
experiment_id = 'A'

sanity check

In [6]:
model_folder = os.path.join(settings.data_base,settings.models_folder,model_name)
os.path.exists(model_folder)

dataset_folder = os.path.join(settings.data_base,settings.model_ready_dataset_folder,dataset_id)
os.path.exists(dataset_folder)

experiment_folder = os.path.join(settings.data_base,settings.variational_data_folder,experiment_id)
if os.path.exists(experiment_folder):
    raise Exception(f"Experiment {experiment_folder} exists. If you want to overwrite it, manually delete the directory first")
else:
    os.mkdir(experiment_folder)

True

True

Exception: Experiment data/variational/Experiment A exists. If you want to overwrite it, manually delete the directory first

# Reproducibility

In [7]:
# Reproducibility
seed = settings.random_seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x7fede5b14db0>

In [8]:
batch_size = 64 # really doesn't matter for this notebook since we will only to inference
grouping = 'ccs'

In [9]:
dataset = DiagnosesDataset(os.path.join(dataset_folder,'dataset.json'),grouping)

train_dataset = DiagnosesDataset(os.path.join(dataset_folder,'train_subset.json'),grouping)
val_dataset = DiagnosesDataset(os.path.join(dataset_folder,'val_subset.json'),grouping)
test_dataset = DiagnosesDataset(os.path.join(dataset_folder,'test_subset.json'),grouping)


len(train_dataset)
len(val_dataset)
len(test_dataset)


train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

5249

1125

1125

In [10]:
# model hyperparameters path
hypp_save_path = os.path.join(model_folder, 'hyper_parameters.json')
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)

# weights path
weights_save_path = os.path.join(model_folder,"weights")

new_model = VariationalRNN(**params_loaded)
new_model.load_state_dict(torch.load(weights_save_path))

<All keys matched successfully>

In [29]:
assert all(np.diff(list(int2code.keys())) == 1)

In [34]:
def variational_forward(model, dataloader, dataset, name, num_passes=2):
    """
    """
    # to make sure the dropout mask is turned on. But we won't train here
    model.train()
    
    # eg:: ccs, icd9, etc..
    code_type = dataset.grouping
    
    int2code = dataset.grouping_data[code_type]['int2code']
    
    assert all(np.diff(list(int2code.keys())) == 1), 'Expecting sorted codes, if this fails then it might be time to update column-naming related code'
    col_names = ['diag_' + str(code) for code in int2code.keys()]
    
    total_loss = 0
    total_seq = 0 #total sequences
    
    sigmoid = nn.Sigmoid()
    
    full_df = None
    
    # variational
    
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']
        
        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]
        
        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        for pass_ in range(1,num_passes+1):

            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = (pd.DataFrame(relevant_outs,
                               columns=col_names)
                  .assign(pat_id=batch['target_pids'],
                          npass=pass_)
                 )
            
            if full_df is None:
                full_df = df
            else:
                full_df = pd.concat([full_df,df])
    
    full_df['adm_index'] = full_df.groupby(['pat_id','npass']).cumcount()+1
    full_df['name'] = name+'_variational'
    full_df = full_df.reset_index(drop=True)
    
    # deterministic
    
    full_df_det = None
    model.eval()
    with torch.no_grad():
         for i, batch in enumerate(iter(dataloader)):
            # get the inputs; data is a list of [inputs, labels]
            history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

            _,lengths = pad_packed_sequence(history_sequences['sequence'])
            relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

            # just flatten the list
            relevant_positions = [item for sublist in relevant_positions for item in sublist]
            
            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = (pd.DataFrame(relevant_outs,
                               columns=col_names)
                  .assign(pat_id=batch['target_pids'])
                 )
            
            if full_df_det is None:
                full_df_det = df
            else:
                full_df_det = pd.concat([full_df_det,df])
    
    full_df_det['adm_index'] = full_df_det.groupby(['pat_id']).cumcount()+1
    full_df_det['name'] = name +'_deterministic'
    full_df_det = full_df_det.reset_index(drop=True)
    
    
    
    # Now to store the true labels
    
    full_df_golden = None
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        relevant_labels = target_sequences['sequence'].view(1,-1,target_sequences['sequence'].size()[2])[:,relevant_positions,:].squeeze().detach().numpy()
        
        df = pd.DataFrame(relevant_labels,columns=col_names).assign(pat_id=batch['target_pids'])
        
        if full_df_golden is None:
            full_df_golden = df
        else:
            full_df_golden = pd.concat([full_df_golden,df])
            
    full_df_golden['adm_index'] = full_df_golden.groupby(['pat_id']).cumcount()+1
    full_df_golden['name'] = name +'_golden'
    full_df_golden = full_df_golden.reset_index(drop=True)
            
    return full_df,full_df_det,full_df_golden

In [35]:
var,det,golden = variational_forward(new_model,val_dataloader,dataset,'validation',num_passes=num_passes)

# save into file

In [39]:
save_var_path = os.path.join(experiment_folder,'variational_forward.csv')
save_det_path = os.path.join(experiment_folder,'deterministic_forward.csv')
save_golden_path = os.path.join(experiment_folder,'golden.csv')

In [40]:
mapping = {save_var_path:var,
       save_det_path:det,
       save_golden_path:golden
      }
for save_path in mapping:
    mapping[save_path].to_csv(save_path,index=False)
    
# also save metadata
metadata = dict(model=model_name,
                num_passes=num_passes,
                dataset_id=dataset_id
               )

metadata_path = os.path.join(experiment_folder,'metadata.json')
with open(metadata_path,'w') as f:
    json.dump(metadata,f)
    
print('Saved!')

# test it out

In [43]:
df = pd.read_csv(save_var_path)

df.head(3)
df.shape

Unnamed: 0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271,pat_id,npass,adm_index,name
0,0.002717,0.29644,0.14781,0.060906,0.0009,0.007469,0.004407,0.006102,0.000658,0.022718,...,0.071754,0.117667,0.004013,0.000732,0.00122,0.033151,21,1,1,validation_variational
1,0.0011,0.102634,0.12115,0.013699,0.000671,0.000878,0.017694,0.004429,0.000642,0.008332,...,0.099391,0.048329,0.000303,0.000429,0.003517,0.017984,23,1,1,validation_variational
2,0.016359,0.196292,0.08693,0.078287,0.014079,0.083415,0.047873,0.013352,0.000362,0.01713,...,0.162894,0.105277,0.003154,0.001287,0.002007,0.032937,61,1,1,validation_variational


(58290, 276)

In [44]:
df = pd.read_csv(save_det_path)

df.head(3)
df.shape

Unnamed: 0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271,pat_id,adm_index,name
0,0.005948,0.174385,0.10825,0.041608,0.009844,0.004907,0.002712,0.006547,0.000586,0.011575,...,0.004836,0.074164,0.111899,0.001364,0.000471,0.002073,0.016945,21,1,validation_deterministic
1,0.003394,0.106665,0.101157,0.01842,0.001011,0.001089,0.014337,0.003634,0.000576,0.014183,...,0.004345,0.099369,0.050602,0.001039,0.000802,0.006341,0.019711,23,1,validation_deterministic
2,0.013099,0.187023,0.066681,0.045532,0.008071,0.059498,0.053763,0.014653,0.000972,0.020505,...,0.007737,0.099287,0.104157,0.007525,0.003345,0.003494,0.028888,61,1,validation_deterministic


(1943, 275)

In [45]:
df = pd.read_csv(save_golden_path)

df.head(3)
df.shape

Unnamed: 0,diag_0,diag_1,diag_2,diag_3,diag_4,diag_5,diag_6,diag_7,diag_8,diag_9,...,diag_265,diag_266,diag_267,diag_268,diag_269,diag_270,diag_271,pat_id,adm_index,name
0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,21,1,validation_golden
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,23,1,validation_golden
2,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,61,1,validation_golden


(1943, 275)