In [None]:
import numpy as np

from scipy import signal

import nengo
from nengo.params import Default
from nengo.dists import Uniform
from nengo.solvers import LstsqL2, NoSolver

from nengolib import Lowpass, DoubleExp
from nengolib.signal import s, z, nrmse, LinearSystem

from train import dh_hyperopt, gbopt, d_lstsq, norms
from neuron_models import AdaptiveLIFT, WilsonEuler, DurstewitzNeuron, reset_neuron

import neuron

import warnings

import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(context='poster', style='white')

In [None]:
def go(fx, d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
        n_neurons=100, n_neurons_pre=100, t=10, max_rates=Uniform(20, 40), intercepts=Uniform(-1, 1), intercepts2=Uniform(-1, 0),
        stim_func=lambda t: np.sin(t), seed=0, dt=0.000025, h_tar=Lowpass(0.1), reg=1e-1, gain=None, bias=None, gb_iter=10, gain2=None, bias2=None, gb_iter2=10):

    solver_nef = LstsqL2(reg=reg)
    solver_lif = NoSolver(d_lif)
    solver_alif = NoSolver(d_alif)
    solver_wilson = NoSolver(d_wilson)
    solver_durstewitz = NoSolver(d_durstewitz)
#     norm, _ = norms(t, dt=dt, stim_func=stim_func)
    
    with nengo.Network(seed=0) as model:
                    
        # Ensembles
#         u_raw = nengo.Node(stim_func)
#         u = nengo.Ensemble(1, u_raw.size_out, neuron_type=nengo.Direct())
        u = nengo.Node(stim_func)
        pre = nengo.Ensemble(n_neurons_pre, 1, max_rates=max_rates, seed=0, label='pre')
        nef = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=nengo.LIF(), seed=0, label='nef')
        lif = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=nengo.LIF(), seed=0, label='lif')
        alif = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=AdaptiveLIFT(tau_adapt=0.1, inc_adapt=0.1), seed=0, label='alif')
        wilson = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=WilsonEuler(), seed=0, label='wilson')
        durstewitz = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=DurstewitzNeuron(), seed=0, label='durstewitz')
        nef2 = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts2, neuron_type=nengo.LIF(), seed=0, label='nef2')
        lif2 = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts2, neuron_type=nengo.LIF(), seed=0, label='lif2')
        alif2 = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts2, neuron_type=AdaptiveLIFT(tau_adapt=0.1, inc_adapt=0.1), seed=0, label='alif2')
        wilson2 = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts2, neuron_type=WilsonEuler(), seed=0, label='wilson2')
        durstewitz2 = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts2, neuron_type=DurstewitzNeuron(), seed=0, label='durstewitz2')
        tar = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
        tar2 = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())

        # Connections
