Backpropagation in spiking neural networks
==========================================
This notebook presents how the gradient descent algorithm can be adapted for backpropagation in a spiking neural networks with non differentiable activation functions.

Copyright (c) 2019, NECOTIS  
All rights reserved.  
Author: Ismael Balafrej  

Work inspired and adapted from 
1. Surrogate Gradient Learning in Spiking Neural Networks by Zenke & Ganguli (2018) https://arxiv.org/pdf/1901.09948.pdf
2. SLAYER: Spike Layer Error Reassignment in Time (2018) https://arxiv.org/pdf/1810.08646.pdf
3. Biologically inspired alternatives to backpropagation through time for learning in recurrent neural nets (2019) https://arxiv.org/pdf/1901.09049.pdf

In [0]:
!pip install quantities sparse > /dev/null

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, model_selection, utils
import torch
import quantities as units
from sparse import COO

In [4]:
# Reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Use the GPU unless there is none available, if you don't have a CUDA enabled GPU, I recommand using Google Colab
# available here: https://colab.research.google.com.
# Create a new notebook and then go to Runtime -> Change runtime type -> Hardware accelerator -> GPU
# This will give you access to a fairly recent GPU for free, for up to 12h continuously
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Let's download the MNIST dataset, available at https://www.openml.org/d/554
# You can edit the argument data_home to the directory of your choice.
# The dataset will be downloaded there; the default directory is ~/scikit_learn_data/
X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True, data_home=None)
nb_of_samples, nb_of_features = X.shape
#X = 70k samples, 28*28 features, y = 70k samples, 1 label (string)

# Shuffle the dataset
X, y = utils.shuffle(X, y)

# Convert the labels (string) to integers for convenience
y = np.array(y, dtype=np.int)
nb_of_ouputs = np.max(y) + 1

# We'll normalize our input data in the range [0., 1[.
X = X / pow(2, 8)

# And convert the data to a spike train
dt = 1*units.ms
duration_per_image = 100*units.ms
absolute_duration = int(duration_per_image / dt)

time_of_spike = (1 - X) * absolute_duration # The brighter the white, the earlier the spike
time_of_spike[X < .25] = 0 # "Remove" the spikes associated with darker pixels (Presumably less information)

sample_id, neuron_idx = np.nonzero(time_of_spike)

# We use a sparse COO array to store the spikes for memory requirements
# You can use the spike_train variable as if it were a tensor of shape (nb_of_samples, nb_of_features, absolute_duration)
spike_train = COO((sample_id, neuron_idx, time_of_spike[sample_id, neuron_idx]),
                  np.ones_like(sample_id), shape=(nb_of_samples, nb_of_features, absolute_duration))

# We create a 2 layer network (1 hidden, 1 output)
nb_hidden = 128 # Number of hidden neurons

w1 = torch.empty((nb_of_features, nb_hidden), device=device, dtype=torch.float, requires_grad=True)
torch.nn.init.normal_(w1, mean=0., std=.1)

w2 = torch.empty((nb_hidden, nb_of_ouputs), device=device, dtype=torch.float, requires_grad=True)
torch.nn.init.normal_(w2, mean=0., std=.1)

# Split in train/test
nb_of_train_samples = int(nb_of_samples * 0.85) # Keep 15% of the dataset for testing
train_indices = np.arange(nb_of_train_samples)
test_indices = np.arange(nb_of_train_samples, nb_of_samples)

class SpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0 # We spike when the (potential-threshold) > 0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone() # Clone will create a copy of the numerical value
        grad_input[input < 0] = 0 # The derivative of a ReLU function
        return grad_input

