In [1]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import monitor, neuron, functional, layer, surrogate

In [2]:
net = nn.Sequential(
    layer.Linear(8, 4),
    neuron.IFNode(),
    layer.Linear(4, 2),
    neuron.IFNode()
)

for param in net.parameters():
    print(param)
    param.data.abs_()

functional.set_step_mode(net, 'm')

Parameter containing:
tensor([[ 0.3061,  0.3096,  0.2182,  0.0596, -0.0272,  0.2302, -0.0408,  0.0239],
        [-0.2832,  0.0125,  0.1958, -0.2635,  0.2015, -0.2020, -0.1491,  0.3531],
        [ 0.3143,  0.0347, -0.0431,  0.2495,  0.1523, -0.1337, -0.3041, -0.2078],
        [-0.2687, -0.2032,  0.2507, -0.0075,  0.2419,  0.2392,  0.1708,  0.0301]],
       requires_grad=True)
Parameter containing:
tensor([ 0.2052, -0.2712,  0.1804,  0.3093], requires_grad=True)
Parameter containing:
tensor([[ 0.3422,  0.0617, -0.0562, -0.2766],
        [-0.4885, -0.3900, -0.4209, -0.2670]], requires_grad=True)
Parameter containing:
tensor([-0.0510, -0.4630], requires_grad=True)


In [3]:
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode)
linear_monitor=monitor.OutputMonitor(net,layer.Linear)
T = 4
N = 1
x_seq = torch.rand([T, N, 8])

# net(x_seq)

with torch.no_grad():
    net(x_seq)

print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
print(f'linear_monitor.records=\n{linear_monitor.records}')

spike_seq_monitor.records=
[tensor([[[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]],

        [[0., 1., 1., 1.]],

        [[1., 0., 0., 0.]]]), tensor([[[0., 0.]],

        [[0., 1.]],

        [[1., 1.]],

        [[0., 0.]]])]
linear_monitor.records=
[tensor([[[0.6820, 0.9185, 0.6898, 0.7219]],

        [[0.8513, 0.9923, 0.7236, 1.0778]],

        [[0.8396, 1.0755, 1.0126, 1.0248]],

        [[0.5967, 0.9662, 0.7659, 0.9066]]]), tensor([[[0.0510, 0.4630]],

        [[0.7877, 2.0296]],

        [[0.4455, 1.5410]],

        [[0.3931, 0.9516]]])]


In [4]:
print(f'spike_seq_monitor[0]={spike_seq_monitor[0]}')

spike_seq_monitor[0]=tensor([[[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]],

        [[0., 1., 1., 1.]],

        [[1., 0., 0., 0.]]])


In [5]:
print(f'net={net}')
print(f'spike_seq_monitor.monitored_layers={spike_seq_monitor.monitored_layers}')
print(f'linear_monitor.monitored_layers={linear_monitor.monitored_layers}')

net=Sequential(
  (0): Linear(in_features=8, out_features=4, bias=True)
  (1): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
  (2): Linear(in_features=4, out_features=2, bias=True)
  (3): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
)
spike_seq_monitor.monitored_layers=['1', '3']
linear_monitor.monitored_layers=['0', '2']


In [6]:
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")

spike_seq_monitor['1']=[tensor([[[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]],

        [[0., 1., 1., 1.]],

        [[1., 0., 0., 0.]]])]


In [7]:
spike_seq_monitor.clear_recorded_data()
print(f'spike_seq_monitor.records={spike_seq_monitor.records}')
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")

spike_seq_monitor.records=[]
spike_seq_monitor['1']=[]


In [8]:
spike_seq_monitor.remove_hooks()
linear_monitor.remove_hooks()

In [9]:
def cal_firing_rate(s_seq: torch.Tensor):
    # s_seq.shape = [T, N, *]
    return s_seq.flatten(1).mean(1)

In [10]:
fr_monitor = monitor.OutputMonitor(net, neuron.IFNode, cal_firing_rate)

with torch.no_grad():
    functional.reset_net(net)
    fr_monitor.disable()
    net(x_seq)
    functional.reset_net(net)
    print(f'after call fr_monitor.disable(), fr_monitor.records=\n{fr_monitor.records}')

    fr_monitor.enable()
    net(x_seq)
    print(f'after call fr_monitor.enable(), fr_monitor.records=\n{fr_monitor.records}')
    functional.reset_net(net)
    del fr_monitor

after call fr_monitor.disable(), fr_monitor.records=
[]
after call fr_monitor.enable(), fr_monitor.records=
[tensor([0.0000, 1.0000, 0.7500, 0.2500]), tensor([0.0000, 0.5000, 1.0000, 0.0000])]


