In [None]:
import numpy as np

import nengo
from nengo.params import Default
from nengo.dists import Uniform
from nengo.solvers import LstsqL2, NoSolver

from nengolib import Lowpass, DoubleExp
from nengolib.signal import s, z, nrmse, LinearSystem

from train import dh_hyperopt, gbopt, d_lstsq
from neuron_models import AdaptiveLIFT, WilsonEuler, DurstewitzNeuron, reset_neuron

import neuron

import warnings

import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(context='poster', style='white')

In [None]:
def go(n_neurons=100, n_neurons_pre=100, t=10, max_rates=Uniform(20, 40), intercepts=Uniform(-1, 1),
       stim_func=lambda t: np.sin(t), seed=1, dt=0.000025, h_tar=Lowpass(0.1), reg=1e-1, gain=None, bias=None, gb_iter=10):

    with nengo.Network(seed=seed) as model:
                    
        # Ensembles
        u = nengo.Node(stim_func)
        pre = nengo.Ensemble(n_neurons_pre, 1, max_rates=max_rates, seed=seed, label='pre')
        nef = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=nengo.LIF(), seed=seed, label='lif')
        lif = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=nengo.LIF(), seed=seed, label='lif')
        alif = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=AdaptiveLIFT(tau_adapt=0.1, inc_adapt=0.1), seed=seed, label='alif')
        wilson = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=WilsonEuler(), seed=seed, label='wilson')
        durstewitz = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=DurstewitzNeuron(), seed=seed, label='durstewitz')
        tar = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())

        # Connections
        nengo.Connection(u, pre, synapse=None, seed=seed)
        nengo.Connection(u, tar, synapse=h_tar, seed=seed)
        nengo.Connection(pre, nef, synapse=h_tar, solver=LstsqL2(reg=reg), seed=seed, label='pre_nef')
        nengo.Connection(pre, lif, synapse=h_tar, solver=LstsqL2(reg=reg), seed=seed, label='pre_lif')
        nengo.Connection(pre, alif, synapse=h_tar, solver=LstsqL2(reg=reg), seed=seed, label='pre_alif')
        nengo.Connection(pre, wilson, synapse=h_tar, solver=LstsqL2(reg=reg), seed=seed, label='pre_wilson')
        pre_durstewitz = nengo.Connection(pre, durstewitz, synapse=h_tar, solver=LstsqL2(reg=reg), seed=seed, label='pre_durstewitz')

        # Probes
        p_u = nengo.Probe(u, synapse=None)
        p_tar = nengo.Probe(tar, synapse=None)
        p_nef = nengo.Probe(nef, synapse=h_tar)
        p_lif = nengo.Probe(lif.neurons, synapse=None)
        p_alif = nengo.Probe(alif.neurons, synapse=None)
        p_wilson = nengo.Probe(wilson.neurons, synapse=None)
        p_durstewitz = nengo.Probe(durstewitz.neurons, synapse=None)

    if gb_iter:
        gain, bias = gbopt(pre_durstewitz, gb_iter=gb_iter, h_tar=h_tar, stim_func=stim_func, pt=False)
    if np.any(gain):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            pre_durstewitz.gain = gain
            pre_durstewitz.bias = bias

    with nengo.Simulator(model, seed=seed, dt=dt) as sim:
        neuron.h.init()
        sim.run(t)
        reset_neuron(sim) 
        
    return dict(
        times=sim.trange(),
        u=sim.data[p_u],
        tar=sim.data[p_tar],
        nef=sim.data[p_nef],
        lif=sim.data[p_lif],
        alif=sim.data[p_alif],
        wilson=sim.data[p_wilson],
        durstewitz=sim.data[p_durstewitz],
        gain=gain,
        bias=bias)

