# Live-coding script for Telluride Rockpool / Xylo demo June 2023
Dylan Muir / Felix Bauer
## Outline
 1. How to define and configure an ``LIF`` module containing a spiking neuron
 4. How to compose a network
    - `Linear` weights
    - `Sequential` combinator
    - `Residual` combinator
 5. Audio task: Spiking Heidelberg Digits
 6. Network architecture
 7. Training
 8. Xylo architecture
 9. Mapping, quantization, deployment
 10. Inference using ``XyloSim``
 11. Inference on Xylo HDK


This live-coding script demonstrates working with Rockpool to train SNN networks for Xylo, on an audio task.

First we need to install the required packages.

## Setup

In [None]:
# - Install requirements for this notebook
%pip install --quiet rockpool matplotlib torch tonic rich jax jaxlib xylosim samna bitstruct

# - Import and configure matplotlib
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams["figure.dpi"] = 300

# - Nice printing
from rich import print

# - Torch and numpy
import torch
import numpy as np

# - For displaying images
from IPython.display import Image


Rockpool is a deep learning library for SNNs, designed to make it very easy to design, train and deploy applications to neuromorphic hardware.

Documentation: https://rockpool.ai

## Getting started with a single LIF neuron

In [None]:
# - The LIF module is a Leaky Integrate and Fire spiking neuron
from rockpool.nn.modules import LIF

In [None]:
# - Create a single LIF neuron to examine
lif = LIF(1, threshold=10.)
print(lif)

In [None]:
# - Generate some Poissonian spiking input to the neuron
f = 0.02
T = 500
Nin = 1
input_sp = np.random.rand(T, Nin) < f

In [None]:
# - Evolve the neuron by passing the data through
#   `record = True` records and returns internal state
out, _, rec_dict = lif(input_sp, record = True)

In [None]:
# - plot the output events
plt.plot(out.squeeze())

In [None]:
# - Let's look at the recorded state. What did we get back from the evolution?
rec_dict.keys()

In [None]:
# - Let's plot the synaptic current `isyn` and membrane potential `vmem`
plt.plot(rec_dict['isyn'].squeeze(), label='$I_{syn}$')
plt.plot(rec_dict['vmem'].squeeze(), label='$V_{mem}$')
plt.plot([0, 500], [10, 10], 'k:', label='threshold')
plt.legend()

In [None]:
# - Rockpool modules all have a `state()` method which returns the internal module state
print(lif.state())

In [None]:
# - Rockpool modules all have a `parameters()` method which returns the trainable parameters of a module
print(lif.parameters())

In [None]:
# - Rockpool modules all have a `simulation_parameters()` method which returns the non-trainable parameters
print(lif.simulation_parameters())

## The data

``tonic`` is a package for managing neuromorphic datasets (https://tonic.readthedocs.io)

We'll use tonic to download the Spiking Heidlberg Digits dataset (https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/), and provide a convenient python `torch`-like dataset.

``tonic`` also provides data transformations and caching.

In [None]:
# - Import tonic, download and import the SHD dataset
import tonic
train_data = tonic.datasets.SHD('./data')
shd_timestep = 1e-6
shd_channels = 700
shd_classes = 20

In [None]:
# - Let's examine one sample of the dataset
events, label = train_data[1]
times = events['t'] * shd_timestep
channels = events['x']
plt.plot(times, channels, '|')

In [None]:
# - We need to downsample the data to use it (to make the network and training simpler)
net_channels = 16
net_dt = 10e-3
sample_T = 250

In [None]:
# - We'll use `tonic` to downsample the data, using a transformation pipeline

import tonic.transforms as T

transform = T.Compose([
    # - Downsample in time and space
    T.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
        ),

    # - Rasterise the events
    T.ToFrame(
        sensor_size=(net_channels, 1, 1), time_window=1
    ),
    
    # - Convert to a tensor
    torch.Tensor,

    # - Make sure the samples are not too long in time
    lambda m: torch.squeeze(m)[:sample_T, :],
])

In [None]:
# - Reload the dataset with these transformations
train_data = tonic.datasets.SHD('./data', transform=transform)

In [None]:
# - Get one training sample
raster, label = train_data[1]

# - Extract spike times
times, channels = torch.where(raster)

