This is playing with pytorch framework for EHR modeling. In general, a patient's health record can be represented as a sequence of visits. Each visit has certain features, and can be represented as a list of medical codes.

For simplicity, we are starting with the data structure that a patient's health record is a list of list, following the line of work from Jimeng Sun's lab. We will use codes from Ed Choi to manipulate the data. 

The core model is an RNN , either LSTM, GRU or Vanilla RNN.

# todos:
* None for now

In [1]:
%matplotlib inline
from __future__ import print_function, division
from io import open
import string
import re
import random
import sklearn 
from sklearn.metrics import roc_auc_score
import plotly.plotly as py 
import plotly.graph_objs as go
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torch.utils.data as Data

use_cuda = torch.cuda.is_available()

import sys, random
import numpy as np
try:
    import cPickle as pickle
except:
    import pickle

from torchviz import make_dot, make_dot_from_trace

# for windows only    
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'

In [2]:
use_cuda

True

In [3]:
# prepare data: load the input file containing list of list of list, and corresponding label file
# and output the splitted training, valid and Test sets

def data_load_split_VT(seqFile = 'data/cerner/hospital_data/h143.visits', labelFile = 'data/cerner/hospital_data/h143.labels' , test_r=0.2 , valid_r=0.1):

    set_x = pickle.load(open(seqFile, 'rb'), encoding='bytes')
    set_y = pickle.load(open(labelFile, 'rb'),encoding='bytes')
    merged_set = [[set_y[i],set_x[i]] for i in range(len(set_x))] # merge the two lists

    # set random seed
    random.seed( 3 )
    
    dataSize = len(merged_set)
    nTest = int(test_r * dataSize)
    nValid = int(valid_r * dataSize)
    
    random.shuffle(merged_set)

    test_set = merged_set[:nTest]
    valid_set = merged_set[nTest:nTest+nValid]
    train_set = merged_set[nTest+nValid:]

    return train_set, valid_set, test_set

In [4]:
train_sl , valid_sl , test_sl = data_load_split_VT()

In [5]:
class EHR_RNN(nn.Module):
    def __init__(self, input_size, hidden_size,embed_dim, n_layers=1,dropout_r=0.1,cell_type='LSTM',bi=False):
        super(EHR_RNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.dropout_r = dropout_r
        self.cell_type = cell_type
        self.embedBag = nn.EmbeddingBag(input_size, self.embed_dim,mode= 'sum')
        
        
        
        if self.cell_type == "GRU":
            cell = nn.GRU
        elif self.cell_type == "RNN":
            cell = nn.RNN
        elif self.cell_type == "LSTM":
            cell = nn.LSTM
        else:
            raise NotImplementedError
        
        self.rnn_c = cell(self.embed_dim, hidden_size,num_layers=n_layers, dropout= dropout_r , bidirectional=bi )
        
        if bi:
            self.out = nn.Linear(self.hidden_size*2,1)
        else: 
            self.out = nn.Linear(self.hidden_size,1)
        self.sigmoid = nn.Sigmoid()

        
    def EmbedPatient_MB(self, seq_mini_batch): # x is a ehr_seq_tensor
        
        lp= len(max(seq_mini_batch, key=lambda xmb: len(xmb[1]))[1]) # max number of visitgs within mb ??? verify again
        #print ('longest',lp)
        tb= torch.FloatTensor(len(seq_mini_batch),lp,self.embed_dim) 
        lbt1= torch.FloatTensor(len(seq_mini_batch),1)

        for pt in range(len(seq_mini_batch)):
              
            lbt ,pt_visits =seq_mini_batch[pt]
            lbt1[pt] = torch.FloatTensor([[float(lbt)]])
            ml=(len(max(pt_visits, key=len))) ## getting the visit with max no. of codes ##the max number of visits for pts within the minibatch
            txs= torch.LongTensor(len(pt_visits),ml)
            
            b=0
            for i in pt_visits:
                pd=(0, ml-len(i))
                txs[b] = F.pad(torch.from_numpy(np.asarray(i)).view(1,-1),pd,"constant", 0).data
                b=b+1
            
            if use_cuda:
                txs=txs.cuda()
                
            emb_bp= self.embedBag(Variable(txs)) ### embed will be num_of_visits*max_num_codes*embed_dim 
            #### the embed Bag dim will be num_of_visits*embed_dim
            
            zp= nn.ZeroPad2d((0,0,0,(lp-len(pt_visits))))
            xzp= zp(emb_bp)
            tb[pt]=xzp.data

        tb= tb.permute(1, 0, 2) ### as my final input need to be seq_len x batch_size x input_size
        emb_m=Variable(tb)
        label_tensor = Variable(lbt1)

        if use_cuda:
                label_tensor = label_tensor.cuda()
                emb_m = emb_m.cuda()
        #print (label_tensor)        
        return emb_m , label_tensor

    def forward(self, input):
        
        x_in , lt = self.EmbedPatient_MB(input)
        
        for i in range(self.n_layers):
                output, hidden = self.rnn_c(x_in) # input (seq_len, batch, input_size) need to check torch.nn.utils.rnn.pack_padded_sequence() 
                                                          
        output = self.sigmoid(self.out(output[0]))
        #print (output, lt)
        return output, lt



In [6]:

model = EHR_RNN(input_size=20000, hidden_size=256 ,embed_dim=512, dropout_r=0, cell_type='GRU',bi=True)

if use_cuda:
    model = model.cuda()

In [7]:
def train (mini_batch, criterion, optimizer):  
    
    model.zero_grad()
    output , label_tensor = model(mini_batch,)
    loss = criterion(output, label_tensor)
    loss.backward()
    optimizer.step()
   
    return output, loss.data[0]

In [8]:
# training all samples in random order
import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)





