# Choosing a neuron model

In [None]:
import sinabs
import sinabs.layers as sl
import torch
import matplotlib.pyplot as plt

Let's start by creating a helpful plotting function and some constant current input.

In [None]:
def plot_evolution(neuron_model: sinabs.layers, input: torch.Tensor):
    neuron_model.reset_states()
    v_mem = []
    spikes = []
    i_syn = []
    thresholds = []
    for step in range(input.shape[1]):
        output = neuron_model(input[:, step])
        v_mem.append(neuron_model.v_mem)
        spikes.append(output)
        if hasattr(neuron_model, "i_syn") and neuron_model.i_syn:
            i_syn.append(neuron_model.i_syn)
        if neuron_model.spike_threshold:
            thresholds.append(torch.as_tensor(neuron_model.spike_threshold))

    plt.figure(figsize=(10, 3))
    v_mem = torch.cat(v_mem).detach().numpy()
    plt.plot(v_mem, drawstyle="steps-post", label="v_mem")

    if neuron_model.spike_fn:
        spikes = torch.cat(spikes).detach().numpy()
        thresholds = torch.stack(thresholds).detach().numpy()
        plt.plot(thresholds, "--", label="spike threshold")
        plt.plot(spikes, label="output", drawstyle="steps", color="black")

    if len(i_syn) > 0:
        i_syn = torch.cat(i_syn).detach().numpy()
        plt.plot(i_syn, drawstyle="steps-post", linewidth=3, label="i_syn", color="C6")

    plt.xlabel("time")
    plt.title(f"{neuron_model.__class__.__name__} neuron dynamics")
    plt.legend()


const_current = torch.ones((1, 100, 1)) * 0.03
single_current = torch.zeros((1, 100, 1))
single_current[:, 0] = 0.1

## Integrate and Fire neuron
This neuron has no leakage and simply integrates all the input it receives. It emits a spike whenever the membrane potential is above the spike threshold.

In [None]:
iaf_neuron = sl.IAF()
plot_evolution(iaf_neuron, const_current)

We can activate synaptic currents in this neuron model by setting `tau_syn`. All inputs will be integrated to its i_syn state, which will then be decayed and added to the membrane potential at every step. In the following plot we only provide an input at the first time step.

In [None]:
iaf_neuron = sl.IAF(tau_syn=15.0)
plot_evolution(iaf_neuron, single_current)

## Leaky Integrate and Fire neuron
This neuron integrates the input and decays its state at every time step.

In [None]:
lif_neuron = sl.LIF(tau_mem=40.0, norm_input=False)
plot_evolution(lif_neuron, const_current)

By default, no synaptic dynamics are used. We can enable that by setting tau_syn. Note that instead of a constant current, we now provide input only at the first time step. 

In [None]:
lif_neuron = sl.LIF(tau_mem=40.0, tau_syn=30.0, norm_input=False)
plot_evolution(lif_neuron, single_current)

## Leaky Integrator neuron
Same as LIF, just without activation function.

In [None]:
exp_leak_neuron = sl.ExpLeak(tau_mem=60.0)
plot_evolution(exp_leak_neuron, const_current)

## Adaptive Leaky Integrate and Fire neuron
This is a LIF neuron with an adaptive threshold.

In [None]:
alif_neuron = sl.ALIF(tau_mem=40.0, tau_adapt=40.0, adapt_scale=20, norm_input=False)
plot_evolution(alif_neuron, const_current * 4)