# Analysis of IAF and LIF neuron gradients in sinabs

This notebook examines the gradients of an IAF or LIF layer in sinabs mathematically and provides examples.

## Forward computation
Let's start by looking at the computations that happen during a `forward` call. For simplicity, but without loss of generality we will look at one single neuron only.
At each time step we have an external input, a state and an activation. For IAF neuron dynamics, the state $v_t$ at the current time step is the sum of the state from the previous time step $v_{t-1}$ and the input $I_t$, minus the previous activation $a_{t-1}$ times a subtraction constant $c$. The activation indicates by how many times the current state exceeds a pre-defined threshold $\theta$. Usually $c = \theta$. In case of LIF neurons, $v_{t-1}$ is multiplied by a decay factor $\alpha \in (0, 1)$. For a given membrane time constant $\tau$, $\alpha = \exp(\frac{-\Delta}{\tau})$, with $\Delta$ the simulation time step. By assuming $\alpha \equiv 1$ in the IAF case, we obtain a set of equations that works with both types of neuron dynamics.


$$
v_t = \alpha \cdot v_{t-1} + I_t - c \cdot a_{t-1} \\
a_t = \big( \lfloor \frac{v_t}{\theta} \rfloor \big)_+ \\
$$
where $\lfloor . \rfloor$ rounds to the next lowest integer and $\big( x \big)_+$ is the maximum of $x$ and $0$. We assume initial conditions $v_0 = a_0 = 0$.


## Backward computation
The derivatives of the states with respect to the inputs and directly preceding activations, as well partial derivatives with respect to directly preceeding states are easy to determine:
$$
\frac{d v_t}{d I_t} = 1 \\
\frac{d v_t}{d a_{t-1}} = - c \\
\frac{\partial v_t}{\partial v_{t-1}} = \alpha
$$

The last equation only gives a partial derivative because $v_t$ does not only depend on $v_{t-1}$ directly through the integration term, but also indirectly through $a_{t-1}$, which is a function of $v_{t-1}$.

To determine the derivative of $a_t$ with respect to $v_t$, we note that the activations don't have a well defined derivative. We can define a surrogate gradient, which we will for now call $s$:
$$
\frac{\partial a_t}{\partial v_t} := s_t := s(v_t)
$$

Now we can calculate the following total derivative:

$$
\frac{d v_t}{d v_{t-1}} = \frac{\partial v_t}{\partial v_{t-1}} + \frac{\partial v_t}{\partial a_{t-1}} \cdot \frac{\partial a_{t-1}}{\partial v_{t-1}} \\
= \alpha - c \cdot s_{t-1}
$$

With this, we can easily find the derivatives of our activations at any given timestep $a_T$ (e.g. the final one) with respect to any previous input $I_{T-n}$. This is helpful, because usually the outputs of our SNN models are functions of $a$. 
The input could be an external input or come from a hidden layer, in which case we could backpropagate through these.

For $n \geq 1$ we get:

$$
\frac{d a_T}{d I_{T-n}} = \frac{d a_T}{d v_T} \cdot \frac{d v_T}{d I_{T-n}}\\
= \frac{d a_T}{d v_T} \cdot \Big( \prod_{i=1}^{n} \frac{ d v_{T-(i-1)} } {d v_{T-i}} \Big) \cdot \frac{d v_{T-n}}{d I_{T-n}} \\
= s_T \cdot \Big( \prod_{i=1}^{n} \big( \alpha - c \cdot s_{T-i} \big) \Big) \cdot 1
$$

Furthermore, for $n = 0$:

$$
\frac{d a_T}{d I_{T}} = \frac{d a_T}{d v_T} \cdot \frac{d v_T}{d I_{T}} \\
= s_T \cdot 1 
$$

The last two equations can also be rewritten as:

<div class="alert alert-success">
$$
\text{For } T > t \text{: } \frac{d a_T}{d I_t} = s_T \cdot \Big( \prod_{i=t}^{T-1} \big( \alpha - c \cdot s_{i} \big) \Big) \\
\text{For } T = t \text{: } \frac{d a_T}{d I_T} = s_T \\
\text{For } T < t \text{: } \frac{d a_T}{d I_t} = 0 \\
$$
</div>

