In [None]:
import time
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Function
import torch.tensor as Tensor
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

## Neuron Model (core forward pass equation)
\
We use a simple piecewise linear postsynaptinc potential based spiking neuron model which has a very low computational cost\
The membrane potential $v_i(t)$ of neuron $i$ at time $t$ is the weighted summation of the PL-PSPs of its aferent neurons:
$$
v_i(t) = \sum_{j\in J}w_{ij}\epsilon(t-t_j)
$$
where, $w_{ij}$ is the synaptic weight connecting the presynaptic neuron $j$ to the neuron $i$ and $t_j$is the spike time of neuron $j$. $\epsilon(t-t_j)$ is the kernel of the PL-PSP function. It's described by the following equation:
$$
\begin{equation}
\epsilon(t-t_j) = \begin{cases}
\frac{t-t_j}{\tau_1}, & \text{if } t_j \le t < t_j+\tau_1; \\
\frac{t_j+\tau_{1}+\tau_{2}-t}{\tau_{2}}, & \text{if } t_j + \tau_1 \le t < t_j + \tau_1 + \tau_2.
\end{cases}
\end{equation}
$$

### Target calculation

In [None]:
def target_firing_time_output(output: Tensor, 
                              tmax: int, label: Tensor, gamma: int, device, dtype)->Tensor:
    
    factory_kwargs = {'device':device, 'dtype':dtype}
    tmp = output
    tau_min = torch.min(output, axis=1, keepdim=True)[0]
    labeled = torch.zeros_like(tmp, **factory_kwargs) + tau_min - gamma
    labeled = torch.where(labeled<0, torch.full_like(tmp, 0), labeled)
    tau_max = torch.max(output, axis=1, keepdim=True)[0]
    unlabeled = torch.zeros_like(tmp, **factory_kwargs) + tau_max + gamma
    unlabeled = torch.where(unlabeled>tmax, torch.full_like(tmp, tmax) , unlabeled)
    target = torch.zeros_like(output, requires_grad=False)
    batch_size = target.shape[0]
    label1 = torch.arange(batch_size, **factory_kwargs).unsqueeze(1).tolist()
    label2 = label.unsqueeze(1).int().tolist()
    target[label1, label2] = 1
    target = torch.where(target==1, labeled, unlabeled)
    
    return target

### Fully connected layer

(1) forward: neuron model\
(2) backward:\
the loss function of each layer $l$ is calculated independently by the following equation:
$$ E^l = \sum_{j}E^l_j = \sum_j\frac{1}{2}(e_j^l)^2 $$
where, $e_j^l$ is the temporal error function for the postsynaptic neuron $j$ obtained by substracting the desired and the actual firing times($T_j^l$ and $t_j^l$, respectively) of the neuron $j$ in the $l^{th}$ layer:
$$ e_j^l = \frac{T_j^l-t_j^l}{T_{max}}$$
and for our gradient of loss function:
$$
\begin{equation}
 \Delta w_{l}^{ji} =
 \begin{cases} 
 -\eta \frac{\partial E_j^l}{\partial t_j^l}\frac{\partial t_j^l}{\partial v_j^l(t)}\frac{\partial v_j^l(t)}{\partial w_{ji}^l}, & \text{if } t_i^{l-1} \le t_j^l \\
 0, & \text{otherwise}.
\end{cases}
\end{equation}
$$

