In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import argparse, os, sys,time
from typing import Callable, Union

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser(description='STDP learning')
parser.add_argument('-T', default=10, type=int, help='simulating time-steps')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-b', default=200, type=int, help='batch size')
parser.add_argument('-epochs', default=5, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-j', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-data-dir', default='/data/tanghao/datasets/', type=str, help='root dir of dataset')
parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron')

args = parser.parse_args(args=[])
print(args)

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

Namespace(T=10, b=200, data_dir='/data/tanghao/datasets/', device='cuda:0', epochs=5, j=4, lr=0.001, momentum=0.9, opt='adam', tau=2.0)


In [3]:
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

In [4]:
print(test_dataset[0][0].shape)
print(test_dataset[0][0].device)

torch.Size([1, 28, 28])
cpu


In [5]:
def Possion_Encoder(x):
    return torch.rand_like(x).le(x).to(x)

In [6]:
print(Possion_Encoder(test_dataset[0][0]))

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,

In [7]:
def f_pre(x, w_min, alpha=0.):
    return (x - w_min) ** alpha

def f_post(x, w_max, alpha=0.):
    return (w_max - x) ** alpha

In [8]:
def stdp_linear_single_step(
    fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor,
    trace_pre: Union[float, torch.Tensor, None],
    trace_post: Union[float, torch.Tensor, None],
    tau_pre: float, tau_post: float, w_min: float, w_max: float,
    f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x
):
    if trace_pre is None:
        trace_pre = torch.zeros_like(in_spike)

    if trace_post is None:
        trace_post = torch.zeros_like(out_spike)

    weight = fc.weight.data
    trace_pre = trace_pre - trace_pre / tau_pre + in_spike      # shape = [batch_size, N_in]
    trace_post = trace_post - trace_post / tau_post + out_spike # shape = [batch_size, N_out]

    # [batch_size, N_out, N_in] -> [N_out, N_in]
    # 此处对照更新公式，使用unsqueeze添加更新公式中所缺失的一维
    # print(trace_pre.shape,trace_post.shape,weight.shape,trace_pre.unsqueeze(1).shape)
    delta_w_pre = -f_pre(weight, w_min) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)).sum(0)
    delta_w_post = f_post(weight, w_max) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)).sum(0)
    return trace_pre, trace_post, delta_w_pre + delta_w_post

此处定义每个神经元的trace都是同输入维度一致。

In [9]:
def trace_update(fc,SpikingNeuron,in_linear ,in_spike: torch.Tensor, out_spike: torch.Tensor, tau_pre: float, tau_post: float):
    if SpikingNeuron.trace_pre is None:
        # print('input:',in_spike)
        SpikingNeuron.trace_pre = torch.zeros_like(in_spike)
        fc.trace_pre = torch.zeros_like(in_spike)

    if SpikingNeuron.trace_post is None:
        SpikingNeuron.trace_post = torch.zeros_like(out_spike)

    if fc.trace is None:
        fc.trace = torch.zeros_like(in_linear)

    # SpikingNeuron.trace_pre = SpikingNeuron.trace_pre - SpikingNeuron.trace_pre / \
    #     tau_pre + in_spike      # shape = [batch_size, N_in]
    # SpikingNeuron.trace_post = SpikingNeuron.trace_post - \
    #     SpikingNeuron.trace_post / tau_post + \
    #     out_spike  # shape = [batch_size, N_out]

    fc.trace = fc.trace - fc.trace / fc.tau + in_linear      # shape = [batch_size, N_in]
    SpikingNeuron.trace_post = SpikingNeuron.trace_post - \
        SpikingNeuron.trace_post / tau_post + \
        out_spike  # shape = [batch_size, N_out]


