# 1. Define a spiking network

In [38]:
%load_ext blackcellmagic
import torch
import torch.nn as nn
import sinabs.layers as sl
import matplotlib.pyplot as plt

The blackcellmagic extension is already loaded. To reload it, use:
  %reload_ext blackcellmagic


In [3]:
class MySNN(nn.Module):
    def __init__(self):
        super(MySNN, self).__init__()
        # Spiking Input Layer
        self.input1 = sl.InputLayer(input_shape=(1, 64, 64), layer_name="input_1")

        # Spiking Conv layer
        self.conv1 = sl.SpikingConv2dLayer(
            channels_in=1,
            image_shape=(64, 64),
            channels_out=6,
            kernel_shape=(5, 5),
            layer_name="conv_1",
        )

        # Spiking SumPooling layer
        self.pool1 = sl.SumPooling2dLayer(
            image_shape=(60, 60), pool_size=(3, 3), layer_name="pool_1"
        )

        # Spiking Conv layer
        self.conv2 = sl.SpikingConv2dLayer(
            channels_in=6,
            image_shape=(20, 20),
            channels_out=6,
            kernel_shape=(5, 5),
            layer_name="conv_2",
        )

        # Spiking SumPooling layer
        self.pool2 = sl.SumPooling2dLayer(
            image_shape=(16, 16), pool_size=(4, 4), layer_name="pool_2"
        )

        # Generating an Equivalent Spiking Dense Layer
        self.flatten1 = sl.FlattenLayer(input_shape=(6, 4, 4), layer_name="flatten_1")
        self.conv3 = sl.SpikingConv2dLayer(
            channels_in=96,
            image_shape=(1, 1),
            channels_out=10,
            kernel_shape=(1, 1),
            layer_name="conv_3",
        )

    def forward(self, x):
        # Define your graph
        x = self.input1(x)
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten1(x)
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = self.conv3(x)
        out = x.squeeze()
        return out


In [4]:
snn = MySNN()
print(snn)

MySNN(
  (input1): InputLayer()
  (conv1): SpikingConv2dLayer(
    (conv): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  )
  (pool1): SumPooling2dLayer(
    (pool): LPPool2d(norm_type=1, kernel_size=(3, 3), stride=(3, 3), ceil_mode=False)
  )
  (conv2): SpikingConv2dLayer(
    (conv): Conv2d(6, 6, kernel_size=(5, 5), stride=(1, 1))
  )
  (pool2): SumPooling2dLayer(
    (pool): LPPool2d(norm_type=1, kernel_size=(4, 4), stride=(4, 4), ceil_mode=False)
  )
  (flatten1): FlattenLayer()
  (conv3): SpikingConv2dLayer(
    (conv): Conv2d(96, 10, kernel_size=(1, 1), stride=(1, 1))
  )
)


### Input and Ouput Size

In [60]:
input_data = torch.randn(100, 1, 64, 64)
output_data = snn(input_data)
print(output_data.shape)

torch.Size([100, 10])


### Generate rate-based spike trains from normalised float number

In [80]:
def get_spike_train(time_win):
    # randomize an image: 1 channel, 64*64 resolution
    input_image = torch.rand(1, 64, 64)
    # randomize a tensor accordingly with #time_win per pixel
    random_tensor = torch.rand(time_win, 1, 64, 64)
    # generating 1 if random number is lower than the pixel value of the input_image
    converted_spike_train =  (random_tensor < input_image).float()
    # imag_original, is of 64*64 from input_image
    img_original = input_image[0]
    # img_converted, is the counted spikes over the time_win divided by the time_win
    img_converted = converted_spike_train.sum(0)[0]/time_win
    # the L2 distance between these two images
    dist = torch.dist(img_original, img_converted, 2).item()
    print("L2 distance between original image and converted spike trains: ", dist)
    return converted_spike_train

# Longer time_win results in more precise conversion
time_win_list = [10, 100, 1000]
for time_win in time_win_list:
    get_spike_train(time_win)


L2 distance between original image and converted spike trains:  8.112542152404785
L2 distance between original image and converted spike trains:  2.5845723152160645
L2 distance between original image and converted spike trains:  0.825164794921875


### Read out a spike train

In [81]:
input_data = get_spike_train(100)
output_data = snn(input_data)
output_spike_count = output_data.sum(0)
print(output_spike_count.shape)

L2 distance between original image and converted spike trains:  2.5908772945404053
torch.Size([10])
