In [215]:
import numpy as np
import scipy.io as sio 
from preprocessing_funcs import get_spikes_with_history
#from LSTM import LSTM
#from train import train
import torch.nn as nn
import torch 

In [216]:
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def __init__(self, threshold):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) > self.threshold

In [217]:
class LSTM(nn.Module):
    
    def __init__(self, TrainX, TrainY, n_hidden= 10 ,n_layers = 1, lr=0.001): # no dropout for now 
        super().__init__()
        self.n_hidden = n_hidden
        self.lr = lr
        self.n_layers = n_layers
        self.input_dim = TrainX.shape[2]
        self.output_dim = TrainY.shape[1]
        self.seq_len = TrainX.shape[1]
        self.batch_size = TrainX.shape[0]

        """self.net = nn.Sequential(nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True), 
                         nn.LSTM(hidden_dim, hidden_dim, n_layers, batch_first=True), 
                        nn.Linear(n_hidden, (TrainY.shape[1])))"""
        #lstm layers
        self.lstm = nn.LSTM(self.input_dim, self.n_hidden, self.n_layers, batch_first=False)
        self.lstm2 = nn.LSTM(self.n_hidden, self.n_hidden, self.n_layers, batch_first=False)
        
        
        #output layer
        self.fc = nn.Linear(self.n_hidden, self.output_dim)
    
    def binarize_weights(self, ind_layer) : 
        net = self.lstm2
        if ind_layer == 0 :
            net = self.lstm
        weights = net.weight_ih_l0
        for idx, w_ in enumerate(weights) : 
            with torch.no_grad() :
                #net.weight_ih_l0[idx] = net.weight_ih_l0[idx].clone()
                arr = net.weight_ih_l0[idx].numpy()
                arr[arr<0] = -1.
                arr[arr>= 0] = 1.

    def pruning(self) : 
        module = net.lstm
        module = prune.ln_structured(module, name="weight_ih_l0", n = 1, dim = 0,  amount=0.15)
        module = net.lstm2
        module = prune.ln_structured(module, name="weight_ih_l0", n=1, dim = 0,  amount=0.15)
        #module = net.fc
        #module = prune.l1_unstructured(module, name="weight", amount=0.2)
    
    def threshold_pruning(self) :
        parameters_to_prune = ((self.lstm, "weight_ih_l0"), (self.lstm2, "weight_ih_l0"), (self.fc, "weight"))
        prune.global_unstructured(parameters_to_prune, pruning_method=ThresholdPruning, threshold= threshold)
    
    def forward(self, TrainX, hidden):
        ''' Forward pass through the network. 
            These inputs are x, and the hidden/cell state `hidden`. '''
        ## Get the outputs and the new hidden state from the lstm
        TrainX = TrainX.reshape([self.seq_len, self.batch_size, self.input_dim])
        #self.binarize_weights(0)
        r_output, hidden = self.lstm(TrainX, hidden)
        
        # pruning
        #self.pruning(1)
        
        #self.binarize_weights(1)
        r_output, hidden = self.lstm2(r_output, hidden)
        
        # pruning
        #self.pruning(2)
        
        # Stack up LSTM outputs using view
        # reshape the output
        #out = r_output.contiguous().view(-1, self.n_hidden)
       
        ## put x through the fully-connected layer
        #self.binarize_weights(2)
        out = self.fc(r_output)
        return out, hidden
    
    
    def init_hidden(self):
        ''' Initializes hidden state '''
        # Create two new tensors with sizes n_layers x n_hidden,
        # initialized to zero, for hidden state and cell state of LSTM
        hidden_state = torch.randn(n_layers, self.batch_size, self.n_hidden)
        cell_state = torch.randn(n_layers, self.batch_size, self.n_hidden)
        hidden = (hidden_state, cell_state)

        return hidden

