# Training spiking neural networks, *fast*

Spiking neural networks (SNN) can be notoriously slow to train. A special case of recurrent neural network, they work with sequential inputs and rely on a form of gradient computation through time, which in the most common scenario is backpropagation through time. Given that events from event cameras or silicon cochlears have a temporal resolution of down to microseconds, the amount of time steps in data to train SNNs can easily go in the hundreds for a single sample. 

This would not be a problem if we trained on the extremely sparse data in continuous time directly, but the legacy of ANN machine learning frameworks has it that we have to work with dense tensors to train our SNN. That means that for a visual event stream input (think video) of spatial size (2, 128, 128) for channels, y and x we not only deal with some 10 frames per second but potentially hundreds per second, which increases input dimensions by a lot. 

When training a neural network of any kind, one might think about how the learning rate or model size affect training time. But when it comes to training *faster*, optimizing data movement is crucial. 3 out of the first 4 points in [this list](https://www.reddit.com/r/MachineLearning/comments/kvs1ex/d_here_are_17_ways_of_making_pytorch_training/) weighted after potential speed-up have to do with how data is shaped and moved around between actual computations. It makes a huge difference, because training faster with the same hardware means getting results faster, and being able to iterate quicker.

For this post we train an SNN using [Sinabs](https://github.com/synsense/sinabs) based on PyTorch and surrogate gradients, which means that in the forward pass we use the heavily quantized output of spiking layers but in the backward pass we use a smoother surrogate function based on the internal state of the neurons. We'll use the [Heidelberg Spiking Speech Commands](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) dataset to train our network to do audio stream classification. We'll benchmark different data loading strategies using [Tonic](https://github.com/neuromorphs/tonic) and show that with the right strategy, we can achieve a speed-up of up to XX times compared to a naïve strategy.

For all our benchmarks, we already assume multiple worker loading threads and pinning the GPU memory. We'll increase throughput by using different forms of caching to disk or GPU. By applying deterministic transformations upfront and saving the new tensor, we can save a lot of time during training. 

All data from neuromorphic datasets in Tonic is provided as NxD numpy arrays. We'll need to transform this into a tensor to bring it to the GPU, and we'll also do some downsampling of time steps. Let's define the transforms for dense and sparse tensors. We know that the input data has 700 channels and about 0.8-1.2s samples with microsecond resolution. We'll downsample each sample to 100 channels, bin every 4 ms to one frame and cut samples that are longer than 1s. That leaves us with a maximum of 250 time steps per sample.

<!-- Accelerators such as TPUs, Cerebra and the like heavily -->

<!-- The reason why EXODUS is efficient is because it vectorizes samples in time. -->

In [None]:
from tonic import transforms

dt = 4000  # all time units in Tonic in us
encoding_dim = 100

dense_transform = transforms.Compose(
    [
        transforms.Downsample(spatial_factor=encoding_dim / 700),
        transforms.CropTime(max=1e6),
        transforms.ToFrame(
            sensor_size=(encoding_dim, 1, 1), time_window=dt, include_incomplete=True
        ),
    ]
)

Next we'll import the training dataset and assign the respective transform.

In [None]:
import torch
import tonic

dense_dataset = tonic.datasets.SSC("./data", split="train", transform=dense_transform)
print(f"This dataset has {len(dense_dataset)} samples.")

To give an idea of how one sample now looks like, let's print one dense and one sparse tensor:

In [None]:
dense_sample, dense_target = dense_dataset[0]

import matplotlib.pyplot as plt

plt.imshow(dense_sample.squeeze().T)
plt.xlabel("Time step")
plt.ylabel("Channel")
plt.title(dense_dataset.classes[dense_target]);

## Naïve dataloading

We start with the first benchmark, where we load every sample from an hdf5 file on disk which provides us with a numpy array in memory. For each sample, we apply the [ToFrame](https://tonic.readthedocs.io/en/main/reference/generated/tonic.transforms.ToFrame.html) transform (defined earlier) to create a dense array which we can then batch together with other samples and feed it to the network.

![naive caching](images/caching1.png "Naive caching")


In [None]:
import tonic
from tqdm import tqdm
from torch.utils.data import DataLoader

dataloader_kwargs = dict(
    batch_size=128,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
)

naive_dataloader = DataLoader(dense_dataset, **dataloader_kwargs, num_workers=4)

In [None]:
for data, target in tqdm(naive_dataloader):
    data, target = data.squeeze().cuda(), target.cuda()

That's about 6-7 iterations/s, which is not very exciting. We haven't even started training yet! 

## Disk caching
Let's try to be a bit smarter now. ToFrame is a deterministic transform, so for the same sample we'll always receive the same transformed data. Given that we might train for 100 epochs, which looks at each sample 100 times, that's a lot of wasted compute! Now we're going to cache, which means save, those transformed sampled to disk during the first epoch, so that we don't need to recompute them!

![Disk caching](images/caching2.png "Disk caching")

In [None]:
disk_cached_dataset = tonic.DiskCachedDataset(
    dataset=dense_dataset,
    cache_path=f"cache/{dense_dataset.__class__.__name__}/train/{encoding_dim}/{dt}",
)

disk_cached_dataloader = DataLoader(disk_cached_dataset, **dataloader_kwargs, num_workers=4)

In [None]:
for data, target in tqdm(disk_cached_dataloader):
    data, target = data.squeeze().cuda(), target.cuda()

8-9 iterations/s is slower than before, what happened? In the first epoch, the caching will likely slow down the training. But let's see what happens in the second epoch.

In [None]:
for data, target in tqdm(disk_cached_dataloader):
    data, target = data.squeeze().cuda(), target.cuda()

32 iterations/s? Now this is faster! Every epoch from now on will load data at this speed, at the expense of disk space. How much disk space does it cost you may ask? Let's compare the size of the original dataset and the cache folder...

In [None]:
from pathlib import Path

size_orig_dataset = (
    sum(f.stat().st_size for f in Path("data").glob("**/*.h5") if f.is_file()) / 1e9
)
size_cache_folder = (
    sum(f.stat().st_size for f in Path("cache").glob("**/*") if f.is_file()) / 1e9
)

print(
    f"The size of the original dataset file is {round(size_orig_dataset, 2)} GB compared to the generated cache folder with {round(size_cache_folder, 2)} GB."
)

This is quite efficient. As a reminder, the original dataset in this case contained numpy events, whereas the cache folder contains dense tensors. We can compress the dense tensors that much because by default Tonic uses lightweight compression during caching. Disk caching is a generally applicable and will save you a lot of time in the long run.

## GPU caching
We can even go faster! Instead of loading dense tensors from disk, we can try to cram all our dataset onto the GPU! Now, the issue is that with dense tensors this wouldn't work as they would occupy too much memory. But events are already an efficient format right? So we'll store the events on the GPU as sparse tensors and then simply inflate them by calling to_dense() for each sample. This method is obviously bound by GPU memory so works with rather small datasets such as then one we're testing. However, once you're setup, you can train with _blazing_ speed. Let's have a look!

In [None]:
sparse_transform = transforms.Compose(
    [
        transforms.Downsample(spatial_factor=encoding_dim / 700),
        transforms.CropTime(max=1e6),
        transforms.ToSparseTensor(
            sensor_size=(encoding_dim, 1, 1), time_window=dt, include_incomplete=True
        ),
    ]
)

sparse_dataset = tonic.datasets.SSC(
    "./data", split="train", transform=sparse_transform, target_transform=torch.tensor
)

sparse_sample = sparse_dataset[0][0]

print(sparse_sample)

In [None]:
data_list = []
target_list = []
for data, targets in tqdm(disk_cached_dataloader):
    data_list.append(data.squeeze().to_sparse().coalesce().cuda())
    target_list.append(targets.byte().cuda())

sparse_tensor_dataset = list(zip(data_list, target_list))

In [None]:
for data, target in tqdm(sparse_tensor_dataset):
    data.to_dense()

The last line returns instantly... 10k batches/s sound like something we'd like to work with. The dataset now occupies some 8-9GB of GPU memory, which is quite a lot for a dataset of this size. But the speed speaks for itself, so it might pay off to run your experiments on an earlier-generation GPU with more memory just to really crank up that utilisation percentage!


## Fast training
So far, all we've done is loading the data on the device. It would be good to see how fast we can really train a network on the task. We'll use again our three dataloaders (naïve version, disk-cached and GPU-cached) to look at actual training times. Let's start by defining a simple integrate-and-fire (IAF) feed-forward (Sequential) architecture using [Sinabs](https://sinabs.readthedocs.io):

In [None]:
import torch.nn as nn
import sinabs.layers as sl
import sinabs.exodus.layers as el


class SNN(nn.Sequential):
    def __init__(self, backend, hidden_dim: int = 128):
        assert backend == sl or backend == el
        super().__init__(
            nn.Linear(encoding_dim, hidden_dim),
            backend.IAF(),
            nn.Linear(hidden_dim, hidden_dim),
            backend.IAF(),
            nn.Linear(hidden_dim, hidden_dim),
            backend.IAF(),
            nn.Linear(hidden_dim, 35),
        )


sinabs_model = SNN(backend=el).cuda()

In [None]:
import sinabs

# sparse_tensor_dataloader = DataLoader(sparse_tensor_dataset, batch_size=None, num_workers=0)
model = sinabs_model
train_dataloader = disk_cached_dataloader # sparse_tensor_dataloader # sparse_tensor_dataset # 
# def training():
optim = torch.optim.Adam(model.parameters())
criterion = torch.nn.functional.cross_entropy

train_loss = []
for epoch in range(3):
    for data, targets in tqdm(train_dataloader):
        sinabs.reset_states(model)
        optim.zero_grad()
        data = data.cuda().squeeze()
        targets = targets.cuda()
        # data = data.to_dense()
        # targets = targets.to_dense()
        output = model(data)
        loss = criterion(output.sum(1), targets)
        loss.backward()
        optim.step()
        train_loss.append(loss)

In [None]:
data.shape

In [None]:
import matplotlib.pyplot as plt

train_loss = [loss.cpu().detach().numpy() for loss in train_loss]
plt.plot(train_loss)