Using a variational model, creates several outputs for each patient in the validation/test set. Saves into csv. The purpose is to do analyses later on 

In [1]:
model_name = 'golden-oath-84'

In [2]:
import os
import json

from rnn_utils import DiagnosesDataset, split_dataset, MYCOLLATE
from rnn_utils import train_one_epoch, eval_model

from mourga_variational.variational_rnn import VariationalRNN

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence, pack_sequence

from sklearn.model_selection import ParameterGrid, ParameterSampler

import numpy as np
import pandas as pd

import wandb

In [7]:
# Reproducibility
np.random.seed(546)
torch.manual_seed(546)
torch.cuda.manual_seed(546)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

<torch._C.Generator at 0x12597e750>

In [8]:
grouping = 'ccs'
batch_size=64

In [9]:
dataset = DiagnosesDataset('data/model_data.json',grouping)
test_size = 0.15
eval_size=0.15
eval_size_corrected = eval_size/(1-test_size)

whole_train_dataset,test_dataset = split_dataset(dataset,test_size)
train_dataset, val_dataset = split_dataset(whole_train_dataset,eval_size_corrected)

len(whole_train_dataset)
len(train_dataset)
len(val_dataset)
len(test_dataset)

whole_train_dataset = DataLoader(whole_train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset),shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset)) #batch_size here is arbitrary and doesn't affect total validation speed
test_dataloader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=MYCOLLATE(dataset))

6375

5250

1125

1124

In [10]:
models_base_path = 'models'
model_path = os.path.join('models',model_name)

# model hyperparameters path
hypp_save_path = os.path.join(model_path, "_".join([model_name,'hypp.json']))
with open(hypp_save_path,'r') as f:
    params_loaded = json.load(f)

    # weights path
weights_save_path = os.path.join('models',
                         model_name,
                         "_".join([model_name,'weights'])
                        )

new_model = VariationalRNN(**params_loaded)
new_model.load_state_dict(torch.load(weights_save_path))

<All keys matched successfully>

In [64]:
def variational_forward(model, dataloader, dataset, name, num_passes=2):
    """
    """
    # to make sure the dropout mask is turned on. But we won't train here
    model.train()
    
    # eg:: ccs, icd9, etc..
    code_type = dataset.grouping
    
    int2code = dataset.grouping_data[code_type]['int2code']
    
    total_loss = 0
    total_seq = 0 #total sequences
    
    sigmoid = nn.Sigmoid()
    
    full_df = None
    
    
    for i, batch in enumerate(iter(dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        history_sequences, target_sequences = batch['train_sequences'],batch['target_sequences']
        
        _,lengths = pad_packed_sequence(history_sequences['sequence'])
        relevant_positions = [[i+idx*max(lengths) for i in range(e)] for idx,e in enumerate(lengths)]
        
        # just flatten the list
        relevant_positions = [item for sublist in relevant_positions for item in sublist]

        for pass_ in range(1,num_passes+1):

            outs = model(history_sequences['sequence'])
            
            outs_flattened = outs.view(1,-1,outs.size()[2])
            
            relevant_outs = outs_flattened[:,relevant_positions,:]
            
            relevant_outs = sigmoid(relevant_outs).detach().numpy().squeeze()
            
            df = pd.DataFrame(relevant_outs).assign(pat_id=batch['target_pids'],npass=pass_)
            
            if full_df is None:
                full_df = df
            else:
                full_df = pd.concat([full_df,df])
    
    full_df['adm_index'] = full_df.groupby(['pat_id','npass']).cumcount()+1
    full_df['name'] = name
    return full_df

In [67]:
res = variational_forward(new_model,val_dataloader,dataset,'validation',num_passes=25)

In [75]:
res

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,266,267,268,269,270,271,pat_id,npass,adm_index,name
0,0.001572,0.114833,0.173299,0.033462,0.000003,0.000570,0.003331,0.002747,0.000202,0.001452,...,0.413414,0.054392,0.001635,0.000894,0.001998,0.022542,57073,1,1,validation
1,0.012642,0.255788,0.281264,0.041398,0.000017,0.000329,0.005451,0.002127,0.000141,0.004886,...,0.310735,0.107714,0.004553,0.006023,0.003653,0.022294,57073,1,2,validation
2,0.003325,0.139801,0.244897,0.025378,0.000007,0.000379,0.002486,0.003731,0.000109,0.002961,...,0.374171,0.051147,0.000797,0.001656,0.002402,0.034179,57073,1,3,validation
3,0.000026,0.071885,0.035722,0.014308,0.000222,0.003418,0.011078,0.000816,0.002010,0.015609,...,0.038718,0.058820,0.008519,0.005279,0.008863,0.011524,73673,1,1,validation
4,0.000576,0.186803,0.123017,0.030666,0.000021,0.014493,0.002726,0.000599,0.000256,0.023996,...,0.083053,0.040633,0.009487,0.002283,0.005041,0.021554,22466,1,1,validation
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61,0.000620,0.098280,0.201488,0.096891,0.000004,0.004767,0.005971,0.003174,0.000704,0.021479,...,0.076961,0.085979,0.013306,0.004263,0.005999,0.033735,27910,25,2,validation
62,0.000825,0.125201,0.233957,0.126330,0.000002,0.003679,0.006167,0.002800,0.000427,0.025058,...,0.080110,0.077173,0.010986,0.006907,0.004292,0.034802,27910,25,3,validation
63,0.001163,0.077228,0.088242,0.028118,0.000655,0.010456,0.008625,0.002257,0.001251,0.025570,...,0.121316,0.047093,0.007006,0.000432,0.004299,0.030367,13208,25,1,validation
64,0.000149,0.091359,0.104006,0.022404,0.000238,0.005860,0.009920,0.003847,0.001215,0.009053,...,0.117061,0.112291,0.006702,0.004740,0.009132,0.034590,27387,25,1,validation


# save into file

In [76]:
save_path = os.path.join('var_runs',model_name,'forward.csv')

In [96]:
current = ''
for idx,subpath in enumerate(save_path.split('/')[:-1]):
    current = os.path.join(current,subpath)
    if not os.path.isdir(current):
        #print(current)
        os.mkdir(current)
res.to_csv(save_path,index=False)

# test it out

In [98]:
df = pd.read_csv(save_path)

In [99]:
df.head(3)
df.shape

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,266,267,268,269,270,271,pat_id,npass,adm_index,name
0,0.001572,0.114833,0.173299,0.033462,3e-06,0.00057,0.003331,0.002747,0.000202,0.001452,...,0.413414,0.054392,0.001635,0.000894,0.001998,0.022542,57073,1,1,validation
1,0.012642,0.255788,0.281264,0.041398,1.7e-05,0.000329,0.005451,0.002127,0.000141,0.004886,...,0.310735,0.107714,0.004553,0.006023,0.003653,0.022294,57073,1,2,validation
2,0.003325,0.139801,0.244897,0.025378,7e-06,0.000379,0.002486,0.003731,0.000109,0.002961,...,0.374171,0.051147,0.000797,0.001656,0.002402,0.034179,57073,1,3,validation


(46450, 276)