# Choosing the right neuron model

In [None]:
import sinabs
import sinabs.layers as sl
import torch
import matplotlib.pyplot as plt

Let's start by creating a helpful plotting function and some constant current input.

In [None]:
def plot_evolution(neuron_model: sinabs.layers, input: torch.Tensor):
    time_steps = input.shape[1]

    neuron_model.reset_states()
    v_mem = []
    spikes = []
    i_syn = []
    spike_thresholds = []
    for step in range(time_steps):
        output = neuron_model(input[:, step])
        v_mem.append(neuron_model.v_mem.detach().numpy())
        if neuron_model.spike_threshold:
            spike_thresholds.append(
                torch.as_tensor(neuron_model.spike_threshold).detach().numpy()
            )
        if hasattr(neuron_model, "i_syn") and neuron_model.i_syn:
            i_syn.append(neuron_model.i_syn.detach().numpy())
        spikes.append(output.sum().detach().numpy())

    plt.plot(v_mem, drawstyle="steps-post", label="v_mem")
    if neuron_model.activation_fn:
        plt.plot(
            spike_thresholds,
            "--",
            label="spike threshold",
        )
        for step, spike in enumerate(spikes):
            if spike > 0:
                spike_time = step
                plt.axvline(x=spike_time, ymax=float(spike), color="black", linewidth=3)
        plt.axvline(
            x=spike_time, ymax=float(spike), label="spikes", color="black", linewidth=3
        )
    if len(i_syn) > 0:
        plt.plot(i_syn, drawstyle="steps-post", linewidth=3, label="i_syn", color="C6")
    plt.xlabel("time")
    plt.title(f"{neuron_model.__class__.__name__} neuron dynamics")
    plt.legend()


const_current = torch.ones((1, 100, 1)) * 0.03
single_current = torch.zeros((1, 100, 1))
single_current[:, 0] = 0.03

## Integrate and Fire neuron
This neuron has no leakage and simply integrates all the input it receives.

In [None]:
iaf_neuron = sl.IAF()
plot_evolution(iaf_neuron, const_current)

We can activate synaptic currents in this neuron model, which will integrate all inputs into its i_syn state, which will then be added to the membrane potential at every step. If we use a single current injection at the first step while also using the static synaptic currents, we essentially achieve the same result as with constant input current in the previous plot. 

In [None]:
iaf_neuron = sl.IAF(use_synaptic_state=True)
plot_evolution(iaf_neuron, single_current)

## Leaky Integrate and Fire neuron
This neuron integrates the input and decays its state at every time step. It emits a spike whenever the membrane potential is above the spike threshold

In [None]:
lif_neuron = sl.LIF(tau_mem=40.0, norm_input=False)
plot_evolution(lif_neuron, const_current)

By default, no synaptic dynamics are used. We can enable that by setting tau_syn. Note that instead of a constant current, we now inject current only at the first time step. 

In [None]:
lif_neuron = sl.LIF(tau_mem=40.0, tau_syn=30.0, norm_input=False)
plot_evolution(lif_neuron, single_current * 3)

## Leaky integrator neuron
Same as LIF, just without activation function.

In [None]:
exp_leak_neuron = sl.ExpLeak(tau_mem=60.0)
plot_evolution(exp_leak_neuron, const_current)

## Adaptive Leaky Integrate and Fire neuron
This is a LIF neuron with an adaptive threshold.

In [None]:
alif_neuron = sl.ALIF(tau_mem=40.0, tau_adapt=40.0, adapt_scale=20, norm_input=False)
plot_evolution(alif_neuron, const_current * 4)