In [10]:
def stdp_grad(
    weight, in_spike: torch.Tensor, out_spike: torch.Tensor,
    trace_pre: Union[float, torch.Tensor, None],
    trace_post: Union[float, torch.Tensor, None],
    w_min: float, w_max: float,
    f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x
):#: nn.Linear
    

    # [batch_size, N_out, N_in] -> [N_out, N_in]
    # 此处对照更新公式，使用unsqueeze添加更新公式中所缺失的一维
    # print(trace_pre.shape,trace_post.shape,weight.shape,trace_pre.unsqueeze(1).shape)
    # [200, 784] [200, 10] [10, 784] [200, 1, 784]
    # torch.Size([200, 784]) torch.Size([200, 10]) torch.Size([10, 784]) torch.Size([200, 1, 784])
    # trace_pre shape = [batch_size, N_in] = [200, 784]
    # trace_post shape = [batch_size, N_out] = [200, 10]
    # in_spike shape = [batch_size, N_in] = [200, 784]
    # out_spike shape = [batch_size, N_out] = [200, 10]
    delta_w_pre = -f_pre(weight, w_min) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)).sum(0)
    delta_w_post = f_post(weight, w_max) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)).sum(0)
    return delta_w_pre + delta_w_post

https://zhuanlan.zhihu.com/p/359524837
此处定义大于等于阈值则发射脉冲。

https://blog.csdn.net/winycg/article/details/100695373
对于pytroch来说，为获取中间层的输出数值，需要使用hook函数，hook函数包括tensor的hook和nn.Module的hook。而此处为获取脉冲神经元的输入，需要得到上一层的输出，由此需要在此注册Modelu对象的hook函数。

有register_forward_hook(hook)和register_backward_hook(hook)两种方法，分别对应前向传播和反向传播的hook函数。

In [11]:
def hook(module, input, output):
    return output.detach()

In [12]:
class STDP(torch.autograd.Function):

    @staticmethod
    def forward(ctx,fc,SpikingNeuron,tau_pre,tau_post, w_min, w_max, f_pre, f_post, input):
        if SpikingNeuron.v is None:
            SpikingNeuron.v = torch.zeros_like(input)
            SpikingNeuron.s = torch.zeros_like(input)
        SpikingNeuron.v = SpikingNeuron.v + \
            (1.0 - SpikingNeuron.v) / SpikingNeuron.tau * input
        spike = (SpikingNeuron.v >= SpikingNeuron.v_threshold).to(input)
        if SpikingNeuron.soft_reset:
            SpikingNeuron.v = SpikingNeuron.v-SpikingNeuron.v_threshold*spike
        else:
            SpikingNeuron.v = SpikingNeuron.v * \
                (1-spike) + SpikingNeuron.v_reset * spike

        spike = SpikingNeuron.v.ge(
            SpikingNeuron.v_threshold).to(SpikingNeuron.v)
        
        trace_update(SpikingNeuron, input, spike, tau_pre, tau_post)
        ctx.save_for_backward(fc.weight.data,input, spike,SpikingNeuron.trace_pre,SpikingNeuron.trace_post)

        ctx.w_min = w_min
        ctx.w_max = w_max
        ctx.f_pre = f_pre
        ctx.f_post = f_post


        return spike

    @staticmethod
    def backward(ctx, grad_output):
        weight,input, output,trace_pre,trace_post = ctx.saved_tensors
        # grad_input = grad_output.clone()
        return grad_output*stdp_grad(weight, input, output,trace_pre, trace_post, ctx.w_min, ctx.w_max, ctx.f_pre, ctx.f_post)


