In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven import functional
import torchvision
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR


In [7]:
class_num = 11
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
dataset_dir ='./data'
batch_size = 16
split_by = 'number'
T = 20
normalization = None

In [19]:
train_data_loader = torch.utils.data.DataLoader(
    dataset=DVS128Gesture(dataset_dir, train=True, frames_number=T, split_by=split_by),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(
    dataset=DVS128Gesture(dataset_dir, train=False, frames_number=T, split_by=split_by),
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    pin_memory=True)

In [16]:
train_data_loader = torch.utils.data.DataLoader(
    dataset=DVS128Gesture(dataset_dir, train=True, use_frame=True, frames_num=T,
                            split_by=split_by, normalization=normalization),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
    pin_memory=True)
test_data_loader = torch.utils.data.DataLoader(
    dataset=DVS128Gesture(dataset_dir, train=False, use_frame=True, frames_num=T,
                            split_by=split_by, normalization=normalization),
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    pin_memory=True)

TypeError: __init__() got an unexpected keyword argument 'use_frame'

In [20]:
img,labels = next(iter(train_data_loader))
img.shape,labels.shape

RuntimeError: stack expects each tensor to be equal size, but got [94283] at entry 0 and [467891] at entry 1

In [None]:
labels

In [None]:
plt.figure(figsize=(20,10))
idx = 0
for i in range(10):
    plt.subplot(2,5,i+1).set_title('frame: '+str(i*2))
    plt.imshow(img[idx,i*2,1,:,:].cpu().numpy())

In [None]:
plt.figure(figsize=(20,10))
idx = 2
for i in range(10):
    plt.subplot(2,5,i+1).set_title('frame: '+str(i*2))
    plt.imshow(img[idx,i*2,0,:,:].cpu().numpy())

In [None]:
for i in range(10):
    mp_img = F.max_pool2d(img[idx:idx+1,i*2,1,:,:],2,2)
    print(np.max(mp_img[0].cpu().numpy()),np.mean(mp_img[0].cpu().numpy()))

In [None]:
import numpy as np
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

surrograte_type = 'MG'
print('gradient type: ', surrograte_type)

torch.manual_seed(2020)
np.random.seed(200)
thresh = 0.5  # neuronal threshold
b_j0 = 0.1  # neural threshold baseline
R_m = 1  # membrane resistance
lens = 0.5
gamma = 0.5

def gaussian(x, mu=0., sigma=.5):
    return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(math.pi)) / sigma

# define approximate firing function

class ActFun_adp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):  # input = membrane potential- threshold
        ctx.save_for_backward(input)
        return input.gt(0).float()  # is firing ???

    @staticmethod
    def backward(ctx, grad_output):  # approximate the gradients
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # temp = abs(input) < lens
        scale = 6.0
        hight = .15
        if surrograte_type == 'G':
            temp = torch.exp(-(input**2)/(2*lens**2))/torch.sqrt(2*torch.tensor(math.pi))/lens
        elif surrograte_type == 'MG':
            temp = gaussian(input, mu=0., sigma=lens) * (1. + hight) \
                - gaussian(input, mu=lens, sigma=scale * lens) * hight \
                - gaussian(input, mu=-lens, sigma=scale * lens) * hight
        elif surrograte_type == 'MG1':
            temp = gaussian(input, mu=0., sigma=lens) * (1. + hight) \
                - gaussian(input, mu=lens, sigma=scale * lens) * hight 
        elif surrograte_type == 'MG2':
            temp = gaussian(input, mu=0., sigma=lens) * (1. + hight) \
                - gaussian(input, mu=-lens, sigma=scale * lens) * hight
        elif surrograte_type =='linear':
            temp = F.relu(1-input.abs())
        elif surrograte_type == 'slayer':
            temp = torch.exp(-5*input.abs())
        return grad_input * temp.float() * gamma
    
    
act_fun_adp = ActFun_adp.apply   

In [None]:
def mem_update_adp(inputs, mem, spike, tau_adp, b, tau_m, dt=1, isAdapt=1):
    alpha = torch.exp(-1. * dt / tau_m).cuda()
    ro = torch.exp(-1. * dt / tau_adp).cuda()
    if isAdapt:
        beta = 1.84
    else:
        beta = 0.
    b = ro * b + (1 - ro) * spike
    B = b_j0 + beta * b

    mem = mem * alpha + (1 - alpha) * R_m * inputs - B * spike * dt
    inputs_ = mem - B
    spike = act_fun_adp(inputs_)  # act_fun : approximation firing function
    return mem, spike, B, b