#         nengo.Connection(u_raw, u, synapse=None, seed=0, transform=1.0/norm)
        nengo.Connection(u, pre, synapse=None, seed=0)
        nengo.Connection(u, tar, synapse=h_tar, seed=0, function=fx)
        nengo.Connection(tar, tar2, synapse=h_tar, seed=0)
        nengo.Connection(pre, nef, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_nef')
        nengo.Connection(pre, lif, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_lif')
        nengo.Connection(pre, alif, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_alif')
        nengo.Connection(pre, wilson, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_wilson')
        pre_durstewitz = nengo.Connection(pre, durstewitz, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_durstewitz')
        nengo.Connection(nef, nef2, synapse=h_tar, solver=solver_nef, function=fx, seed=0, label='nef_nef2')
        nengo.Connection(lif, lif2, synapse=h_lif, solver=solver_lif, seed=0, label='lif_lif2')
        nengo.Connection(alif, alif2, synapse=h_alif, solver=solver_alif, seed=0, label='alif_alif2')
        nengo.Connection(wilson, wilson2, synapse=h_wilson, solver=solver_wilson, seed=0, label='wilson_wilson2')
        durstewitz_durstewitz2 = nengo.Connection(durstewitz, durstewitz2, synapse=h_durstewitz, solver=solver_durstewitz, seed=0, label='durstewitz_durstewitz2')

        # Probes
        p_u = nengo.Probe(u, synapse=None)
        p_tar = nengo.Probe(tar, synapse=None)
        p_tar2 = nengo.Probe(tar2, synapse=None)
        p_nef = nengo.Probe(nef, synapse=h_tar)
        p_lif = nengo.Probe(lif.neurons, synapse=None)
        p_alif = nengo.Probe(alif.neurons, synapse=None)
        p_wilson = nengo.Probe(wilson.neurons, synapse=None)
        p_durstewitz = nengo.Probe(durstewitz.neurons, synapse=None)
        p_nef2 = nengo.Probe(nef2, synapse=h_tar)
        p_lif2 = nengo.Probe(lif2.neurons, synapse=None)
        p_alif2 = nengo.Probe(alif2.neurons, synapse=None)
        p_wilson2 = nengo.Probe(wilson2.neurons, synapse=None)
        p_durstewitz2 = nengo.Probe(durstewitz2.neurons, synapse=None)  

    if gb_iter:
        gain, bias = gbopt(pre_durstewitz, h_tar=h_tar, gb_iter=gb_iter, stim_func=stim_func, pt=False)
    if gb_iter2:
        gain2, bias2 = gbopt(durstewitz_durstewitz2, h_tar=h_tar, gb_iter=gb_iter2,
            delta_bias=5e-6, fx=fx, gain_pre=gain, bias_pre=bias, stim_func=stim_func, pt=False)
    if np.any(gain):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            pre_durstewitz.gain = gain
            pre_durstewitz.bias = bias
#             durstewitz_durstewitz2.gain = gain
#             durstewitz_durstewitz2.bias = bias
    if np.any(gain2):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            durstewitz_durstewitz2.gain = gain2
            durstewitz_durstewitz2.bias = bias2    

    with nengo.Simulator(model, seed=0, dt=dt) as sim:
        neuron.h.init()
        sim.run(t)
        reset_neuron(sim) 
        
    return dict(
        times=sim.trange(),
        u=sim.data[p_u],
        tar=sim.data[p_tar],
        tar2=sim.data[p_tar2],
        nef=sim.data[p_nef],
        lif=sim.data[p_lif],
        alif=sim.data[p_alif],
        wilson=sim.data[p_wilson],
        durstewitz=sim.data[p_durstewitz],
        nef2=sim.data[p_nef2],
        lif2=sim.data[p_lif2],
        alif2=sim.data[p_alif2],
        wilson2=sim.data[p_wilson2],
        durstewitz2=sim.data[p_durstewitz2],
        gain=gain,
        bias=bias,
        gain2=gain2,
        bias2=bias2)

In [None]:
def trials(fx, n_neurons=100, t_train=30, t=10, h_tar=Lowpass(0.1), dt=0.000025, n_trials=10, gb_iter=10, gb_iter2=10, h_iter=100, order=1):

    h_lif, h_alif, h_wilson, h_durstewitz = h_tar, h_tar, h_tar, h_tar
    d_lif, d_alif, d_wilson, d_durstewitz = np.zeros((n_neurons, 1)), np.zeros((n_neurons, 1)), np.zeros((n_neurons, 1)), np.zeros((n_neurons, 1))

    print('optimizing gain/bias and ens-ens2 filters/decoders')
#     stim_func = nengo.processes.WhiteSignal(period=t_train, high=1, rms=0.4, seed=0)
    stim_func = lambda t: np.sin(2*t)
#     stim_func = lambda t: np.sin(t*(0.5+t/15))
    data = go(fx, d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
#         n_neurons=n_neurons, t=t_train, h_tar=h_tar, dt=dt, gb_iter=gb_iter, gb_iter2=gb_iter2, stim_func=stim_func)
        n_neurons=n_neurons, t=t_train, h_tar=h_tar, dt=dt, gb_iter=gb_iter, gb_iter2=0, stim_func=stim_func)
    if h_iter:
        d_lif, h_lif  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['lif'], order=order, h_iter=h_iter, dt=dt)
        d_alif, h_alif  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['alif'], order=order, h_iter=h_iter, dt=dt)
        d_wilson, h_wilson  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['wilson'], order=order, h_iter=h_iter, dt=dt)
        d_durstewitz, h_durstewitz  = dh_hyperopt(h_tar.filt(data['tar'], dt=dt), data['durstewitz'], order=order, h_iter=h_iter, dt=dt)
        times = np.arange(0, 1, 0.0001)
        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(times, h_tar.impulse(len(times), dt=0.0001), label="h_tar")
        ax.plot(times, h_lif.impulse(len(times), dt=0.0001), label="h_lif, nonzero d: %s/%s"%(np.count_nonzero(d_lif), n_neurons))
        ax.plot(times, h_alif.impulse(len(times), dt=0.0001), label="h_alif, nonzero d: %s/%s"%(np.count_nonzero(d_alif), n_neurons))
        ax.plot(times, h_wilson.impulse(len(times), dt=0.0001), label="h_wilson, nonzero d: %s/%s"%(np.count_nonzero(d_wilson), n_neurons))
        ax.plot(times, h_durstewitz.impulse(len(times), dt=0.0001), label="h_durstewitz, nonzero d: %s/%s"%(np.count_nonzero(d_durstewitz), n_neurons))
        ax.set(xlabel='time (seconds)', ylabel='impulse response', ylim=((0, 10)))
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.show()
    else:
        h_lif, h_alif, h_wilson, h_durstewitz = h_tar, h_tar, h_tar, h_tar
        d_lif = d_lstsq(data['tar'], data['lif'], h_lif, h_tar, reg=1e-1, dt=dt)
        d_alif = d_lstsq(data['tar'], data['alif'], h_alif, h_tar, reg=1e-1, dt=dt)
        d_wilson = d_lstsq(data['tar'], data['wilson'], h_wilson, h_tar, reg=1e-1, dt=dt)
        d_durstewitz = d_lstsq(data['tar'], data['durstewitz'], h_durstewitz, h_tar, reg=1e-1, dt=dt)

    a_lif = h_lif.filt(data['lif'], dt=dt)
    a_alif = h_alif.filt(data['alif'], dt=dt)
    a_wilson = h_wilson.filt(data['wilson'], dt=dt)
    a_durstewitz = h_durstewitz.filt(data['durstewitz'], dt=dt)
    target = h_tar.filt(data['tar'], dt=dt)
    xhat_nef = data['nef']
    xhat_lif = np.dot(a_lif, d_lif)
    xhat_alif = np.dot(a_alif, d_alif)
    xhat_wilson = np.dot(a_wilson, d_wilson)
    xhat_durstewitz = np.dot(a_durstewitz, d_durstewitz)
    nrmses = np.zeros((5, 1))
    nrmses[0, 0] = nrmse(xhat_nef, target=target)
    nrmses[1, 0] = nrmse(xhat_lif, target=target)
    nrmses[2, 0] = nrmse(xhat_alif, target=target)
    nrmses[3, 0] = nrmse(xhat_wilson, target=target)
    nrmses[4, 0] = nrmse(xhat_durstewitz, target=target)

    fig, ax = plt.subplots(figsize=((12, 8)))
    ax.plot(data['times'], target, linestyle="--", label='target')
#     ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmses[0, 0])
    ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmses[1, 0])
    ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmses[2, 0])
    ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmses[3, 0])
    ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmses[4, 0])
    ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="train")
    plt.legend(loc='upper right')
    plt.show()
        
    print('optimizing gain/bias2 and readout filters/decoders')
