In [1]:
import torch
import torch.nn as nn
from sinabs.backend.dynapcnn import DynapcnnNetworkGraph
from sinabs.layers import Merge, IAFSqueeze, SumPool2d
import sinabs.layers as sl
from sinabs.activation.surrogate_gradient_fn import PeriodicExponential

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f89206bfad0>

In [3]:
channels = 2
height = 34
width = 34

input_shape = (channels, height, width)

## Network Module

We need to define a `nn.Module` implementing the network we want the chip to reproduce.

```mermaid
stateDiagram
    [*] --> A
    A --> B
    A --> C
    C --> D
    C --> E
    B --> D
    D --> F
    E --> F
    F --> [*]
```

In [4]:
class SNN(nn.Module):
    def __init__(self, nb_classes, batch_size, surrogate_fn, min_v_mem=-0.313, spk_thr=2.0) -> None:
        super().__init__()

        self.conv1 = nn.Conv2d(2, 8, 2, 1, bias=False)
        self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        self.conv2 = nn.Conv2d(8, 8, 2, 1, bias=False)
        self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool2 = sl.SumPool2d(2,2)

        self.conv3 = nn.Conv2d(8, 8, 2, 1, bias=False)
        self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool3 = sl.SumPool2d(2,2)
        self.pool3a = sl.SumPool2d(6,6)

        self.conv4 = nn.Conv2d(8, 8, 2, 1, bias=False)
        self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)
        self.pool4 = sl.SumPool2d(3,3)

        self.flat = nn.Flatten()
        self.flat_a = nn.Flatten()

        self.fc1 = nn.Linear(200, 200, bias=False)
        self.iaf1_fc = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        self.fc2 = nn.Linear(200, nb_classes, bias=False)
        self.iaf2_fc = IAFSqueeze(batch_size=batch_size, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_fn, spike_threshold=spk_thr)

        # -- merges --
        self.merge1 = Merge()
        self.merge2 = Merge()

    def forward(self, x):
        # conv 1 - A
        con1_out = self.conv1(x)
        iaf1_out = self.iaf1(con1_out)

        # conv 2 - B
        conv2_out = self.conv2(iaf1_out)
        iaf2_out = self.iaf2(conv2_out)
        pool2_out = self.pool2(iaf2_out)

        # conv 3 - C
        conv3_out = self.conv3(iaf1_out)
        iaf3_out = self.iaf3(conv3_out)
        pool3_out = self.pool3(iaf3_out)
        pool3a_out = self.pool3a(iaf3_out)

        # conv 4 - D
        merge1_out = self.merge1(pool2_out, pool3_out)
        conv4_out = self.conv4(merge1_out)
        iaf4_out = self.iaf4(conv4_out)
        pool4_out = self.pool4(iaf4_out)
        flat_out = self.flat(pool4_out)
        
        # fc 1 - E
        flat_a_out = self.flat_a(pool3a_out)
        fc1_out = self.fc1(flat_a_out)
        iaf1_fc_out = self.iaf1_fc(fc1_out)

        # fc 2 - F
        merge2_out = self.merge2(iaf1_fc_out, flat_out)
        fc2_out = self.fc2(merge2_out)
        iaf2_fc_out = self.iaf2_fc(fc2_out)

        return iaf2_fc_out
    
snn = SNN(11, 1, PeriodicExponential())

## DynapcnnNetwork Class

In the constructor of `DynapcnnNetworkGraph` the SNN passed as argument (defined as a `nn.Module`) will be parsed such that each layer is represented in a computational graph (using `nirtorch.extract_torch_graph`). 

The layers are the `nodes` of the graph, while their connectivity (how the outputs from a layer are sent to other layers) is represented as `edges`, represented in a `list` of `tuples`.

Once the constructor finishes its initialization, the `hw_model.dynapcnn_layers` property is a dictionary where each entry represents the ID of a `DynapcnnLayer` instance (an `int` from `0` to `L`), with this entry containing a `DynapcnnLayer` instance where a subset of the layers in the original SNN has been incorporated into, the core such instance has been assigned to, and the list of `DynapcnnLayer` instances (their IDs) the layer targets.

In [5]:
hw_model = DynapcnnNetworkGraph(
    snn,
    discretize=True,
    input_shape=input_shape
)

The `hw_model.to()` call will figure out into which core eac `DynapcnnLayer` instance will be assigned to. Once this assingment is made the instance itself is used to configure the `CNNLayerConfig` instance representing the core's configuration.

If the cores' configuration is valid, each `DynapcnnLayer` instance and their respective destinations will be used to create a computational graph that encodes how the `forward` method of `hw_model.network` (a `nn.Module` using the `DynapcnnLayer` instances) propagates that through the network.

In [6]:
hw_model.to(device="speck2fmodule:0")

Network is valid


RuntimeError: Device is already opened!

The layers comprising our `hw_model` and their respective metadata can be inspected by calling `print` on a `DynapcnnNetworkGraph` instance.

In [None]:
print(hw_model)

---- DynapcnnLayer 0 ----------------------------------------------------------
> layer modules: 
(node 0): Conv2d(2, 8, kernel_size=(2, 2), stride=(1, 1), bias=False)
(node 1): IAFSqueeze(spike_threshold=Parameter containing:
tensor(723.), min_v_mem=Parameter containing:
tensor(-113.), batch_size=1, num_timesteps=-1)
> layer destinations: [1, 2]
> assigned core: 0

---- DynapcnnLayer 1 ----------------------------------------------------------
> layer modules: 
(node 2): Conv2d(8, 8, kernel_size=(2, 2), stride=(1, 1), bias=False)
(node 4): IAFSqueeze(spike_threshold=Parameter containing:
tensor(1452.), min_v_mem=Parameter containing:
tensor(-227.), batch_size=1, num_timesteps=-1)
(node 5): SumPool2d(norm_type=1, kernel_size=2, stride=2, ceil_mode=False)
> layer destinations: [3]
> assigned core: 1

---- DynapcnnLayer 2 ----------------------------------------------------------
> layer modules: 
(node 3): Conv2d(8, 8, kernel_size=(2, 2), stride=(1, 1), bias=False)
(node 7): IAFSqueeze(