def output_Neuron(inputs, mem, tau_m, dt=1):
    """
    The read out neuron is leaky integrator without spike
    """
    alpha = torch.exp(-1. * dt / tau_m)
    mem = mem *alpha +  (1-alpha)*inputs
    return mem

In [None]:
class spike_cnn(nn.Module):
    def __init__(self,
                 input_size,output_dim, kernel_size=5,strides=1,
                 pooling_type = None,pool_size = 2, pool_strides =2,
                 tauM = 20,tauAdp_inital =100, tau_initializer = 'normal',tauM_inital_std = 5,tauAdp_inital_std = 5,
                 is_adaptive=1,device='cuda:0'):
        
        super(spike_cnn, self).__init__()
        # input_size = [c,w,h]
        self.input_size = input_size
        self.input_dim = input_size[0]
        self.output_dim = output_dim
        self.is_adaptive = is_adaptive
        self.device = device
        
        if pooling_type is not None: 
            if pooling_type =='max':
                self.pooling = nn.MaxPool2d(kernel_size=pool_size, stride=pool_strides, padding=1)
            elif pooling_type =='avg':
                self.pooling = nn.AvgPool2d(kernel_size=pool_size, stride=pool_strides, padding=1)
        else:
            self.pooling = None
        self.BN = nn.BatchNorm2d(output_dim)
        self.conv= nn.Conv2d(self.input_dim,output_dim,kernel_size=kernel_size,stride=strides)
        
        self.output_size = self.compute_output_size()
        
        self.tau_m = nn.Parameter(torch.Tensor(self.output_size))
        self.tau_adp = nn.Parameter(torch.Tensor(self.output_size))
        
        if tau_initializer == 'normal':
            nn.init.normal_(self.tau_m,tauM,tauM_inital_std)
            nn.init.normal_(self.tau_adp,tauAdp_inital,tauAdp_inital_std)
    
    def set_neuron_state(self,batch_size):
        self.mem = torch.rand(batch_size,self.output_size[0],self.output_size[1],self.output_size[2]).to(self.device)
        self.spike = torch.zeros(batch_size,self.output_size[0],self.output_size[1],self.output_size[2]).to(self.device)
        self.b = (torch.ones(batch_size,self.output_size[0],self.output_size[1],self.output_size[2])*b_j0).to(self.device)


    def forward(self,input_spike):
        d_input = self.conv(input_spike.float())
        d_input = self.BN(d_input)
        if self.pooling is not None: 
            d_input = self.pooling(d_input)
        self.mem,self.spike,theta,self.b = mem_update_adp(d_input,self.mem,self.spike,self.tau_adp,self.b,self.tau_m,isAdapt=self.is_adaptive)
        
        return self.mem,self.spike
    
    def compute_output_size(self):
        x_emp = torch.randn([1,self.input_size[0],self.input_size[1],self.input_size[2]])   
        out = self.conv(x_emp)
        if self.pooling is not None: out=self.pooling(out)
        # print(self.name+'\'s size: ', out.shape[1:])
        return out.shape[1:]

