In [9]:
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
use_cuda = torch.cuda.is_available()
torch.cuda.set_device(0)
import sys, random
import numpy as np
try:
    import cPickle as pickle
except:
    import pickle
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
import math

In [54]:
class TPLSTM(nn.Module):

    """
    An implementation of Hochreiter & Schmidhuber:
    'Long-Short Term Memory'
    http://www.bioinf.jku.at/publications/older/2604.pdf
    Special args:
    dropout_method: one of
            * pytorch: default dropout implementation
            * gal: uses GalLSTM's dropout
            * moon: uses MoonLSTM's dropout
            * semeniuta: uses SemeniutaLSTM's dropout
    """

    def __init__(self, input_size, hidden_size, bias=True):
        super(TPLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
        self.W_decomp = Parameter(torch.Tensor(hidden_size, hidden_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
            self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
            self.b_decomp = Parameter(torch.Tensor(hidden_size))

        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
            self.register_parameter('b_decomp', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, input, hx):
        return self.TPLSTMCell(
            input, hx,
            self.weight_ih, self.weight_hh,self.W_decomp,
            self.bias_ih, self.bias_hh,self.b_decomp)
    
    
    def TPLSTMCell(input, hidden, w_ih, w_hh,w_decomp, b_ih=None, b_hh=None,b_decomp=None):
        if input.is_cuda:
            igates = F.linear(input, w_ih)
            hgates = F.linear(hidden[0], w_hh)
            state = fusedBackend.LSTMFused.apply
            return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)

        hx, cx = hidden

        ## Map elapse time in days or months
        T = self.map_elapse_time(t) #----> need to define that

        # Decompose the previous cell if there is a elapse time
        C_ST = F.tanh(F.Linear(cx, w_decomp, b_decomp) ) 
        C_ST_dis = torch.mm(T, C_ST)
            # if T is 0, then the weight is one

        cpt = cx - C_ST + C_ST_dis
        gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)
            #C_ST = F.tanh(torch.mm(prev_cell, self.W_decomp) + self.b_decomp)  
            #C_ST_dis = torch.mm(T, C_ST)
            ## if T is 0, then the weight is one
            #prev_cell = prev_cell - C_ST + C_ST_dis

        #cy = (forgetgate * cx) + (ingate * cellgate)
        ct = (forgetgate * cpt) + (ingate * cellgate)         ## Current Memory cell with time
        ht = outgate * F.tanh(ct)

        return ht, ct
    
    def map_elapse_time(self, t):

        c1 = torch.constant(1, dtype=float32)
        c2 = torch.constant(2.7183, dtype=float32)
        T = torch.div(c1, torch.log(t + c2))#, name='Log_elapse_time')
        Ones = torch.ones([1, self.hidden_dim], dtype=float32)
        T = torch.matmul(T, Ones)

        return T

In [80]:
class EHR_TLSTM(nn.Module):
    def __init__(self, input_size,embed_dim, hidden_size, n_layers=1,dropout_r=0.1,cell_type='TLSTM'):#,bi=False , preTrainEmb=''):
        super(EHR_TLSTM,self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.dropout_r = dropout_r
        self.cell_type = cell_type
        self.preTrainEmb=preTrainEmb=''
        bi=False
        if bi: self.bi=2 
        else: self.bi=1
              
        if len(self.preTrainEmb)>0:
            emb_t= torch.FloatTensor(np.asmatrix(self.preTrainEmb))
            self.embed= nn.Embedding.from_pretrained(emb_t)#,freeze=False)  
        else:
            self.embed= nn.Embedding(input_size, self.embed_dim,padding_idx=0)
        
        if self.cell_type == "GRU":
            cell = nn.GRU
        elif self.cell_type == "RNN":
            cell = nn.RNN
        elif self.cell_type == "LSTM":
            cell = nn.LSTM
        elif self.cell_type == "BNLSTM":
            cell = bnlstm.LSTM    
        elif self.cell_type == "TLSTM":
            cell = TPLSTM 
        else:
            raise NotImplementedError
      
        if self.cell_type == "BNLSTM":
            self.rnn_c = cell(bnlstm.BNLSTMCell, self.embed_dim, hidden_size,num_layers=n_layers,use_bias=False, dropout= dropout_r,max_length=30)
        elif self.cell_type == "TLSTM":
            self.bi=1 
            #self.rnn_c = cell(self.embed_dim, 1, hidden_size, hidden_size/2)
            self.rnn_c = cell(self.embed_dim, hidden_size)

        else:
            self.rnn_c = cell(self.embed_dim, hidden_size,num_layers=n_layers, dropout= dropout_r , bidirectional=bi  )
        
        self.out = nn.Linear(self.hidden_size*self.bi,1)
        self.sigmoid = nn.Sigmoid()

        
    def EmbedPatient_MB(self, input): # x is a ehr_seq_tensor
        
        mb=[]
        lbt=[]
        seq_l=[]
        self.bsize=len(input)
        lp= len(max(input, key=lambda xmb: len(xmb[1]))[1])
        self.max_len_bn=lp
        #print (lp)
        llv=0
        for x in input:
            lv= len(max(x[1], key=lambda xmb: len(xmb)))
            if llv< lv:
                llv=lv
        #print (llv)
        for pt in input:
            label, ehr_seq_l = pt
            lpx=len(ehr_seq_l)
            seq_l.append(lpx)
            label_tensor = Variable(torch.FloatTensor([[float(label)]]))
            if use_cuda:
                label_tensor = label_tensor.cuda()
            lbt.append(label_tensor)
            ml=(len(max(ehr_seq_l, key=len)))
            ehr_seq_tl=[]
            for ehr_seq in ehr_seq_l: 
                pd=(0, llv-len(ehr_seq))
                result = F.pad(torch.from_numpy(np.asarray(ehr_seq,dtype=int)).type(torch.cuda.LongTensor),pd,"constant", 0)
                if use_cuda:
                    result.cuda()
                ehr_seq_tl.append(result)
            ehr_seq_t= Variable(torch.stack(ehr_seq_tl,0))     
            lpp= lp-lpx
            zp= nn.ZeroPad2d((0,0,lpp,0))
            ehr_seq_t= zp(ehr_seq_t)
            mb.append(ehr_seq_t)
                
        mb_t= Variable(torch.stack(mb,0)) 
        if use_cuda:
            mb_t.cuda()
        embedded = self.embed(mb_t)
        embedded = torch.sum(embedded, dim=2) 
        lbt_t= Variable(torch.stack(lbt,0)) 
        return embedded, lbt_t,seq_l
    
    def init_hidden(self):
        
        h_0 = Variable(torch.rand(self.n_layers*self.bi,self.bsize, self.hidden_size))
        if use_cuda:
            h_0= h_0.cuda()
        if self.cell_type == "LSTM"or self.cell_type == "TLSTM":
            result = (h_0,h_0)
        else: 
            result = h_0

        return result
    
    def forward(self, input):
        
        x_in , lt ,x_lens = self.EmbedPatient_MB(input)
        x_in = x_in.permute(1,0,2) ## QRNN not support batch first
        #x_inp = nn.utils.rnn.pack_padded_sequence(x_in,x_lens,batch_first=True)
        h_0 = self.init_hidden()
        output, hidden = self.rnn_c(x_in,h_0) 
        if self.cell_type == "LSTM" or self.cell_type == "TLSTM":
            hidden=hidden[0]
        if self.bi==2:
            output = self.sigmoid(self.out(torch.cat((hidden[-2],hidden[-1]),1)))
        #elif self.cell_type == "TLSTM":
            #output = hidden
        else:
            output = self.sigmoid(self.out(hidden[-1]))
        return output.squeeze(), lt.squeeze()

In [81]:
model = EHR_TLSTM(input_size=16000, hidden_size=128 ,embed_dim=256, dropout_r=0, cell_type='TLSTM', n_layers=1)
if use_cuda:
    model = model.cuda()


In [82]:
train_sl = pickle.load(open('pdata_3hosp/h143_train', 'rb'), encoding='bytes')
valid_sl = pickle.load(open('pdata_3hosp/h143_valid', 'rb'), encoding='bytes')
test_sl = pickle.load(open('pdata_3hosp/h143_test', 'rb'), encoding='bytes')

In [83]:
def train (tmodel,mini_batch, criterion, optimizer):  
    
    tmodel.train()
    tmodel.zero_grad()
    output , label_tensor = tmodel(mini_batch)
    loss = criterion(output, label_tensor)
    loss.backward()
    optimizer.step()
   
    return output, loss.item()

In [84]:
# training all samples in random order
import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [85]:
def run_model_train(tmodel,dataset,batch_size,learning_rate = 0.01, l2=1e-04,epsl=1e-06 ):
    
    #optimizer = optim.SGD(tmodel.parameters(), lr=learning_rate)#, weight_decay=l2)
    #optimizer = optim.Adadelta(tmodel.parameters(), lr=learning_rate, weight_decay=l2)
    #optimizer = optim.ASGD(tmodel.parameters(), lr=learning_rate, weight_decay=l2 )
    #optimizer = optim.SparseAdam (tmodel.parameters(),lr=learning_rate) #'''lr=learning_rate,''' 
    optimizer = optim.Adagrad (tmodel.parameters(),lr=learning_rate, weight_decay=l2) #'''lr=learning_rate,''' 
    #optimizer = optim.Adamax(tmodel.parameters(), lr=learning_rate, weight_decay=l2 ,eps=epsl)
    #optimizer = optim.Adamax(filter(lambda p: p.requires_grad, tmodel.parameters()), lr=learning_rate, weight_decay=l2 ,eps=epsl) ### Beta defaults (0.9, 0.999)
    #optimizer = optim.RMSprop (tmodel.parameters(),lr=learning_rate, weight_decay=l2 ,eps=epsl)
    #optimizer = optim.Adam(tmodel.parameters(), lr=learning_rate, weight_decay=learning_rate)
    dataset.sort(key=lambda pt:len(pt[1]),reverse=True) 
    # Keep track of losses for plotting
    current_loss = 0
    all_losses = []
    print_every = 10#int(batch_size/2)
    plot_every = 5
    iter=0
    n_batches = int(np.ceil(int(len(dataset)) / int(batch_size)))
    start = time.time()

    for index in random.sample(range(n_batches), n_batches):
            batch = dataset[index*batch_size:(index+1)*batch_size]
            output, loss = train(tmodel,batch, criterion = nn.BCELoss(), optimizer = optimizer)
            current_loss += loss
            iter +=1
            # Add current loss avg to list of losses
            if iter % plot_every == 0:
                all_losses.append(current_loss / plot_every)
                current_loss = 0
                
    return current_loss,all_losses


In [86]:
def calculate_auc(test_model, dataset, batch_size=200):
    test_model.eval()
    dataset.sort(key=lambda pt:len(pt[1]),reverse=True) 
    n_batches = int(np.ceil(int(len(dataset)) / int(batch_size)))
    labelVec =[]
    y_hat= []
    
    for index in range(n_batches):
            batch = dataset[index*batch_size:(index+1)*batch_size]
            output, label_t = test_model(batch)
            y_hat.extend(output.cpu().data.view(-1).numpy())
            labelVec.extend(label_t.cpu().data.view(-1).numpy())
    auc = roc_auc_score(labelVec, y_hat)
    
    return auc

In [87]:
epochs=100
batch_size=128
current_loss_l=[]
all_losses_l=[]
train_auc_allep =[]
valid_auc_allep =[]
test_auc_allep=[]
bestValidAuc = 0.0
bestTestAuc = 0.0
bestValidEpoch = 0

                    


### Run Epochs    
for ep in range(epochs):
    
    #print (model.embed.weight.data[135] )
    start = time.time()
    current_loss_la,all_losses_la = run_model_train(model,train_sl,batch_size)
    train_time = timeSince(start)
    eval_start = time.time()
    train_auc = calculate_auc(model,train_sl,batch_size)
    test_auc = calculate_auc(model,test_sl,batch_size)
    valid_auc = calculate_auc(model,valid_sl,batch_size)
    eval_time = timeSince(eval_start)
    all_losses_l.append (all_losses_la)
    avg_loss = np.mean(all_losses_la)
    train_auc_allep.append(train_auc)
    valid_auc_allep.append(valid_auc)
    test_auc_allep.append(test_auc)
    current_loss_l.append(current_loss_la)
    print ("Epoch ", ep," Train_auc :", train_auc, " , Valid_auc : ", valid_auc, " ,& Test_auc : " , test_auc," Avg Loss: ", avg_loss, 'Train Time (%s) Eval Time (%s)'%(train_time,eval_time) )
     
    if valid_auc > bestValidAuc: 
        bestValidAuc = valid_auc
        bestValidEpoch = ep
        bestTestAuc = test_auc
        best_model = model
        torch.save(best_model, bmodel_pth)
        torch.save(best_model.state_dict(), bmodel_st)
    if ep - bestValidEpoch >12: break
            
print ('bestValidAuc %f has a TestAuc of %f at epoch %d ' % (bestValidAuc, bestTestAuc, bestValidEpoch))



TypeError: TPLSTMCell() takes from 5 to 8 positional arguments but 9 were given