In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import argparse, os, sys,time
from typing import Callable, Union

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser(description='STDP learning')
parser.add_argument('-T', default=100, type=int, help='simulating time-steps')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-b', default=64, type=int, help='batch size')
parser.add_argument('-epochs', default=1, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-j', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-data-dir', default='/data/tanghao/datasets/', type=str, help='root dir of dataset')
parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron')

args = parser.parse_args(args=[])
print(args)

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

Namespace(T=100, b=64, data_dir='/data/tanghao/datasets/', device='cuda:0', epochs=1, j=4, lr=0.001, momentum=0.9, opt='adam', tau=2.0)


In [3]:
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

In [4]:
print(test_dataset[0][0].shape)
print(test_dataset[0][0].device)

torch.Size([1, 28, 28])
cpu


设像素密度为$\lambda$，$N(\Delta t)\sim P(\lambda \Delta t)$，则$\Delta t$时刻后发射脉冲的概率为$P(N(\Delta t)=1)=\lambda \Delta t(e^{-\lambda \Delta t})\approx \lambda \Delta t+o(\Delta t)$
因此在$\Delta t$时刻后发射脉冲的概率正比像素密度，设$x\sim U(0,1)$，则$P(N(\Delta t)=1)=P(x<\lambda \Delta t)=\lambda \Delta t$，由此，可以使用代码实现为`torch.rand_like(x).le(x)`。
    <!-- def spike_generator(self, rate, dt):
        return np.random.uniform(0, 1, size=rate.shape) < rate * dt -->

In [5]:
def Possion_Encoder(x):
    return torch.rand_like(x).le(x).to(x)

In [6]:
print(Possion_Encoder(test_dataset[0][0]))

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0.,

In [7]:
def mem_update(ops,x,mem,tau,spike):
    mem=mem*tau*(1-spike)+ops(x)

In [8]:
class ActFun(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(0).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

In [9]:
# class LIFNeuron(nn.Module):
#     def __init__(self, tau=2.0,v_threshold=1.0, v_reset=0.0,soft_reset=True):
#         super(LIFNeuron, self).__init__()
#         self.tau = tau
#         self.v = None
#         self.s = None
#         self.v_threshold = v_threshold
#         self.v_reset = v_reset
#         self.soft_reset=soft_reset

#     def reset(self):
#         self.v = None
#         self.s = None
    
#     def charge(self, x):
#         if self.v is None:
#             self.v = torch.zeros_like(x)
#             self.s = torch.zeros_like(x)
#         self.v = self.v + (1.0 - self.v) / self.tau * x

#         return self

#     def spike(self):
#         if self.soft_reset:
#             self.v = self.v * (self.v < self.v_threshold).float() + self.v_reset * (self.v >= self.v_threshold).float()
#         else:
#             self.v=self.v*(self.v<self.v_threshold)
#         return (self.v >= self.v_threshold).to(self.v)

In [10]:
class LIFNeuron(nn.Module):
    def __init__(self, tau=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True):
        super(LIFNeuron, self).__init__()
        self.tau = tau
        self.v = None
        self.s = None
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.soft_reset = soft_reset

    def reset(self):
        self.v = None
        self.s = None

    def forward(self, x):
        if self.v is None:
            self.v = torch.zeros_like(x)
            self.s = torch.zeros_like(x)
        self.v = self.v + (1.0 - self.v) / self.tau * x
        spike = (self.v >= self.v_threshold).to(x)
        if self.soft_reset:
            self.v = self.v-self.v_threshold*spike
        else:
            self.v = self.v * (1-spike) + self.v_reset * spike

        return (self.v >= self.v_threshold).to(self.v)


In [11]:
# class Linear_Spiking(nn.Module):
#     def __init__(self, N_in, N_out, tau=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True):
#         super(Linear_Spiking, self).__init__()
#         self.prototype = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(N_in, N_out,bias=False)
#         )
#         self.lif = LIFNeuron(tau=tau, v_threshold=v_threshold,
#                              v_reset=v_reset, soft_reset=soft_reset)
        

#     def reset(self):
#         self.lif.reset()
        

#     def forward(self, x):
#         x = self.prototype(x)
#         y = self.lif.charge(x).spike()
#         return y


$$
tr_{pre}[i][t]=tr_{pre}[i][t]-\frac{tr_{pre}[i][t-1]}{\tau_{pre}}+s[i][t]\\
tr_{post}[j][t]=tr_{post}[j][t]-\frac{tr_{post}[j][t-1]}{\tau_{post}}+s[j][t]\\
$$

$$
\Delta W[i][j][t]=F_{post}(w[i][j][t])\cdot tr_{pre}[i][t]\cdot s[j][t]-F_{pre}(w[i][j][t])\cdot tr_{post}[j][t]\cdot s[i][t]
$$

In [None]:
def stdp_linear_single_step(
    fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor,
    trace_pre: Union[float, torch.Tensor, None],
    trace_post: Union[float, torch.Tensor, None],
    tau_pre: float, tau_post: float, w_min: float, w_max: float,
    f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x
):
    if trace_pre is None:
        trace_pre = 0.

    if trace_post is None:
        trace_post = 0.

    weight = fc.weight.data
    trace_pre = trace_pre - trace_pre / tau_pre + in_spike      # shape = [batch_size, N_in]
    trace_post = trace_post - trace_post / tau_post + out_spike # shape = [batch_size, N_out]

    # [batch_size, N_out, N_in] -> [N_out, N_in]
    # 此处对照更新公式，使用unsqueeze添加更新公式中所缺失的一维
    delta_w_pre = -f_pre(weight, w_min) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)).sum(0)
    delta_w_post = f_post(weight, w_max) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)).sum(0)
    return trace_pre, trace_post, delta_w_pre + delta_w_post