#     stim_func = nengo.processes.WhiteSignal(period=t_train, high=1, rms=0.4, seed=0)
    stim_func = lambda t: np.sin(2*t)
#     stim_func = lambda t: np.sin(t*(0.5+t/15))
    data = go(fx, d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
        n_neurons=n_neurons, t=t_train, h_tar=h_tar, dt=dt, gain=data['gain'], bias=data['bias'], gb_iter=0, gb_iter2=gb_iter2, stim_func=stim_func)
#         n_neurons=n_neurons, t=t_train, h_tar=h_tar, dt=dt, gain=data['gain'], bias=data['bias'], gb_iter=0, gb_iter2=0, stim_func=stim_func)
    if h_iter:
        d_lif2, h_lif2  = dh_hyperopt(h_tar.filt(data['tar2'], dt=dt), data['lif2'], order=order, h_iter=h_iter, dt=dt, name='function_lif2')
        d_alif2, h_alif2  = dh_hyperopt(h_tar.filt(data['tar2'], dt=dt), data['alif2'], order=order, h_iter=h_iter, dt=dt, name='function_alif2')
        d_wilson2, h_wilson2  = dh_hyperopt(h_tar.filt(data['tar2'], dt=dt), data['wilson2'], order=order, h_iter=h_iter, dt=dt, name='function_wilson2')
        d_durstewitz2, h_durstewitz2  = dh_hyperopt(h_tar.filt(data['tar2'], dt=dt), data['durstewitz2'], order=order, h_iter=h_iter, dt=dt, name='function_durstewitz2')
        times = np.arange(0, 1, 0.0001)
        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(times, h_tar.impulse(len(times), dt=0.0001), label="h_tar")
        ax.plot(times, h_lif2.impulse(len(times), dt=0.0001), label="h_lif2, nonzero d: %s/%s"%(np.count_nonzero(d_lif2), n_neurons))
        ax.plot(times, h_alif2.impulse(len(times), dt=0.0001), label="h_alif2, nonzero d: %s/%s"%(np.count_nonzero(d_alif2), n_neurons))
        ax.plot(times, h_wilson2.impulse(len(times), dt=0.0001), label="h_wilson2, nonzero d: %s/%s"%(np.count_nonzero(d_wilson2), n_neurons))
        ax.plot(times, h_durstewitz2.impulse(len(times), dt=0.0001), label="h_durstewitz2, nonzero d: %s/%s"%(np.count_nonzero(d_durstewitz2), n_neurons))
        ax.set(xlabel='time (seconds)', ylabel='impulse response', ylim=((0, 10)))
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.show()
    else:
        h_lif2, h_alif2, h_wilson2, h_durstewitz2 = h_tar, h_tar, h_tar, h_tar
        h_lif2, h_alif2, h_wilson2, h_durstewitz2 = h_lif, h_alif, h_wilson, h_durstewitz
        d_lif2 = d_lstsq(data['tar2'], data['lif2'], h_lif2, h_tar, reg=1e-1, dt=dt)
        d_alif2 = d_lstsq(data['tar2'], data['alif2'], h_alif2, h_tar, reg=1e-1, dt=dt)
        d_wilson2 = d_lstsq(data['tar2'], data['wilson2'], h_wilson2, h_tar, reg=1e-1, dt=dt)
        d_durstewitz2 = d_lstsq(data['tar2'], data['durstewitz2'], h_durstewitz2, h_tar, reg=1e-1, dt=dt)

    a_lif = h_lif.filt(data['lif'], dt=dt)
    a_alif = h_alif.filt(data['alif'], dt=dt)
    a_wilson = h_wilson.filt(data['wilson'], dt=dt)
    a_durstewitz = h_durstewitz.filt(data['durstewitz'], dt=dt)
    target = h_tar.filt(data['tar'], dt=dt)
    xhat_nef = data['nef']
    xhat_lif = np.dot(a_lif, d_lif)
    xhat_alif = np.dot(a_alif, d_alif)
    xhat_wilson = np.dot(a_wilson, d_wilson)
    xhat_durstewitz = np.dot(a_durstewitz, d_durstewitz)
    nrmses = np.zeros((5, 1))
    nrmses[0, 0] = nrmse(xhat_nef, target=target)
    nrmses[1, 0] = nrmse(xhat_lif, target=target)
    nrmses[2, 0] = nrmse(xhat_alif, target=target)
    nrmses[3, 0] = nrmse(xhat_wilson, target=target)
    nrmses[4, 0] = nrmse(xhat_durstewitz, target=target)

    fig, ax = plt.subplots(figsize=((12, 8)))
    ax.plot(data['times'], target, linestyle="--", label='target')
    ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmses[0, 0])
    ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmses[1, 0])
    ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmses[2, 0])
    ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmses[3, 0])
    ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmses[4, 0])
    ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="xhat1 train2")
    plt.legend(loc='upper right')
    plt.show()
        
        
    a_lif2 = h_lif2.filt(data['lif2'], dt=dt)
    a_alif2 = h_alif2.filt(data['alif2'], dt=dt)
    a_wilson2 = h_wilson2.filt(data['wilson2'], dt=dt)
    a_durstewitz2 = h_durstewitz2.filt(data['durstewitz2'], dt=dt)
    target2 = h_tar.filt(data['tar2'], dt=dt)
    xhat_nef = data['nef2']
    xhat_lif = np.dot(a_lif2, d_lif2)
    xhat_alif = np.dot(a_alif2, d_alif2)
    xhat_wilson = np.dot(a_wilson2, d_wilson2)
    xhat_durstewitz = np.dot(a_durstewitz2, d_durstewitz2)
    nrmses = np.zeros((5, 1))
    nrmses[0, 0] = nrmse(xhat_nef, target=target2)
    nrmses[1, 0] = nrmse(xhat_lif, target=target2)
    nrmses[2, 0] = nrmse(xhat_alif, target=target2)
    nrmses[3, 0] = nrmse(xhat_wilson, target=target2)
    nrmses[4, 0] = nrmse(xhat_durstewitz, target=target2)

    fig, ax = plt.subplots(figsize=((12, 8)))
    ax.plot(data['times'], target2, linestyle="--", label='target')
    ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmses[0, 0])
    ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmses[1, 0])
    ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmses[2, 0])
    ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmses[3, 0])
    ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmses[4, 0])
    ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="xhat2 train2")
    plt.legend(loc='upper right')
    plt.show()
        

    print('running experimental trials')
    nrmses = np.zeros((5, n_trials))
    for trial in range(n_trials):
        print('trial %s' %trial)
        stim_func = nengo.processes.WhiteSignal(period=t, high=1, rms=0.5, seed=1+trial)
