In [13]:
import numpy as np
import scipy.io as sio 

import torch.nn as nn
import torch 
from torch.utils.data import Dataset, DataLoader

from preprocessing_funcs import get_spikes_with_history, standardize, remove_outliers
from model import LSTM
from trainer import train
from evaluator import test
from FingerDataset import FingerDataset
from Loss import corr_coeff, corr_coeff_loss

In [19]:
class LSTM(nn.Module):
    
    def __init__(self, input_dim, output_dim, batch_size, seq_len, n_hidden= 10 ,n_layers = 1): # no dropout for now 
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        
       

        """self.net = nn.Sequential(nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True), 
                         nn.LSTM(hidden_dim, hidden_dim, n_layers, batch_first=True), 
                        nn.Linear(n_hidden, (TrainY.shape[1])))"""
        #lstm layers
        self.lstm = nn.LSTM(self.input_dim, self.n_hidden, self.n_layers, batch_first=False)
        self.lstm2 = nn.LSTM(self.n_hidden, self.n_hidden, self.n_layers, batch_first=False)
        self.lstm3 = nn.LSTM(self.n_hidden, self.n_hidden, self.n_layers, batch_first=False)
        #output layer
        self.fc = nn.Linear(self.n_hidden, self.output_dim)
        self.act = nn.ReLU()
    
    def binarize_weights(self, ind_layer) : 
        weights = self.net[ind_layer].weight_ih_l[0] 
        for w in weights : 
            if w >= 0 : 
                w = 1
            else : 
                w = -1 
        self.net[ind_layer].weight_ih_l[k]  = weights 

    
    def forward(self, input, hidden):
        ''' Forward pass through the network. 
            These inputs are x, Ifand the hidden/cell state `hidden`. '''
        
        ## Get the outputs and the new hidden state from the lstm


        #our input has shape [batch_size, seq_len, input_dim] but lstm wants [seq_len, batch_size, input_dim]
        #reshaping does not ahcieve what we want here so we need to reconstrcut the input the way lstm wants:
        new_input = torch.ones((self.seq_len, self.batch_size, self.input_dim))
        for i in range(self.seq_len):
            new_input[i] = input[:,i,:]
       
        input = new_input
        #input = input.reshape((self.seq_len, self.batch_size, self.input_dim))
    
        #self.binarize_weights(0)
        
     
        r_output, hidden = self.lstm(new_input, hidden)
        #self.binarize_weights(1)
        r_output, hidden = self.lstm2(r_output, hidden)

        r_output, hidden = self.lstm3(r_output, hidden)
        #out = self.act(r_output)
        out = self.fc(r_output)
        #print(out)
       
        #print(out)
        return out, hidden
    
    
    def init_hidden(self):
        ''' Initializes hidden state '''
        # Create two new tensors with sizes n_layers x n_hidden,
        # initialized to zero, for hidden state and cell state of LSTM
        hidden_state = torch.zeros(n_layers, self.batch_size, self.n_hidden)
        cell_state = torch.zeros(n_layers, self.batch_size, self.n_hidden)
        hidden = (hidden_state, cell_state)

        return hidden

In [20]:
def batch_train(batch, net, lossfunc, optimizer, clip = 5):
    input, target = batch['input'], batch['target']
    
    # TODO: Step 1 - create torch variables corresponding to features and labels
        

    #x = TrainX.reshape([seq_len, TrainX.shape[0],TrainX.shape[1]])
    x = input.float()
    y = target.float()
        
    # initialize hidden state 
    h = net.init_hidden()
    # TODO: Step 2 - compute model predictions and loss
    pred, h = net(x, h)
        
    #target = torch.reshape(y, (-1,)).long()
   
    #loss = lossfunc(pred[-1,:,:], y.squeeze_().long())
    
    loss = corr_coeff_loss(pred[-1,:,:], y)
    # TODO: Step 3 - do a backward pass and a gradient update step
    optimizer.zero_grad()
    loss.backward()
    # gradient clipping - prevents gradient explosion 
    nn.utils.clip_grad_norm_(net.parameters(), clip)
    optimizer.step()

    return pred, target, loss
        


