In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import argparse, os, sys,time
from typing import Callable, Union

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser(description='STDP learning')
parser.add_argument('-T', default=100, type=int, help='simulating time-steps')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-b', default=200, type=int, help='batch size')
parser.add_argument('-epochs', default=5, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-j', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-data-dir', default='/data/tanghao/datasets/', type=str, help='root dir of dataset')
parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron')

args = parser.parse_args(args=[])
print(args)

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

Namespace(T=100, b=200, data_dir='/data/tanghao/datasets/', device='cuda:0', epochs=5, j=4, lr=0.001, momentum=0.9, opt='adam', tau=2.0)


In [3]:
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

In [4]:
def Possion_Encoder(x):
    return torch.rand_like(x).le(x).to(x)

In [5]:
def f_pre(x, w_min, alpha=0.):
    return (x - w_min) ** alpha

def f_post(x, w_max, alpha=0.):
    return (w_max - x) ** alpha

此处定义每个神经元的trace都是同输入维度一致。

In [6]:
def trace_update(SpikingNeuron,in_spike: torch.Tensor, out_spike: torch.Tensor, tau_pre: float, tau_post: float):
    if SpikingNeuron.trace_pre is None:
        # print('input:',in_spike)
        SpikingNeuron.trace_pre = torch.zeros_like(in_spike)

    if SpikingNeuron.trace_post is None:
        SpikingNeuron.trace_post = torch.zeros_like(out_spike)


    # SpikingNeuron.trace_pre = SpikingNeuron.trace_pre - SpikingNeuron.trace_pre / \
    #     tau_pre + in_spike      # shape = [batch_size, N_in]
    # SpikingNeuron.trace_post = SpikingNeuron.trace_post - \
    #     SpikingNeuron.trace_post / tau_post + \
    #     out_spike  # shape = [batch_size, N_out]

    SpikingNeuron.trace_pre = SpikingNeuron.trace_pre - SpikingNeuron.trace_pre / tau_pre + in_spike      # shape = [batch_size, N_in]
    SpikingNeuron.trace_post = SpikingNeuron.trace_post - SpikingNeuron.trace_post / tau_post + out_spike  # shape = [batch_size, N_out]


In [7]:
def stdp_grad(
    weight, in_spike: torch.Tensor, out_spike: torch.Tensor,
    trace_pre: Union[float, torch.Tensor, None],
    trace_post: Union[float, torch.Tensor, None],
    w_min: float, w_max: float,
    f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x
):#: nn.Linear
    

    # [batch_size, N_out, N_in] -> [N_out, N_in]
    # 此处对照更新公式，使用unsqueeze添加更新公式中所缺失的一维
    # print(trace_pre.shape,trace_post.shape,weight.shape,trace_pre.unsqueeze(1).shape)
    # [200, 784] [200, 10] [10, 784] [200, 1, 784]
    # torch.Size([200, 784]) torch.Size([200, 10]) torch.Size([10, 784]) torch.Size([200, 1, 784])
    # trace_pre shape = [batch_size, N_in] = [200, 784]
    # trace_post shape = [batch_size, N_out] = [200, 10]
    # in_spike shape = [batch_size, N_in] = [200, 784]
    # out_spike shape = [batch_size, N_out] = [200, 10]
    delta_w_pre = -f_pre(weight, w_min) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)).sum(0)
    delta_w_post = f_post(weight, w_max) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)).sum(0)
    return delta_w_pre + delta_w_post

https://zhuanlan.zhihu.com/p/359524837
此处定义大于等于阈值则发射脉冲。

https://blog.csdn.net/winycg/article/details/100695373
对于pytroch来说，为获取中间层的输出数值，需要使用hook函数，hook函数包括tensor的hook和nn.Module的hook。而此处为获取脉冲神经元的输入，需要得到上一层的输出，由此需要在此注册Modelu对象的hook函数。

有register_forward_hook(hook)和register_backward_hook(hook)两种方法，分别对应前向传播和反向传播的hook函数。