In [None]:
# neuron model
class Spiking_linear(Function):
    @staticmethod
    def forward(ctx, firing_time: Tensor, weight: Tensor, tau1: int, tau2: int, tmax: int, threshold: int, beta: int, device, dtype) -> Tensor:
        """
        :param ctx: same as self
        :param firing_time: batch_size*(i-1)-th neuron
        :param weight: (i)-th neuron * (i-1)-th neuron
        :param tau1: int
        :param tau2: int
        :param tmax: maimum time interval to fire
        :param threshold: threshold of firing
        :param beta: calculate for target firing time in middle layers
        :return:
            :param output: batch_size * i-th neuron
        """
        
        batch_size = firing_time.shape[0]
        neuron_num = firing_time.shape[1]
        factory_kwargs = {'device': device, 'dtype': dtype}
        
        real_firing = torch.where(firing_time > tmax, torch.full_like(firing_time, tmax), firing_time)
        
        # spread to 0-1 matrix
        spike01 = torch.zeros((batch_size, neuron_num, tmax+1+tau1+tau2), **factory_kwargs).float()
        label1 = torch.arange(batch_size, **factory_kwargs).unsqueeze(1).unsqueeze(2).tolist()
        label2 = torch.meshgrid(torch.arange(batch_size, **factory_kwargs), torch.arange(neuron_num, **factory_kwargs))[1].unsqueeze(2).tolist()
        # print(real_firing)
        # print(real_firing.unsqueeze(2).shape, torch.arange(tau1+tau2+1, **factory_kwargs).unsqueeze(0).unsqueeze(0).shape)
        label3 = real_firing.unsqueeze(2).int() + torch.arange(tau1+tau2+1, **factory_kwargs).unsqueeze(0).unsqueeze(0)
        
        a = torch.linspace(0,1,steps=tau1+1, **factory_kwargs)
        b = torch.linspace(1,0,steps=tau2+1, **factory_kwargs)
        c = torch.cat((a,b[1:]), 0)
        spike01[label1, label2, label3] = c
        # for i in range(tau1+tau2+1):
        #    spike01[label1, label2, (label3+i).tolist()] = c[i]
        epsilon = spike01[:,:,:tmax+1]
        # spike01 shape: batch_size, i-1 th neuron, tmax+1
        # weight shape: i-th neuron i-1 th neuron
         
        # get voltage of next level neuron
        Voltage = torch.matmul(weight, epsilon).cumsum(dim = 2)
        # print('why voltage cannot grad?')
        
        # get the firing_time of next level neuron
        Spike = Voltage > threshold
        Spike[:,:,-1] = 1
        output = torch.argmax(torch.eq(Spike.cumsum(axis=2).cumsum(axis=2), 1).int(), axis=2)
        output = output.float()
        # output = torch.where((output > tmax), torch.full_like(output, tmax), output).float()
        
        # firing_time.register_hook(lambda grad: print('firing_time grad: ', grad))
        # weight.register_hook(lambda grad: print('weight grad: ', grad))
        
        ctx.save_for_backward(output, firing_time, weight)
        
        ctx.tau1, ctx.tau2, ctx.tmax, ctx.threshold, ctx.beta = tau1, tau2, tmax, threshold, beta
        
        return output.requires_grad_(True)
    
    @staticmethod
    def backward(ctx, grad_outputs: Tensor) -> Tensor:
        """
        :param ctx: same as self
        :param grad_outputs: postsynaptic loss
        :return:
            grad
        """
        """
        output: batch_size * i-th neuron
        firing_time: batch_size * (i-1)-th neuron
        weight: i-th neuron * (i-1)-th neuron
        """
        output, firing_time, weight = ctx.saved_tensors
        tau1, tau2, tmax, threshold, beta = ctx.tau1, ctx.tau2, ctx.tmax, ctx.threshold, ctx.beta
        # print("beta type:", type(beta))
        # print("tau1 type:", type(tau1))
        assert type(beta) == int
        
        # get hasfired: shape(batch_size, (i-1)-th neuron, i-th neuron)
        hasfired = (firing_time.transpose(0,1) < output.transpose(0,1).unsqueeze(1)).transpose(0,2).contiguous()
        
        # hasfired.shape = batch_size, i-1th neuron, ith neuron
        # output.shape = batch_size, ith neuron
        # firing_time.shape = batch_size, i-1th neuron
        # grad_ouputs.shape = batch_size, ith neuron
        
        # calculate grad_weight: grad_weight(ij) = grad_output(j) * (- tj/threshold) * epsilon(tj - ti)(has fired)
        epsilon = (output.transpose(0,1) - firing_time.transpose(0,1).unsqueeze(1)).float()
        epsilon = torch.where(epsilon<0, torch.full_like(epsilon, 0), epsilon)
        epsilon = torch.where((epsilon>=0)&(epsilon<tau1), epsilon/tau1, epsilon)
        epsilon = torch.where((epsilon>=tau1)&(epsilon<tau1+tau2), (tau1+tau2-epsilon)/tau2, epsilon)
        epsilon = torch.where((epsilon>=tau1+tau2), torch.full_like(epsilon,0), epsilon).float().transpose(0,2).contiguous()
        # epsilon.shape: batch_size*i-th neuron*(i-1)-th neuron

        grad_weight = torch.sum(((grad_outputs * -(output.float()/threshold)).unsqueeze(2) * epsilon), axis = 0)
        
        # calculate grad_input(for the former layer to calculate)
        tmp = beta * grad_outputs * -(output.float()/threshold)
        dv = (output.transpose(0,1) - firing_time.transpose(0,1).unsqueeze(1)).transpose(0,1)
        tochange1 = (dv >= 0) & (dv < tau1)
        tochange2 = (dv >= tau1) & (dv < tau1 + tau2)
        dvdt = - tochange1.int() * (weight / tau1).unsqueeze(2) + tochange2.int() * (weight / tau2).unsqueeze(2)
        dvdt = dvdt.transpose(0,2).contiguous()
        dvdt = dvdt.transpose(1,2).contiguous()
        deltat = (dvdt*tmp.unsqueeze(2)).transpose(1,2).contiguous()
        # print(tochange1.shape, tochange2.shape, dvdt.shape, tmp.shape)
        # deltat = torch.matmul(dvdt, tmp.unsqueeze(2)).squeeze(2)
        grad_input = deltat/(tmax*tmax)*hasfired.float()
        grad_input = grad_input.sum(axis = 2)
        
        return grad_input, grad_weight, None, None, None, None, None, None, None

In [None]:
# layer model
class Spike_linear(nn.Module):
    
    __constants__ = ['in_features', 'out_features', 'tau1', 'tau2', 'tmax', 'threshold', 'beta']
    
    def __init__(self, in_features:int, out_features:int, 
                 tau1:int, tau2:int, tmax:int, threshold:int, beta:int, device, dtype) -> None:
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.device = device
        self.dtype = dtype
        super(Spike_linear, self).__init__()
        # self.weight = nn.Parameter(torch.empty((out_features, in_features), **self.factory_kwargs))
        self.tmax, self.tau1, self.tau2 = tmax, tau1, tau2
        self.threshold, self.beta = threshold, beta
        self.weight = nn.Parameter(0.5 * torch.rand((out_features, in_features), **self.factory_kwargs))
        
    def reset_parameters(self, upperbound:int, lowerbound:int) -> None:
        self.weight = nn.Parameter((upperbound - lowerbound) * torch.rand_like(self.weight, **self.factory_kwargs) + lowerbound)
        
    def forward(self, firing_time: Tensor) -> Tensor:
        return Spiking_linear.apply(firing_time, self.weight, self.tau1, self.tau2, self.tmax, self.threshold, self.beta, self.device, self.dtype)

In [None]:
class linear_loss_f(Function):
    @staticmethod
    def forward(ctx, input, target, tmax):
        result = torch.sum((input-target)*(input-target)/(tmax*tmax), axis = 1, keepdim = True)/2
        ctx.save_for_backward(input, target)
        ctx.tmax = tmax
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        input, target = ctx.saved_tensors
        tmax = ctx.tmax
        result = (input - target)/(tmax*tmax)
        return grad_output.new(result), None, None
    
def linear_loss(firing, target, tmax):
    return linear_loss_f.apply(firing, target, tmax)

In [None]:
class S4NN(nn.Module):
    def __init__(self, input_size: int = 784, hidden_size: int = 400, classes: int=10, tau1: int = 40, tau2: int = 40, tmax: int=256, beta: int=1,  device= 'cpu', dtype = None) -> None:
        super(S4NN, self).__init__()
        self.layer1 = Spike_linear(input_size, hidden_size, tau1, tau2, tmax, 50, beta, device, dtype)
        self.layer1.reset_parameters(0.25, 0)
        self.layer2 = Spike_linear(hidden_size, classes, tau1, tau2, tmax, 10, beta, device, dtype)
        self.layer2.reset_parameters(0.5, 0)

    def forward(self, input: Tensor) -> Tensor:
        """
        :param input: batch_size, input_size
        :return:
        """
        return self.layer2(self.layer1(input))

In [None]:
train_dataset = datasets.MNIST(root='./data/', train=True, transform = transforms.ToTensor(), download = True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform = transforms.ToTensor(), download = True)
training_batch, testing_batch = 1, 1000
train_loader = DataLoader(dataset = train_dataset, batch_size = training_batch, shuffle = True, drop_last = True)
test_loader = DataLoader(dataset = test_dataset, batch_size = testing_batch, shuffle = True, drop_last = True)

In [None]:
# hyperparameters
input_size, hidden_size, num_class = 10, 6, 5
I_max, tmax = 1, 256  # constant for computing spike time from pixel value
nepoch = 1000  # n of epochs
gamma = 10  # the constant for computing target
Dropout = [0, 0]  # didn't realize dropout
lr, lamda = 1, 0  # learning rate, L1 regularization
tau1, tau2 = 40, 40
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = S4NN(input_size = input_size, hidden_size = hidden_size, classes = num_class, tau1 = tau1, tau2 = tau2, tmax = tmax, beta = 1, device = device, dtype = None)
model = final_S4NN(input_size = input_size, hidden_size = hidden_size, classes = num_class, tmax = tmax, beta = 1, device = device, dtype = None)

tensorboard_path = './log/figures'
writer = SummaryWriter(tensorboard_path)

model_path = './log/models'


In [None]:
np.set_printoptions(threshold = np.inf)
file = open("test.txt", 'w').close()

model.train()
# optimizer = torch.optim.SGD(params = model.parameters(), lr = lr, weight_decay = lamda)
optimizer = torch.optim.SGD([
            {'params': model.layer1.parameters(), 'lr': 0.2},
            {'params': model.layer2.parameters(), 'lr': 0.2}])

"""print("init parameters: ")
for parameters in model.parameters():
    print(parameters)
"""
has_set = False