In [13]:
class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx,in_linear ,input,fc,SpikingNeuron,tau_pre,tau_post, w_min, w_max, f_pre, f_post):
        if SpikingNeuron.v is None:
            SpikingNeuron.v = torch.zeros_like(input)
            SpikingNeuron.s = torch.zeros_like(input)
        SpikingNeuron.v = SpikingNeuron.v + (- SpikingNeuron.v+input) / SpikingNeuron.tau
        # spike = (SpikingNeuron.v >= SpikingNeuron.v_threshold).to(input)
        spike = SpikingNeuron.v.ge(
            SpikingNeuron.v_threshold).to(SpikingNeuron.v)
        if SpikingNeuron.soft_reset:
            SpikingNeuron.v = SpikingNeuron.v-SpikingNeuron.v_threshold*spike
        else:
            SpikingNeuron.v = SpikingNeuron.v * \
                (1-spike) + SpikingNeuron.v_reset * spike
        
        trace_update(fc,SpikingNeuron,in_linear ,input, spike, tau_pre, tau_post)# todo: 需要考察一下trace_pre到底是linuear层的输出还是lif层的输入 fc.trace
        ctx.save_for_backward(fc.weight.data,in_linear, spike,fc.trace,SpikingNeuron.trace_post)#SpikingNeuron.trace_pre

        ctx.w_min = w_min
        ctx.w_max = w_max
        ctx.f_pre = f_pre
        ctx.f_post = f_post


        return spike

    @staticmethod
    def backward(ctx, grad_output):
        weight,input, output,trace_pre,trace_post = ctx.saved_tensors
        # grad_input = grad_output.clone()
        # grad_output.shape=[200, 10], stdp_grad.shape=[10, 784]
        print(grad_output.shape, stdp_grad(weight, input, output,trace_pre, trace_post, ctx.w_min, ctx.w_max, ctx.f_pre, ctx.f_post).shape)
        return grad_output*stdp_grad(weight, input, output,trace_pre, trace_post, ctx.w_min, ctx.w_max, ctx.f_pre, ctx.f_post)


In [14]:
# output=STDP.apply(fc,SpikingNeuron,tau_pre,tau_post, w_min, w_max, f_pre, f_post, input)

In [15]:
# def mem_update(ops,x,mem,tau,spike):
#     mem=mem*tau*(1-spike)+ops(x)
#     spike = act_fun(mem)
#     return mem, spike

对于STDP，当使用trace的时候，需要记录前后神经元的信息，主要需要tau的值，因此重构Linear类，加入tau属性。

同时需要考察一下trace_pre到底是linuear层的输出还是lif层的输入，此处作为linear层的输出，加入该属性，方便trace_update函数的调用。

在spikingjelly中，此部分代码为：
```python
class STDPLearner(base.MemoryModule):
    def __init__(
        self, step_mode: str,
        synapse: Union[nn.Conv2d, nn.Linear], sn: neuron.BaseNode,
        ...
        self.in_spike_monitor = monitor.InputMonitor(synapse)
        self.out_spike_monitor = monitor.OutputMonitor(sn)

def stdp_linear_single_step(
    fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor,
    trace_pre: Union[float, torch.Tensor, None],
    trace_post: Union[float, torch.Tensor, None],
    ...
    trace_pre = trace_pre - trace_pre / tau_pre + in_spike      # shape = [batch_size, N_in]
    trace_post = trace_post - trace_post / tau_post + out_spike # shape = [batch_size, N_out]
```
因此他认为，trace_pre是linear层的输入。

In [16]:
class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True,tau=2.):
        super(Linear, self).__init__(in_features, out_features, bias)
        self.tau = tau
        self.trace = None

    def reset(self):
        self.trace = None

In [17]:
# class LIFNeuron(nn.Module):
#     def __init__(self ,weight,tau=2.0,tau_pre=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True):
#         super(LIFNeuron, self).__init__()
#         self.tau = tau
#         self.tau_pre = tau_pre
#         self.v = None
#         self.s = None
#         self.v_threshold = v_threshold
#         self.v_reset = v_reset
#         self.soft_reset = soft_reset
#         self.trace_pre = None
#         self.trace_post = None
        

#     def reset(self):
#         self.v = None
#         self.s = None
#         self.trace_pre = None
#         self.trace_post = None

#     def forward(self, x):
#         spike=STDP.apply(x,self)

#         return spike


In [18]:
# class LIFNeuron(nn.Module):
#     def __init__(self, tau=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True):
#         super(LIFNeuron, self).__init__()
#         self.tau = tau
#         self.v = None
#         self.s = None
#         self.v_threshold = v_threshold
#         self.v_reset = v_reset
#         self.soft_reset = soft_reset
#         self.trace_pre = None
#         self.trace_post = None

