# Training a Spiking Convolutional Neural Network for analysing DVS data

## Loading and understanding the training data

In [None]:
from aermanager import AERFolderDataset
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

FOLDER = '/home/martino/Work/synoploss/mnist_dvs/data/train/'
BATCH_SIZE = 256

train_dataset = AERFolderDataset(
    root=FOLDER,
    from_spiketrain=False,
    transform=ToTensor(),
)

print("Number of training frames:", len(train_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# The dataset object contains all our training images and labels.
# Have a look at how they look like.

sample, label = train_dataset[0]

import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(sample.squeeze())  # remove the extra dimension
print(label)

## Defining a model

In [None]:
import torch

class MNISTClassifier(torch.nn.Module):
    def __init__(self, quantize=False):
        super().__init__()

        self.input_relu = torch.nn.ReLU()

        self.seq = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=8,
                            kernel_size=(3, 3), bias=False),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Conv2d(in_channels=8, out_channels=12,
                            kernel_size=(3, 3), bias=False),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Conv2d(in_channels=12, out_channels=12,
                            kernel_size=(3, 3), bias=False),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            torch.nn.Dropout2d(0.5),
            torch.nn.Flatten(),
            torch.nn.Linear(432, 10, bias=False),
            torch.nn.ReLU(),   # note that it's needed, but odd, to add a ReLU at the end
        )

    def forward(self, x):
        x = self.input_relu(x)
        return self.seq(x)


In [None]:
model = MNISTClassifier()
model.cuda()

## Main training phase

In [None]:
# defining the loss function
criterion = torch.nn.CrossEntropyLoss()
# defining the Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Set up a training loop
from tqdm.notebook import tqdm
n_epochs = 3

for epoch in range(n_epochs):
    print("Epoch", epoch+1)
    progress_bar = tqdm(train_dataloader)
    for (images, labels) in progress_bar:
        # move to the GPU
        images = images.cuda()
        labels = labels.cuda()
        
        # reset the gradients
        optimizer.zero_grad()
        
        # forward pass through the network
        outputs = model(images)
        
        # compute and backpropagate the loss
        loss_value = criterion(outputs, labels)
        loss_value.backward()
        optimizer.step()
        progress_bar.set_postfix(LOSS=loss_value.item())

## Live demo

In [None]:
from sinabs.from_torch import from_model

net = from_model(
    model.seq,
    input_shape=(1, 64, 64),
    threshold=1.0,
    membrane_subtract=1.0,
    threshold_low=-5.0,
).cuda()

# we resize and crop our input so that it matches the training data
adaptivepool = torch.nn.AdaptiveAvgPool2d((64, 64))
resize_factor = 16
def transform(x):
    x = x[:, :, 2:-2, 45:-45]  # crop
    x = torch.tensor(x).float().cuda()
    x = adaptivepool(x) * resize_factor
    return x

In [None]:
from aermanager import LiveDv

live = LiveDv(host='localhost', port=7777, qlen=64)

In [None]:
from IPython.display import display, clear_output

while True:
    batch = live.get_batch()
    batch = transform(batch)

    out = net(batch)
    maxval, pred_label = torch.max(out.sum(0), dim=0)

    THR = 30
    clear_output()
    if maxval > THR:
        display(pred_label.item())
    else:
        display('.')