In [None]:
class RNN_s(nn.Module):
    def __init__(self,criterion):
        super(RNN_s, self).__init__()
        self.criterion = criterion

        self.n = 128
        dim = 128
        self.dim = dim
        self.dense_i = nn.Linear(dim*4*4,self.n,bias=False)
        self.dense_i2r = nn.Linear(self.n,self.n,bias=False)
        self.dense_r = nn.Linear(self.n,self.n,bias=False)
        self.dense_o = nn.Linear(self.n,11,bias=False)

        self.conv1 = nn.Sequential(nn.Conv2d(2,dim,3,padding=1),
                                   nn.BatchNorm2d(dim),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2,2))
        self.conv2 = nn.Sequential(nn.Conv2d(dim,dim,3,padding=1),
                                   nn.BatchNorm2d(dim),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2,2))
        self.conv3 = nn.Sequential(nn.Conv2d(dim,dim,3,padding=1),
                                   nn.BatchNorm2d(dim),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2,2))
        self.conv4 = nn.Sequential(nn.Conv2d(dim,dim,3,padding=1),
                                   nn.BatchNorm2d(dim),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2,2))
        self.conv5 = nn.Sequential(nn.Conv2d(dim,dim,3,padding=1),
                                   nn.BatchNorm2d(dim),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2,2))
        
        
        self.dp = nn.Dropout(0.5)
        self.tau_adp_i = nn.Parameter(torch.Tensor(self.n))
        self.tau_adp_r = nn.Parameter(torch.Tensor(self.n))
        self.tau_adp_o = nn.Parameter(torch.Tensor(11))
        
        self.tau_m_i = nn.Parameter(torch.Tensor(self.n))
        self.tau_m_r = nn.Parameter(torch.Tensor(self.n))
        self.tau_m_o = nn.Parameter(torch.Tensor(11))

        nn.init.xavier_uniform_(self.dense_r.weight)
        nn.init.xavier_uniform_(self.dense_i.weight)
        nn.init.xavier_uniform_(self.dense_i2r.weight)

        # nn.init.kaiming_uniform_(self.conv1[0].weight)
        # nn.init.xavier_uniform_(self.conv2[0].weight)
        # nn.init.xavier_uniform_(self.conv3[0].weight)
        # nn.init.xavier_uniform_(self.conv4[0].weight)
        # nn.init.xavier_uniform_(self.conv5[0].weight)

        nn.init.kaiming_uniform_(self.conv1[0].weight)
        nn.init.kaiming_uniform_(self.conv2[0].weight)
        nn.init.kaiming_uniform_(self.conv3[0].weight)
        nn.init.kaiming_uniform_(self.conv4[0].weight)
        nn.init.kaiming_uniform_(self.conv5[0].weight)
      

        
        nn.init.normal_(self.tau_adp_i,25,5)
        nn.init.normal_(self.tau_adp_r,25,5)
        nn.init.normal_(self.tau_adp_o,25,5)
        
        nn.init.normal_(self.tau_m_i,20,5)
        nn.init.normal_(self.tau_m_r,20,5)
        nn.init.normal_(self.tau_m_o,10,2)

        # nn.init.constant_(self.tau_adp_i,20)
        # nn.init.constant_(self.tau_adp_r,20)
        # nn.init.constant_(self.tau_adp_o,10)
        
        # nn.init.constant_(self.tau_m_i,20)
        # nn.init.constant_(self.tau_m_r,20)
        # nn.init.constant_(self.tau_m_o,3)
        
        self.b_h = self.b_o = b_j0

    def forward(self, input,labels=None,sub_length =5,output_type='integrator'):
        b,s,c,h,w = input.shape
        mem_layer1 = spike_layer1 = torch.zeros(b, self.n).cuda()
        mem_layer2 = spike_layer2 = torch.zeros(b, self.n).cuda()
        mem_layer3 = spike_layer3 = mem_output = torch.zeros(b, 11).cuda()
        # print(self.conv1.output_size)

        self.b_i = self.b_o=self.b_r = b_j0
        output = torch.zeros(b, 11).cuda()
        loss = 0
        predictions = []
        fr = []

        input_ = input.reshape(b*s,c,h,w)
        conv1_out = self.conv1(input_)
        conv2_out = self.conv2(conv1_out)
        conv3_out = self.conv3(conv2_out)
        conv4_out = self.conv4(conv3_out)
        conv5_out = self.conv5(conv4_out).reshape(b,s,self.dim,4,4)
        for i in range(s):
            input_x = conv5_out[:,i,:,:,:]
            # input_x= input[:,i,:,:,:]

            # conv1_out = self.conv1(input_x)
            # conv2_out = self.conv2(conv1_out)
            # conv3_out = self.conv3(conv2_out)
            # conv4_out = self.conv3(conv3_out)
            # conv5_out = self.conv3(conv4_out)
  
            # print(conv5_out.shape)
            # snn_in = conv5_out.view(-1,self.dim*2*2)
            snn_in = self.dp(input_x.view(-1,self.dim*4*4))

            d1_output = self.dense_i(snn_in)+self.dense_r(spike_layer2)
            mem_layer1, spike_layer1, theta_i, self.b_i = mem_update_adp(d1_output, mem_layer1, spike_layer1, self.tau_adp_i, self.b_i,self.tau_m_i)
            r_input = self.dense_i2r(spike_layer1)
            mem_layer2, spike_layer2, theta_r, self.b_r = mem_update_adp(r_input, mem_layer2, spike_layer2, self.tau_adp_r, self.b_r,self.tau_m_r)
            o_input = self.dense_o(spike_layer2)
            if output_type == 'adp-mem':
                mem_layer3, spike_layer3, theta_o, self.b_o = mem_update_adp(o_input, mem_layer3, spike_layer3, self.tau_adp_o, self.b_o,self.tau_m_o)
            elif output_type == 'integrator':
                mem_layer3 = output_Neuron(o_input, mem_layer3, self.tau_m_o)
            output = output+ mem_layer3
            output = F.log_softmax(output,dim=-1)#
            
            # output_  = F.log_softmax(output,dim=1)
            predictions.append(output.data.cpu().numpy())
            
            
            fr.append([spike_layer1.detach().mean().cpu().numpy(),
                        spike_layer2.detach().mean().cpu().numpy()])

        if labels is not None and i > 5:
            loss += self.criterion(output, labels[:,i])*(1+(i-5)/5)
                # if i==s-1:
                #     loss += self.criterion(output, labels[:,i])
        predictions = torch.tensor(predictions)
        return predictions, loss,np.array(fr)

    def predict(self,input):
        prediction, _= self.forward(input)
        return prediction