In [238]:
def train(TrainX, TrainY, net, lossfunc, optimizer, num_epoch, clip = 5, pruning_weights= True, threshold = None):
    seq_len = TrainX.shape[1]

    for epoch in range(num_epoch):
        # TODO: Step 1 - create torch variables corresponding to features and labels
        

        #x = TrainX.reshape([seq_len, TrainX.shape[0],TrainX.shape[1]])
        x = torch.from_numpy(TrainX).float()
        y = torch.from_numpy(TrainY).float()
        
        # initialize hidden state 
        h = net.init_hidden()
        # TODO: Step 2 - compute model predictions and loss
        pred, h = net(x, h)
        
        #target = torch.reshape(y, (-1,)).long()

        #loss = lossfunc(pred[-1,:,:], y)
        #The main idea is to calculate the correlation as dot product between two features. Larger result indicates more similar
        loss = corr_coeff(pred[-1,:,:], y)
        
        # TODO: Step 3 - do a backward pass and a gradient update step
        optimizer.zero_grad()
        loss.backward()
        # gradient clipping - prevents gradient explosion 
        nn.utils.clip_grad_norm_(net.parameters(), clip)
        
        # pruning after the foward pass 
        if(pruning_weights) :
            if (threshold == None) : 
            # with pre-implemented prune method 
                net.pruning()
            # with the threshold class 
            else : 
                parameters_to_prune = ((net.lstm, "weight_ih_l0"), (net.lstm2, "weight_ih_l0"))
                prune.global_unstructured(parameters_to_prune, pruning_method=ThresholdPruning, threshold= threshold)
                print(net.lstm.weight_ih_l0.shape)
                print(net.lstm2.weight_ih_l0.shape)
                print(net.fc.weight.shape)
        optimizer.step()
        
        
        if epoch == num_epoch-1:
            print(pred[pred>0])
            corrcoef = np.corrcoef(pred[-1,:,:].detach().numpy().reshape((-1,)),y.detach().numpy().reshape((-1,)))
            print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epoch, loss.item()))
            print ('Correlation coefficient : {corrcoef}'.format(corrcoef=corrcoef))
            
            corr_coef.append(np.corrcoef(pred[-1,:,:].detach().numpy().reshape((-1,)),y.detach().numpy().reshape((-1,)))[0,1])
            #print('weights', net.lstm.weight_ih_l0)
            # count the number of zeros 
            number_weights_pruned = np.count_nonzero(net.lstm.weight_ih_l0.detach().numpy() ==0)
            print('number of pruned lstm: ', number_weights_pruned) 
            print('total number : ', len(net.lstm.weight_ih_l0.detach().numpy().reshape((-1,))) )
            
            number_weights_pruned = np.count_nonzero(net.lstm2.weight_ih_l0.detach().numpy() ==0)
            print('number of pruned lstm 2: ', number_weights_pruned) 
            print('total number : ', len(net.lstm2.weight_ih_l0.detach().numpy().reshape((-1,))) )
            
            number_weights_pruned = np.count_nonzero(net.fc.weight.detach().numpy() ==0)
            print('number of pruned fc: ', number_weights_pruned) 
            print('total number : ', len(net.fc.weight.detach().numpy().reshape((-1,))) )

In [239]:
# train using correlation coefficient for loss 
def corr_coeff(x, y):
    vx = x - torch.mean(x)
    vy = y - torch.mean(y)

    cost = torch.mean(vx * vy) / (torch.sqrt(torch.mean(vx ** 2)) * torch.sqrt(torch.mean(vy ** 2)))
    return (1/cost)

In [240]:
# preprocessing 
from sklearn.preprocessing import StandardScaler
# maybe try robustscaler 
def preprocessing(X) : 
    scaler = StandardScaler()
    # fit and transform the data 
    scaled_data = scaler.fit_transform(X)
    return scaled_data

In [241]:
def analytics(analysis) : 
    nb_pruned = len(net.lstm.weight_ih_l0.detach().numpy().reshape((-1,))) + len(net.lstm2.weight_ih_l0.detach().numpy().reshape((-1,))) + len(net.fc.weight.detach().numpy().reshape((-1,))) 
    total_weights = 46260
    
    per_pruned = nb_pruned/total_weights
    
    analysis.append(per_pruned)

In [242]:
#list_ = [0.01, 0.05, 0.1, 0.125, 0.15, 0.175]
#list_ = [0, 0.06, 0.07, 0.075, 0.08]
list_ = [0.075]
#should contained percentage of pruned for each threshold 
analysis = []
corr_coef = []

