## Feedforward Spiking Neural Network

This notebook demonstrates the implementation of a feedforward spiking neural network (SNN) using the SNN toolbox, `snntorch`. The network consists of fully connected layers with leaky integrate-and-fire (LIF) neurons.

In [None]:
!pip install snntorch

First, we need to install the `snntorch` library, which provides a variety of tools for building and training SNNs in PyTorch.

In [None]:
# Imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

We start by importing the necessary libraries. `snntorch` is the main library for SNN operations, while `torch` and `torch.nn` are used for defining and managing the network architecture. `matplotlib.pyplot` is used for plotting the results.

In [None]:
# Plotting Settings
def plot_cur_mem_spk(cur, mem, spk, thr_line=False, vline=False, title=False, ylim_max1=1.25, ylim_max2=1.25):
  # Generate Plots
  fig, ax = plt.subplots(3, figsize=(8,6), sharex=True,
                        gridspec_kw = {'height_ratios': [1, 1, 0.4]})

  # Plot input current
  ax[0].plot(cur, c="tab:orange")
  ax[0].set_ylim([0, ylim_max1])
  ax[0].set_xlim([0, 200])
  ax[0].set_ylabel("Input Current ($I_{in}$)")
  if title:
    ax[0].set_title(title)

  # Plot membrane potential
  ax[1].plot(mem)
  ax[1].set_ylim([0, ylim_max2])
  ax[1].set_ylabel("Membrane Potential ($U_{mem}$)")
  if thr_line:
    ax[1].axhline(y=thr_line, alpha=0.25, linestyle="dashed", c="black", linewidth=2)
  plt.xlabel("Time step")

  # Plot output spike using spikeplot
  splt.raster(spk, ax[2], s=400, c="black", marker="|")
  if vline:
    ax[2].axvline(x=vline, ymin=0, ymax=6.75, alpha = 0.15, linestyle="dashed", c="black", linewidth=2, zorder=0, clip_on=False)
  plt.ylabel("Output spikes")
  plt.yticks([])

  plt.show()

def plot_snn_spikes(spk_in, spk1_rec, spk2_rec, title):
  # Generate Plots
  fig, ax = plt.subplots(3, figsize=(8,7), sharex=True,
                        gridspec_kw = {'height_ratios': [1, 1, 0.4]})

  # Plot input spikes
  splt.raster(spk_in[:,0], ax[0], s=0.03, c="black")
  ax[0].set_ylabel("Input Spikes")
  ax[0].set_title(title)

  # Plot hidden layer spikes
  splt.raster(spk1_rec.reshape(num_steps, -1), ax[1], s = 0.05, c="black")
  ax[1].set_ylabel("Hidden Layer")

  # Plot output spikes
  splt.raster(spk2_rec.reshape(num_steps, -1), ax[2], c="black", marker="|")
  ax[2].set_ylabel("Output Spikes")
  ax[2].set_ylim([0, 10])

  plt.show()

The plotting functions `plot_cur_mem_spk` and `plot_snn_spikes` are defined to visualize the input current, membrane potential, and spike trains of the neurons in our network.

In [None]:
# Layer parameters
num_inputs = 784  # Number of input neurons (e.g., for an image of size 28x28)
num_hidden = 1000 # Number of neurons in the hidden layer
num_outputs = 10  # Number of output neurons (e.g., for 10 classes)
beta = 0.99       # Decay rate for the leaky integrate-and-fire neurons

# Initialize layers
fc1 = nn.Linear(num_inputs, num_hidden)  # First fully connected layer
lif1 = snn.Leaky(beta=beta)              # LIF neuron layer after fc1
fc2 = nn.Linear(num_hidden, num_outputs) # Second fully connected layer
lif2 = snn.Leaky(beta=beta)              # LIF neuron layer after fc2

Here, we define the network parameters and initialize the layers. The network consists of an input layer, a hidden layer, and an output layer. The LIF neuron model is used for the hidden and output layers.

In [None]:
# Initialize hidden states
mem1 = lif1.init_leaky() # Initial membrane potential for the hidden layer
mem2 = lif2.init_leaky() # Initial membrane potential for the output layer

# Record outputs
mem2_rec = [] # Record of membrane potentials for the output layer
spk1_rec = [] # Record of spikes for the hidden layer
spk2_rec = [] # Record of spikes for the output layer

We initialize the hidden states (membrane potentials) for the LIF neurons and create lists to record the membrane potentials and spikes during the simulation.

In [None]:
spk_in = spikegen.rate_conv(torch.rand((200, 784))).unsqueeze(1) # Generate random input spikes
print(f"Dimensions of spk_in: {spk_in.size()}")

We generate random input spikes using `spikegen.rate_conv` for 200 time steps and 784 input neurons.

In [None]:
# Network simulation
for step in range(num_steps):
    cur1 = fc1(spk_in[step]) # Calculate post-synaptic current for the hidden layer
    spk1, mem1 = lif1(cur1, mem1) # Update hidden layer's spikes and membrane potentials
    cur2 = fc2(spk1) # Calculate post-synaptic current for the output layer
    spk2, mem2 = lif2(cur2, mem2) # Update output layer's spikes and membrane potentials

    mem2_rec.append(mem2) # Record membrane potential for output layer
    spk1_rec.append(spk1) # Record spikes for hidden layer
    spk2_rec.append(spk2) # Record spikes for output layer

# Convert lists to tensors
mem2_rec = torch.stack(mem2_rec)
spk1_rec = torch.stack(spk1_rec)
spk2_rec = torch.stack(spk2_rec)

plot_snn_spikes(spk_in, spk1_rec, spk2_rec, "Fully Connected Spiking Neural Network")

We simulate the network for a given number of time steps (`num_steps`). At each time step, we compute the post-synaptic current, update the spikes and membrane potentials, and record the results.

In [None]:
from IPython.display import HTML

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']
spk2_rec = spk2_rec.squeeze(1).detach().cpu()

# Plot spike count histogram
anim = splt.spike_count(spk2_rec, fig, ax, labels=labels, animate=True)
HTML(anim.to_html5_video())
# anim.save("spike_bar.mp4")

We use `spikeplot` to visualize the spike count histogram of the output layer neurons. This provides insight into the firing behavior of each output neuron over time.

`spikeplot.traces` lets you visualize the membrane potential traces. We will plot 9 out of 10 output neurons.
Compare it to the animation and raster plot above to see if you can match the traces to the neuron.

In [None]:
# Plot membrane potential traces
splt.traces(mem2_rec.squeeze(1), spk=spk2_rec.squeeze(1))
fig = plt.gcf()
fig.set_size_inches(8, 6)