In [11]:
for m in net.modules():
    if isinstance(m, neuron.IFNode):
        m.store_v_seq = True

In [12]:
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.IFNode)
with torch.no_grad():
    net(x_seq)
    print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
    functional.reset_net(net)
    del v_seq_monitor

v_seq_monitor.records=
[tensor([[[0.6820, 0.9185, 0.6898, 0.7219]],

        [[0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.8396, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.9662, 0.7659, 0.9066]]]), tensor([[[0.0510, 0.4630]],

        [[0.8387, 0.0000]],

        [[0.0000, 0.0000]],

        [[0.3931, 0.9516]]])]


In [13]:
input_monitor = monitor.InputMonitor(net, neuron.IFNode)
with torch.no_grad():
    net(x_seq)
    print(f'input_monitor.records=\n{input_monitor.records}')
    functional.reset_net(net)
    del input_monitor

input_monitor.records=
[tensor([[[0.6820, 0.9185, 0.6898, 0.7219]],

        [[0.8513, 0.9923, 0.7236, 1.0778]],

        [[0.8396, 1.0755, 1.0126, 1.0248]],

        [[0.5967, 0.9662, 0.7659, 0.9066]]]), tensor([[[0.0510, 0.4630]],

        [[0.7877, 2.0296]],

        [[0.4455, 1.5410]],

        [[0.3931, 0.9516]]])]


In [14]:
spike_seq_grad_monitor = monitor.GradOutputMonitor(net, neuron.IFNode)
net(x_seq).sum().backward()
print(f'spike_seq_grad_monitor.records=\n{spike_seq_grad_monitor.records}')
functional.reset_net(net)
del spike_seq_grad_monitor

spike_seq_grad_monitor.records=
[tensor([[[1., 1.]],

        [[1., 1.]],

        [[1., 1.]],

        [[1., 1.]]]), tensor([[[ 0.5635,  0.2185,  0.2244,  0.4061]],

        [[ 0.3540,  0.0684,  0.0633,  0.2842]],

        [[ 0.0602, -0.0479, -0.0564,  0.0736]],

        [[ 0.5860,  0.4048,  0.4338,  0.3470]]])]


In [15]:
net = []
for i in range(10):
    net.append(layer.Linear(8, 8))
    net.append(neuron.IFNode())

net = nn.Sequential(*net)

functional.set_step_mode(net, 'm')

T = 4
N = 1
x_seq = torch.rand([T, N, 8])

input_grad_monitor = monitor.GradInputMonitor(net, neuron.IFNode, function_on_grad_input=torch.norm)

for alpha in [0.1, 0.5, 2, 4, 8]:
    for m in net.modules():
        if isinstance(m, surrogate.Sigmoid):
            m.alpha = alpha
    net(x_seq).sum().backward()
    print(f'alpha={alpha}, input_grad_monitor.records=\n{input_grad_monitor.records}\n')
    functional.reset_net(net)
    # zero grad
    for param in net.parameters():
        param.grad.zero_()

    input_grad_monitor.records.clear()

alpha=0.1, input_grad_monitor.records=
[tensor(0.3861), tensor(0.0124), tensor(0.0004), tensor(5.4694e-06), tensor(1.0116e-07), tensor(2.4601e-09), tensor(4.0712e-11), tensor(7.2066e-13), tensor(1.1894e-14), tensor(2.7602e-16)]

alpha=0.5, input_grad_monitor.records=
[tensor(1.7764), tensor(0.2671), tensor(0.0431), tensor(0.0026), tensor(0.0002), tensor(2.4132e-05), tensor(1.8576e-06), tensor(1.7545e-07), tensor(1.3205e-08), tensor(1.4386e-09)]

alpha=2, input_grad_monitor.records=
[tensor(3.7181), tensor(1.2200), tensor(0.5242), tensor(0.0952), tensor(0.0183), tensor(0.0020), tensor(0.0003), tensor(7.2221e-05), tensor(9.8139e-06), tensor(2.5349e-06)]

alpha=4, input_grad_monitor.records=
[tensor(3.6314), tensor(1.0233), tensor(0.4522), tensor(0.0771), tensor(0.0056), tensor(0.0009), tensor(0.0001), tensor(6.2765e-05), tensor(8.8703e-06), tensor(3.6065e-06)]

alpha=8, input_grad_monitor.records=
[tensor(2.4827), tensor(0.5324), tensor(0.2613), tensor(0.0909), tensor(0.0322), tensor(0.0