# BindsNET Learning Techniques

## 1. Table of Contents
1. Table of Contents
2. Overview
3. Import Statements
4. Learning Rules
    1. PostPre
    2. Hebbian
    3. WeightDependentPostPre
    4. MSTDP
    5. MSTDPET
5. Custom Learning Rules


## 2. Overview

Detail documentation of usage of learning rules has been specified [here] (https://bindsnet-docs.readthedocs.io/guide/guide_part_ii.html). This document will go into more specific examples of configuring a spiking neural network in BindsNET.

The specified learning rule is passed into a `Connection` object via the `update_rule` argument.

* `nu`: a 2-tuple pre- and post- synaptic learning rates (how quickly synapse weights change)
* `reduction`: specifies how parameter updates are aggregated across the batch dimension
* `weight_decay`: specifies the time constant of the rate of decay of synapse weights to zero
* 

Parameter updates are averaged across the batch dimension by default, so there is no weight decay.


In [None]:
# Create two populations of neurons, one to act as the "source"
# population, and the other, the "target population".
# Neurons involved in certain learning rules must record synaptic
# traces, a vector of short-term memories of the last emitted spikes.
source_layer = Input(n=100, traces=True)
target_layer = LIFNodes(n=1000, traces=True)

# Connect the two layers.
connection = Connection(
    source=source_layer, target=target_layer, update_rule=PostPre, nu=(1e-4, 1e-2)
)

## 3. Import Statements

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from bindsnet.encoding import *
from bindsnet.network import Network
from bindsnet.network.monitors import Monitor
from bindsnet.network.monitors import NetworkMonitor

from bindsnet.analysis.plotting import plot_spikes, plot_voltages, plot_input, plot_weights

from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.learning import PostPre, Hebbian, WeightDependentPostPre, MSTDP, MSTDPET

## 4. Learning Rules

### 4a. PostPre

Simple STDP rule involving both pre- and post-synaptic spiking activity. By default, pre-synaptic update is negative and the post-synaptic update is positive.

| Parameters   | Type                                    | Description                                                                               | Default Value |
|--------------|-----------------------------------------|-------------------------------------------------------------------------------------------|---------------|
| connection   | AbstractConnection                      | An ``AbstractConnection`` object whose weights the ``PostPre`` learning rule will modify. |               |
| nu           | Optional[Union[float, Sequence[float]]] | Single or pair of learning rates for pre- and post-synaptic events.                       | None          |
| reduction    | Optional[callable]                      | Method for reducing parameter updates along the batch                                     | None          |
| weight_decay | float                                   | Constant multiple to decay weights by on each iteration.                                  | 0.0           |

In [None]:
# initialize network
network = Network()

# set number of neurons
input_neurons = 1
lif_neurons = 1

# simulation time
time = 100
dt = 1

# configure weights for the synapses between the input layer and LIF layer
w = torch.round(torch.abs(2 * torch.randn(input_neurons, lif_neurons)))

# initialize input and LIF layers
input_layer = Input(n=input_neurons)
lif_layer = LIFNodes(n=lif_neurons)

# connection between the input layer and the LIF layer
connection = Connection(
    source=input_layer, target=lif_layer, w=w, update_rule=PostPre, nu=(1e-4, 1e-2)
)

# create a monitor
lif_layer_monitor = Monitor(
    obj=lif_layer,
    state_vars=("s", "v"),  # Record spikes and voltages.
    time=time,  # Length of simulation (if known ahead of time).
)

# add layers to network
network.add_layer(
    layer=input_layer, name="Input Layer"
)
network.add_layer(
    layer=lif_layer, name="LIF Layer"
)

# add connection to network
network.add_connection(
    connection=connection, source="Input Layer", target="LIF Layer"
)

# add monitor to the network
network.add_monitor(monitor=lif_layer_monitor, name="LIF Layer")

# create input spike data, where each spike is distributed according to Bernoulli(0.1)
input_data = torch.bernoulli(0.1 * torch.ones(time, input_layer.n)).byte()
encoded_image = input_data
inputs = {"Input Layer": input_data}

# simulate network on input data
network.run(inputs=inputs, time=time)

# retrieve and plot simulation spike, voltage data from monitors
spikes = {"LIF Layer": lif_layer_monitor.get("s")}
voltages = {"LIF Layer": lif_layer_monitor.get("v")}

# plot spikes and voltages of the LIF layer
# TODO: plot axes
plot_spikes(spikes)
plot_voltages(voltages, plot_type="line")
# plot_weights(w)

# plot image
# TODO: use a standard input image and encode it
# e_img = encoded_image.view(int(time / dt), 1, 1, input_layer.n, 1)
# inpt = e_img.view(int(time / dt), input_layer.n).sum(0).view(input_layer.n, 1)
# plot_input(input_data,inpt)

plt.show()

## 5. Custom Learning Rules

Custom learning rules can be implemented by subclassing `bindsnet.learning.LearningRule` and providing implementations for the types of `AbstractConnection` objects intended to be used. 

For example, the `Connection` and `LocalConnection` objects rely on the implementation of a private method, `_connection_update`, whereas the `Conv2dConnection` object uses the `_conv2d_connection_update` version.