In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import neuron,functional,layer

In [2]:
net = neuron.IFNode(step_mode='m')
# 'm' is the multi-step mode
net.step_mode = 's'
# 's' is the single-step mode

In [3]:
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = []
for t in range(T):
    x = x_seq[t]  # x.shape = [N, C, H, W]
    y = net_s(x)  # y.shape = [N, C, H, W]
    y_seq.append(y.unsqueeze(0))# torch.Size([1, 1, 3, 8, 8])

print(y_seq[1].shape)
y_seq = torch.cat(y_seq)# torch.Size([4, 1, 3, 8, 8])
print(y_seq.shape)
# y_seq.shape = [T, N, C, H, W]

torch.Size([1, 1, 3, 8, 8])
torch.Size([4, 1, 3, 8, 8])


In [4]:
print(x_seq.shape)
y_seq = functional.multi_step_forward(x_seq, net_s)
print(y_seq.shape)

torch.Size([4, 1, 3, 8, 8])
torch.Size([4, 1, 3, 8, 8])


In [5]:
net_m = neuron.IFNode(step_mode='m')
y_seq = net_m(x_seq)
print(y_seq.shape)

torch.Size([4, 1, 3, 8, 8])


In [6]:
net_s.reset()
x = torch.rand([4])
print(net_s)
print(f'the initial v={net_s.v}')
y = net_s(x)
print(f'x={x}')
print(f'y={y}')
print(f'v={net_s.v}')

IFNode(
  v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
  (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
the initial v=0.0
x=tensor([0.8799, 0.1571, 0.1993, 0.9120])
y=tensor([0., 0., 0., 0.])
v=tensor([0.8799, 0.1571, 0.1993, 0.9120])


In [8]:
net_s.reset()
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
print(y_seq.shape)

net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
print(z_seq.shape)

# z_seq is identical to y_seq

  v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
  (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
) supports for step_mode == 's', which should not be contained by MultiStepContainer!


torch.Size([4, 1, 3, 8, 8])
torch.Size([4, 1, 3, 8, 8])


In [11]:
with torch.no_grad():
    T = 4
    N = 1
    C = 3
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])

    conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
    bn = nn.BatchNorm2d(8)

    y_seq = functional.multi_step_forward(x_seq, (conv, bn))
    print(y_seq.shape)

    net = layer.MultiStepContainer(conv, bn)
    z_seq = net(x_seq)
    print(z_seq.shape)
    # z_seq is identical to y_seq

    p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
    # p_seq.shape = [T, N, 8, H, W]

    net = layer.SeqToANNContainer(conv, bn)
    q_seq = net(x_seq)
    print(q_seq.shape)

    # q_seq is identical to p_seq, and also identical to y_seq and z_seq

torch.Size([4, 1, 8, 8, 8])
torch.Size([4, 1, 8, 8, 8])
torch.Size([4, 1, 8, 8, 8])