## Example

We can verify this by running a small simulation and plugging in the required values into the equation.
Let's assume we have only 1 neuron, which we evolve over 5 time steps with some random input between 0 and 1, with $\theta = 0.9, c = 0.8$ and $w$ the same as $\theta$ (which is the default).

As surrogate gradient we will choose a simple step function: It is $\frac{1}{\theta}$ whenever the state is greater than $\theta - w$ and $0$ otherwise. Here $w > 0$ is some constant, that we will call the window. To distinguish between these two cases Weintroduce the characteristic (boolean) function $\chi_w(x)$, which is $1$ if $x > w$ and $0$ otherwise. Therefore,
$$
\frac{\partial a_t}{\partial v_t} = \frac{1}{\theta} \cdot \chi_w(v_t) \\
\frac{d a_T}{d I_{T}} = \frac{1}{\theta} \cdot \chi_w(v_T)
$$


In [20]:
from sinabs.layers import IAF
import torch

inp = torch.rand((1, 5, 1, 1, 1), requires_grad=True)
lyr = IAF(threshold=0.9, threshold_low=None, membrane_subtract=0.8)

out = lyr(inp)

print(inp.flatten())
print(out.flatten())

tensor([0.1431, 0.1943, 0.3937, 0.7224, 0.3122], grad_fn=<ViewBackward>)
tensor([0., 0., 0., 1., 1.], grad_fn=<ViewBackward>)


We will calculate the gradients of the final output value with respect to the five input values.
First of all, the example has been chosen such that $v_t > \theta - w$ for all $t$, such that $\chi$ is always 1.

We can therefore calculate the gradients with the following equation:
$$
\frac{d a_5}{d I_{5}} = \frac{1}{\theta} = 1.\overline{1} \\
\frac{d a_5}{d I_{T-n}} = \frac{1}{\theta} \cdot \Big( \prod_{i=1}^{n} \big( 1 - \frac{c}{\theta} \big) \Big) = \frac{1}{\theta} \cdot \big( 1 - \frac{c}{\theta} \big)^n = 1.\overline{1} \cdot (0.\overline{1})^n
$$
Which gives the following results:
$$
\frac{d a_5}{d I_{1}} = 0.00017 \\
\frac{d a_5}{d I_{2}} = 0.00152 \\
\frac{d a_5}{d I_{3}} = 0.01372 \\
\frac{d a_5}{d I_{4}} = 0.12346 \\
\frac{d a_5}{d I_{5}} = 1.11111 \\
$$

Let's calculate the gradients with `torch` and compare:

In [21]:
out.flatten()[-1].backward()

print(inp.grad.flatten())

ts tensor(4)
data tensor([[[[0.9658]]]])
out tensor([[[[1.]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[1.1111]]]])
ts tensor(3)
data tensor([[[[1.4536]]]])
out tensor([[[[-0.8889]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[-0.9877]]]])
ts tensor(2)
data tensor([[[[0.7311]]]])
out tensor([[[[-0.0988]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[-0.1097]]]])
ts tensor(1)
data tensor([[[[0.3374]]]])
out tensor([[[[-0.0110]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[-0.0122]]]])
ts tensor(0)
data tensor([[[[0.1431]]]])
out tensor([[[[-0.0012]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[-0.0014]]]])
tensor([1.6935e-04, 1.5242e-03, 1.3717e-02, 1.2346e-01, 1.1111e+00])


Looks about right. Note two things here:
First, in this simple example, the gradients are independent of the input $I$, as long as the states remain above $\theta - w$. In more complex scenarios this will not be the case anymore.
Second, the gradients decay towards the past. The reason is that we chose $c$ and $\theta$ such that $(0 < 1 - \frac{c}{\theta} < 1)$. Usually, however, $c = \theta$ , such that $1 - \frac{c}{\theta} = 0$. This means, gradients with respect to past inputs are 0 most of the time.

In [22]:
inp = inp.clone().detach().requires_grad_(True)
lyr = IAF(threshold=0.9, threshold_low=None, subtract=0.9)

out = lyr(inp)

