In [1]:
import numpy as np
import scipy.io as sio 
from preprocessing_funcs import get_spikes_with_history
from LSTM import LSTM
from train import train
import torch.nn as nn
import torch 

In [2]:
class LSTM(nn.Module):
    
    def __init__(self, TrainX, TrainY, n_hidden= 10 ,n_layers = 1, lr=0.001): # no dropout for now 
        super().__init__()
        self.n_hidden = n_hidden
        self.lr = lr
        self.n_layers = n_layers
        self.input_dim = TrainX.shape[2]
        self.output_dim = TrainY.shape[1]
        self.seq_len = TrainX.shape[1]
        self.batch_size = TrainX.shape[0]

        """self.net = nn.Sequential(nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True), 
                         nn.LSTM(hidden_dim, hidden_dim, n_layers, batch_first=True), 
                        nn.Linear(n_hidden, (TrainY.shape[1])))"""
        #lstm layers
        self.lstm = nn.LSTM(self.input_dim, self.n_hidden, self.n_layers, batch_first=False)
        self.lstm2 = nn.LSTM(self.n_hidden, self.n_hidden, self.n_layers, batch_first=False)
        #output layer
        self.fc = nn.Linear(self.n_hidden, self.output_dim)
    
    def binarize_weights(self, ind_layer) : 
        weights = self.net[ind_layer].weight_ih_l[0] 
        for w in weights : 
            if w >= 0 : 
                w = 1
            else : 
                w = -1 
        self.net[ind_layer].weight_ih_l[k]  = weights 

    
    def forward(self, TrainX, hidden):
        ''' Forward pass through the network. 
            These inputs are x, and the hidden/cell state `hidden`. '''
        ## Get the outputs and the new hidden state from the lstm
        TrainX = TrainX.reshape([self.seq_len, self.batch_size, self.input_dim])
        #self.binarize_weights(0)
        r_output, hidden = self.lstm(TrainX, hidden)
        #self.binarize_weights(1)
        r_output, hidden = self.lstm2(r_output, hidden)
        # Stack up LSTM outputs using view
        # reshape the output
        #out = r_output.contiguous().view(-1, self.n_hidden)
       
        ## put x through the fully-connected layer
        #self.binarize_weights(2)
        out = self.fc(r_output)
        return out, hidden
    
    
    def init_hidden(self):
        ''' Initializes hidden state '''
        # Create two new tensors with sizes n_layers x n_hidden,
        # initialized to zero, for hidden state and cell state of LSTM
        hidden_state = torch.randn(n_layers, self.batch_size, self.n_hidden)
        cell_state = torch.randn(n_layers, self.batch_size, self.n_hidden)
        hidden = (hidden_state, cell_state)

        return hidden

In [6]:
def train(TrainX, TrainY, net, lossfunc, optimizer, num_epoch, clip = 5):
    seq_len = TrainX.shape[1]

    for epoch in range(num_epoch):
        # TODO: Step 1 - create torch variables corresponding to features and labels
        

        #x = TrainX.reshape([seq_len, TrainX.shape[0],TrainX.shape[1]])
        x = torch.from_numpy(TrainX).float()
        y = torch.from_numpy(TrainY).float()
        
        # initialize hidden state 
        h = net.init_hidden()
        # TODO: Step 2 - compute model predictions and loss
        pred, h = net(x, h)
        
        #target = torch.reshape(y, (-1,)).long()
       
        loss = lossfunc(pred, y)
        # TODO: Step 3 - do a backward pass and a gradient update step
        optimizer.zero_grad()
        loss.backward()
        # gradient clipping - prevents gradient explosion 
        nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()
        
        if epoch == num_epoch-1:
            print(pred[pred>0])
            corrcoef = np.corrcoef(pred[-1,:,:].detach().numpy().reshape((-1,)),y.detach().numpy().reshape((-1,)))
            print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epoch, loss.item()))
            print ('Correlation coefficient : {corrcoef}'.format(corrcoef=corrcoef))


In [7]:
for Idx_subject in list([10]):#,11,12]): # 3 subjects index 10-12
        for Finger in list([0,1,2,3,4]): # 5 fingers for each subject. 0:thumb, 1:index, 2:middle ...

            #load training data (TrainX: feature vectors, TrainY: labels)
            matData = sio.loadmat('data/BCImoreData_Subj_'+str(Idx_subject)+'_200msLMP.mat')
            TrainX = matData['Data_Feature'].transpose()
            TrainY = matData['SmoothedFinger']
            TrainY = TrainY [:,Finger]
            TrainY = TrainY.reshape(TrainY.shape[0],1)
            #load testing data (TestX: feature vectors, TestY: labels)
            matData = sio.loadmat('data/BCImoreData_Subj_'+str(Idx_subject)+'_200msLMPTest.mat')
            TestX = matData['Data_Feature'].transpose()
            TestY = matData['SmoothedFinger']
            TestY = TestY[:,Finger]
            TestY = TestY.reshape(TestY.shape[0],1)
            
            
            
            # from here, we reconstruct the input by "looking back" a few steps
            bins_before= 20 #How many bins of neural data prior to the output are used for decoding
            bins_current=1 #Whether to use concurrent time bin of neural data
            bins_after=0 #How many bins of neural data after the output are used for decoding
            
            TrainX=get_spikes_with_history(TrainX,bins_before,bins_after,bins_current)

            TrainX, TrainY = TrainX[bins_before:,:,:], TrainY[bins_before:,]
         
            TestX=get_spikes_with_history(TestX,bins_before,bins_after,bins_current)
            TestX, TestY = TestX[bins_before:,:,:], TestY[bins_before:,]
            
            # Now, we reconstructed TrainX/TestX to have a shape (num_of_samples, sequence_length, input_size)
            # You can fit this to the LSTM

            print("run for finger ", Finger)

            n_hidden = 20
            n_layers = 5
            n_epochs =  50 # start small 

            net = LSTM(TrainX, TrainY,  n_hidden, n_layers)

            lossfunc = nn.L1Loss()
            #lossfunc = nn.NLLLoss()
            optimizer = torch.optim.Adamax(net.parameters())
            train(TrainX, TrainY, net, lossfunc, optimizer, n_epochs, clip = 5)
            # Preprocess the data may leed to better performance. e.g. StandardScaler 


              
       
    




run for finger  0
tensor([[[-0.0029],
         [-0.0031],
         [-0.0023],
         ...,
         [-0.0022],
         [-0.0028],
         [-0.0027]],

        [[-0.0073],
         [-0.0074],
         [-0.0067],
         ...,
         [-0.0068],
         [-0.0073],
         [-0.0072]],

        [[-0.0068],
         [-0.0068],
         [-0.0062],
         ...,
         [-0.0064],
         [-0.0069],
         [-0.0066]],

        ...,

        [[-0.0027],
         [-0.0027],
         [-0.0027],
         ...,
         [-0.0027],
         [-0.0027],
         [-0.0027]],

        [[-0.0027],
         [-0.0027],
         [-0.0027],
         ...,
         [-0.0027],
         [-0.0027],
         [-0.0027]],

        [[-0.0027],
         [-0.0027],
         [-0.0027],
         ...,
         [-0.0027],
         [-0.0027],
         [-0.0027]]], grad_fn=<AddBackward0>)
Epoch [50/50], Loss: 0.2104
Correlation coefficient : [[1.      0.01229]
 [0.01229 1.     ]]
run for finger  1


KeyboardInterrupt: 