In [1]:
import pandas as pd
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.metrics import roc_auc_score
from clockwork_helperfunc import *
from clockwork_helperfunc import evaluation
import clockwork_helperfunc 
from imp import reload  
reload(clockwork_helperfunc)
import time

#configuration
batch_size = 25
num_epochs = 50
number_units = 7

### Data

In [2]:
training = pd.read_pickle('/Users/leilei/Documents/DS1005/CW/truncate_train_data.p')
val = pd.read_pickle('/Users/leilei/Documents/DS1005/CW/truncate_valid_data.p')

In [3]:
patient_feature_train, label_train, label_time_train = downsample(training, proportional = 2)

In [4]:
sum(label_time_train), sum(label_train*label_time_train)

(519228, 232732)

In [5]:
patient_feature_val = val[0] #before cleaning, original data
label_val = val[1]
label_time_val = val[2]

In [6]:
Counter(label_train), Counter(label_val)

(Counter({0: 3535, 1: 1790}), Counter({0: 4089, 1: 511}))

In [7]:
len(patient_feature_train)

5325

In [8]:
5325/25

213.0

### Model

In [9]:
### Model
#forward
#cell_class, step
class Clock_NN(nn.Module):
    def __init__(self, scale,batch_size, group_size = 1, activation_fun =nn.Tanh, mean = 0, std = 1, input_dim = 48,mode = 'shift'):
        super(Clock_NN, self).__init__()
        '''
        scale: the updating frequency, a list. [1,2,4,8,16,32]
        batch_size: the size of batch
        group_size: the number of nodes in each scale, default is 1.
        activation_function
        mean: the mean of Gaussian distribution for initialize weights for hidden layer
        std: the standard devation of the Gaussian distribution for initialize weights for hidden layer
        input_dim: the feature dimension of each time step
        '''
        self.scale = scale
        self.group_size = group_size
        self.batch_size = batch_size
        self.mode = mode
        if mode == 'original':
            self.num_units = len(self.scale)*self.group_size
            self.index_li = {self.scale[i]: i for i in range(len(self.scale))}
        elif mode == 'shift':
            self.num_units = sum(self.scale)*self.group_size
            self.index_li = {i:i-1 for i in self.scale}
            
        self.class_dim = 2
        self.input_dim = input_dim
        self.linear_h = nn.Linear(self.num_units,self.num_units)
        self.linear_o = nn.Linear(self.num_units,self.class_dim)
        self.linear_i = nn.Linear(self.input_dim, self.num_units)
        self.activation_fun = activation_fun
        self.connect = torch.from_numpy(block_tri(self.group_size, self.scale, self.num_units, self.mode)).float()
        self.time_step = 0
        
        self.initial_weights(mean, std)
        
        #the connectivity, when we disabled the weight, this should not change
        self.linear_h.weight.data = self.linear_h.weight.data*self.connect#here needs transpose since previously left multiplication, activate mtrx doesn't need as rewrite and select cols.
        #self.linear_i.weight.data = self.linear_i.weight.data
        
    def forward(self, sequence, hidden):#depends on what passed for model.train(), to be filled)
        '''
        sequence: batch  x timestep x number_feature matrix
        hidden: should be h0
        '''     
        #sequence = sequence.view(48,-1)when this is only one batch
        hidden_output = []
        length = sequence.size()[1]
        logit = []
        for i in range(length):
            #print('this is the timestep ' + str(self.time_step))
            self.time_step += 1
            #backwards, want discharge/dead time aligns
            #print(sequence[:,:,-i].size())#would be batch*48
            hidden = self.CW_RNN_Cell(sequence[:,i,:].contiguous(), hidden)
            hidden_output.append(hidden)#become batch_size x hidden_dim
            out = self.linear_o(hidden)
            logit.append(F.log_softmax(out))
        return hidden_output, logit
            
                
    def CW_RNN_Cell(self, x_input, hidden):
        '''
        x_input: number_feature x batch vector, representing one time stamp
        hidden: output of the last cell, should be hidden_dim(i.e. num_units) x batch
        '''
        #which time bloack to change
        activate = activate_index(self.time_step, self.num_units, self.group_size, self.scale,self.index_li,batch_size,self.mode, self.input_dim)
        activate_re = torch.from_numpy(np.ones((self.batch_size,self.num_units))).float() - activate

        hidden_next = self.linear_h(hidden) + self.linear_i(x_input) #should be batch_size x hidden_dim       
        hidden_next.data = activate*hidden_next.data + activate_re*hidden.data
        hidden_next = self.activation_fun(hidden_next)
        return hidden_next

    def init_hidden(self):
        h0 = Variable(torch.zeros(self.batch_size,self.num_units))
        return h0
        
    def initial_weights(self, mean, std):
        lin_layers = [self.linear_h, self.linear_o, self.linear_i]
        for layer in lin_layers:
            layer.weight.data.normal_(mean, std**2)
            layer.bias.data.fill_(0) 