def run_spiking_layer(input_spike_train, layer_weights, tau_v=20*units.ms, tau_i=5*units.ms, v_threshold=1.0):
    """Here we implement a current-LIF dynamic in pytorch"""

    # First, we multiply the input spike train by the weights of the current layer to get the current that will be added
    # We can calculate this beforehand because the weights are constant in the forward pass (no plasticity)
    input_current = torch.einsum("abc,bd->adc", (input_spike_train, layer_weights)) # Equivalent to a matrix multiplication for tensors of dim > 2 using Einstein's Notation

    recorded_spikes = [] # Array of the output spikes at each time t
    membrane_potential_at_t = torch.zeros((input_spike_train.shape[0], layer_weights.shape[-1]), device=device, dtype=torch.float)
    membrane_current_at_t = torch.zeros((input_spike_train.shape[0], layer_weights.shape[-1]), device=device, dtype=torch.float)

    for t in range(absolute_duration): # For every timestep
        # Apply the leak
        membrane_potential_at_t = toBeCompleted # Using tau_v with euler or exact method
        membrane_current_at_t = toBeCompleted # Using tau_i with euler or exact method

        # Select the input current at time t
        input_at_t = input_current[:, :, t]

        # Integrate the input current
        membrane_current_at_t += input_at_t

        # Integrate the input to the membrane potential
        membrane_potential_at_t += membrane_current_at_t

        # Apply the non-differentiable function
        recorded_spikes_at_t = SpikeFunction.apply(membrane_potential_at_t - v_threshold)
        recorded_spikes.append(recorded_spikes_at_t)

        # Reset the spiked neurons
        membrane_potential_at_t[membrane_potential_at_t > v_threshold] = 0

    recorded_spikes = torch.stack(recorded_spikes, dim=2) # Stack over time axis (Array -> Tensor)
    return recorded_spikes


# Set-up training
nb_of_epochs = 20
batch_size = 256 # The backpropagation is done after every batch, but a batch here is also used for memory requirements 
number_of_batches = len(train_indices) // batch_size

params = [w1, w2] # Trainable parameters
optimizer = torch.optim.Adam(params, lr=0.01, amsgrad=True)
loss_fn = torch.nn.MSELoss(reduction='mean')

for e in range(nb_of_epochs):
    epoch_loss = 0
    for batch in np.array_split(train_indices, number_of_batches):
        # Select batch and convert to tensors
        batch_spike_train = torch.FloatTensor(spike_train[batch].todense()).to(device)
        batch_labels = torch.LongTensor(y[batch, np.newaxis]).to(device)

        # Here we create a target spike count (10 spikes for wrong label, 100 spikes for true label) in a one-hot fashion
        # This approach is seen in Shrestha & Orchard (2018) https://arxiv.org/pdf/1810.08646.pdf
        # Code available at https://github.com/bamsumit/slayerPytorch
        min_spike_count = 10 * torch.ones((batch.shape[0], 10), device=device, dtype=torch.float)
        target_output = min_spike_count.scatter_(1, batch_labels, 100.0)

        # Forward propagation
        layer_1_spikes = run_spiking_layer(batch_spike_train, w1)
        layer_2_spikes = run_spiking_layer(layer_1_spikes, w2)
        network_output = torch.sum(layer_2_spikes, 2) # Count the spikes over time axis
        loss = loss_fn(network_output, target_output)

        # Backward propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print("Epoch %i -- loss : %.4f" %(e+1, epoch_loss / number_of_batches))

Epoch 1 -- loss : 328.7854
Epoch 2 -- loss : 280.9454
Epoch 3 -- loss : 274.2292
Epoch 4 -- loss : 272.9392
Epoch 5 -- loss : 272.9115
Epoch 6 -- loss : 272.9176
Epoch 7 -- loss : 272.7327
Epoch 8 -- loss : 272.4144
Epoch 9 -- loss : 272.1312
Epoch 10 -- loss : 271.8605
Epoch 11 -- loss : 271.4815
Epoch 12 -- loss : 270.7438
Epoch 13 -- loss : 270.3649
Epoch 14 -- loss : 270.3311
Epoch 15 -- loss : 270.2173
Epoch 16 -- loss : 270.2439
Epoch 17 -- loss : 270.1217
Epoch 18 -- loss : 270.1146
Epoch 19 -- loss : 270.0847
Epoch 20 -- loss : 270.1059


In [5]:
# Test the accuracy of the model
correct_label_count = 0
# We only need to batchify the test set for memory requirements
for batch in np.array_split(test_indices,  len(test_indices) // batch_size):
    test_spike_train = torch.FloatTensor(spike_train[batch].todense()).to(device)
  
    # Same forward propagation as before
    layer_1_spikes = run_spiking_layer(test_spike_train, w1)
    layer_2_spikes = run_spiking_layer(layer_1_spikes, w2)
    network_output = torch.sum(layer_2_spikes, 2) # Count the spikes over time axis
    
    # Do the prediction by selecting the output neuron with the most number of spikes
    _, am = torch.max(network_output, 1) 
    correct_label_count += np.sum(am.detach().cpu().numpy() == y[batch])

print("Model accuracy on test set: %.3f" % (correct_label_count / len(test_indices)))

Model accuracy on test set: 0.869
