In [1]:
import os
import time
import argparse
import sys
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
from torch.utils.tensorboard import SummaryWriter
import torchvision
import numpy as np

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

In [2]:
class SNN(nn.Module):
    def __init__(self, tau):
        super().__init__()

        self.layer = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
            )

    def forward(self, x: torch.Tensor):
        return self.layer(x)

In [3]:
parser = argparse.ArgumentParser(description='LIF MNIST Training')
parser.add_argument('-T', default=100, type=int, help='simulating time-steps')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-b', default=64, type=int, help='batch size')
parser.add_argument('-epochs', default=1, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-j', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-data-dir', default='./datasets', type=str, help='root dir of MNIST dataset')
parser.add_argument('-resume', default=False, type=str, help='resume from the checkpoint path')
parser.add_argument('-amp', default=True,action='store_true', help='automatic mixed precision training')
parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron')

args = parser.parse_args(args=[])
print(args)

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

Namespace(T=100, amp=True, b=64, data_dir='./datasets', device='cuda:0', epochs=1, j=8, lr=0.001, momentum=0.9, opt='adam', resume=False, tau=2.0)


In [4]:
net = SNN(tau=args.tau)

print(net)
net.to(args.device)

SNN(
  (layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1, step_mode=s)
    (1): Linear(in_features=784, out_features=10, bias=False)
    (2): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)


SNN(
  (layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1, step_mode=s)
    (1): Linear(in_features=784, out_features=10, bias=False)
    (2): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)

In [5]:
# 初始化数据加载器
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

In [6]:
scaler = None
if args.amp:
    scaler = amp.GradScaler()

start_epoch = 0
max_test_acc = -1

optimizer = None
if args.opt == 'sgd':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
elif args.opt == 'adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise NotImplementedError(args.opt)

if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    max_test_acc = checkpoint['max_test_acc']


encoder = encoding.PoissonEncoder()

In [7]:
for epoch in range(start_epoch, args.epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_data_loader:
        optimizer.zero_grad()
        img = img.to(args.device)
        label = label.to(args.device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is not None:
            with amp.autocast():
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out_fr = 0.
            for t in range(args.T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / args.T
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()
        # print('label: ', label)
        # print('loss: ', loss.item())
        # print('train_loss: ', train_loss)
        # print('train_acc: ', train_acc)
        # label:  tensor([5, 9, 8, 9, 5, 5, 7, 0, 8, 0, 7, 1, 0, 1, 6, 0, 3, 1, 3, 3, 7, 2, 1, 6,
        # 8, 9, 2, 1, 8, 8, 1, 8, 1, 0, 0, 5, 0, 7, 2, 1, 4, 6, 5, 1, 2, 1, 9, 3,
        # 8, 4, 7, 3, 5, 0, 2, 3, 2, 1, 0, 1, 9, 5, 2, 9], device='cuda:0')
        # loss:  0.10000000149011612
        # train_loss:  6.400000095367432
        # train_acc:  9.0

        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples


    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for img, label in test_data_loader:
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()
            out_fr = 0.
            for t in range(args.T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / args.T
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples

    save_max = False
    if test_acc > max_test_acc:
        max_test_acc = test_acc
        save_max = True

    checkpoint = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'max_test_acc': max_test_acc
    }



    print(args)
    print(f'epoch ={epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
    print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
    print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')

Namespace(T=100, amp=True, b=64, data_dir='./datasets', device='cuda:0', epochs=1, j=8, lr=0.001, momentum=0.9, opt='adam', resume=False, tau=2.0)
epoch =0, train_loss = 0.0273, train_acc = 0.8701, test_loss = 0.0192, test_acc = 0.9117, max_test_acc = 0.9117
train speed = 756.2034 images/s, test speed = 2408.4661 images/s
escape time = 2022-11-25 23:29:19



In [8]:
# 保存绘图用数据
net.eval()
# 注册钩子
output_layer = net.layer[-1] # 输出层
output_layer.v_seq = []
output_layer.s_seq = []
def save_hook(m, x, y):
    m.v_seq.append(m.v.unsqueeze(0))
    m.s_seq.append(y.unsqueeze(0))

output_layer.register_forward_hook(save_hook)


with torch.no_grad():
    img, label = test_dataset[0]
    img = img.to(args.device)
    out_fr = 0.
    for t in range(args.T):
        encoded_img = encoder(img)
        out_fr += net(encoded_img)
    out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()
    print(f'Firing rate: {out_spikes_counter_frequency}')

    output_layer.v_seq = torch.cat(output_layer.v_seq)
    output_layer.s_seq = torch.cat(output_layer.s_seq)
    v_t_array = output_layer.v_seq.cpu().numpy().squeeze()  # v_t_array[i][j]表示神经元i在j时刻的电压值
    print('v_t_array:',v_t_array)
    s_t_array = output_layer.s_seq.cpu().numpy().squeeze()  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲，为0或1
    print('s_t_array:',s_t_array)

Firing rate: [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
v_t_array: [[-0.29920655 -2.1234877  -0.34667656  0.27111456 -0.91487545 -0.5903152
  -1.081855    0.         -0.12994415 -0.09245802]
 [-0.56135046 -3.2449474  -0.5342432   0.32430953 -1.3813639  -0.7676717
  -1.7292999   0.         -0.2601941  -0.07444678]
 [-0.42563343 -3.3545718  -0.47962427  0.31095773 -1.5317702  -0.7029953
  -1.8540742   0.         -0.38154185 -0.16001144]
 [-0.45681864 -3.6512332  -0.3617103   0.46627215 -1.4908309  -0.904159
  -2.1455832   0.         -0.39970565 -0.2900252 ]
 [-0.5011374  -3.6133394  -0.4066478   0.27096352 -1.6788766  -0.6792593
  -2.170682    0.         -0.43814945 -0.42537618]
 [-0.61798954 -3.5834594  -0.4156293   0.47524625 -1.7780392  -0.8914727
  -2.3014255   0.         -0.40963766 -0.28636733]
 [-0.4824766  -3.4112496  -0.39720678  0.2831207  -2.0067127  -0.8851657
  -2.1426694   0.         -0.54436094 -0.22954482]
 [-0.5222386  -3.540553   -0.4803449   0.23832893 -1.9021075  -0.8203033
  