for epoch in range(nepoch):
    print('begin epoch {}'.format(epoch + 1))
    # print('begin epoch {}'.format(epoch + 1), file = f)
    start = time.time()
    
    right, al = 0, 0
    for datas, labels in train_loader:
        datas, labels = datas.to(device), labels.to(device)
        datas = datas.view(training_batch, -1)
        datas = (((I_max - datas)/I_max)*tmax).int().float().requires_grad_()
        # print("datas = ", datas)
        # print("labels = ", labels)
        # assert 1 == -1

        with open("test.txt", "a") as file1:
            print("datas = ", datas, file = file1)

        output = model(datas)

        target = target_firing_time_output(output = output, tmax = tmax, label = labels, gamma = gamma, device = device, dtype = None).detach()

        with open("test.txt", "a") as file1:
            print('output = ', output, file = file1)
            print('target = ', target, file = file1)
        # tmp = torch.sum(torch.pow((target - output)/tmax, 2), axis = 1, keepdims = True)/2
        # print(tmp, torch.mean(tmp, axis = 1))
        # loss = torch.mean(torch.sum(torch.pow((target - output)/(tmax*tmax), 1), axis = 1, keepdims = True))
        # print('loss = ',loss, loss.shape)
        # assert 1 == 0

        loss = linear_loss(output, target, tmax)
        # print("loss = ", loss)
        # loss.retain_grad()
        optimizer.zero_grad()
        loss.backward(torch.ones(loss.size()).to(device))
        with open("test.txt", "a") as file1:
            for name, parms in model.named_parameters():
                    print('-->name:', name, file = file1)
                    print('-->para:', parms, file = file1)
                    print('-->grad_requirs:',parms.requires_grad, file = file1)
                    print('-->grad_value:',parms.grad, file = file1)
                    print("=====================================", file = file1)
            print("after step:", file = file1)      
        optimizer.step()
        with open("test.txt", "a") as file1:
            for name, parms in model.named_parameters():
                    print('-->name:', name, file = file1)
                    print('-->para:', parms, file = file1)
                    print('-->grad_requirs:',parms.requires_grad, file = file1)
                    print('-->grad_value:',parms.grad, file = file1)
                    print("=====================================", file = file1)
                    
        right = right + (torch.min(output, axis=1, keepdim=False)[1]==labels).sum()
        al = al + training_batch
        
    end = time.time()
    print("training accuracy: {}".format(right/al))
    print('Epoch {} finished in {} seconds (100 datas training)'.format(epoch+1, end-start))
        
    
    #print("after training epoch {}, parameters:".format(epoch+1))
    #for parameters in model.parameters():
    #    print(parameters)
    # print('Epoch {} finished in {} seconds (60000 datas training)'.format(epoch+1, end-start), file = f)


    with torch.no_grad():
        all_correct, all_samples = 0, 0
        for datas, labels in test_loader:
            datas, labels = datas.to(device), labels.to(device)
            datas = datas.view(testing_batch, -1)
            datas = (((I_max - datas)/I_max)*tmax).int().float()
            output = model(datas)
            predict = torch.argmin(output, dim=1)
            correct = torch.sum(predict == labels)
            all_correct += correct
            all_samples += testing_batch
        testing_loss = 1 - all_correct/all_samples
        print('After epoch {}, the testing accuracy is {}%'.format(epoch + 1, 100 * all_correct / all_samples))
        # print('After epoch{}, the testing accuracy is {}%'.format(epoch + 1, 100 * all_correct / all_samples), file = f)

        all_correct, all_samples = 0, 0
        count, max_count = 0, 10000
        for datas, labels in train_loader:
            count = count + 1
            datas, labels = datas.to(device), labels.to(device)
            datas = datas.view(training_batch, -1)
            datas = (((I_max - datas) / I_max) * tmax).int().float()
            output = model(datas)
            predict = torch.argmin(output, dim=1)
            correct = torch.sum(predict == labels)
            all_correct += correct
            all_samples += training_batch
            if count == max_count:
                break
        training_loss = 1 - all_correct/all_samples
        print("After epoch {}, the training accuracy is {}%".format(epoch + 1, 100 * all_correct / all_samples))
        # print("After epoch {}, the training accuracy is {}%".format(epoch + 1, 100 * all_correct / all_samples), file=f)

        torch.save(model.state_dict(), model_path+'/model_after_epoch{}.pkl'.format(epoch+1))
        # writer.add_scalar(tag = 'loss/testing_loss', scalar_value=testing_loss, global_step=epoch+1)
        # writer.add_scalar(tag = 'loss/training_loss', scalar_value=training_loss, global_step=epoch+1)

        if (testing_loss < 0.35) and (has_set==False):
            has_set = True
            set_learning_rate = (optimizer, 0.01)