out.flatten()[-1].backward()
print(inp.grad.flatten())

ts tensor(4)
data tensor([[[[0.8658]]]])
out tensor([[[[1.]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[1.1111]]]])
ts tensor(3)
data tensor([[[[1.4536]]]])
out tensor([[[[-1.]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[-1.1111]]]])
ts tensor(2)
data tensor([[[[0.7311]]]])
out tensor([[[[0.]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[0.]]]])
ts tensor(1)
data tensor([[[[0.3374]]]])
out tensor([[[[0.]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[0.]]]])
ts tensor(0)
data tensor([[[[0.1431]]]])
out tensor([[[[0.]]]])
grad tensor([[[[1.1111]]]])
in tensor([[[[0.]]]])
tensor([0.0000, 0.0000, 0.0000, 0.0000, 1.1111])


## Lower threshold

Sinabs supports setting a lower bound $\theta_{low}$ on the membrane potential. $v$ and $a$ then evolve as follows:
$$
\tilde{v}_t = \alpha \cdot v_{t-1} + I_t - c \cdot a_{t-1} \\
v_t = max(\tilde{v}, \theta_{low})) \\
a_t = \big( \lfloor \frac{v_t}{\theta} \rfloor \big)_+
$$

This introduces a new derivative term, $\frac{d v_{t}}{d \tilde{v}_{t}}$, which is $1$ whenever $\tilde{v}_t$ is strictly greater than $\theta_{low}$ and $0$ otherwise. We will express this relation again by a characteristic function:
$$
\frac{d v_{t}}{d \tilde{v}_{t}} = \chi_{\theta_{low}}(\tilde{v}_t)
$$

<div class="alert alert-info">
To be precise, from a mathematical perspective $\chi_{\theta_{low}}$ is undefined for $\tilde{v}_t = \theta_{low}$, but in machine learning implementations the convention is to set it to $0$ in this case, allowing for sparser computation.
</div>

We therefore get the following, slightly modified total derivatives, which are gated by $\chi_{\theta_{low}}$:
$$
\frac{d v_t}{d I_t} = \chi_{\theta_{low}}(\tilde{v}_t) \\
\frac{d v_t}{d a_{t-1}} = - c \cdot \chi_{\theta_{low}}(\tilde{v}_t) \\
\frac{d v_t}{d v_{t-1}} = (\alpha - c \cdot s_{t-1} ) \cdot \chi_{\theta_{low}}(\tilde{v}_t)
$$

And therefore, for the activations:

<div class="alert alert-success">
$$
\frac{d a_T}{d I_{T}} = \frac{d a_T}{d v_T} \cdot \frac{d v_T}{d I_{T}} = s_T \cdot \chi_{\theta_{low}}(\tilde{v}_T)
$$

For $T > t$:
$$
\frac{d a_T}{d I_{t}} = \frac{d a_T}{d v_T} \cdot \Big( \prod_{i=t}^{T-1} \frac{ d v_{i+1)} } {d v_{i}} \Big) \cdot \frac{d v_{t}}{d I_{t}} \\
= s_T \cdot \Big( \prod_{i=t}^{T-1} \big( \alpha - c \cdot s_{i} \big) \cdot \chi_{\theta_{low}}(\tilde{v}_{i+1}) \Big) \cdot \chi_{\theta_{low}}(\tilde{v}_t) \\
= s_T \cdot \chi_{\theta_{low}}(\tilde{v}_T) \cdot \Big( \prod_{i=t}^{T-1} \big( \alpha - c \cdot s_{i} \big) \cdot \chi_{\theta_{low}}(\tilde{v}_{i}) \Big)
$$
</div>

## A few notes on computation

### Backpropagation
Given a (vector valued) function $F(x)$, its derivative with respect to $x$ can be described in a Jacobian matrix $JF(x)$, with elements
$$
\big( J_F(x) \big)_{i,j} = \frac{\partial F_i}{\partial x_j} \Big|_{x}
$$

For compositions of functions, e.g. $H(x) := G \circ F (x)$, the chain rule states, that we can multiply the individual functions' Jacobians to get the Jacobian of the composed function:
$$
J_{H}(x) = J_{G(F(x))} \cdot J_{F(x)}
$$
This can be extended for compositions of arbitrary lengths.