In [None]:
### Training original
model = Clock_NN([1,2,4,8,16], batch_size, group_size = 2, activation_fun = F.tanh, mean = 0, std = 0.1, input_dim = 40, mode = 'shift')

loss = torch.nn.NLLLoss(ignore_index=-1)  
optimizer = torch.optim.Adam(model.parameters(), lr=0.000005)
accuracy_list = []
train_loader, validation_loader = reload_data(batch_size, patient_feature_train, label_train, label_time_train,patient_feature_val, label_val, label_time_val)
start = time.time()
for epoch in range(10):
    for step, (data, label,label_time_list) in enumerate(train_loader):
        data, label = Variable(data), Variable(label)
        model.zero_grad()
        hidden= model.init_hidden()
        model.time_step = 0
        hidden, output = model(data, hidden)
        #now get a list of hidden and a list of outputs
        label = label.transpose(0,1).contiguous().view(-1) 
        #should be flatten, batch_size x hidden. transpose due to below order, was batch, seq => follow up 2 down. get size batch*seq          
        output = torch.stack(output, dim=1).view(-1, 2) 
        #print(output[-1])
        lossy = loss(output, label)
        lossy.backward()
        model.linear_h.weight.grad.data = model.linear_h.weight.grad.data*model.connect
        optimizer.step()
                #for now debugging, to be removed
    #print('----------------weight--------------------------------------------------------')
    #print(model.linear_h.weight.data)
    #print(model.linear_h.bias.data)
        if step%100 ==0 :
            print("Epoch: {}; Loss: {}".format(epoch, lossy.data[0]))
                #print('accuracy_on_training: {}'.format(evaluation(train_loader))) 
            acc0, acc1, val_acc, auc = evaluation(validation_loader, model)
            print('accuracy_on_validation: {}, the acc for LIVE is {}, the acc for DEAD is {}'.format(val_acc, acc0, acc1)) 
            print('the auc is ' + str(auc))
    #accuracy_list.append(val_acc)
    

    '''
    if ((epoch > 5) and ((accuracy_list[-1] < (accuracy_list[-2] - 0.01)) or (accuracy_list[-1] < (accuracy_list[-3] - 0.01)))):
        print("early stop, accuracy = ", accuracy_list[-2])
        break
    '''
end = time.time()             

Epoch: 0; Loss: 0.6930655837059021
accuracy_on_validation: 0.40543478260869564, the acc for LIVE is 0.3846906334067009, the acc for DEAD is 0.5714285714285714
the auc is 0.478059602418
Epoch: 0; Loss: 0.6932099461555481
accuracy_on_validation: 0.43478260869565216, the acc for LIVE is 0.43237955490339935, the acc for DEAD is 0.45401174168297453
the auc is 0.443195648293
Epoch: 0; Loss: 0.6931344270706177


In [None]:
print(end - start)