In [1]:
import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
from sinabs.from_torch import from_model
from sinabs.backend.dynapcnn import DynapcnnNetwork, DynapcnnNetworkGraph

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fb8db634c90>

In [3]:
channels = 1
height = 28
width = 28

input_shape = (channels, height, width)

## Network Module (pure Pytorch)

In [4]:
ann = nn.Sequential(
    nn.Conv2d(1, 20, 5, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Conv2d(20, 32, 5, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Conv2d(32, 128, 3, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Flatten(),
    nn.Linear(128, 500, bias=False),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(500, 10, bias=False),
)

## Sinabs Model

In [5]:
sinabs_model = from_model(ann, add_spiking_output=True, batch_size=1)
count = 0
for l in sinabs_model.spiking_model:
    print(count, l)
    count += 1

0 Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), bias=False)
1 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
2 AvgPool2d(kernel_size=2, stride=2, padding=0)
3 Conv2d(20, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
4 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
5 AvgPool2d(kernel_size=2, stride=2, padding=0)
6 Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
7 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
8 AvgPool2d(kernel_size=2, stride=2, padding=0)
9 Flatten(start_dim=1, end_dim=-1)
10 Linear(in_features=128, out_features=500, bias=False)
11 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
12 Flatten(

In [6]:
hw_model = DynapcnnNetworkGraph(
    sinabs_model,
    discretize=True,
    input_shape=input_shape
)

0 Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), bias=False)
1 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
2 AvgPool2d(kernel_size=2, stride=2, padding=0)
3 Conv2d(20, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
4 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
5 AvgPool2d(kernel_size=2, stride=2, padding=0)
6 Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
7 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
8 AvgPool2d(kernel_size=2, stride=2, padding=0)
9 Linear(in_features=128, out_features=500, bias=False)
10 IAFSqueeze(spike_threshold=Parameter containing:
tensor(1.), min_v_mem=Parameter containing:
tensor(-1.), batch_size=1, num_timesteps=-1)
11 Linear(in_features=500, out_features=10, bia

In [7]:
for edge in hw_model.graph_tracer.edges_list:
    print(edge)

(0, 1)
(1, 2)
(2, 3)
(3, 4)
(4, 5)
(5, 6)
(6, 7)
(7, 8)
(8, 9)
(9, 10)
(10, 11)
(11, 12)
(12, 13)


In [8]:
for edge in hw_model.sinabs_edges:
    print(edge)

(0, 1)
(1, 2)
(2, 3)
(3, 4)
(4, 5)
(5, 6)
(6, 7)
(7, 8)
(8, 10)
(10, 11)
(11, 13)