When working with neural networks, we are looking at such a composition of functions, which are mostly vector valued (layer activations as functions of layer inputs and parameters). The loss however, which is the final function in the composition, is a scalar. Therefore the derivatives of the loss form a gradient vector $v^T_{G}$ rather than a matrix $J_{G}$. Assuming a one-layer network, we can describe the layer outputs as some function $F$ of the layer parameters $x$ and the loss as a function $G$ of the layer outputs. (We don't care about the layer inputs for now.) In order to train our network we need to know the derivatives of the loss with respect to the layer parameters $x$ -- our gradients.
This corresponds to $J_{H}$ in the equation above. Because $v^T_{G}$ is a (row) vector, $J_{H}$ will be so, as well. We will call it $v^T_{H}$. Transposing the equation from above (and omitting the function arguments) gives:
$$
v_{H}(x) = \big(v^T_{G)} \cdot J_{F})\big)^T = J_{F}^T \cdot v_G
$$

We now have the gradient vector $v_{H}$. If we had multiple layers, and therefore a composition of more functions, we can get the gradient vectors for parameters deeper in the network by successively multiplying the resulting gradient vector with the transposed Jacobian of the next layer. This is exactly what happens during backpropagation of gradients.

### Backpropagation in pyTorch
In pyTorch, we can define autograd-functions, which do the backward pass for us. All standard pyTorch operations come with the backward pass already defined. For custom functions (e.g. the forward method of a spiking sinabs layer), we need to define the backward pass ourselves. When doing the backward computation, this function will get as argument the gradient vector, which would be $v_G$ in the equation above. It is generally refered to as "output gradient". The job of our backward function is now, to compute $v_{H}$, the "input gradient".

We do so by computing the complete Jacobian, transposing it and multiplying it with the output gradient. Here, the jacobian consists of the derivatives of the layer activations $a$ with respect to the layer input $I$, $\frac{d a_T}{d I_t}$. The diagonal (where $T = t$) is simply given by the surrogate gradients $s$ (and the gating $\chi_{\theta_{low}}$ if there is a lower bound to $v$). Because for $t > T$ the derivatives are $0$, the transposed Jacobian will be an upper triangular matrix, which will save us some compuation. In order to compute the remaining derivatives, where $t < T$, we first note that even though we have a recursive equation,
$$
\frac{d a_T}{d I_{t}} = s_T \cdot \Big( \prod_{i=t}^{T-1} \big( 1 - c \cdot s_{i} \big) \cdot \chi_{\theta_{low}}(\tilde{v}_{i+1}) \Big) \cdot \chi_{\theta_{low}}(\tilde{v}_t)
$$
derivatives with different denominators $d I_t$ are completely independent from another. Therefore we can parallelize computation over them, along with parallelizing over neurons and batches. 

Apart from that we see that we can, for a given $I_t$, calculate the derivatives iteratively, starting with $\frac{d a_t}{d I_{t}}$, then $\frac{d a_{t+1}}{d I_{t}}$, and so on. This is how we do it in our implementation of sinabs-slayer for the IAF neurons, which turns out to be very efficient. We furthermore don't compute the complete Jacobian but rather its product with the output gradient directly.

Finally, there is one last remark for the case with lower-bounded membrane potentials. Here, we have to compute $\chi_{\theta_{low}}$ first, wich is a function of the intermediate states $\tilde{v}_t$. In theory we would have to store all the intermediate states for this. In order to save some GPU memory, however, we will use the actual states, which we are saving anyways. They work equally well under the following assumption:
If a state is greater than $\theta_{low}$, the intermediate state was so as well. On the other hand, if it is equal to $\theta_{low}$, this is either because the intermediate state was also of value $\theta_{low}$, or it was less, in which case it got clipped. Either way, we can infer the value of the gating function, which is either $1$ or $0$, respectively.

# Analysis in Slayer

Now let's look at the same scenario in SLAYER.
Here, the goal is to make computation more efficient by computing the forward pass as a temporal convolution over a tensor instead of iterating over individual time steps.