In [9]:
def run_model_train(dataset,batch_size,learning_rate = 0.01, l2=0.0001,epsl=1e-08 ):
    
    #optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    #optimizer = optim.Adadelta(model.parameters(), lr=learning_rate, weight_decay=0)
    #optimizer = optim.Adam(model.parameters(),lr=learning_rate, weight_decay=learning_rate) #'''lr=learning_rate,''' 
    optimizer = optim.Adamax(model.parameters(), lr=learning_rate, weight_decay=l2 ,eps=epsl) ### Beta defaults (0.9, 0.999)
    #optimizer = optim.RMSprop (model.parameters())
    #optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=learning_rate)
    dataset.sort(key=lambda pt:len(pt[1])) 
    # Keep track of losses for plotting
    current_loss = 0
    all_losses = []
    print_every = 10#int(batch_size/2)
    plot_every = 5
    iter=0
    n_batches = int(np.ceil(int(len(dataset)) / int(batch_size)))
    #print('number of Batches',n_batches)
    start = time.time()

    for index in random.sample(range(n_batches), n_batches):
            batch = dataset[index*batch_size:(index+1)*batch_size]
            output, loss = train(batch, criterion = nn.BCELoss(), optimizer = optimizer)
            current_loss += loss
            iter +=1
            # Print iter number, loss, name and guess
            #if iter % print_every == 0:
               #print('%d %d%% (%s) %.4f ' % ( iter, iter/ n_batches * 100, timeSince(start), loss))

            # Add current loss avg to list of losses
            if iter % plot_every == 0:
                all_losses.append(current_loss / plot_every)
                current_loss = 0
                
    return current_loss,all_losses


In [10]:
def calculate_auc(test_model, dataset, batch_size=200):

    n_batches = int(np.ceil(int(len(dataset)) / int(batch_size)))
    labelVec =[]
    y_hat= []
    
    for index in range(n_batches):
            batch = dataset[index*batch_size:(index+1)*batch_size]
            output, label_t = test_model(batch)
            y_hat.extend(output.cpu().data.view(-1).numpy())
            labelVec.extend(label_t.cpu().data.view(-1).numpy())
    auc = roc_auc_score(labelVec, y_hat)
    
    return auc

In [12]:
epochs=20
batch_size=200
current_loss_l=[]
all_losses_l=[]
train_auc_allep =[]
valid_auc_allep =[]
test_auc_allep=[]
bestValidAuc = 0.0
bestTestAuc = 0.0
bestValidEpoch = 0

