In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import neuron,functional,layer

In [2]:
net_s=neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
print(y_seq.shape)

net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
print(z_seq.shape)

# z_seq is identical to y_seq

  v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
  (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
) supports for step_mode == 's', which should not be contained by MultiStepContainer!


torch.Size([4, 1, 3, 8, 8])
torch.Size([4, 1, 3, 8, 8])


In [3]:
with torch.no_grad():
    T = 4
    N = 1
    C = 3
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])

    conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
    bn = nn.BatchNorm2d(8)

    y_seq = functional.multi_step_forward(x_seq, (conv, bn))
    print(y_seq.shape)

    net = layer.MultiStepContainer(conv, bn)
    z_seq = net(x_seq)
    print(z_seq.shape)
    # z_seq is identical to y_seq

    p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
    # p_seq.shape = [T, N, 8, H, W]

    net = layer.SeqToANNContainer(conv, bn)
    q_seq = net(x_seq)
    print(q_seq.shape)

    # q_seq is identical to p_seq, and also identical to y_seq and z_seq

torch.Size([4, 1, 8, 8, 8])
torch.Size([4, 1, 8, 8, 8])
torch.Size([4, 1, 8, 8, 8])


In [4]:
ann = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(8),
    nn.ReLU()
)

print(f'ann.state_dict.keys()={ann.state_dict().keys()}')

net_container = nn.Sequential(
    layer.SeqToANNContainer(
        nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(8),
    ),
    neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')

net_origin = nn.Sequential(
    layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(8),
    neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')

try:
    print('net_container is trying to load state dict from ann...')
    net_container.load_state_dict(ann.state_dict())
    print('Load success!')
except BaseException as e:
    print('net_container can not load! The error message is\n', e)

try:
    print('net_origin is trying to load state dict from ann...')
    net_origin.load_state_dict(ann.state_dict())
    print('Load success!')
except BaseException as e:
    print('net_origin can not load! The error message is', e)

ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
 Error(s) in loading state_dict for Sequential:
	Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var". 
	Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked". 
net_origin is trying to load state dict from ann...
Load success!


In [5]:
with torch.no_grad():
    T = 4
    N = 2
    C = 4
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])
    net = layer.StepModeContainer(
        False,
        nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(C),
    )
    print(f'net.step_mode={net.step_mode}')
    net.step_mode = 'm'
    y_seq = net(x_seq)
    print(y_seq.shape)
    # y_seq.shape = [T, N, C, H, W]

    net.step_mode = 's'
    y = net(x_seq[0])
    print(y.shape)
    # y.shape = [N, C, H, W]

net.step_mode=s
torch.Size([4, 2, 4, 8, 8])
torch.Size([2, 4, 8, 8])


In [6]:
with torch.no_grad():
    net = layer.StepModeContainer(
        True,
        neuron.IFNode()
    )
    functional.set_step_mode(net, 'm')
    print(f'net.step_mode={net.step_mode}')
    print(f'net[0].step_mode={net[0].step_mode}')

  v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
  (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
) supports for step_mode == 's', which should not be contained by StepModeContainer!


net.step_mode=m
net[0].step_mode=s