In [None]:
# - Plot this sample
plt.plot(times * net_dt, channels, '|')

Now we create a data loader to use in training. This is a standard `torch` data loader, so I'm going to gloss over this cell.

We will select only the first 8 class labels to use, since Xylo only supports 8 output channels.

We will creata a data loader, using ``tonic`` to provide disk caching of the data. In-memory caching is also supported by ``tonic``, but not used here.

In [None]:
# - Create a class which subsets a dataset to a list of matching labels
class SubsetClasses(torch.utils.data.Dataset):
    def __init__(self, dataset, match_labels):
        indices = []
        for idx in range(len(dataset)):
            _, label = dataset[idx]
            if label in match_labels:
                indices.append(idx)

        self._subset_ds = torch.utils.data.Subset(dataset, indices)
        self._len = len(indices)

    def __getitem__(self, index):
        return self._subset_ds[index]
    
    def __len__(self):
        return self._len

In [None]:
# - Define arguments for the data loader
dataloader_kwargs = dict(
    batch_size=128,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
    num_workers=0,
)

# - Create the data loader, using `tonic` to provide a disk cache
train_dl = torch.utils.data.DataLoader(
    tonic.DiskCachedDataset(
        dataset=SubsetClasses(train_data, range(8)),
        cache_path=f"cache/{train_data.__class__.__name__}/train/{net_channels}/{net_dt}",
        reset_cache = False,
    ),
    **dataloader_kwargs
)

Now we'll define and train an SNN for the SHD task, to deploy to Xylo. We'll use the ``torch`` backend of Rockpool, which uses the PyTorch automatic differentiation pipeline to train NNs.

## Network definition

In [None]:
# - Show an image of the target network architecture
Image('images/network-layout-shd.svg.png')

In [None]:
# - Import the required torch-backed modules and combinators
from rockpool.nn.modules import LIFTorch, LinearTorch
from rockpool.nn.combinators import Sequential, Residual

# - Define a simple network architecture
Nin = net_channels
Nhid = 20
Nout = 8

net = Sequential(
    LinearTorch((Nin, Nhid)),
    LIFTorch(Nhid),

    Residual(
        LinearTorch((Nhid, Nhid)),
        LIFTorch(Nhid),
    ),

    LinearTorch((Nhid, Nout)),
    LIFTorch(Nout),
)
print(net)

In [None]:
# - By default all parameters are trainable
print(
    {
        module_name: list(module_parameters.keys())
        for module_name, module_parameters in net.parameters().items()
    }
)

In [None]:
# - Import the `Constant` decorator, so we can specify non-trainable parameters
from rockpool.parameters import Constant

# - Define shared neuron parameters to use
neuron_parameters = {
    'tau_mem': Constant(50e-3),
    'tau_syn': Constant(20e-3),
    'bias': Constant(0.),
    'threshold': Constant(1.),
    'dt': net_dt,
}

# - Define the network with shared parameters
net = Sequential(
    LinearTorch((Nin, Nhid)),
    LIFTorch(Nhid, **neuron_parameters),

    Residual(
        LinearTorch((Nhid, Nhid)),
        LIFTorch(Nhid, **neuron_parameters),
    ),

    LinearTorch((Nhid, Nout)),
    LIFTorch(Nout, **neuron_parameters),
)

In [None]:
# - Now only weights are trainable
print(
    "Trainable parameters:",
    {
        module_name: list(module_parameters.keys())
        for module_name, module_parameters in net.parameters().items()
    }
)
print(
    "Non-trainable parameters:",
    {
        module_name: list(module_parameters.keys())
        for module_name, module_parameters in net.simulation_parameters().items()
    }
)

## Training loop

In [None]:
# - Import optimizer and loss function from pytorch
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

# - Get the optimiser functions
optimizer = Adam(net.parameters().astorch(), lr=1e-3)

# - Loss function
loss_fun = CrossEntropyLoss()

In [None]:
train = False

