In [1]:
import torch
import torch.nn as nn
from sinabs.backend.dynapcnn import DynapcnnNetworkGraph
from sinabs.layers import Merge, IAFSqueeze, SumPool2d
import sinabs.layers as sl
from sinabs.activation.surrogate_gradient_fn import PeriodicExponential

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fac3a52bad0>

In [3]:
channels = 2
height = 34
width = 34

input_shape = (channels, height, width)

## Network Module

We need to define a `nn.Module` implementing the network we want the chip to reproduce.

```mermaid
stateDiagram
    [*] --> A
    A --> B
    B --> C
    D --> E
    E --> F
    C --> G
    F --> G
    G --> H
    H --> I
    I --> [*]
```

In [4]:
class SNN(nn.Module):
    def __init__(self, nb_classes, batch_size, surrogate_fn, min_v_mem=-0.313, spk_thr=2.0) -> None:
        super().__init__()

        self.conv_A = nn.Conv2d(2, 4, 2, 1, bias=False)
        self.iaf_A = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        self.conv_B = nn.Conv2d(4, 4, 2, 1, bias=False)
        self.iaf_B = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool_B = sl.SumPool2d(2,2)

        self.conv_C = nn.Conv2d(4, 4, 2, 1, bias=False)
        self.iaf_C = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool_C = sl.SumPool2d(2,2)

        self.conv_D = nn.Conv2d(2, 4, 2, 1, bias=False)
        self.iaf_D = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        self.conv_E = nn.Conv2d(4, 4, 2, 1, bias=False)
        self.iaf_E = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool_E = sl.SumPool2d(2,2)

        self.conv_F = nn.Conv2d(4, 4, 2, 1, bias=False)
        self.iaf_F = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool_F = sl.SumPool2d(2,2)

        self.flat_brach1 = nn.Flatten()
        self.flat_brach2 = nn.Flatten()
        self.merge = Merge()

        self.fc1 = nn.Linear(196, 200, bias=False)
        self.iaf1_fc = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        self.fc2 = nn.Linear(200, 200, bias=False)
        self.iaf2_fc = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        self.fc3 = nn.Linear(200, nb_classes, bias=False)
        self.iaf3_fc = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

    def forward(self, x):
        # conv 1 - A
        conv_A_out = self.conv_A(x)
        iaf_A_out = self.iaf_A(conv_A_out)
        # conv 2 - B
        conv_B_out = self.conv_B(iaf_A_out)
        iaf_B_out = self.iaf_B(conv_B_out)
        pool_B_out = self.pool_B(iaf_B_out)
        # conv 3 - C
        conv_C_out = self.conv_C(pool_B_out)
        iaf_C_out = self.iaf_C(conv_C_out)
        pool_C_out = self.pool_C(iaf_C_out)

        # ---

        # conv 4 - D
        conv_D_out = self.conv_D(x)
        iaf_D_out = self.iaf_D(conv_D_out)
        # conv 5 - E
        conv_E_out = self.conv_E(iaf_D_out)
        iaf_E_out = self.iaf_E(conv_E_out)
        pool_E_out = self.pool_E(iaf_E_out)
        # conv 6 - F
        conv_F_out = self.conv_F(pool_E_out)
        iaf_F_out = self.iaf_F(conv_F_out)
        pool_F_out = self.pool_F(iaf_F_out)

        # ---

        flat_brach1_out = self.flat_brach1(pool_C_out)
        flat_brach2_out = self.flat_brach2(pool_F_out)
        merge_out = self.merge(flat_brach1_out, flat_brach2_out)

        # FC 7 - G
        fc1_out = self.fc1(merge_out)
        iaf1_fc_out = self.iaf1_fc(fc1_out)
        # FC 8 - H
        fc2_out = self.fc2(iaf1_fc_out)
        iaf2_fc_out = self.iaf2_fc(fc2_out)
        # FC 9 - I
        fc3_out = self.fc3(iaf2_fc_out)
        iaf3_fc_out = self.iaf3_fc(fc3_out)

        return iaf3_fc_out
    
snn = SNN(11, 1, PeriodicExponential())

## DynapcnnNetwork Class

In the constructor of `DynapcnnNetworkGraph` the SNN passed as argument (defined as a `nn.Module`) will be parsed such that each layer is represented in a computational graph (using `nirtorch.extract_torch_graph`). 

The layers are the `nodes` of the graph, while their connectivity (how the outputs from a layer are sent to other layers) is represented as `edges`, represented in a `list` of `tuples`.

Once the constructor finishes its initialization, the `hw_model.dynapcnn_layers` property is a dictionary where each entry represents the ID of a `DynapcnnLayer` instance (an `int` from `0` to `L`), with this entry containing a `DynapcnnLayer` instance where a subset of the layers in the original SNN has been incorporated into, the core such instance has been assigned to, and the list of `DynapcnnLayer` instances (their IDs) the layer targets.

In [5]:
hw_model = DynapcnnNetworkGraph(
    snn,
    discretize=True,
    input_shape=input_shape
)

The `hw_model.to()` call will figure out into which core eac `DynapcnnLayer` instance will be assigned to. Once this assingment is made the instance itself is used to configure the `CNNLayerConfig` instance representing the core's configuration.

If the cores' configuration is valid, each `DynapcnnLayer` instance and their respective destinations will be used to create a computational graph that encodes how the `forward` method of `hw_model.network` (a `nn.Module` using the `DynapcnnLayer` instances) propagates that through the network.

In [6]:
hw_model.to(device="speck2fmodule:0")

Network is valid


<sinabs.backend.dynapcnn.dynapcnn_network_graph.DynapcnnNetworkGraph at 0x7fac1bdc5250>

The layers comprising our `hw_model` and their respective metadata can be inspected by calling `print` on a `DynapcnnNetworkGraph` instance.

In [7]:
print(hw_model)

---- DynapcnnLayer 0 ----------------------------------------------------------
> layer modules: 
(node 0): Conv2d(2, 4, kernel_size=(2, 2), stride=(1, 1), bias=False)
(node 1): IAFSqueeze(spike_threshold=Parameter containing:
tensor(758.), min_v_mem=Parameter containing:
tensor(-119.), batch_size=1, num_timesteps=-1)
> layer destinations: [1]
> assigned core: 0

---- DynapcnnLayer 1 ----------------------------------------------------------
> layer modules: 
(node 2): Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False)
(node 3): IAFSqueeze(spike_threshold=Parameter containing:
tensor(1022.), min_v_mem=Parameter containing:
tensor(-160.), batch_size=1, num_timesteps=-1)
(node 4): SumPool2d(norm_type=1, kernel_size=2, stride=2, ceil_mode=False)
> layer destinations: [2]
> assigned core: 1

---- DynapcnnLayer 2 ----------------------------------------------------------
> layer modules: 
(node 5): Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False)
(node 6): IAFSqueeze(spi