In [12]:
class Linear_Spiking(nn.Module):
    def __init__(self, N_in, N_out, tau=2.0, v_threshold=1.0, v_reset=0.0, soft_reset=True):
        super(Linear_Spiking, self).__init__()
        self.lif = LIFNeuron(tau=tau, v_threshold=v_threshold,
                             v_reset=v_reset, soft_reset=soft_reset)
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(N_in, N_out, bias=False),
            LIFNeuron(tau=tau, v_threshold=v_threshold,
                      v_reset=v_reset, soft_reset=soft_reset)
        )

    def reset(self):
        self.lif.reset()

    def forward(self, x):
        y = self.net(x)
        if self.synapse.weight.grad is None:
            self.synapse.weight.grad = -delta_w
        else:
            self.synapse.weight.grad = self.synapse.weight.grad - delta_w
        # if on_grad:
        #     if self.synapse.weight.grad is None:
        #         self.synapse.weight.grad = -delta_w
        #     else:
        #         self.synapse.weight.grad = self.synapse.weight.grad - delta_w
        # else:
        #     return delta_w
        return y


In [None]:
class weight_update(nn.Module):
    def __init__(self, synapse, tau_pre, tau_post, w_min, w_max, f_pre, f_post):
        super(weight_update, self).__init__()
        self.synapse = synapse
        self.tau_pre = tau_pre
        self.tau_post = tau_post
        self.w_min = w_min
        self.w_max = w_max
        self.f_pre = f_pre
        self.f_post = f_post
        self.trace_pre = None
        self.trace_post = None

    

In [13]:
model=Linear_Spiking(784,10)
# if args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model=model.to(args.device)

In [14]:
for item in model.parameters():
    print(item.shape)

torch.Size([10, 784])


In [15]:
acc_max = 0
epoch_max = 0
for epoch in range(args.epochs):
    start_time = time.time()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    model.train()
    for imag, label in train_data_loader:
        imag = imag.to(args.device)
        label = label.to(args.device)
        label_onehot = F.one_hot(label, 10)
        out_fr = 0.
        for t in range(args.T):
            imag_possion = Possion_Encoder(imag)
            out_fr += model(imag_possion)

        out_fr = out_fr / args.T
        optimizer.zero_grad()
        loss = F.mse_loss(out_fr, label_onehot)
        
        # loss.backward()
        optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn