# Training an SNN with fewer synops
Similar as in the previous tutorial, we start by defining a spiking model. 

In [None]:
import torch
import torch.nn as nn
import sinabs
import sinabs.layers as sl

class SNN(nn.Sequential):
    def __init__(self, batch_size):
        super().__init__(
            sl.FlattenTime(),
            nn.Conv2d(1, 16, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(16, 32, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(32, 120, 4, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.Flatten(),
            nn.Linear(120, 10, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.UnflattenTime(batch_size=batch_size),
        )

batch_size = 5
snn = SNN(batch_size=batch_size)
snn

The `SNNAnalyzer` class tracks different statistics for spiking (such as IAF/LIF) and parameter (such as Conv2d/Linear) layers. The number of synaptic operations is part of the parameter layers. If we attach such an analyser to the model, we'll be able to use layer- or model-wide statistics during training, for optimization or logging purposes.

In [None]:
analyser = sinabs.SNNAnalyzer(snn)
print(f"Synops before feeding input: {analyser.get_model_statistics()['synops']}")

rand_input_spikes = (torch.ones((batch_size, 10, 1, 28, 28)) ).float()
y_hat = snn(rand_input_spikes)
print(f"Synops after feeding input: {analyser.get_model_statistics()['synops']}")

You can break down the statistics for each layer:

In [None]:
analyser.get_layer_statistics()

Once we have can calculate the total synops, we might want to choose a target synops number in order to decrease power consumption. As a rule of thumb we're going to take half of the number of initial synops as constant target.

In [None]:
# Find out the target number of operations
target_synops = (analyser.get_model_statistics()['synops'] / 2).clone()#.detach()

optim = torch.optim.Adam(snn.parameters())

n_synops = []
for epoch in range(10):
    sinabs.reset_states(snn)
    sinabs.zero_grad(snn)
    optim.zero_grad()
    
    output = snn(rand_input_spikes)
    synops = analyser.get_model_statistics()['synops']
    synops_loss = (target_synops - synops).square() / target_synops.square()
    # print(snn[2].v_mem)
    print(target_synops)
    print(synops_loss)
    synops_loss.backward()
    optim.step()
    # n_synops.append(synops.detach().cpu().numpy())


In [None]:
n_synops