Using a variational model, creates several outputs for each patient in the validation/test set. Saves into csv. The purpose is to do analyses later on 

In [1]:
model_name = 'golden-oath-84'

In [2]:
import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model

from mourga_variational.variational_rnn import VariationalRNN

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence

from sklearn.model_selection import ParameterGrid, ParameterSampler

import numpy as np
import pandas as pd

import wandb

In [3]:
# Reproducibility
np.random.seed(546)
torch.manual_seed(546)
torch.cuda.manual_seed(546)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x12fa16710>

In [4]:
grouping = 'ccs'
batch_size=64

In [5]:
dataset = DiagnosesDataset('data/model_data.json',grouping)
test_size = 0.15
eval_size=0.15
eval_size_corrected = eval_size/(1-test_size)

whole_train_dataset,test_dataset = split_dataset(dataset,test_size)
train_dataset, val_dataset = split_dataset(whole_train_dataset,eval_size_corrected)

len(whole_train_dataset)
len(train_dataset)
len(val_dataset)
len(test_dataset)

whole_train_dataset = DataLoader(whole_train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

6375

5250

1125

1124

In [6]:
models_base_path = 'models'
model_path = os.path.join('models',model_name)

# model hyperparameters path
hypp_save_path = os.path.join(model_path, "_".join([model_name,'hypp.json']))
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)

    # weights path
weights_save_path = os.path.join('models',
                         model_name,
                         "_".join([model_name,'weights'])
                        )

new_model = VariationalRNN(**params_loaded)
new_model.load_state_dict(torch.load(weights_save_path))

<All keys matched successfully>

In [7]:
model_parameters = filter(lambda p: p.requires_grad, new_model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [8]:
params

139672

In [19]:
def variational_forward(model, dataloader, dataset, name, num_passes=2):
    """
    """
    # to make sure the dropout mask is turned on. But we won't train here
    model.train()
    
    # eg:: ccs, icd9, etc..
    code_type = dataset.grouping
    
    int2code = dataset.grouping_data[code_type]['int2code']
    
    total_loss = 0
    total_seq = 0 #total sequences
    
    sigmoid = nn.Sigmoid()
    
    full_df = None
    
    # variational
    
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']
        
        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]
        
        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        for pass_ in range(1,num_passes+1):

            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = pd.DataFrame(relevant_outs).assign(pat_id=batch['target_pids'],npass=pass_)
            
            if full_df is None:
                full_df = df
            else:
                full_df = pd.concat([full_df,df])
    
    full_df['adm_index'] = full_df.groupby(['pat_id','npass']).cumcount()+1
    full_df['name'] = name+'_variational'
    
    
    # deterministic
    
    full_df_det = None
    model.eval()
    with torch.no_grad():
         for i, batch in enumerate(iter(dataloader)):
            # get the inputs; data is a list of [inputs, labels]
            history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

            _,lengths = pad_packed_sequence(history_sequences['sequence'])
            relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

            # just flatten the list
            relevant_positions = [item for sublist in relevant_positions for item in sublist]
            
            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = pd.DataFrame(relevant_outs).assign(pat_id=batch['target_pids'])
            
            if full_df_det is None:
                full_df_det = df
            else:
                full_df_det = pd.concat([full_df_det,df])
    
    full_df_det['adm_index'] = full_df_det.groupby(['pat_id']).cumcount()+1
    full_df_det['name'] = name +'_deterministic'
    
    
    
    # Now to store the true labels
    
    full_df_golden = None
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']

        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]

        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        relevant_labels = target_sequences['sequence'].view(1,-1,target_sequences['sequence'].size()[2])[:,relevant_positions,:].squeeze().detach().numpy()
        
        df = pd.DataFrame(relevant_labels).assign(pat_id=batch['target_pids'])
        
        if full_df_golden is None:
            full_df_golden = df
        else:
            full_df_golden = pd.concat([full_df_golden,df])
            
    full_df_golden['adm_index'] = full_df_golden.groupby(['pat_id']).cumcount()+1
    full_df_golden['name'] = name +'_golden'
            
    return full_df,full_df_det,full_df_golden

In [20]:
var,det,golden = variational_forward(new_model,val_dataloader,dataset,'validation',num_passes=25)

# save into file

In [24]:
save_var_path = os.path.join('var_runs',model_name,'variational_forward.csv')
save_det_path = os.path.join('var_runs',model_name,'deterministic_forward.csv')
save_golden_path = os.path.join('var_runs',model_name,'golden.csv')

In [25]:
tmp = {save_var_path:var,save_det_path:det,save_golden_path:golden}
for save_path in tmp:
    current = ''
    for idx,subpath in enumerate(save_path.split('/')[:-1]):
        current = os.path.join(current,subpath)
        if not os.path.isdir(current):
            #print(current)
            os.mkdir(current)
    tmp[save_path].to_csv(save_path,index=False)

# test it out

In [116]:
df = pd.read_csv(save_var_path)

In [117]:
df.head(3)
df.shape

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,266,267,268,269,270,271,pat_id,npass,adm_index,name
0,0.001856,0.138134,0.11462,0.017878,2.8e-05,0.000119,0.012129,0.004681,0.000156,0.001759,...,0.298088,0.071813,0.00246,0.00118,0.000948,0.008432,57073,1,1,validation_variational
1,0.007063,0.150442,0.127057,0.020311,6.6e-05,4.5e-05,0.010244,0.001438,0.000118,0.003473,...,0.326944,0.193534,0.002313,0.001776,0.001629,0.007794,57073,1,2,validation_variational
2,0.001551,0.147782,0.10462,0.014406,1.6e-05,2.7e-05,0.009318,0.001628,5.2e-05,0.00365,...,0.251702,0.144012,0.001188,0.00041,0.001214,0.005411,57073,1,3,validation_variational


(46450, 276)

In [118]:
df = pd.read_csv(save_det_path)

In [119]:
df.head(3)
df.shape

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,265,266,267,268,269,270,271,pat_id,adm_index,name
0,0.001139,0.075773,0.084585,0.023194,1.8e-05,0.000688,0.008969,0.006719,0.000329,0.001288,...,0.001133,0.344245,0.039821,0.001351,0.002167,0.000913,0.015739,57073,1,validation_deterministic
1,0.004444,0.141498,0.099273,0.02531,0.000119,0.000456,0.009971,0.003447,0.000193,0.002948,...,0.002121,0.233177,0.116428,0.002726,0.010303,0.0015,0.013493,57073,2,validation_deterministic
2,0.000986,0.0783,0.087702,0.020546,1.1e-05,0.000237,0.004596,0.004172,0.000128,0.001062,...,0.000851,0.257091,0.073811,0.000787,0.001843,0.000998,0.015166,57073,3,validation_deterministic


(1858, 275)