In [21]:
def train(net, dataset, num_epoch=10, batch_size=32):
    dataloader = DataLoader(dataset, batch_size, drop_last=True)
    lossfunc =  nn.L1Loss()
    #optimizer = torch.optim.Adamax(net.parameters(),lr=0.01)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.02, momentum=0.9)
   
    
    for epoch in range(num_epoch):
        
        losses = []
        
        for batch_idx, batch in enumerate(dataloader):
            
            pred, y, loss = batch_train(batch, net, lossfunc, optimizer)
            losses.append(loss.item())
            if batch_idx==0:
                preds = pred
                ys = y
            else:
                preds = torch.cat((preds, pred), dim=1)
                ys = torch.cat((ys,y), dim=0)
            
        losses = np.array(losses)
       

        if (epoch+1)%10 == 0: #num_epoch-1:
            #preds = preds[-1,:,:].detach().numpy().reshape((-1,))
            #ys = ys.detach().numpy().reshape((-1,))
            corrcoef = corr_coeff(preds[-1,:,:],ys).item() #np.corrcoef(preds, ys)
            print ('Epoch [%d/%d], Average Batch Loss: %.4f' %(epoch+1, num_epoch, np.mean(losses)))
            print ('Correlation Coefficient : {corrcoef}'.format(corrcoef=corrcoef))


In [22]:
def test(net, dataset, batch_size=32):
    dataloader = DataLoader(dataset, batch_size, drop_last=True)

    for batch_idx, batch in enumerate(dataloader):
        with torch.no_grad():
            input, target = batch['input'], batch['target']
            h = net.init_hidden()
            pred, _ = net(input.float(), h)
            target = batch['target']
            print(pred)
    

In [23]:
for Idx_subject in list([10]):#,11,12]): # 3 subjects index 10-12
        for Finger in list([0,1,2,3,4]): # 5 fingers for each subject. 0:thumb, 1:index, 2:middle ...

            #load training data (TrainX: feature vectors, TrainY: labels)
            matData = sio.loadmat('data/BCImoreData_Subj_'+str(Idx_subject)+'_200msLMP.mat')
            TrainX = matData['Data_Feature'].transpose()
            TrainY = matData['SmoothedFinger']
            TrainY = TrainY [:,Finger]
            TrainY = TrainY.reshape(TrainY.shape[0],1)
            #load testing data (TestX: feature vectors, TestY: labels)
            matData = sio.loadmat('data/BCImoreData_Subj_'+str(Idx_subject)+'_200msLMPTest.mat')
            TestX = matData['Data_Feature'].transpose()
            TestY = matData['SmoothedFinger']
            TestY = TestY[:,Finger]
            TestY = TestY.reshape(TestY.shape[0],1)
            
            #standardize and remove outliers from the data
            TrainX = standardize(TrainX)
            TrainX = remove_outliers(TrainX)
            TestX = standardize(TestX)

            # from here, we reconstruct the input by "looking back" a few steps
            bins_before= 20 #How many bins of neural data prior to the output are used for decoding
            bins_current=1 #Whether to use concurrent time bin of neural data
            bins_after=0 #How many bins of neural data after the output are used for decoding

           
            
            TrainX=get_spikes_with_history(TrainX,bins_before,bins_after,bins_current)

            TrainX, TrainY = TrainX[bins_before:,:,:], TrainY[bins_before:,]
         
            TestX=get_spikes_with_history(TestX,bins_before,bins_after,bins_current)
            TestX, TestY = TestX[bins_before:,:,:], TestY[bins_before:,]
            
            # Now, we reconstructed TrainX/TestX to have a shape (num_of_samples, sequence_length, input_size)
            
            print(TrainX.shape)
            # You can fit this to the LSTM

            print("run for finger ", Finger)
            
            input_dim = TrainX.shape[2]
            output_dim = TrainY.shape[1]
            batch_size = 32
            seq_len = TrainX.shape[1]
            n_hidden = 50
            n_layers = 10

            net = LSTM(input_dim, output_dim, batch_size, seq_len, n_hidden, n_layers)
            
            train_dataset = FingerDataset(TrainX, TrainY)
            train(net, train_dataset, num_epoch=100, batch_size=batch_size)
            
            # Preprocess the data may leed to better performance. e.g. StandardScaler 
           
            #test_dataset = FingerDataset(TestX, TestY)
            #test(net, test_dataset, batch_size=32)

(9976, 21, 558)
run for finger  0