In [None]:
num_epochs = 500
criterion = nn.NLLLoss()#nn.CrossEntropyLoss()#
model = RNN_s(criterion=criterion)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)
# model=torch.load('./models/0.958-MG.pth')
# model=torch.load('./models/0.9583-24--relu-MG.pth')

# model=torch.load('./models/0.9548-153--relu-MG.pth')
model=torch.load('./models/0.9652-22--relu-MG.pth')
model.to(device)

In [None]:
def test(data_loader,after_num_frames=0,is_show=0):
    model.eval()
    test_acc = 0.
    sum_samples = 0
    fr_list = []
    for i, (images, labels) in enumerate(data_loader):
        images = images.view(-1, 20,2,128,128).to(device)
        labels = labels.view(-1,1).repeat_interleave(20,dim=1).long().to(device)#labels.long().to(device)
        predictions, _,fr = model(images)
        _, predicted = torch.max(predictions.data, 2)
        labels = labels.cpu()
        predicted = predicted.cpu().t()
        fr_list.append(fr)
        # print(predicted.shape)
        test_acc += (predicted[:,-1] == labels[:,-1]).sum()
        
        sum_samples = sum_samples + predicted.numel()/20
        torch.cuda.empty_cache()
    if is_show:
        print('Mean FR: ',np.mean(fr_list),np.array(fr_list).mean(axis=(0,1)))
        return test_acc.data.cpu().numpy() / sum_samples,np.mean(fr_list)
    else:
        return test_acc.data.cpu().numpy() / sum_samples

def test_frame(data_loader,after_num_frames=0,is_show=0):
    model.eval()
    test_acc = 0.
    sum_samples = 0
    test_acc_classes = np.zeros((11,20))
    test_acc_count = np.zeros((11,1))                                       
    fr_list = []
    for i, (images, labels) in enumerate(data_loader):
        images = images.view(-1, 20,2,128,128).to(device)
        labels = labels.view(-1,1).repeat_interleave(20,dim=1).long().to(device)#labels.long().to(device)
        predictions, _,fr = model(images)
        _, predicted = torch.max(predictions.data, 2)
        labels = labels.cpu()
        predicted = predicted.cpu().t()
        fr_list.append(fr)
        
        test_acc += (predicted[:,after_num_frames:] == labels[:,after_num_frames:]).float().mean(axis=0)
        f_test = predicted[:,after_num_frames:] == labels[:,after_num_frames:]
        for i in range(f_test.shape[0]):
            tmp = labels[i,0]
            test_acc_classes[tmp] += f_test[i].float().cpu().numpy()
            test_acc_count[tmp] += 1
        if i==1:
            print(f_test.shape)
        
        sum_samples = sum_samples + predicted.numel()
        torch.cuda.empty_cache()
    if is_show:
        print('Mean FR: ',np.mean(fr_list),np.array(fr_list).mean(axis=(0,1)))
        return test_acc.data.cpu().numpy() / i,test_acc_classes/test_acc_count,np.mean(fr_list)
    else:
        return test_acc.data.cpu().numpy() / i,test_acc_classes/test_acc_count

In [None]:
test_acc = test(test_data_loader,is_show=1)
print(test_acc)

In [None]:
test_acc_,test_acc_classes = test_frame(test_data_loader)
last_acc = np.mean(test_acc_classes[:,-1])
print('last frame',last_acc)

In [None]:
test_acc_classes.mean(axis=0)