for ep in range(epochs):
    
    start = time.time()
    current_loss_la,all_losses_la = run_model_train(train_sl,batch_size)
    train_time = timeSince(start)
    eval_start = time.time()
    train_auc = calculate_auc(model,train_sl,batch_size)
    test_auc = calculate_auc(model,test_sl,batch_size)
    valid_auc = calculate_auc(model,valid_sl,batch_size)
    eval_time = timeSince(eval_start)
    all_losses_l.append (all_losses_la)
    avg_loss = np.mean(all_losses_la)
    train_auc_allep.append(train_auc)
    valid_auc_allep.append(valid_auc)
    test_auc_allep.append(test_auc)
    current_loss_l.append(current_loss_la)
    print ("Epoch ", ep," Train_auc :", train_auc, " , Valid_auc : ", valid_auc, " ,& Test_auc : " , test_auc, " Avg Loss: ", avg_loss, 'Train Time (%s) Eval Time (%s)'%(train_time,eval_time) )
    
    if valid_auc > bestValidAuc: 
        bestValidAuc = valid_auc
        bestValidEpoch = ep
        bestTestAuc = test_auc
    if ep - bestValidEpoch >3:
          break
            
print ('bestValidAuc %f has a TestAuc of %f at epoch %d ' % (bestValidAuc, bestTestAuc, bestValidEpoch))
    


Epoch  0  Train_auc : 0.8265810791616597  , Valid_auc :  0.7822775638535217  ,& Test_auc :  0.8043696971889033  Avg Loss:  0.31459326048692066 Train Time (0m 24s) Eval Time (0m 38s)
Epoch  1  Train_auc : 0.8422385182055991  , Valid_auc :  0.7872117555067054  ,& Test_auc :  0.8044121210374154  Avg Loss:  0.3041466940442721 Train Time (0m 24s) Eval Time (0m 38s)
Epoch  2  Train_auc : 0.8563318623721132  , Valid_auc :  0.7780059090086974  ,& Test_auc :  0.7998262661819107  Avg Loss:  0.2969696709513664 Train Time (0m 24s) Eval Time (0m 38s)
Epoch  3  Train_auc : 0.865132950305252  , Valid_auc :  0.7836381549073896  ,& Test_auc :  0.7948940858707645  Avg Loss:  0.29456270466248197 Train Time (0m 24s) Eval Time (0m 38s)
Epoch  4  Train_auc : 0.8591881307274746  , Valid_auc :  0.7741689329523678  ,& Test_auc :  0.7852971178705839  Avg Loss:  0.2902241370081902 Train Time (0m 24s) Eval Time (0m 38s)
Epoch  5  Train_auc : 0.8670956978524216  , Valid_auc :  0.7724337695421036  ,& Test_auc :  0.

In [13]:
import plotly.plotly as py 
import plotly.graph_objs as go
py.sign_in('LailaRasmy','mzNHzVvwYjcZwBDZx3B7')

train_auc_fg= go.Scatter(x= np.arange(epochs), y=train_auc_allep, name='train')
test_auc_fg= go.Scatter(x= np.arange(epochs), y=test_auc_allep, name='test')
valid_auc_fg= go.Scatter(x= np.arange(epochs), y=valid_auc_allep, name='valid')
valid_max = max(valid_auc_allep)
test_max = max(test_auc_allep)
data = [train_auc_fg,test_auc_fg,valid_auc_fg]#,valid_auc_allep,test_auc_allep] 
layout = go.Layout(xaxis=dict(dtick=1))
layout.update(dict(annotations=[go.Annotation(text="Max Valid", x=valid_auc_allep.index(valid_max), y=valid_max)]))
#layout.update(dict(annotations=[go.Annotation(text="Max Test", x=test_auc_allep.index(test_max), y=test_max)]))
fig = go.Figure(data=data, layout=layout)
py.iplot(fig, filename='DRNN_Auc')
#url = py.plot(data, filename='some-data')  # gen. online plot
#py.image.save_as(data, 'some-data.png') 