# SNN on Icub Data

Here we implement autoencoder:

loss = loss_classification + loss_regression

In [82]:
import sys, os
CURRENT_TEST_DIR = os.getcwd()
sys.path.append(CURRENT_TEST_DIR + "/../../../../slayerPytorch/src")

In [97]:
import slayerSNN as snn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import norm
from joblib import Parallel, delayed
import torch
import copy
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from tas_utils import get_trainValLoader, get_testLoader

np.random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x7f7713a36d90>

### upload data

In [112]:
data_dir = '../../new_data_folder/'
logDir = 'models_and_stats/'
kfold_number = 0

model_name = 'snn_classify_icub_' + str(kfold_number)
screen_fr = 20

save_dir = logDir + model_name + '.pt'

train_loader, val_loader, train_dataset, val_dataset = get_trainValLoader(data_dir, k=0)
test_loader, test_dataset = get_testLoader(data_dir)

### define spike neuron

In [113]:
params = {
    "neuron": {
        "type": "SRMALPHA",
        "theta": 5, # 10
        "tauSr": 10.0,
        "tauRef": 2.0,
        "scaleRef": 2,
        "tauRho": 1,
        "scaleRho": 1,
    },
    "simulation": {"Ts": 1.0, "tSample": 75, "nSample": 1},
    "training": {
        "error": {
            "type": "NumSpikes",  # "NumSpikes" or "ProbSpikes"
            "probSlidingWin": 20,  # only valid for ProbSpikes
            "tgtSpikeRegion": {  # valid for NumSpikes and ProbSpikes
                "start": 0,
                "stop": 75,
            },
            "tgtSpikeCount": {True: 55, False: 15},
        }
    },
}

In [166]:
# def get_icub_spike(X):
    
#     # parameters
#     C = 0.5
#     p_pos = 1 
#     p_neg = -1
    
    
#     X = X.squeeze()
    
#     # nonzero elements -> log
#     non_zero_indx = np.where(X > 0)
#     log_X = torch.zeros(X.shape)
#     log_X[non_zero_indx] = torch.log( X[ non_zero_indx ] )
#     x_diff = log_X[..., 1:] - log_X[..., :-1]
    
#     brightness_diff = torch.cat([log_X[...,0].reshape([log_X.shape[0], log_X.shape[1],1]),  x_diff], dim=2)
    
#     spike_train_pos = torch.zeros(X.shape)
#     spike_train_neg = torch.zeros(X.shape)

#     spike_train_pos[brightness_diff >= p_pos*C] = 1
#     spike_train_neg[brightness_diff <= p_neg*C] = 1
    
#     res = torch.cat([spike_train_pos, spike_train_neg], dim=1)
    
#     return res.reshape(res.shape[0],res.shape[1],1,1,res.shape[-1])

In [172]:
class SlayerMLP(torch.nn.Module):
    def __init__(self, params, input_size, hidden_size1, hidden_size2, output_size):
        super(SlayerMLP, self).__init__()
        self.output_size = output_size
        self.slayer = snn.layer(params["neuron"], params["simulation"])
        self.fc1 = self.slayer.dense(input_size, hidden_size1)
        self.fc2 = self.slayer.dense(hidden_size1, hidden_size2)
        self.fc3 = self.slayer.dense(hidden_size2, output_size)
        
    def get_spike(self, inp):
        return self.slayer.spike(inp)
        
    def forward(self, spike_input):
        spike_1 = self.slayer.spike(self.slayer.psp(self.fc1(spike_input)))
        spike_2 = self.slayer.spike(self.slayer.psp(self.fc2(spike_1)))
        spike_output = self.slayer.spike(self.slayer.psp(self.fc3(spike_2)))
        
        return spike_output

In [168]:
device = torch.device("cuda:1")
net = SlayerMLP(params, 60, 50, 50, 20).to(device)

In [169]:
error = snn.loss(params).to(device)
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001, weight_decay=0.5)

In [171]:
train_total_losses=[]
train_class_losses=[]

val_total_losses=[]
val_class_losses=[]

test_total_losses=[]
test_class_losses=[]

train_accs = []
test_accs = []
val_accs = []

max_val_acc = 0