for threshold in list_ : 
    print('threshold = ', threshold)
    for Idx_subject in list([10,11,12]): # 3 subjects index 10-12
        print("SUBJECT : ", Idx_subject)
        for Finger in list([0,1,2,3,4]): # 5 fingers for each subject. 0:thumb, 1:index, 2:middle ...

            #load training data (TrainX: feature vectors, TrainY: labels)
            matData = sio.loadmat('data/BCImoreData_Subj_'+str(Idx_subject)+'_200msLMP.mat')
            TrainX = matData['Data_Feature'].transpose()
            TrainY = matData['SmoothedFinger']
            TrainY = TrainY [:,Finger]
            TrainY = TrainY.reshape(TrainY.shape[0],1)
            #load testing data (TestX: feature vectors, TestY: labels)
            matData = sio.loadmat('data/BCImoreData_Subj_'+str(Idx_subject)+'_200msLMPTest.mat')
            TestX = matData['Data_Feature'].transpose()
            TestY = matData['SmoothedFinger']
            TestY = TestY[:,Finger]
            TestY = TestY.reshape(TestY.shape[0],1)

            # preprocessing 
            print("preprocessing...")
            TrainX = preprocessing(TrainX)
            TestX = preprocessing(TestX)

            # from here, we reconstruct the input by "looking back" a few steps
            bins_before= 20 #How many bins of neural data prior to the output are used for decoding
            bins_current=1 #Whether to use concurrent time bin of neural data
            bins_after=0 #How many bins of neural data after the output are used for decoding

            TrainX=get_spikes_with_history(TrainX,bins_before,bins_after,bins_current)

            TrainX, TrainY = TrainX[bins_before:,:,:], TrainY[bins_before:,]

            TestX=get_spikes_with_history(TestX,bins_before,bins_after,bins_current)
            TestX, TestY = TestX[bins_before:,:,:], TestY[bins_before:,]

            # Now, we reconstructed TrainX/TestX to have a shape (num_of_samples, sequence_length, input_size)
            # You can fit this to the LSTM


            print("run for finger ", Finger)

            n_hidden = 20
            n_layers = 5


            n_epochs = 50 # start small 

            net = LSTM(TrainX, TrainY,  n_hidden, n_layers)

            lossfunc = nn.L1Loss()


            optimizer = torch.optim.Adamax(net.parameters())
            train(TrainX, TrainY, net, lossfunc, optimizer, n_epochs, clip = 5, pruning_weights = False)

            net.threshold_pruning()
            
            # store them in a sparse tensor 
            #net.lstm.weight_ih_l0 = torch.nn.Parameter(net.lstm.weight_ih_l0.data.to_sparse())
            #net.lstm2.weight_ih_l0 = torch.nn.Parameter(net.lstm2.weight_ih_l0.data.to_sparse()) 
                
                
            #analytics(analysis)

            # retrain it the pruned network to improve accuracy  
            print("retrain the pruned network ")
            n_epochs =  50

            train(TrainX, TrainY, net, lossfunc, optimizer, n_epochs, clip = 5, pruning_weights = False)




    




threshold =  0.075
SUBJECT :  10
preprocessing...
run for finger  0
Epoch [1/50], Loss: 33.8061
Correlation coefficient : [[1.         0.02958043]
 [0.02958043 1.        ]]
Epoch [2/50], Loss: 12.1101
Correlation coefficient : [[1.         0.08257581]
 [0.08257581 1.        ]]
Epoch [3/50], Loss: 8.9355
Correlation coefficient : [[1.         0.11191343]
 [0.11191343 1.        ]]
Epoch [4/50], Loss: 7.1546
Correlation coefficient : [[1.         0.13977033]
 [0.13977033 1.        ]]
Epoch [5/50], Loss: 5.9779
Correlation coefficient : [[1.         0.16728323]
 [0.16728323 1.        ]]
Epoch [6/50], Loss: 5.4367
Correlation coefficient : [[1.        0.1839343]
 [0.1839343 1.       ]]
Epoch [7/50], Loss: 5.0748
Correlation coefficient : [[1.         0.19705372]
 [0.19705372 1.        ]]
Epoch [8/50], Loss: 4.6990
Correlation coefficient : [[1.         0.21281151]
 [0.21281151 1.        ]]
Epoch [9/50], Loss: 4.4405
Correlation coefficient : [[1.         0.22519895]
 [0.22519895 1.        ]

KeyboardInterrupt: 

In [237]:
len(analysis)
#analysis_stored = analysis
len(corr_coef)
#corr_coef_stored = corr_coef 

30

# Pruning 

In [None]:
import torch.nn.utils.prune as prune

# select a pruning technique from pytorch 

# choose the percentage of connections hat you would like to prune 

# it has do be pruned and then retrained on the remaining weights so the accuracy can go up again.

In [77]:
# pruning does work, more and more weight units are set to zero while running. 

# todo : 
- compute the number of pruned connection for each layer and globally 
- try with pruning on the linear layer 
- understand how to store the sparse weights tensor (ask Bingzhao about the paper)
- find the right lethod of pruning 
- see if we can adapt it to the finger number 

# done : 
- try with multiple layers 
- find the right percentage of connections to prune 
- see if there is a dimension (channel) that is better for pruning than another 
- try to do our own class with a threshold instead of a percentage of 