In [8]:
class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx,input,fc,SpikingNeuron,tau_pre,tau_post, w_min, w_max, f_pre, f_post):
        in_spike=fc(input)
        if SpikingNeuron.v is None:
            SpikingNeuron.v = torch.zeros_like(in_spike)
            SpikingNeuron.s = torch.zeros_like(in_spike)

        SpikingNeuron.v = SpikingNeuron.v + (-SpikingNeuron.v+in_spike) / tau_post# todo: 此处的tau
        spike = SpikingNeuron.v.ge(SpikingNeuron.v_threshold).to(SpikingNeuron.v)
        # SpikingNeuron.v = SpikingNeuron.v + \
        #     (1.0 - SpikingNeuron.s) / tau_post * in_spike# todo: 此处的tau
        # spike = (SpikingNeuron.v >= SpikingNeuron.v_threshold).to(in_spike)
        if SpikingNeuron.soft_reset:
            SpikingNeuron.v = SpikingNeuron.v-SpikingNeuron.v_threshold*spike
        else:
            SpikingNeuron.v = SpikingNeuron.v * (1-spike) + SpikingNeuron.v_reset * spike

        
        trace_update(SpikingNeuron ,input, spike, tau_pre, tau_post)# todo: 需要考察一下trace_pre到底是linuear层的输出还是lif层的输入 fc.trace
        ctx.save_for_backward(fc.weight.data,input, spike,SpikingNeuron.trace_pre,SpikingNeuron.trace_post)#

        ctx.w_min = w_min
        ctx.w_max = w_max
        ctx.f_pre = f_pre
        ctx.f_post = f_post

        return spike

    @staticmethod
    def backward(ctx, grad_output):
        weight,input, output,trace_pre,trace_post = ctx.saved_tensors
        # grad_input = grad_output.clone()
        # grad_output.shape=[200, 10], stdp_grad.shape=[10, 784]
        # print(grad_output.shape, stdp_grad(weight, input, output,trace_pre, trace_post, ctx.w_min, ctx.w_max, ctx.f_pre, ctx.f_post).shape)
        # 此处返回梯度维度应该为[200, 784]，与输入维度一致,None,None,None,None,None,None,None,None
        return torch.mm(grad_output,stdp_grad(weight, input, output,trace_pre, trace_post, ctx.w_min, ctx.w_max, ctx.f_pre, ctx.f_post))


对于STDP，当使用trace的时候，需要记录前后神经元的信息，主要需要tau的值，因此重构Linear类，加入tau属性。

同时需要考察一下trace_pre到底是linuear层的输出还是lif层的输入，此处作为linear层的输出，加入该属性，方便trace_update函数的调用。

在spikingjelly中，此部分代码为：
```python
class STDPLearner(base.MemoryModule):
    def __init__(
        self, step_mode: str,
        synapse: Union[nn.Conv2d, nn.Linear], sn: neuron.BaseNode,
        ...
        self.in_spike_monitor = monitor.InputMonitor(synapse)
        self.out_spike_monitor = monitor.OutputMonitor(sn)

def stdp_linear_single_step(
    fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor,
    trace_pre: Union[float, torch.Tensor, None],
    trace_post: Union[float, torch.Tensor, None],
    ...
    trace_pre = trace_pre - trace_pre / tau_pre + in_spike      # shape = [batch_size, N_in]
    trace_post = trace_post - trace_post / tau_post + out_spike # shape = [batch_size, N_out]
```
因此他认为，trace_pre是linear层的输入。

当使用`loss.backward()`时，需要计算全连接层权重的梯度，即计算
$$
\frac{\partial L}{\partial w}=\frac{\partial L}{\partial z}\frac{\partial z}{\partial y}\frac{\partial y}{\partial w}
$$
其中，$z$为LIF模型输出，$y$为全连接层输出。
$\frac{\partial L}{\partial z}$为自定义反向传播的grad_output，自定义的反向传播值为$\frac{\partial y}{\partial z}$.