for epoch in range(10001):
    net.train()
    correct = 0
    loss_train = 0
    for i, (tact, _,  target, label) in enumerate(train_loader):
        
        tact = get_icub_spike(tact)
        
        tact = tact.to(device)
        target = target.to(device)
        
        
        output = net.forward(tact)
        
        correct += torch.sum(snn.predict.getClass(output) == label).data.item()
        loss = error.numSpikes(output, target)
        
        loss_train += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
                
    if epoch%screen_fr == 0:
        print('Epoch: ', epoch, ' --------------------------')
        print('Train loss :', 
              loss_train/len(train_dataset))
        print('Train accuracy:', correct/len(train_dataset))
    train_accs.append(correct/len(train_dataset))
    train_total_losses.append(loss_train/len(train_dataset))
    
#     net.eval()
    correct = 0
    loss_val = 0
    with torch.no_grad():
        for i, (tact, _, target, label) in enumerate(val_loader):

            tact = get_icub_spike(tact)
            
            tact = tact.to(device)
            target = target.to(device)

            

            output = net.forward(tact)

            correct += torch.sum(snn.predict.getClass(output) == label).data.item()
            loss = error.numSpikes(output, target)

            loss_val += loss.item()

    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()

        
    if epoch%screen_fr == 0:
        print('Val loss (all, class, reg):', 
              loss_val/len(val_dataset))
        print('Val accuracy:', correct/len(val_dataset))
    val_accs.append(correct/len(val_dataset))
    val_total_losses.append(loss_val/len(val_dataset))
    
    if correct/len(val_dataset) >= max_val_acc:
        print('Saving model at ', epoch, ' epoch')
        max_val_acc = correct/len(val_dataset)
        torch.save(net.state_dict(), save_dir)

Epoch:  0  --------------------------
Train loss : 10.910644454956055
Train accuracy: 0.115
Val loss (all, class, reg): 10.946500091552734
Val accuracy: 0.07
Saving model at  0  epoch
Saving model at  1  epoch
Saving model at  2  epoch
Saving model at  3  epoch
Saving model at  4  epoch
Saving model at  5  epoch
Saving model at  6  epoch
Saving model at  10  epoch
Saving model at  12  epoch
Saving model at  14  epoch
Saving model at  16  epoch
Saving model at  17  epoch
Epoch:  20  --------------------------
Train loss : 8.968511034647623
Train accuracy: 0.395
Val loss (all, class, reg): 10.200366554260254
Val accuracy: 0.235
Saving model at  20  epoch
Saving model at  21  epoch
Saving model at  23  epoch
Saving model at  25  epoch
Saving model at  26  epoch
Saving model at  27  epoch
Saving model at  29  epoch
Saving model at  34  epoch
Saving model at  38  epoch
Saving model at  39  epoch
Epoch:  40  --------------------------
Train loss : 7.729033279418945
Train accuracy: 0.625
Val 

KeyboardInterrupt: 

In [118]:
# save stats
import pickle
all_stats = [
    train_total_losses,
    val_total_losses,
    train_accs,
    val_accs
]

pickle.dump(all_stats, open(logDir + model_name + '_stats.pkl', 'wb'))

In [None]:
fig, ax = plt.subplots(2, figsize=(15,15))

ax[0].set_title('Total loss')
ax[0].plot(train_total_losses)
ax[0].plot(val_total_losses)
ax[0].set_ylabel('Loss')
ax[0].legend(['Train', 'Validation'])

ax[1].set_title('Accuracy')
ax[1].plot(train_accs)
ax[1].plot(val_accs)
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].legend(['Train', 'Validation'])

plt.show()

In [121]:
# testing set check
net_trained = SlayerMLP(params, 60, 50, 50, 20).to(device)
net_trained.load_state_dict(torch.load(save_dir))
net_trained.eval()

SlayerMLP(
  (slayer): spikeLayer()
  (fc1): _denseLayer(60, 50, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (fc2): _denseLayer(50, 50, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (fc3): _denseLayer(50, 20, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
)

In [122]:
correct = 0
loss_test = 0
with torch.no_grad():
    for i, (tact, _, target, label) in enumerate(test_loader):

        tact = tact.to(device)
        target = target.to(device)
        
        tact = net.get_spike(tact)
        
        output = net_trained.forward(tact)

        correct += torch.sum(snn.predict.getClass(output) == label).data.item()

In [123]:
print(correct/len(test_loader.dataset))

0.825