if train:
    # - Training Loop
    num_epochs = 10
    for e in range(num_epochs):

        # - Loop over dataset, getting batches
        for events, labels in train_dl:
            # - Zero the optimizer gradients
            optimizer.zero_grad()

            # - Evolve the network with this batch
            output, _, _ = net(events)

            # - Get the prediction -- number of spikes in each channel
            pred = torch.sum(output, dim=1)

            # - Get the loss value for this batch
            loss = loss_fun(pred, labels)

            # - Compute gradients with backward step and update parameters
            loss.backward()
            optimizer.step()

        # - Print the current loss
        print(f'Epoch {e}/{num_epochs}, loss {loss.item():.2e}')
else:
    # - Load a pre-trained version
    net.load('pretrained-37ke.json')

    # - Plot the loss curve over training this pre-trained version
    Image('loss-pretrained-37ke.png')

## Inference in simulation

In [None]:
# - Evolve the trained network over a training sample
events, label = train_data[2]
out, _, rd = net(events, record = True)

In [None]:
# - Plot the output of the network
times, channels = torch.where(out[0])
plt.plot(times * net_dt, channels, '|')

# - Indicate the target label
plt.plot(0.01, label, '>', ms=18)
plt.ylim([-1, 8])
plt.xlim([0, 1.]);

## Deployment on hardware

In [None]:
# Hardware specific imports
import rockpool.devices.xylo.syns61201 as xylo
from rockpool.transform.quantize_methods import channel_quantize

# From software model to hardware

## Extract computational graph
graph = net.as_graph()
## Map graph to hardware specifications
spec = xylo.mapper(graph)
## Quantize parameters
spec.update(channel_quantize(**spec))
## Generate configuration of the hardware
config, is_valid, _ = xylo.config_from_specification(**spec)
## Deploy module - for now in precise simulation
mod = xylo.XyloSim.from_config(config, dt = net_dt)

In [None]:
# - Let's look at the mapped weights
plt.subplot(1, 3, 1)
plt.imshow(spec['weights_in'].T)
plt.title('$W_{in}$')

plt.subplot(1, 3, 2)
plt.imshow(spec['weights_rec'].T)
plt.title('$W_{rec}$')

plt.subplot(1, 3, 3)
plt.imshow(spec['weights_out'].T)
plt.title('$W_{out}$');

In [None]:
# - Compare this against our assumed solution from before
Image('images/mapped-weights.svg.png')

In [None]:
# - Let's look at the output of the simulated HDK on a training sample
events, label = train_data[2]
out_xsim, _, rd_xsim = mod(events.numpy(), record = True)
times, channels = np.nonzero(out_xsim)
plt.plot(times * net_dt, channels, '|')
plt.plot(0.01, label, '>', ms=18)
plt.ylim([-0.5, 7.5])
plt.xlim([0, 1.]);

In [None]:
# - Let's plot the membrane potential of the output neurons
times = np.arange(out_xsim.shape[0]) * net_dt
plt.plot(times, rd_xsim['Vmem_out']);

In [None]:
# - Import the helper function to connect to a Xylo HDK
from rockpool.devices.xylo import find_xylo_hdks

# - Locate an HDK, if one is available
hdks, _, _ = find_xylo_hdks()
assert len(hdks) > 0, 'The rest of this notebook needs a connected Xylo HDK.'

# - We'll use the first connected HDK
hdk = hdks[0]

# - Now we can create a Rockpool module that wraps the HDK, by providing the configuration bitstream as before
mod_hdk = xylo.XyloSamna(hdk, config, net_dt)

In [None]:
# - Evolve with a single training sample and also record power consumption
events, label = train_data[2]
out_xhdk, _, rd_xhdk = mod_hdk(events.numpy().astype(int), record = True, record_power=True)

print(f'Power measurement: {np.mean(rd_xhdk["io_power"]) * 1e6:.2f} muW')

In [None]:
# - Plot the output
times, channels = np.nonzero(out_xhdk)
plt.plot(times * net_dt, channels, '|')
plt.plot(0.01, label, '>', ms=18)
plt.ylim([-0.5, 7.5])
plt.xlim([0, 1.]);

In [None]:
# - Let's compare the output Vmem between the HDK and the simulator
times = np.arange(out_xhdk.shape[0]) * net_dt
plt.subplot(1, 2, 1)
plt.plot(times, rd_xhdk['Vmem_out'])
plt.subplot(1, 2, 2)
plt.plot(times, rd_xsim['Vmem_out']);

## Using the Xylo-Audio v2 audio front-end interface