In [None]:
def train(model,loader,optimizer,scheduler=None,num_epochs=10,file_name='-relu-MG.pth'):
    best_acc = .87
    path = 'models/'  # .pth'
    acc_list=[]
    test_list = []
    for epoch in range(num_epochs):
        model.train()
        train_acc = 0
        train_loss_sum = 0
        sum_samples = 0
        fr_list = []
        for i, (images, labels) in enumerate(loader):
            images = images.view(-1, 20,2,128,128).to(device)
            labels = labels.view(-1,1).repeat_interleave(20,dim=1).long().to(device)#labels.long().to(device)
            optimizer.zero_grad()
            
            predictions, train_loss,fr_ = model(images, labels)
            _, predicted = torch.max(predictions.data, 2)
            
            train_loss.backward()
            train_loss_sum += train_loss
            fr_list.append(fr_)
            optimizer.step()

            labels = labels.cpu()
            predicted = predicted.cpu().t()
            train_acc += (predicted == labels).sum()
            sum_samples = sum_samples + predicted.numel()
            torch.cuda.empty_cache()
        if scheduler is not None:
            scheduler.step()
            
        train_acc = train_acc.data.cpu().numpy() / sum_samples
        valid_acc = test(test_data_loader)
        
        if valid_acc>best_acc:
            best_acc = valid_acc
            torch.save(model, path+str(best_acc)[:6]+'-'+str(epoch)+'-'+file_name)

        test_list.append(valid_acc)
        acc_list.append(train_acc)
        if epoch%1==0:
            fr_ = np.array(fr_list).mean(axis=(0,1))
            print(fr_,best_acc)
            print('epoch: {:3d}, Train Loss: {:.4f}, Train Acc: {:.4f},Valid Acc: {:.4f},fr: {:.4f}'.format(epoch,
                                                                            train_loss_sum.item()/len(loader),
                                                                            train_acc,valid_acc,mean(fr_)), flush=True)
    return [acc_list,test_list]

In [None]:
# base_params = []
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name)
  

In [None]:
learning_rate = 1.e-3#3e-3
cnn_param = [model.conv1[1].weight, model.conv1[1].bias, model.conv1[0].weight, model.conv1[0].bias,
             model.conv2[1].weight, model.conv2[1].bias, model.conv2[0].weight, model.conv2[0].bias,
             model.conv3[1].weight, model.conv3[1].bias, model.conv3[0].weight, model.conv3[0].bias,
             model.conv4[1].weight, model.conv4[1].bias, model.conv4[0].weight, model.conv4[0].bias,
             model.conv5[1].weight, model.conv5[1].bias, model.conv5[0].weight, model.conv5[0].bias,
             ]

base_params = [model.dense_i.weight,#model.dense_i.bias, 
               model.dense_o.weight, #model.dense_o.bias,
               model.dense_r.weight, #model.dense_r.bias, 
               model.dense_i2r.weight,# model.dense_i2r.bias
               ]+cnn_param

optimizer = torch.optim.Adam([
    {'params': base_params},
    {'params': model.tau_adp_i, 'lr': learning_rate * 2},
    {'params': model.tau_adp_r, 'lr': learning_rate * 2},
    {'params': model.tau_adp_o, 'lr': learning_rate * 2},
    {'params': model.tau_m_i, 'lr': learning_rate * 2},
    {'params': model.tau_m_r, 'lr': learning_rate * 2},
    {'params': model.tau_m_o, 'lr': learning_rate * 2}],
    lr=learning_rate)

# scheduler = StepLR(optimizer, step_size=100, gamma=.5) 
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=64)

# training network

# with sechdual
acc_list = train(model,train_data_loader,optimizer,scheduler,num_epochs=num_epochs)

In [None]:
print(model.tau_adp_i.mean(),model.tau_adp_i.std())
print(model.tau_adp_r.mean(),model.tau_adp_r.std())
print(model.tau_adp_o.mean(),model.tau_adp_o.std())
print(model.tau_m_i.mean(),model.tau_m_i.std())
print(model.tau_m_r.mean(),model.tau_m_r.std())
print(model.tau_m_o.mean(),model.tau_m_o.std())

In [None]:
plt.plot(acc_list[1])
# plt.plot(acc_list[0])

In [None]:
torch.cuda.empty_cache()

In [None]:
test_acc = test(test_data_loader,is_show=1)
print(test_acc)

In [None]:
test_acc_,test_acc_classes = test_frame(test_data_loader)
last_acc = np.mean(test_acc_classes[:,-1])
print('last frame',last_acc)
print(test_acc_classes[:,-1])

In [None]:
test_acc_classes.mean(axis=0)

In [None]:
for i in range(11):
    plt.plot(test_acc_classes[i,:])