## Forward computation

### PSP
For a full forward pass, computation consists of two major steps: First the `psp_function`, which generates a post-synaptic potential (PSP) from the inputs by filtering them with a pre-defined psp kernel $\kappa$. Both synaptic and membrane dynamics are encoded in the PSP throught the kernel. In case of an IAF neuron, $\kappa$ would simply be a heaviside function. 

$$
PSP_t = (I * \kappa)_t = \sum_{\tau = 0}^t I_{\tau} \cdot \kappa_{t - \tau}
$$

### Spike mechanism and refractory response
The second step consists of iterating over the PSP in time to determine when the neuron spikes and to apply the refractory response to the PSP. We will refer to the PSP with refractory response applied as the neuron's membrane potential. In the original implementation of SLAYER, the PSP at each time step is compared to a firing threshold. If it is exceeded, an output spike will be registered and a refractory response kernel will be added to the subsequent time steps.

In sinabs-slayer, this mechanism has been slightly extended by two adaptations:
1. Multiple spikes per time step are possible, if the PSP exceeds the threshold by multiple times. The refractory response will then be applied the same number of times as there were spikes.
2. The membrane potential can be lower bounded. Each time the PSP is below some lower bound $\theta_{low}$, the difference to the bound will be added to the current and all subsequent time steps.

We can describe the resulting membrane potential $v$ and output spike train $a$ recursively again. We will refer to the intermediate membrane potential after the reset mechanism but before clipping to the lower bound as $\tilde{v}$:
$$
\Delta_t := max(0, \theta_{low} - \tilde{v}_t) \\
\gamma_t := \chi_{\theta_{low}}(\tilde{v}_t) \\
\tilde{v}_t = PSP_t + \sum_{\tau=0}^{t-1} \big( \Delta_{\tau} + \kappa_{t-\tau} \cdot a_{\tau} \big) \\
v_t = max(\theta_{low}, \tilde{v}_t) = \tilde{v}_t + \Delta_{t}
$$

Here we introduced the following terms:
$\Delta_t$ is the differnce between $\theta_{low}$ and $\tilde{v}_t$ if the intermediate membrane potential is below the lower bound.
The characteristic function $\gamma_t$ indicates whether at time $t$ the intermediate membrane potential is below the lower bound.
The kernel $\kappa$ is the refractory response that is added to the PSP after the neuron spikes.

The output spike train $a$ is computed the same way as before:
$$
a_t = \big( \lfloor \frac{v_t}{\theta} \rfloor \big)_+
$$

Note: For IAF neurons, this implementation, in particular with the lower bounded membrane potential, is slower than the standard sinabs forward call. Therefore sinabs-slayer uses a fully iterative forward function for IAF layers just like in sinabs. Here, the backward function is implemented as described in Section "Analysis of IAF neuron gradients in sinabs". In this section we will focus on the SLAYER based implementation, which is also used for the LIF layer of sinabs-slayer.

## Backward pass

### PSP
Let's find the derivative of the PSP at some time point $T$ wrt. the input at any given time point $t$:

$$
\frac{d PSP_T}{d I_t} = \frac{d}{d I_t} \sum_{\tau = 0}^T I_{\tau} \cdot \kappa_{T - \tau} \\
= \sum_{\tau = 0}^T \frac{dI_{\tau}}{d I_t} \cdot \kappa_{T - \tau} \\
= \sum_{\tau = 0}^T \delta_{t, \tau} \cdot \kappa_{T - \tau} \\
= \kappa_{T - t} \text{ if } t \leq T \text{ and } 0 \text{ otherwise }
$$

This is also how the backward function of the PSP is implemented in SLAYER.

With this we can calculate the corresponding Jacobian $J_{PSP}$ as $(J_{PSP})_{i,j} = \kappa{i - j}$ for $j \leq i$ and 0 otherwise. This results in a lower triangular matrix. 

When backpropagating gradients through the computational tree, we now have to apply the transpose $J^T$ to the preceding gradient tensor $D_{Out}$. Due to the structure of $J$ the result is the cross-correlation between $D_{Out}$ and $\kappa$.

