# Notebook description

This notebook is for implementing and studying the QIF neuron

In [None]:
import numpy as np
from numpy.random import lognormal
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class QIF_heun(QIF):
    """A spiking quadratic integrate-and-fire neuron model"""   
    def step_math(self, dt, J, spiked, voltage):
        """Integrate the QIF model one step forward in time using Heun's method

        Parameters
        ----------
        dt : float
            simulation time step
        J : array of floats
            input current to neurons
        spiked : array of 0 or 1/dt's
            indicates which neurons has spiked during time step
        voltage : array of floats
            voltage of neurons
        """
        ########### Your code here ###########
        vv = voltage
        dV = dt / self.tau * (-vv + self.km*vv**2/2 + self.kbg*self.ibg + self.kin*J)
        v = vv + dV
        dV2 = dt / self.tau * (-v + self.km*v**2/2 + self.kbg*self.ibg + self.kin*J)
        dVavg = (dV+dV2)/2.
        v = vv + dVavg
        
        spiked[:] = np.where(v > QIFRate.threshold, 1.0/dt, 0.0)

        v[spiked > 0] = 0.
        voltage[:] = v

# Register the QIF_heun model with Nengo
# You should not have to modify this
@Builder.register(QIF_heun)
def build_qif_heun(model, qif_heun, neurons):
    model.sig[neurons]['voltage'] = Signal(
        np.zeros(neurons.size_in), name="%s.voltage" % neurons)
    model.add_op(SimNeurons(
        neurons=qif_heun,
        J=model.sig[neurons]['in'],
        output=model.sig[neurons]['out'],
        states=[model.sig[neurons]['voltage']]))

In [None]:
############### Your Code Here ##############
idx = 2
dts = [.01, .001, .0001]
T = .2
net = nengo.Network()
with net:
    stim = nengo.Node(1.)
    neuron_type = QIF(
        ptau_list[idx] * tau, a_list[idx], b_list[idx], c_list[idx], ibg)
    ens = nengo.Ensemble(1, 1, neuron_type=neuron_type,
                       encoders=np.ones((1, 1)))
    nengo.Connection(stim, ens, synapse=0)
    probe_v = nengo.Probe(ens.neurons, 'voltage')
fig, axs = plt.subplots(ncols=3, figsize=(15,4))
for idx, dt in enumerate(dts):
    sim = nengo.Simulator(net, dt=dt)
    sim.run(T, progress_bar=False)
    t = sim.trange()
    axs[idx].plot(sim.trange(), sim.data[probe_v])
    axs[idx].set_ylabel('voltage', fontsize=14)
    axs[idx].set_xlabel(r'$t$', fontsize=20)

In [None]:
def check_tuning_curve_accuracy(neuron_model):
    """Takes in a neuron model (e.g. QIF) and checks the accuracy"""
############### Your Code Here ##############
    xs = np.linspace(-1., 1., 30)
    N = 5
    idx = np.random.choice(np.arange(num_neurons), N, replace=False)
    T = 1.
    dt = .0001

    net = nengo.Network()
    with net:
        stim = nengo.Node(0.)
        encoders = np.random.choice([-1, 1], N, replace=True).reshape((N, 1))
        neuron_type = neuron_model(ptau_list[idx] * tau, a_list[idx], b_list[idx], c_list[idx], ibg)
        ens = nengo.Ensemble(N, 1, encoders=encoders, neuron_type=neuron_type)
        nengo.Connection(stim, ens, synapse=0.)
        probe_s = nengo.Probe(ens.neurons, 'spikes')
        probe_v = nengo.Probe(ens.neurons, 'voltage')

    a_hat = np.zeros((len(xs), N))
    for x_idx, x in enumerate(xs):
        stim.output = x
        sim = nengo.Simulator(net, dt=dt)
        sim.run(T, progress_bar=False)
        t = sim.trange()
        spike_dat = sim.data[probe_s]
        for n in xrange(N):
            spk_times = t[np.nonzero(spike_dat[:, n])[0]]
            if len(spk_times) == 0:
                a_hat[x_idx, n] = 0
            elif len(spk_times) == 1:
                a_hat[x_idx, n] = 1./T
            else:
                isi = np.diff(spk_times)
                a_hat[x_idx, n] = 1./np.mean(isi)

    sim = nengo.Simulator(net)
    _, a = tuning_curves(ens, sim, inputs=xs.reshape((-1,1)))

    a_error = a - a_hat

    cc = ['r', 'g', 'b', 'c', 'm']
    fig, axs = plt.subplots(ncols=2, figsize=(15,5))
    axs[0].set_color_cycle(cc)
    axs[0].plot(xs, a)
    axs[0].plot(xs, a_hat, 'o', ms=6)
    axs[0].set_xlabel(r'$x$', fontsize=20)
    axs[0].set_ylabel(r'$a,\ \hat{a}$', fontsize=20)
    axs[1].set_color_cycle(cc)
    axs[1].plot(xs, a_error)
    axs[1].set_xlabel(r'$x$', fontsize=20)
    axs[1].set_ylabel(r'$a-\hat{a}$', fontsize=20);