#     def reset(self):
#         self.v = None
#         self.s = None

#     def forward(self, x):
#         if self.v is None:
#             self.v = torch.zeros_like(x)
#             self.s = torch.zeros_like(x)
#         self.v = self.v + (1.0 - self.v) / self.tau * x
#         spike = (self.v >= self.v_threshold).to(x)
#         if self.soft_reset:
#             self.v = self.v-self.v_threshold*spike
#         else:
#             self.v = self.v * (1-spike) + self.v_reset * spike

#         return (self.v >= self.v_threshold).to(self.v)


当使用`loss.backward()`时，需要计算全连接层权重的梯度，即计算
$$
\frac{\partial L}{\partial w}=\frac{\partial L}{\partial z}\frac{\partial z}{\partial y}\frac{\partial y}{\partial w}
$$
其中，$z$为LIF模型输出，$y$为全连接层输出。
$\frac{\partial L}{\partial z}$为自定义反向传播的grad_output，自定义的反向传播值为$\frac{\partial y}{\partial z}$.

对于nn.Linear来说，$y=X\cdot W^T$，其中，$X$为$m\times n$的矩阵，$W$为$p\times n$的矩阵，$y$为$m\times p$的矩阵。

则$\frac{\partial y}{\partial w}$为$m\times p$的矩阵，其每个元素为$\frac{\partial y_{ij}}{\partial w_{kl}}$，其中，$i$为行，$j$为列，$k$为行，$l$为列。

In [19]:
class Linear_Spiking(nn.Module):
    def __init__(self, in_feature, out_feature, f_pre, f_post, tau=2.0, tau_pre=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True, w_min=-1.0, w_max=1.0):
        super(Linear_Spiking, self).__init__()
        self.tau = tau
        self.v = None
        self.s = None
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.soft_reset = soft_reset
        self.trace_pre = None
        self.trace_post = None
        self.w_min = w_min
        self.w_max = w_max
        self.f_pre = f_pre
        self.f_post = f_post
        self.flatten = nn.Flatten()
        self.linear = Linear(in_feature, out_feature, bias=False, tau=tau_pre)
        # self.lif=LIFNeuron(tau=tau, v_threshold=v_threshold, v_reset=v_reset, soft_reset=soft_reset)

    def reset(self):
        self.linear.reset()
        self.v = None
        self.s = None
        self.trace_pre = None
        self.trace_post = None

    def forward(self, x):
        y = self.flatten(x)
        self.linear.weight.data.clamp_(self.w_min, self.w_max)
        in_spike = self.linear(y)
        # print(input.shape)
        # if self.v is None:
        #     self.v = torch.zeros_like(input)
        #     self.s = torch.zeros_like(input)
        # self.v = self.v + (1.0 - self.v) / self.tau * input
        # spike = (self.v >= self.v_threshold).to(input)
        # if self.soft_reset:
        #     self.v = self.v-self.v_threshold*spike
        # else:
        #     self.v = self.v * (1-spike) + self.v_reset * spike
        # (ctx, input,fc,SpikingNeuron,tau_pre,tau_post, w_min, w_max, f_pre, f_post)
        spike = ActFun.apply(y,in_spike, self.linear, self, self.linear.tau,
                             self.tau, self.w_min, self.w_max, self.f_pre, self.f_post)

        return spike


In [20]:
# class Linear_Spiking(nn.Module):
#     def __init__(self, N_in, N_out, tau=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True):
#         super(Linear_Spiking, self).__init__()
#         self.lif = LIFNeuron(tau=tau, v_threshold=v_threshold,
#                              v_reset=v_reset, soft_reset=soft_reset)
#         self.net = nn.Sequential(
#             # nn.Flatten(),
#             Linear(N_in, N_out, bias=False,tau=2.),
#             LIFNeuron(tau=tau, v_threshold=v_threshold,
#                       v_reset=v_reset, soft_reset=soft_reset)
#         )
#         # self.handle=self.net[0].register_forward_hook(hook)