<div class="alert alert-success">
$$
D_{PSP, t} = (J_{PSP} \cdot D_{Out})_t = \sum_{\tau=0}^{T-t} \kappa_t \cdot D_{out, t+\tau} = (\kappa \star D_{out})_t
$$
</div>

### Spike mechanism and refractory response
Because the spike function in SLAYER takes the PSP as argument and not the spiking input, we need to calculate the derivatives of the output with respect to the PSP, which is different to what is described in Section "Analysis of IAF neuron gradients in sinabs". (The composition of the backward functions of the PSP and the spike mechanism, however, should return the same gradients as before.

By differentiating the equations of the forward pass, we obtain the following derivatives:

$$
\frac{d \Delta_t}{d \tilde{v}_t} = (1 - \gamma_t) \cdot (- 1) = (\gamma_t - 1) \\
\frac{d \tilde{v}_t}{d PSP_t} = 1 \\
\frac{d v_t}{d \tilde{v}_t} = \gamma_t
$$

For the spike train we again define (arbitrary) surrogate gradients $s$:
$$
\frac{d a_t}{d v_t} := s_t
$$

Now we can find a recursive description for the derivative of $\tilde{v}$ and $a$ with respect to earlyer states of $\tilde{v}$:
$$
\frac{d a_T}{d \tilde{v}_t} = \frac{d a_T}{d v_T} \cdot \frac{d v_T}{d \tilde{v}_T} \cdot \frac{d \tilde{v}_T}{d \tilde{v}_t} = s_T \cdot \gamma_T \cdot \frac{d \tilde{v}_T}{d \tilde{v}_t} \\
\frac{d \tilde{v}_T}{d \tilde{v}_t} = \sum_{\tau=0}^{T-1} \big( \frac{d \Delta_{\tau}}{d \tilde{v}_t} + \kappa_{T-\tau} \cdot \frac{d a_{\tau}}{d \tilde{v}_t} \big) 
= \sum_{\tau=0}^{T-1} \big( (\gamma_{\tau} - 1) + \kappa_{T-\tau} \cdot s_{\tau} \cdot \gamma_{\tau} \big) \cdot \frac{d \tilde{v}_{\tau}}{d \tilde{v}_t}
$$

<div class="alert alert-success">
Therefore, for $T > t$:
$$
\frac{d a_T}{d PSP_t} = \frac{d a_T}{d \tilde{v}_t} \cdot \frac{d \tilde{v}_t}{d PSP_t}
= s_T \cdot \gamma_T \cdot \sum_{\tau=0}^{T-1} \big( (\gamma_{\tau} - 1) + \kappa_{T-\tau} \cdot s_{\tau} \cdot \gamma_{\tau} \big) \cdot \frac{d \tilde{v}_{\tau}}{d \tilde{v}_t} \cdot 1
$$
and
$$
\frac{d a_t}{d PSP_t} = s_t \cdot \gamma_t
$$

If the membrane potential is not lower bounded, we can set $\gamma = 1$, which simplifies the term inside the sum slightly:
$$
\frac{d a_T}{d PSP_t} = s_T \cdot \sum_{\tau=0}^{T-1} \kappa_{T-\tau} \cdot s_{\tau} \cdot \frac{d v_{\tau}}{d v_t}
$$
</div>

### Computation

In principle, similar considerations as in Section "Analysis of IAF neuron gradients in sinabs" also apply here. Again computation of $\frac{d a_i}{d PSP_j}$ can be parallelized over neurons and batches as well as denominators ${d PSP_j}$ of the derivatives.
However, because of the constant refractory response, the product over past gradients could be easily accumulated there. Now, with an arbitrary refractory kernel $\kappa$, building the sum in the equation for $\frac{d a_i}{d PSP_j}$ is less straightforward. We can still iteratively compute the derivatives for a given $j$, starting with $i=j$. However, now we need to store all the $\frac{d \tilde{v}_i}{d \tilde{v}_j}$ for future computations and for each $i$ have to convolve $\kappa$ with the previously stored terms. This is more costly both in terms of memory and computation.

Efficiency can be improved if $\kappa$ is constant. In this case the sum can be simply accumulated, without having to store and convolve over all intermediate results