In [None]:
def trials(n_neurons=100, t=10, h_tar=Lowpass(0.1), dt=0.000025, n_trials=10, gb_iter=10, h_iter=100, order=1, stim_func = lambda t: np.sin(t)):

    data = go(n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gb_iter=gb_iter, stim_func=stim_func)

    if h_iter:
        print('optimizing filters and decoders')
        d_lif, h_lif  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['lif'], order=order, h_iter=h_iter, dt=dt, name='feedforward_lif')
        d_alif, h_alif  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['alif'], order=order, h_iter=h_iter, dt=dt, name='feedforward_alif')
        d_wilson, h_wilson  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['wilson'], order=order, h_iter=h_iter, dt=dt, name='feedforward_wilson')
        d_durstewitz, h_durstewitz  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['durstewitz'], order=order, h_iter=h_iter, dt=dt, name='feedforward_durstewitz')
        times = np.arange(0, 1, 0.0001)
        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(times, h_tar.impulse(len(times), dt=0.0001), label="h_tar")
        ax.plot(times, h_lif.impulse(len(times), dt=0.0001), label="h_lif")
        ax.plot(times, h_alif.impulse(len(times), dt=0.0001), label="h_alif")
        ax.plot(times, h_wilson.impulse(len(times), dt=0.0001), label="h_wilson")
        ax.plot(times, h_durstewitz.impulse(len(times), dt=0.0001), label="h_durstewitz")
        ax.set(xlabel='time (seconds)', ylabel='impulse response', ylim=((0, 10)))
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.show()
    else:
        h_lif, h_alif, h_wilson, h_durstewitz = h_tar, h_tar, h_tar, h_tar
        d_lif = d_lstsq(data['tar'], data['lif'], h_lif, h_tar, reg=1e-1, dt=dt)
        d_alif = d_lstsq(data['tar'], data['alif'], h_alif, h_tar, reg=1e-1, dt=dt)
        d_wilson = d_lstsq(data['tar'], data['wilson'], h_wilson, h_tar, reg=1e-1, dt=dt)
        d_durstewitz = d_lstsq(data['tar'], data['durstewitz'], h_durstewitz, h_tar, reg=1e-1, dt=dt)

    print('running experimental trials')
    nrmses = np.zeros((5, n_trials))
    for trial in range(n_trials):
        print('trial %s' %trial)
        stim_func = nengo.processes.WhiteSignal(period=t, high=1, rms=0.5, seed=trial)
        data = go(n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gain=data['gain'], bias=data['bias'], gb_iter=0, stim_func=stim_func)

        a_lif = h_lif.filt(data['lif'], dt=dt)
        a_alif = h_alif.filt(data['alif'], dt=dt)
        a_wilson = h_wilson.filt(data['wilson'], dt=dt)
        a_durstewitz = h_durstewitz.filt(data['durstewitz'], dt=dt)
        target = h_tar.filt(data['tar'], dt=dt)
        xhat_nef = data['nef']
        xhat_lif = np.dot(a_lif, d_lif)
        xhat_alif = np.dot(a_alif, d_alif)
        xhat_wilson = np.dot(a_wilson, d_wilson)
        xhat_durstewitz = np.dot(a_durstewitz, d_durstewitz)
        nrmses[0, trial] = nrmse(xhat_nef, target=target)
        nrmses[1, trial] = nrmse(xhat_lif, target=target)
        nrmses[2, trial] = nrmse(xhat_alif, target=target)
        nrmses[3, trial] = nrmse(xhat_wilson, target=target)
        nrmses[4, trial] = nrmse(xhat_durstewitz, target=target)
        
        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(data['times'], target, linestyle="--", label='target')
        ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmses[0, trial])
        ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmses[1, trial])
        ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmses[2, trial])
        ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmses[3, trial])
        ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmses[4, trial])
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="test %s"%trial)
        plt.legend(loc='upper right')
        plt.show()
            
    if n_trials > 1:
        nt_names =  ['LIF\n(static)', 'LIF\n(temporal)', 'ALIF', 'Wilson', 'Durstewitz']
        fig, ax = plt.subplots(1, 1, figsize=(14, 6))
        sns.barplot(data=nrmses.T)
        ax.set(ylabel='NRMSE')
        plt.xticks(np.arange(len(nt_names)), tuple(nt_names), rotation=0)
        plt.show()
 
    return nrmses

In [None]:
nrmses = trials(order=1)
np.savez('data/nrmses_feedforward_order1.npz', nrmses=nrmses)

In [None]:
nrmses = np.load('data/nrmses_feedforward_order1.npz')['nrmses']
for nt in range(nrmses.shape[0]):
    print(sns.utils.ci(nrmses[nt]))

In [None]:
nrmses2 = trials(order=2)
np.savez('data/nrmses_feedforward_order2.npz', nrmses=nrmses2)

In [None]:
nrmses2 = np.load('data/nrmses_feedforward_order2.npz')['nrmses']
for nt in range(nrmses2.shape[0]):
    print(sns.utils.ci(nrmses2[nt]))

In [None]:
def db_train(n_neurons=5, t=10, h_tar=Lowpass(0.1), dt=0.001, gb_iter=0, h_iter=100, order=1, seed=0, stim_func=lambda t: np.sin(t)):

    data = go(n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gb_iter=gb_iter, seed=seed, stim_func=stim_func)

    d_lif, h_lif  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['lif'], order=order, h_iter=h_iter, dt=dt, name='feedforward_lif')
    d_durstewitz, h_durstewitz  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['durstewitz'], order=order, h_iter=h_iter, dt=dt, name='feedforward_durstewitz')
    a_lif = h_lif.filt(data['lif'], dt=dt)
    a_durstewitz = h_durstewitz.filt(data['durstewitz'], dt=dt)
    
    cmap = sns.color_palette()
    fig, ax = plt.subplots(figsize=(8, 8))

    from utils import bin_activities_values_single
    for n in range(n_neurons):
        lif_bins, lif_means, lif_stds = bin_activities_values_single(data['u'][:,0], a_lif[:, n], bins=20)
        ens_bins, ens_means, ens_stds = bin_activities_values_single(data['u'][:,0], a_durstewitz[:, n], bins=20)
        ax.plot(ens_bins, ens_means, c=cmap[n])
        ax.fill_between(ens_bins, ens_means+ens_stds, ens_means-ens_stds, alpha=0.25, color=cmap[n])
        ax.plot(lif_bins, lif_means, linestyle='--', c=cmap[n])

    ax.set(xlim=((-1, 1)), ylim=((0, 40)),
    xlabel='$\mathbf{x}$', ylabel='a (Hz)')
    plt.tight_layout()
    plt.show()

In [None]:
db_train(t=100, seed=1, gb_iter=0)
db_train(t=100, seed=1, gb_iter=10)