#     def reset(self):
#         self.lif.reset()

#     def forward(self, x):
#         y = self.net(x)
#         # if on_grad:
#         #     if self.synapse.weight.grad is None:
#         #         self.synapse.weight.grad = -delta_w
#         #     else:
#         #         self.synapse.weight.grad = self.synapse.weight.grad - delta_w
#         # else:
#         #     return delta_w
#         return y


In [21]:
model=Linear_Spiking(784,10,f_pre, f_post)
# if args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model=model.to(args.device)

In [22]:
for item in model.parameters():
    print(item.shape)

torch.Size([10, 784])


In [23]:
model.linear.weight.grad=torch.rand_like(model.linear.weight)
print(model.linear.weight.grad)
print(model.linear.weight)
model.linear.weight.backward(torch.ones_like(model.linear.weight))
print(model.linear.weight)

tensor([[0.7171, 0.5731, 0.0279,  ..., 0.9072, 0.8416, 0.5684],
        [0.3181, 0.1979, 0.6324,  ..., 0.5656, 0.8090, 0.6182],
        [0.2453, 0.8775, 0.4082,  ..., 0.4226, 0.4794, 0.9095],
        ...,
        [0.7169, 0.0875, 0.5168,  ..., 0.3958, 0.0653, 0.9119],
        [0.1037, 0.6083, 0.0148,  ..., 0.5965, 0.9038, 0.9610],
        [0.4198, 0.6720, 0.6086,  ..., 0.4907, 0.4397, 0.3188]],
       device='cuda:0')
Parameter containing:
tensor([[-0.0137, -0.0128,  0.0069,  ..., -0.0109, -0.0007,  0.0099],
        [-0.0193, -0.0287, -0.0151,  ..., -0.0254, -0.0139, -0.0271],
        [-0.0329,  0.0083,  0.0298,  ..., -0.0232, -0.0212,  0.0169],
        ...,
        [-0.0099, -0.0263,  0.0098,  ..., -0.0223,  0.0094, -0.0254],
        [-0.0216,  0.0084, -0.0029,  ..., -0.0162, -0.0305,  0.0315],
        [-0.0056, -0.0263,  0.0056,  ..., -0.0016,  0.0349, -0.0276]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[-0.0137, -0.0128,  0.0069,  ..., -0.0109, -0.00

In [24]:
acc_max = 0
epoch_max = 0
for epoch in range(args.epochs):
    start_time = time.time()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    model.train()
    for imag, label in train_data_loader:
        optimizer.zero_grad()
        imag = imag.to(args.device)
        label = label.to(args.device)
        label_onehot = F.one_hot(label, 10).float().to(args.device)
        out_fr = 0.
        for t in range(args.T):
            imag_possion = Possion_Encoder(imag)
            out_fr += model(imag_possion)
            # model.net[1].trace_pre, model.net[1].trece_post, grad = stdp_linear_single_step(
            #     model.net[0], imag_possion, out_fr, model.net[1].trace_pre, model.net[1].trace_post, 1, 1, 0, 1, f_pre, f_post)
            # model.net[0].weight.grad += grad

        out_fr = out_fr / args.T

        loss = F.mse_loss(out_fr, label_onehot)
        # model.net[1].trace_pre, model.net[1].trece_post, model.net[0].weight.grad = stdp_linear_single_step(
        #     model.net[0], imag_possion, out_fr, model.net[1].trace_pre, model.net[1].trace_post, 1, 1, 0, 1, f_pre, f_post)
        loss.backward()
        optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

    print('Epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Time: {:.4f}'.format(
        epoch, train_loss / train_samples, train_acc / train_samples, time.time() - start_time))


torch.Size([200, 10]) torch.Size([10, 784])


RuntimeError: The size of tensor a (10) must match the size of tensor b (784) at non-singleton dimension 1