# BindsNET Learning Techniques

## 1. Table of Contents
1. Table of Contents
2. Overview
3. Import Statements
4. Learning Flow
5. Learning Rules
    1. PostPre
    2. Hebbian
    3. WeightDependentPostPre
    4. MSTDP
    5. MSTDPET
6. Custom Learning Rules


## 2. Overview

Detail documentation of usage of learning rules has been specified [here] (https://bindsnet-docs.readthedocs.io/guide/guide_part_ii.html). This document will go into more specific examples of configuring a spiking neural network in BindsNET.

The specified learning rule is passed into a `Connection` object via the `update_rule` argument. The connection encapsulates the learning rule object.

* `nu`: a 2-tuple pre- and post- synaptic learning rates (how quickly synapse weights change)
* `reduction`: specifies how parameter updates are aggregated across the batch dimension
* `weight_decay`: specifies the time constant of the rate of decay of synapse weights to zero

Parameter updates are averaged across the batch dimension by default, so there is no weight decay.


In [105]:
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.learning import PostPre

# Create two populations of neurons, one to act as the "source"
# population, and the other, the "target population".
# Neurons involved in certain learning rules must record synaptic
# traces, a vector of short-term memories of the last emitted spikes.
source_layer = Input(n=100, traces=True)
target_layer = LIFNodes(n=1000, traces=True)

# Connect the two layers.
connection = Connection(
    source=source_layer, target=target_layer, update_rule=PostPre, nu=(1e-4, 1e-2)
)

## 3. Import Statements

In [106]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random

from bindsnet.encoding import *
from bindsnet.network import Network
from bindsnet.network.monitors import Monitor
from bindsnet.network.monitors import NetworkMonitor

from bindsnet.analysis.plotting import plot_spikes, plot_voltages, plot_input, plot_weights

from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.learning import PostPre, Hebbian, WeightDependentPostPre, MSTDP, MSTDPET

from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.utils import get_square_weights, get_square_assignments

## 4. Learning Flow

1. Define simulation parameters
2. Create input data
3. Configure network architecture
4. Define simulation variables
5. Perform learning iterations
6. Evaluate classification performance

### 4.1 Simulation Parameters

In [107]:
# configure number of input neurons
input_layer_name = "Input Layer"
input_neurons = 9

# configure the number of output lif neurons
lif_layer_name = "LIF Layer"
lif_neurons = 2

# set number of classes
n_classes = 2

# simulation time
time = 100
dt = 1

# ratio of neurons to classes
per_class = int(lif_neurons / n_classes)

### 4.2 Input Configuration

In [108]:
# input data

# initialize list of inputs
imgs = []

# Class 0 Image
imgs += torch.flatten(torch.FloatTensor([[128,0,128],[0,128,0],[128,0,128]]))

# Class 1 Image
imgs += torch.flatten(torch.FloatTensor([[0,128,0],[128,128,128],[0,128,0]]))

# initialize the encoder
encoder = BernoulliEncoder(time=time, dt=dt)

# list of encoded images for random selection during training
encoded_inputs = []

# loop through encode each image type and store into a list of encoded images
for img in imgs:

    # encode the image 
    encoded_img = encoder(img)

    # add to the encoded input list along with the input layer name
    encoded_inputs += {input_layer_name: encoded_img0}

### 4.3 Network Configuration

When creating a connection between two layers, the learning (update) rule should be specified as well as the learning rate (nu) 

In [109]:
# initialize network
network = Network()

# configure weights for the synapses between the input layer and LIF layer
w = torch.round(torch.abs(2 * torch.randn(input_neurons, lif_neurons)))

# initialize input and LIF layers
# spike traces must be recorded (why?)

# initialize input layer
input_layer = Input(n=input_neurons,traces=True)

# initialize input layer
lif_layer = LIFNodes(n=lif_neurons,traces=True)

# initialize connection between the input layer and the LIF layer
# specify the learning (update) rule and learning rate (nu)
connection = Connection(
    source=input_layer, target=lif_layer, w=w, update_rule=PostPre, nu=(1e-4, 1e-2)
)

# add input layer to the network
network.add_layer(
    layer=input_layer, name=input_layer_name
)

# add lif neuron layer to the network
network.add_layer(
    layer=lif_layer, name=lif_layer_name
)

# add connection to network
network.add_connection(
    connection=connection, source=input_layer_name, target=lif_layer_name
)

### 4.4 Simulation Variables

In [110]:
# record the spike times of each neuron during the simulation.
spike_record = torch.zeros(1, int(time / dt), lif_neurons)

# record the mapping of each neuron to its corresponding label
assignments = -torch.ones_like(torch.Tensor(lif_neurons))

# 
rates = torch.zeros_like(torch.Tensor(lif_neurons, n_classes))

# 
proportions = torch.zeros_like(torch.Tensor(lif_neurons, n_classes))


# label(s) of the input(s) being processed
labels = torch.empty(1,dtype=torch.int)

# create a spike monitor for each layer in the network
# this allows us to read the spikes in order to assign labels to neurons and determine the predicted class 
layer_monitors = {}
for layer in set(network.layers):

    # initialize spike monitor at the layer
    # do not record the voltage if at the input layer
    state_vars = ["s","v"] if (layer != input_layer_name) else ["s"]
    layer_monitors[layer] = Monitor(network.layers[layer], state_vars=state_vars, time=time)

    # connect the monitor to the network
    network.add_monitor(layer_monitors[layer], name="%s_spikes" % layer)

### 4.5 Training

Below are descriptions of the functions required to train an SNN in BindsNET


---


`all_activity()`

Classify data with the label with highest average spiking activity over all neurons.

Returns a predictions tensor of shape `(n_samples,)` resulting from the "all activity" classification scheme (`torch.Tensor`)

| Parameter  | Type         | Description                                                                           | Default Value |
|-------------|--------------|---------------------------------------------------------------------------------------|---------|
| spikes      | `torch.Tensor` | Binary tensor of shape `(n_samples, time, n_neurons)` of a layer'sspiking activity. |         |
| assignments | `torch.Tensor` | A vector of shape `(n_neurons,)` of neuron label assignments.                       |         |
| n_labels    | `int`          | The number of target labels in the data.                                              |         |


----


`proportion_weighting()`

Classify data with the label with highest average spiking activity over all neurons, weighted by class-wise proportion.

Returns a predictions tensor of shape `(n_samples,)` resulting from the "proportion weighting" classification scheme (`torch.Tensor`)

| Parameter   | Type         | Description                                                                                              | Default Value |
|-------------|--------------|----------------------------------------------------------------------------------------------------------|---------------|
| spikes      | `torch.Tensor` | Binary tensor of shape `(n_samples, time, n_neurons)` of a single layer's spiking activity.            |               |
| assignments | `torch.Tensor` | A vector of shape `(n_neurons,)` of neuron label assignments.                                          |               |
| proportions | `torch.Tensor` | A matrix of shape `(n_neurons, n_labels)` giving the per-class proportions of neuron spiking activity. |               |
| n_labels    | `int`          | The number of target labels in the data.                                                                 |               |

----

`assign_labels()`

Assign labels to the neurons based on highest average spiking activity.

Returns a Tuple of class assignments, per-class spike proportions, and per-class firing rates (`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`)

| Parameter | Type                     | Descriptions                                                                                  | Default Value |  
|------------|--------------------------|-----------------------------------------------------------------------------------------------|---------------|
| spikes     | `torch.Tensor`             | Binary tensor of shape `(n_samples, time, n_neurons)` of a single layer's spiking activity. |                | 
| labels     | `torch.Tensor`             | Vector of shape `(n_samples,)` with data labels corresponding to spiking activity.          |                | 
| n_labels   | `int`                      | The number of target labels in the data.                                                      |                | 
| rates      | `Optional[torch.Tensor]` | If passed, these represent spike rates from a previous `assign_labels()` call.              | None          | 
| alpha      | `float`                    | Rate of decay of label assignments.                                                           | 1             | 


In [111]:
weight_history = None
first_pass = True
plot_weights = False
num_correct = 0.0
epochs = 10

# simulate network on input data
# iterate for epochs
for step in range(epochs):

    # randomly select and input image class
    labels[0] = random.randint(0,n_classes-1)

    choice = np.random.choice(int(lif_neurons / n_classes), size=1, replace=False)

    # clamp: Mapping of layer names to boolean masks if neurons should be clamped to spiking. 
    # The ``Tensor``s have shape ``[n_neurons]`` or ``[time, n_neurons]``.
    # clamp on the output layer (Ae), forces the node corresponding to the label's class to spike
    clamp = {lif_layer_name: per_class * labels[0] + torch.Tensor(choice).long()}

    # get the input image from the list of encoded inputs
    inputs = encoded_inputs[labels[0]]

    ### Step 1: Run the network with the provided inputs ###
    network.run(inputs=inputs, time=time, clamp=clamp)

    ### Step 2: Get the spikes produced at the output layer ###
    spike_record[0] = layer_monitors[lif_layer_name].get("s").view(time, lif_neurons)
    
    ### Step 3: ###

    # Assign labels to the neurons based on highest average spiking activity.
    # Returns a Tuple of class assignments, per-class spike proportions, and per-class firing rates 
    # Return Type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    assignments, proportions, rates = assign_labels( spike_record, labels, n_classes, rates )

    ### Step 4: Classify data based on the neuron (label) with the highest average spiking activity ###

    # Classify data with the label with highest average spiking activity over all neurons.
    all_activity_pred = all_activity(spike_record, assignments, n_classes)

    ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity, weighted by class-wise proportion ###
    proportion_pred = proportion_weighting(spike_record, assignments, proportions, n_classes)

    ### Update Accuracy
    num_correct += 1 if (labels.numpy()[0] == all_activity_pred.numpy()[0]) else 0

    ######## Display Information ########
    print("Actual Label | Predicted Label [All Activity] | Predicted Label [Weighted Proportion]")
    print(labels.numpy(),"|",all_activity_pred.numpy(),"|",proportion_pred.numpy())
    print("Assignments:")
    print(assignments)

    print("Proportions:")
    print(proportions)

    print("Rates:")
    print(rates)
    #####################################


    ### For Weight Plotting ###
    if plot_weights:
        weights = network.connections[("Input Layer", "LIF Layer")].w[:,0].numpy().reshape((1,input_neurons))
        weight_history = weights.copy() if first_pass else np.concatenate((weight_history,weights),axis=0)
    first_pass = False
    #############################

    print("Accuracy:", num_correct / (step + 1.0) )

    print("====================\n\n")

### For Weight Plotting ###
# Plot Weight Changes
if plot_weights:
    [plt.plot(weight_history[:,idx]) for idx in range(weight_history.shape[1])]
    plt.show()
#############################

Actual Label | Predicted Label [All Activity] | Predicted Label [Weighted Proportion]
[1] | [1] | [1]
Assignments:
tensor([1, 1])
Proportions:
tensor([[0., 1.],
        [0., 1.]])
Rates:
tensor([[  0.,   9.],
        [  0., 100.]])
Accuracy: 1.0


Actual Label | Predicted Label [All Activity] | Predicted Label [Weighted Proportion]
[1] | [1] | [1]
Assignments:
tensor([1, 1])
Proportions:
tensor([[0., 1.],
        [0., 1.]])
Rates:
tensor([[  0.,  18.],
        [  0., 200.]])
Accuracy: 1.0


Actual Label | Predicted Label [All Activity] | Predicted Label [Weighted Proportion]
[1] | [1] | [1]
Assignments:
tensor([1, 1])
Proportions:
tensor([[0., 1.],
        [0., 1.]])
Rates:
tensor([[  0.,  28.],
        [  0., 300.]])
Accuracy: 1.0


Actual Label | Predicted Label [All Activity] | Predicted Label [Weighted Proportion]
[1] | [1] | [1]
Assignments:
tensor([1, 1])
Proportions:
tensor([[0., 1.],
        [0., 1.]])
Rates:
tensor([[  0.,  39.],
        [  0., 400.]])
Accuracy: 1.0


Actual L

### 4.6 Evaluate Performance

In [112]:
# loop through each test example and record performance

## 5. Learning Rules

### 5a. PostPre

Simple STDP rule involving both pre- and post-synaptic spiking activity. By default, pre-synaptic update is negative and the post-synaptic update is positive.

| Parameters   | Type                                    | Description                                                                               | Default Value |
|--------------|-----------------------------------------|-------------------------------------------------------------------------------------------|---------------|
| connection   | AbstractConnection                      | An `AbstractConnection` object whose weights the `PostPre` learning rule will modify. |               |
| nu           | Optional\[Union\[float, Sequence\[float]]] | Single or pair of learning rates for pre- and post-synaptic events.                       | None          |
| reduction    | Optional\[callable]                      | Method for reducing parameter updates along the batch                                     | None          |
| weight_decay | float                                   | Constant multiple to decay weights by on each iteration.                                  | 0.0           |

## 6. Custom Learning Rules

Custom learning rules can be implemented by subclassing `bindsnet.learning.LearningRule` and providing implementations for the types of `AbstractConnection` objects intended to be used. 

For example, the `Connection` and `LocalConnection` objects rely on the implementation of a private method, `_connection_update`, whereas the `Conv2dConnection` object uses the `_conv2d_connection_update` version.