# BindsNET Learning Techniques

## 1. Table of Contents
1. Table of Contents
2. Overview
3. Import Statements
4. Learning Flow
    1. Define Simulation Parameters
    2. Create Input Data
    3. Configure Network Architecture
    4. Define Simulation Variables
    5. Perform Learning Iterations
    6. Evaluate Classification Performance
5. Learning Rules
    1. PostPre
    2. Hebbian
    3. WeightDependentPostPre
    4. MSTDP
    5. MSTDPET
    6. RMAX
6. Custom Learning Rules


## 2. Overview

Detail documentation of usage of learning rules has been specified [here](https://bindsnet-docs.readthedocs.io/guide/guide_part_ii.html). This document will go into more specific examples of configuring a spiking neural network in BindsNET.

The specified learning rule is passed into a `Connection` object via the `update_rule` argument. The connection encapsulates the learning rule object.

* `nu`: a 2-tuple pre- and post- synaptic learning rates (how quickly synapse weights change)
* `reduction`: specifies how parameter updates are aggregated across the batch dimension
* `weight_decay`: specifies the time constant of the rate of decay of synapse weights to zero

Parameter updates are averaged across the batch dimension by default, so there is no weight decay.


In [2]:
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.learning import PostPre

# Create two populations of neurons, one to act as the "source"
# population, and the other, the "target population".
# Neurons involved in certain learning rules must record synaptic
# traces, a vector of short-term memories of the last emitted spikes.
source_layer = Input(n=100, traces=True)
target_layer = LIFNodes(n=1000, traces=True)

# Connect the two layers.
connection = Connection(
    source=source_layer, target=target_layer, update_rule=PostPre, nu=(1e-4, 1e-2)
)

ModuleNotFoundError: No module named 'bindsnet'

## 3. Import Statements

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random

from bindsnet.encoding import *
from bindsnet.network import Network
from bindsnet.network.monitors import Monitor
from bindsnet.network.monitors import NetworkMonitor

from bindsnet.analysis.plotting import plot_spikes, plot_voltages, plot_input, plot_weights

from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.learning import PostPre, Hebbian, WeightDependentPostPre, MSTDP, MSTDPET

from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.utils import get_square_weights, get_square_assignments

## 4. Learning Flow

1. Define Simulation Parameters
2. Create Input Data
3. Configure Network Architecture
4. Define Simulation Variables
5. Perform Learning Iterations
6. Evaluate Classification Performance

### 4.1 Simulation Parameters

In [None]:
### Input Data Parameters ###

# number of training samples
training_samples = 1
testing_samples = 10

# set number of classes
n_classes = 2

### Network Configuration Parameters ###

# configure number of input neurons
input_layer_name = "Input Layer"
input_neurons = 9

# configure the number of output lif neurons
lif_layer_name = "LIF Layer"
lif_neurons = 2

### Simulation Parameters ###

# simulation time
time = 10
dt = 1

# number of training iterations
epochs = 1

# ratio of neurons to classes
per_class = int(lif_neurons / n_classes)

### 4.2 Input Configuration

In [None]:
# store unique images in a list
imgs = []

# Class 0 Image
img0 = {"Label" : 0, "Image" : torch.FloatTensor([[1,1,1],[1,0,1],[1,1,1]])}
imgs.append(img0)

# Class 1 Image
img1 = {"Label" : 1, "Image" : torch.FloatTensor([[0,1,0],[0,1,0],[0,1,0]])}
imgs.append(img1)

# initialize list of inputs for training
training_dataset = []

# for the number of specified training samples
for i in range(training_samples):

    # randomly select a training sample
    # rand_sample = random.randint(0,n_classes-1)
    
    # provide an even number of training samples
    rand_sample = i % n_classes

    # add the sample to the list of training samples
    training_dataset.append(imgs[rand_sample])

# initialize the encoder
encoder = BernoulliEncoder(time=time, dt=dt)

# list of encoded images for random selection during training
encoded_train_inputs = []

# loop through encode each image type and store into a list of encoded images
for sample in training_dataset:

    # encode the image 
    encoded_img = encoder(torch.flatten(sample["Image"]))

    # encoded image input for the network
    encoded_img_input = {input_layer_name: encoded_img}

    # encoded image label
    encoded_img_label = sample["Label"]

    # add to the encoded input list along with the input layer name
    encoded_train_inputs.append({"Label" : encoded_img_label, "Inputs" : encoded_img_input})

# initialize list of inputs for testing
testing_dataset = []

# for the number of specified testing samples
for i in range(testing_samples):

    # randomly select a training sample
    rand_sample = random.randint(0,n_classes-1)

    # add the sample to the list of training samples
    testing_dataset.append(imgs[rand_sample])

# list of encoded images for random selection during training
encoded_test_inputs = []

# loop through encode each image type and store into a list of encoded images
for sample in testing_dataset:

    # encode the image 
    encoded_img = encoder(torch.flatten(sample["Image"]))

    # encoded image input for the network
    encoded_img_input = {input_layer_name: encoded_img}

    # encoded image label
    encoded_img_label = sample["Label"]

    # add to the encoded input list along with the input layer name
    encoded_test_inputs.append({"Label" : encoded_img_label, "Inputs" : encoded_img_input})

### 4.3 Network Configuration

When creating a connection between two layers, the learning (update) rule should be specified as well as the learning rate (nu) 

In [None]:
# initialize network
network = Network()

# configure weights for the synapses between the input layer and LIF layer
#w = torch.round(torch.abs(2 * torch.randn(input_neurons, lif_neurons)))
w = torch.zeros(input_neurons,lif_neurons)

# initialize input and LIF layers
# spike traces must be recorded (why?)

# initialize input layer
input_layer = Input(n=input_neurons,traces=True)

# initialize input layer
lif_layer = LIFNodes(n=lif_neurons,traces=True)

# initialize connection between the input layer and the LIF layer
# specify the learning (update) rule and learning rate (nu)
connection = Connection(
    #source=input_layer, target=lif_layer, w=w, update_rule=PostPre, nu=(1e-4, 1e-2)
    source=input_layer, target=lif_layer, w=w, update_rule=PostPre, nu=(1, 1)
)

# add input layer to the network
network.add_layer(
    layer=input_layer, name=input_layer_name
)

# add lif neuron layer to the network
network.add_layer(
    layer=lif_layer, name=lif_layer_name
)

# add connection to network
network.add_connection(
    connection=connection, source=input_layer_name, target=lif_layer_name
)

### 4.4 Simulation Variables

In [None]:
# record the spike times of each neuron during the simulation.
spike_record = torch.zeros(1, int(time / dt), lif_neurons)

# record the mapping of each neuron to its corresponding label
assignments = -torch.ones_like(torch.Tensor(lif_neurons))

# how frequently each neuron fires for each input class
rates = torch.zeros_like(torch.Tensor(lif_neurons, n_classes))

# the likelihood of each neuron firing for each input class
proportions = torch.zeros_like(torch.Tensor(lif_neurons, n_classes))


# label(s) of the input(s) being processed
labels = torch.empty(1,dtype=torch.int)

# create a spike monitor for each layer in the network
# this allows us to read the spikes in order to assign labels to neurons and determine the predicted class 
layer_monitors = {}
for layer in set(network.layers):

    # initialize spike monitor at the layer
    # do not record the voltage if at the input layer
    state_vars = ["s","v"] if (layer != input_layer_name) else ["s"]
    layer_monitors[layer] = Monitor(network.layers[layer], state_vars=state_vars, time=int(time/dt))

    # connect the monitor to the network
    network.add_monitor(layer_monitors[layer], name="%s_spikes" % layer)

### 4.5 Training

Below are descriptions of the functions for evaluating the behavior of an SNN in BindsNET


---


`all_activity()`

Classify data with the label with highest average spiking activity over all neurons.

Returns a predictions tensor of shape `(n_samples,)` resulting from the "all activity" classification scheme (`torch.Tensor`)

| Parameter  | Type         | Description                                                                           | Default Value |
|-------------|--------------|---------------------------------------------------------------------------------------|---------|
| spikes      | `torch.Tensor` | Binary tensor of shape `(n_samples, time, n_neurons)` of a layer'sspiking activity. |         |
| assignments | `torch.Tensor` | A vector of shape `(n_neurons,)` of neuron label assignments.                       |         |
| n_labels    | `int`          | The number of target labels in the data.                                              |         |


----


`proportion_weighting()`

Classify data with the label with highest average spiking activity over all neurons, weighted by class-wise proportion.

Returns a predictions tensor of shape `(n_samples,)` resulting from the "proportion weighting" classification scheme (`torch.Tensor`)

| Parameter   | Type         | Description                                                                                              | Default Value |
|-------------|--------------|----------------------------------------------------------------------------------------------------------|---------------|
| spikes      | `torch.Tensor` | Binary tensor of shape `(n_samples, time, n_neurons)` of a single layer's spiking activity.            |               |
| assignments | `torch.Tensor` | A vector of shape `(n_neurons,)` of neuron label assignments.                                          |               |
| proportions | `torch.Tensor` | A matrix of shape `(n_neurons, n_labels)` giving the per-class proportions of neuron spiking activity. |               |
| n_labels    | `int`          | The number of target labels in the data.                                                                 |               |

----

`assign_labels()`

Assign labels to the neurons based on highest average spiking activity.

Returns a Tuple of class assignments, per-class spike proportions, and per-class firing rates (`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`)

| Parameter | Type                     | Descriptions                                                                                  | Default Value |  
|------------|--------------------------|-----------------------------------------------------------------------------------------------|---------------|
| spikes     | `torch.Tensor`             | Binary tensor of shape `(n_samples, time, n_neurons)` of a single layer's spiking activity. |                | 
| labels     | `torch.Tensor`             | Vector of shape `(n_samples,)` with data labels corresponding to spiking activity.          |                | 
| n_labels   | `int`                      | The number of target labels in the data.                                                      |                | 
| rates      | `Optional[torch.Tensor]` | If passed, these represent spike rates from a previous `assign_labels()` call.              | None          | 
| alpha      | `float`                    | Rate of decay of label assignments.                                                           | 1             | 


In [None]:
weight_history = None
num_correct = 0.0

### DEBUG ###
### can be used to force the network to learn the inputs in a specific way
supervised = True
### used to determine if status messages are printed out at each sample
log_messages = True
### used to show weight changes
plot_weights = True
###############

# iterate for epochs
for step in range(epochs):
    for sample in encoded_train_inputs:
        
        # get the label for the current image
        labels[0] = sample["Label"]

        # randomly decide which output neuron should spike if more than one neuron corresponds to the class
        # choice will always be 0 if there is one neuron per output class
        choice = np.random.choice(per_class, size=1, replace=False)

        # clamp on the output layer forces the node corresponding to the label's class to spike
        # this is necessary in order for the network to learn which neurons correspond to which classes
        # clamp: Mapping of layer names to boolean masks if neurons should be clamped to spiking. 
        # The ``Tensor``s have shape ``[n_neurons]`` or ``[time, n_neurons]``.
        clamp = {lif_layer_name: per_class * labels[0] + torch.Tensor(choice).long()} if supervised else {}

        print(sample["Inputs"])

        ### Step 1: Run the network with the provided inputs ###
        network.run(inputs=sample["Inputs"], time=time, clamp=clamp)

        ### Step 2: Get the spikes produced at the output layer ###
        spike_record[0] = layer_monitors[lif_layer_name].get("s").view(time, lif_neurons)
        
        ### Step 3: ###

        # Assign labels to the neurons based on highest average spiking activity.
        # Returns a Tuple of class assignments, per-class spike proportions, and per-class firing rates 
        # Return Type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        assignments, proportions, rates = assign_labels( spike_record, labels, n_classes, rates )

        ### Step 4: Classify data based on the neuron (label) with the highest average spiking activity ###

        # Classify data with the label with highest average spiking activity over all neurons.
        all_activity_pred = all_activity(spike_record, assignments, n_classes)

        ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity
        ###         weighted by class-wise proportion ###
        proportion_pred = proportion_weighting(spike_record, assignments, proportions, n_classes)

        ### Update Accuracy
        num_correct += 1 if (labels.numpy()[0] == all_activity_pred.numpy()[0]) else 0

        ######## Display Information ########
        if log_messages:
            print("Actual Label:",labels.numpy(),"|","Predicted Label:",all_activity_pred.numpy(),"|","Proportionally Predicted Label:",proportion_pred.numpy())
            
            print("Neuron Label Assignments:")
            for idx in range(assignments.numel()):
                print(
                    "\t Output Neuron[",idx,"]:",assignments[idx],
                    "Proportions:",proportions[idx],
                    "Rates:",rates[idx]
                    )
            print("\n")
        #####################################


    ### For Weight Plotting ###
    if plot_weights:
        weights = network.connections[("Input Layer", "LIF Layer")].w[:,0].numpy().reshape((1,input_neurons))
        weight_history = weights.copy() if step == 0 else np.concatenate((weight_history,weights),axis=0)
        print("Neuron 0 Weights:\n",network.connections[("Input Layer", "LIF Layer")].w[:,0])
        print("Neuron 1 Weights:\n",network.connections[("Input Layer", "LIF Layer")].w[:,1])
        print("====================")
    #############################

    if log_messages:
        print("Epoch #",step,"\tAccuracy:", num_correct / ((step + 1) * len(encoded_train_inputs)) )
        print("===========================\n\n")

### For Weight Plotting ###
# Plot Weight Changes
if plot_weights:
    [plt.plot(weight_history[:,idx]) for idx in range(weight_history.shape[1])]
    plt.show()
    
#############################

### Print Final Class Assignments and Proportions ###
print("Neuron Label Assignments:")
for idx in range(assignments.numel()):
    print(
        "\t Output Neuron[",idx,"]:",assignments[idx],
        "Proportions:",proportions[idx],
        "Rates:",rates[idx]
        )

### 4.6 Evaluate Performance

In [None]:
num_correct = 0

log_messages = True

# disable training mode
network.train(False)

# loop through each test example and record performance
for sample in encoded_test_inputs:

    # get the label for the current image
    labels[0] = sample["Label"]

    ### Step 1: Run the network with the provided inputs ###
    network.run(inputs=sample["Inputs"], time=time)

    ### Step 2: Get the spikes produced at the output layer ###
    spike_record[0] = layer_monitors[lif_layer_name].get("s").view(time, lif_neurons)

    ### Step 3: ###

    # Assign labels to the neurons based on highest average spiking activity.
    # Returns a Tuple of class assignments, per-class spike proportions, and per-class firing rates 
    # Return Type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    assignments, proportions, rates = assign_labels( spike_record, labels, n_classes, rates )

    ### Step 4: Classify data based on the neuron (label) with the highest average spiking activity ###

    # Classify data with the label with highest average spiking activity over all neurons.
    all_activity_pred = all_activity(spike_record, assignments, n_classes)

    ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity
    ###         weighted by class-wise proportion ###
    proportion_pred = proportion_weighting(spike_record, assignments, proportions, n_classes)

    ### Update Accuracy
    num_correct += 1 if (labels.numpy()[0] == all_activity_pred.numpy()[0]) else 0

    ######## Display Information ########
    if log_messages:
        print("Actual Label:",labels.numpy(),"|","Predicted Label:",all_activity_pred.numpy(),"|","Proportionally Predicted Label:",proportion_pred.numpy())
        
        print("Neuron Label Assignments:")
        for idx in range(assignments.numel()):
            print(
                "\t Output Neuron[",idx,"]:",assignments[idx],
                "Proportions:",proportions[idx],
                "Rates:",rates[idx]
                )
        print("\n")
    #####################################
print("Accuracy:", num_correct / len(encoded_test_inputs) )

## 5. Learning Rules

### Introduction:

#### [Basic STDP Model:](http://www.scholarpedia.org/article/Spike-timing_dependent_plasticity)

The weight change $\Delta w_j$ of a synapse from a presynaptic neuron $j$| depends on the relative timing between presynaptic spike arrivals and postsynaptic spikes. 

Presynaptic spike arrival times at synapse $j$ are denoted by $t^f_j$ where $f$=1,2,3,... counts the presynaptic spikes. 

Postsynaptic firing times are denoted by $t^n_i$ where $n$=1,2,3,... counts the postsynaptic spikes. 

The total weight change $\Delta w_j$ induced by a stimulation protocol with pairs of pre- and postsynaptic spikes is then:

$$
\Delta w = \sum_{f=1}^{N} \sum_{n=1}^{N} W (t_i^n - t_j^f)
$$

where **$W(x)$** denotes one of the STDP functions (also called learning window).

A popular choice for the STDP function **$W(x)$**
$$
W(x)=A_+e^{−x/\tau+} \hspace{5mm} for \hspace{5mm} x > 0
$$

$$
W(x)=−A_−e^{x/\tau−} \hspace{5mm} for \hspace{5mm} x < 0 
$$

The parameters A+ and A− may depend on the current value of the synaptic weight $w_j$. The time constants are on the order of $\tau_+$ = 10ms and $\tau_-$=10ms

In summary: 

The weight change $\Delta w_j$ will be a decreasing positive value the more $t^n_i$ (post synaptic firing time) exceedes $t^f_j$ (presynaptic spike time). This is also referred to as long-term potentiation (LTP).

The weight change $\Delta w_j$ will be a decreasing negative value the more $t^f_j$ (presynaptic spike time) exceedes $t^n_i$ (post synaptic firing time). This is also referred to as long-term depression (LTD).

Source: http://www.scholarpedia.org/article/Spike-timing_dependent_plasticity

### 5a. PostPre

#### Summary

A simple STDP rule involving both pre- and post-synaptic spiking activity. By default, pre-synaptic update is negative and the post-synaptic update is positive.

The rule follows the equation below for each timestep $t$:

$$
\Delta w (t) = \eta_1 (e^{\frac{t - t_{pre}}{\tau}})S_{post}(t) - \eta_0 (e^{\frac{t - t_{post}}{\tau}})S_{pre}(t)
$$

Where $S_{pre}(t)$ and $S_{post}(t)$ indicate if there was a spike at time t for either the pre-synaptic or post-synaptic neurons. 

Additionally $t_{pre}$ is the timestamp when the pre-synaptic neuron last fired, and $t_{post}$ is the timestamp when the post-synaptic neuron last fired.

The `trace_decay` value specified when creating a new `Nodes` layer is given by:

$$
trace\_decay = {e^{-\frac{1}{\tau}}}
$$

The value $\Delta w$ is calculated and applied at each synapse for every timestep.

#### Table of Parameters

| Parameters   | Type                                    | Description                                                                               | Default Value |
|--------------|-----------------------------------------|-------------------------------------------------------------------------------------------|---------------|
| connection   | AbstractConnection                      | An `AbstractConnection` object whose weights the `PostPre` learning rule will modify. |               |
| nu           | Optional\[Union\[float, Sequence\[float]]] | Single or pair of learning rates for pre- and post-synaptic events.                       | None          |
| reduction    | Optional\[callable]                      | Method for reducing parameter updates along the batch                                     | None          |
| weight_decay | float                                   | Constant multiple to decay weights by on each iteration.                                  | 0.0           |

### 5b. WeightDependentPostPre

#### Summary

STDP rule involving both pre- and post-synaptic spiking activity. The post-synaptic update is positive and the pre- synaptic update is negative, and both are dependent on the magnitude of the synaptic weights.

The rule follows the equation below for each timestep $t$:

$$
\Delta w (t) = \eta_1 (e^{\frac{t - t_{pre}}{\tau}})S_{post}(t)(w_{max} - w) - \eta_0 (e^{\frac{t - t_{post}}{\tau}})S_{pre}(t) (w - w_{min})
$$

#### Table of Parameters

| Parameters   | Type                                    | Description                                                                               | Default Value |
|--------------|-----------------------------------------|-------------------------------------------------------------------------------------------|---------------|
| connection   | AbstractConnection                      | An `AbstractConnection` object whose weights the `PostPre` learning rule will modify. |               |
| nu           | Optional\[Union\[float, Sequence\[float]]] | Single or pair of learning rates for pre- and post-synaptic events.                       | None          |
| reduction    | Optional\[callable]                      | Method for reducing parameter updates along the batch                                     | None          |
| weight_decay | float                                   | Constant multiple to decay weights by on each iteration.                                  | 0.0           |

### 5c. Hebbian

#### Summary

Simple Hebbian learning rule. Pre- and post-synaptic updates are both positive.

The rule follows the equation below for each timestep $t$:

$$
\Delta w (t) = \eta_1 (e^{\frac{t - t_{pre}}{\tau}})S_{post}(t) + \eta_0 (e^{\frac{t - t_{post}}{\tau}})S_{pre}(t)
$$


#### Table of Parameters

| Parameters   | Type                                    | Description                                                                               | Default Value |
|--------------|-----------------------------------------|-------------------------------------------------------------------------------------------|---------------|
| connection   | AbstractConnection                      | An `AbstractConnection` object whose weights the `PostPre` learning rule will modify. |               |
| nu           | Optional\[Union\[float, Sequence\[float]]] | Single or pair of learning rates for pre- and post-synaptic events.                       | None          |
| reduction    | Optional\[callable]                      | Method for reducing parameter updates along the batch                                     | None          |
| weight_decay | float                                   | Constant multiple to decay weights by on each iteration.                                  | 0.0           |

### 5d. MSTDP
Reward-modulated STDP. Adapted from [Florian 2007] (https://florian.io/papers/2007_Florian_Modulated_STDP.pdf).

### 5e. MSTDPET
Reward-modulated STDP with eligibility trace. Adapted from [Florian 2007] (https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>).

### 5f. Rmax
Reward-modulated learning rule derived from reward maximization principles. Adapted from [Vasilaki et al., 2009] (https://intranet.physio.unibe.ch/Publikationen/Dokumente/Vasilaki2009PloSComputBio_1.pdf).

## 6. Custom Learning Rules

Custom learning rules can be implemented by subclassing `bindsnet.learning.LearningRule` and providing implementations for the types of `AbstractConnection` objects intended to be used. 

For example, the `Connection` and `LocalConnection` objects rely on the implementation of a private method, `_connection_update`, whereas the `Conv2dConnection` object uses the `_conv2d_connection_update` version.