#         stim_func = lambda t: np.sin(t)
        data = go(fx, d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
            n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gain=data['gain'], bias=data['bias'], gb_iter=0, gain2=data['gain2'], bias2=data['bias2'], gb_iter2=0, stim_func=stim_func)
    
        a_lif = h_lif.filt(data['lif'], dt=dt)
        a_alif = h_alif.filt(data['alif'], dt=dt)
        a_wilson = h_wilson.filt(data['wilson'], dt=dt)
        a_durstewitz = h_durstewitz.filt(data['durstewitz'], dt=dt)
        target = h_tar.filt(data['tar'], dt=dt)
        xhat_nef = data['nef']
        xhat_lif = np.dot(a_lif, d_lif)
        xhat_alif = np.dot(a_alif, d_alif)
        xhat_wilson = np.dot(a_wilson, d_wilson)
        xhat_durstewitz = np.dot(a_durstewitz, d_durstewitz)

        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(data['times'], target, linestyle="--", label='target')
        ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmse(xhat_nef, target=target))
        ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmse(xhat_lif, target=target))
        ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmse(xhat_alif, target=target))
        ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmse(xhat_wilson, target=target))
        ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmse(xhat_durstewitz, target=target))
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="xhat1 test")
        plt.legend(loc='upper right')
        plt.show()
    
        a_lif2 = h_lif2.filt(data['lif2'], dt=dt)
        a_alif2 = h_alif2.filt(data['alif2'], dt=dt)
        a_wilson2 = h_wilson2.filt(data['wilson2'], dt=dt)
        a_durstewitz2 = h_durstewitz2.filt(data['durstewitz2'], dt=dt)
        target2 = h_tar.filt(data['tar2'], dt=dt)
        xhat_nef = data['nef2']
        xhat_lif = np.dot(a_lif2, d_lif2)
        xhat_alif = np.dot(a_alif2, d_alif2)
        xhat_wilson = np.dot(a_wilson2, d_wilson2)
        xhat_durstewitz = np.dot(a_durstewitz2, d_durstewitz2)
        nrmses[0, trial] = nrmse(xhat_nef, target=target2)
        nrmses[1, trial] = nrmse(xhat_lif, target=target2)
        nrmses[2, trial] = nrmse(xhat_alif, target=target2)
        nrmses[3, trial] = nrmse(xhat_wilson, target=target2)
        nrmses[4, trial] = nrmse(xhat_durstewitz, target=target2)
        
        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(data['times'], target2, linestyle="--", label='target')
        ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmses[0, trial])
        ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmses[1, trial])
        ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmses[2, trial])
        ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmses[3, trial])
        ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmses[4, trial])
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="test %s"%trial)
        plt.legend(loc='upper right')
        plt.show()
            
    if n_trials > 1:
        nt_names =  ['LIF\n(static)', 'LIF\n(temporal)', 'ALIF', 'Wilson', 'Durstewitz']
        fig, ax = plt.subplots(1, 1, figsize=(14, 6))
        sns.barplot(data=nrmses.T)
        ax.set(ylabel='NRMSE')
        plt.xticks(np.arange(len(nt_names)), tuple(nt_names), rotation=0)
        plt.show()
 
    return nrmses

In [None]:
def fx(x): return np.square(x)
nrmses = trials(fx, n_neurons=10, t_train=30, order=2, n_trials=3, gb_iter=20, gb_iter2=20, dt=0.000025)

In [None]:
def fx(x): return np.square(x)
nrmses = trials(fx, n_neurons=100, t_train=10, order=2, n_trials=3, gb_iter=20, gb_iter2=20, dt=0.000025)

In [None]:
def fx(x): return np.square(x)
nrmses = trials(fx, order=1)
np.savez('data/nrmses_function_order1.npz', nrmses=nrmses)

In [None]:
def fx(x): return np.square(x)
nrmses2 = trials(fx, order=2)
np.savez('data/nrmses_function_order2.npz', nrmses=nrmses2)