The AFE (Audio Front-End) is used to preprocess single-channel audio signals and convert them into spikes.
Here the audio signal is input to the AFE by a microphone mounted on the hardware dev kit, or an external differential analog audio signal.
The AFE has 16 output channels, and you can adjust its parameters via hyperparameters in the `.AFESamna` class.

`.AFESamna` allows you to access the audio front-end on the dev kit, and record encoded audio either from the on-board microphone or analog audio injected to the dev kit.

`.AFESamna` also allows a custom config input which is without auto-calibration.
If you do not provide a custom config, we highly suggest you set ``auto_calibrate = True`` on instantiation, which helps to mitigate the effects of background and mechanical noise.

In [None]:
Image('AFESamna.png', width=400)

In [None]:
# - Set the time resolution and duration to record encoded audio
dt = 10e-3
timesteps = 1000

In [None]:
# - Create an AFESamna module, which wraps the AFE on the Xylo A2 HDK
mod = xylo.AFESamna(hdk, None, dt=dt, auto_calibrate=True, amplify_level='high')

In [None]:
# - Evolve the module to record encoded real-time audio as events
spikes_ts, _, _ = mod(np.zeros([0, timesteps, 0]))

In [None]:
# - Plot some encoded audio events recorded from the AFE
plt.imshow(spikes_ts.T, aspect='auto', interpolation='none')
plt.title('#Spikes in AFE output channels')
plt.xlabel('Time')
plt.ylabel('Channel')
plt.show()


## Deploying the AFE and SNN cores in free-running inference mode


Once you have a complete chip HW specification, you can deploy it to the chip in real-time infrence mode, using the class `.XyloMonitor`.
This mode uses the AFE core to pre-process audio signals in real time, then send encoded audio to the SNN core for inference.
In this mode you only read the output events from the SNN core, without providing input.

In [None]:
Image('XyloMonitor.png', width=400)

In [None]:
# - Use XyloMonitor to deploy to the HDK
# - You need to wait 45s until the AFE auto-calibration is done

output_mode = "Vmem"
amplify_level = "high"
hibernation = False
DN = False
T = 10

modMonitor = xylo.XyloMonitor(
    hdk,
    config,
    dt=net_dt,
    output_mode=output_mode,
    amplify_level=amplify_level,
    hibernation_mode=hibernation,
    divisive_norm=DN
)

In [None]:
from collections import deque

# - A resultList stack to store the results

class ResultList(object):
    def __init__(self, max_len=100):
        self._list = deque(maxlen=max_len)
        self.max_len = max_len

    def reset(self):
        self._list = deque(maxlen=max_len)

    def append(self, num):
        self._list.append(num)

    def counts(self, features=None):
        features = features or []
        count = 0
        for _ in self._list:
            if _ in features:
                count += 1
        return count

    def __len__(self):
        return len(self._list)

    def print_result(self):
        return self._list



In [None]:
import numpy as np
from IPython import display

# - Draw a real time image for output channels
lines = [ResultList(max_len=10) for _ in range(Nout)]
time_base = ResultList(max_len=10)
tt = 0
t_inference = 10.


from time import time

t_start = time()

while (time() - t_start) < t_inference:
    # - Perform inference on the Xylo A2 HDK
    output, _, _ = modMonitor(input_data=np.zeros((T, Nin)))
    if output is not None:
        output = np.max(output, axis=0)
        for i in range(Nout):
            lines[i].append(output[i])

        time_base.append(tt)
        tt += 0.1
        ax_time = time_base.print_result()
        
        for i in range(Nout):
            plt.plot(ax_time, lines[i].print_result(), label=f"class{i}")

        plt.xlabel('time')
        plt.ylabel('Vmem')
        plt.legend()
        plt.pause(0.1)
        display.clear_output(wait=True)

## Next steps

__Talk to Felix, Gregor, or Sadique and start playing with Xylo__

Check out further tutorials and documenation on (https://rockpool.ai)!

For more information about Xylo™, see (https://rockpool.ai/devices/xylo-overview.html)

For information about the Xylo™ HDK, see (https://www.synsense.ai/products/xylo/)

For a more in-depth published example, see (https://ieeexplore.ieee.org/document/9967462) (https://doi.org/10.48550/arXiv.2208.12991)