对于nn.Linear来说，$y=X\cdot W^T$，其中，$X$为$m\times n$的矩阵，$W$为$p\times n$的矩阵，$y$为$m\times p$的矩阵。

则$\frac{\partial y}{\partial w}$为$m\times p$的矩阵，其每个元素为$\frac{\partial y_{ij}}{\partial w_{kl}}$，其中，$i$为行，$j$为列，$k$为行，$l$为列。

In [9]:
class Linear_Spiking(nn.Module):
    def __init__(self, in_feature, out_feature, f_pre, f_post, tau_pre=2.0, tau_post=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True, w_min=-1.0, w_max=1.0):
        super(Linear_Spiking, self).__init__()
        self.tau_pre = tau_pre
        self.tau_post = tau_post
        self.v = None
        self.s = None
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.soft_reset = soft_reset
        self.trace_pre = None
        self.trace_post = None
        self.w_min = w_min
        self.w_max = w_max
        self.f_pre = f_pre
        self.f_post = f_post
        self.linear = nn.Linear(in_feature, out_feature, bias=False)
        # self.lif=LIFNeuron(tau=tau, v_threshold=v_threshold, v_reset=v_reset, soft_reset=soft_reset)

    def reset(self):
        self.v = None
        self.s = None
        self.trace_pre = None
        self.trace_post = None

    def forward(self, x):
        self.linear.weight.data.clamp_(self.w_min, self.w_max)
        # (ctx,input,fc,SpikingNeuron,tau_pre,tau_post, w_min, w_max, f_pre, f_post)
        spike = ActFun.apply(x, self.linear, self, self.tau_pre,
                             self.tau_post, self.w_min, self.w_max, self.f_pre, self.f_post)
        # spike=(torch.rand(200,10)>0.7).to(args.device)

        return spike


In [10]:
model=nn.Sequential(
    nn.Flatten(),
    Linear_Spiking(784,10,f_pre, f_post)
)
# if args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model=model.to(args.device)

In [11]:
acc_max = 0
epoch_max = 0
for epoch in range(args.epochs):
    start_time = time.time()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    model.train()
    for imag, label in train_data_loader:
        optimizer.zero_grad()
        imag = imag.to(args.device)
        label = label.to(args.device)
        label_onehot = F.one_hot(label, 10).float().to(args.device)
        out_fr = 0.
        for t in range(args.T):
            imag_possion = Possion_Encoder(imag)
            out_fr += model(imag_possion)
            # model.net[1].trace_pre, model.net[1].trece_post, grad = stdp_linear_single_step(
            #     model.net[0], imag_possion, out_fr, model.net[1].trace_pre, model.net[1].trace_post, 1, 1, 0, 1, f_pre, f_post)
            # model.net[0].weight.grad += grad

        out_fr = out_fr / args.T

        out_fr.requires_grad = True

        loss = F.mse_loss(out_fr, label_onehot)
        # model.net[1].trace_pre, model.net[1].trece_post, model.net[0].weight.grad = stdp_linear_single_step(
        #     model.net[0], imag_possion, out_fr, model.net[1].trace_pre, model.net[1].trace_post, 1, 1, 0, 1, f_pre, f_post)
        loss.backward()
        optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

    print('Epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Time: {:.4f}'.format(
        epoch, train_loss / train_samples, train_acc / train_samples, time.time() - start_time))


Epoch: 0, Train Loss: 0.1000, Train Acc: 0.0987, Time: 10.8497
Epoch: 1, Train Loss: 0.1000, Train Acc: 0.0987, Time: 14.3037
Epoch: 2, Train Loss: 0.1000, Train Acc: 0.0987, Time: 11.4937
Epoch: 3, Train Loss: 0.1000, Train Acc: 0.0987, Time: 11.0757
Epoch: 4, Train Loss: 0.1000, Train Acc: 0.0987, Time: 10.7054
