# Quickstart Sinabs
If you're familiar with how SNNs work, you might find this quick overview about *Sinabs* useful.

## Sinabs is based on PyTorch
All of Sinabs' layers inherit from `torch.nn.Module`. Thus you will be able to access your parameters, wrap layers in a `nn.Sequential` module and all the other things that you would do with a normal PyTorch layer. 

## How to define your network
We want to re-use as much PyTorch functionality as possible. We use Linear, Conv2d and AvgPool layers to define weight matrices, whereas *Sinabs* layers add state as well as the non-linear activation to each of those weight layers. This is a definition of a simple SNN which takes as an input a tensor of (Batch, Time, Channels):


In [156]:
import torch
import torch.nn as nn

import sinabs.activation
import sinabs.layers as sl

model = nn.Sequential(
    nn.Linear(16, 64),
    sl.LIF(tau_mem=10., norm_input=True, activation_fn=sinabs.activation.ActivationFunction(surrogate_grad_fn=sinabs.activation.SingleExponential())),
    #sl.LIF(tau_mem=10., norm_input=False),
    nn.Linear(64, 4),
    sl.LIF(tau_mem=10., norm_input=True, activation_fn=sinabs.activation.ActivationFunction(surrogate_grad_fn=sinabs.activation.SingleExponential())),
    #sl.LIF(tau_mem=10., norm_input=False),
)


## Inference with SNNs
For simple inference using SNNs, you just use the model like any other torch model

In [157]:
# Define an input (Batch, Time, Channels)
input = (torch.rand(1, 100, 16) > 0.2).float()

# Compute output with the model
with torch.no_grad():
    output = model(input)

print(output.sum())  # You would expect an output of shape (batch_size*time_steps, 4)

tensor(0.)


Note that the network state is retained after any forward pass/inference. If you require resetting of the states/gradient, you can do so using the corresponding methods `layer.reset_states()` or `layer.zero_grad()`.

## Training with BPTT

In [158]:
# Some helper functions to reset our model during the training loops
def reset_model_states(seq_model: nn.Sequential, randomize: bool=False):
    """
    Method to reset the internal states of a model
    """
    for lyr in seq_model:
        if isinstance(lyr, sl.LIF):
            lyr.reset_states(randomize=randomize)
    return


def zero_grad_states(seq_model: nn.Sequential):
    """
    Method to reset the gradients of the internal states of a model
    """
    for lyr in seq_model:
        if isinstance(lyr, sl.LIF):
            lyr.zero_grad()
    return


In [162]:



# Training routine
optim = torch.optim.RMSprop(model.parameters(), lr=1e-3)
num_epochs = 100
target_num_spikes = 10

for epoch in range(num_epochs):
    # Reset the gradients of the parameters
    optim.zero_grad()

    # We will also need to reset the gradients of neuron states.
    zero_grad_states(model)
    # Alternatively you could also reset the states themselves.
    reset_model_states(model, randomize=False)

    out = model(input)
    print(f"Epoch {epoch}: Output spikes: {out.sum().item()}")
    loss = (out.sum() - target_num_spikes)**2
    print(model[0].weight.sum().item())

    loss.backward()
    optim.step()

    # Early stopage
    #if not loss:
    #    break


Epoch 0: Output spikes: 0.0
-6.826314449310303
Epoch 1: Output spikes: 0.0
-16.426284790039062
Epoch 2: Output spikes: 33.0
11.343732833862305
Epoch 3: Output spikes: 0.0
-18.771665573120117
Epoch 4: Output spikes: 0.0
-24.664527893066406
Epoch 5: Output spikes: 0.0
-24.771141052246094
Epoch 6: Output spikes: 12.0
-20.896390914916992
Epoch 7: Output spikes: 8.0
-22.420576095581055
Epoch 8: Output spikes: 12.0
-21.05232810974121
Epoch 9: Output spikes: 8.0
-22.574174880981445
Epoch 10: Output spikes: 11.0
-21.21195411682129
Epoch 11: Output spikes: 11.0
-21.968189239501953
Epoch 12: Output spikes: 8.0
-22.677114486694336
Epoch 13: Output spikes: 11.0
-21.308813095092773
Epoch 14: Output spikes: 11.0
-22.068706512451172
Epoch 15: Output spikes: 8.0
-22.781986236572266
Epoch 16: Output spikes: 12.0
-21.404481887817383
Epoch 17: Output spikes: 8.0
-22.94760513305664
Epoch 18: Output spikes: 11.0
-21.57461929321289
Epoch 19: Output spikes: 10.0
-22.338441848754883
Epoch 20: Output spikes: 1

In [163]:
out.sum(), out.shape

(tensor(10., grad_fn=<SumBackward0>), torch.Size([1, 100, 4]))

## Working with Convolutional networks

When working with convolutional connectivity, a `nn.Conv2d` layer only takes as input a tensor of (Batch, Channels, Height, Width). If we feed a tensor that has an additional time dimension (Batch, Time, Channels, Height, Width) to such a layer, we will receive an error. In order for us to apply 2D convolutions across time, we have to make use of a small trick where we flatten batch and time dimension before feeding it to the Conv layer. If the input is flattened, the `Squeeze` versions of spiking `Sinabs` layers understand and take care of expanding the time dimension appropriately, without any major changes to your model definition.

In [8]:
batch_size = 8
time_steps = 100

conv_model = nn.Sequential(
    nn.Conv2d(2, 16, kernel_size=3),
    sl.LIFSqueeze(tau_mem=20., batch_size=batch_size),
    nn.Conv2d(16, 32, kernel_size=3),
    sl.LIFSqueeze(tau_mem=20., batch_size=batch_size),
    nn.Flatten(),
    nn.Linear(512, 4),
)

# (Batch*Time, Channels, Height, Width)
data = torch.rand(batch_size, time_steps, 2, 8, 8)

# Data reshaped to fit the flattened model definition
input = data.resize(batch_size*time_steps, 2, 8, 8)

The rest of the forward pass or training loops remain the same as described in the above sections.

In [None]:
with torch.no_grad():
    output = conv_model(input)

This output can then be reshaped to split the dimensions between batch and time.

In [10]:
output_spike_raster = output.reshape(batch_size, time_steps, 4)