In [1]:
import numpy as np

import nengo
from nengo.params import Default
from nengo.dists import Uniform
from nengo.solvers import LstsqL2, NoSolver

from nengolib import Lowpass, DoubleExp
from nengolib.signal import s, z, nrmse, LinearSystem

from train import dh_hyperopt, dh_lstsq, gbopt, d_lstsq
from neuron_models import AdaptiveLIFT, WilsonEuler, DurstewitzNeuron, reset_neuron

import neuron

import warnings

import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(context='poster', style='white')

	1 


  % nengo_class)


In [2]:
def norms(t, dt=0.001, stim_func=lambda t: np.cos(t)):
    with nengo.Network() as model:
        stim = nengo.Node(stim_func)
        p_stimulus = nengo.Probe(stim, synapse=None)
        p_integral = nengo.Probe(stim, synapse=1/s)
    with nengo.Simulator(model, progress_bar=False, dt=dt) as sim:
        sim.run(t, progress_bar=False)
    norm_stim = np.max(np.abs(sim.data[p_stimulus]))
    norm_int = np.max(np.abs(sim.data[p_integral]))
    return norm_stim, norm_int

In [3]:
def go(d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
       n_neurons=100, n_neurons_pre=100, t=4*np.pi, max_rates=Uniform(20, 40), intercepts=Uniform(-1, 1),
       stim_func=lambda t: np.sin(t), seed=0, dt=0.001, h_tar=Lowpass(0.1), reg=1e-1, gbd=dict(), gb_iter=0, supv=0):

    solver_nef = LstsqL2(reg=reg)
    solver_lif = NoSolver(d_lif)
    solver_alif = NoSolver(d_alif)
    solver_wilson = NoSolver(d_wilson)
    solver_durstewitz = NoSolver(d_durstewitz)

    norm_stim, norm_int = norms(t, dt=dt, stim_func=stim_func)

    with nengo.Network(seed=0) as model:
                    
        model.T = t
        def flip(t, x):
            if t<model.T/2: return x/norm_int
            elif t>=model.T/2: return -1.0*x/norm_int

        # Ensembles
        u_raw = nengo.Node(stim_func)
        u = nengo.Node(output=flip, size_in=1)
        pre_u = nengo.Ensemble(n_neurons_pre, 1, max_rates=max_rates, seed=0, label='pre_u')
        pre_x = nengo.Ensemble(n_neurons_pre, 1, max_rates=max_rates, seed=0, label='pre_x')
        nef = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=nengo.LIF(), seed=0, label='lif')
        lif = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=nengo.LIF(), seed=0, label='lif')
        alif = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=AdaptiveLIFT(tau_adapt=0.1, inc_adapt=0.1), seed=0, label='alif')
        wilson = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=WilsonEuler(), seed=0, label='wilson')
        durstewitz = nengo.Ensemble(n_neurons, 1, max_rates=max_rates, intercepts=intercepts, neuron_type=DurstewitzNeuron(), seed=0, label='durstewitz')
        tar = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())

        # Normal connections
        nengo.Connection(u_raw, u, synapse=None)
        nengo.Connection(u, pre_u, synapse=None, seed=0)
        nengo.Connection(u, pre_x, synapse=1/s, seed=0)
        nengo.Connection(u, tar, synapse=1/s)
        nef_nef = nengo.Connection(nef, nef, synapse=h_tar, solver=solver_nef, seed=0, label='nef_nef')

        # Feedforward connections
        pre_u_nef = nengo.Connection(pre_u, nef, synapse=h_tar, transform=h_tar.tau, solver=LstsqL2(reg=reg), seed=0, label='pre_u_nef')
        pre_u_lif = nengo.Connection(pre_u, lif, synapse=h_tar, transform=h_tar.tau, solver=LstsqL2(reg=reg), seed=0, label='pre_u_lif')
        pre_u_alif = nengo.Connection(pre_u, alif, synapse=h_tar, transform=h_tar.tau, solver=LstsqL2(reg=reg), seed=0, label='pre_u_alif')
        pre_u_wilson = nengo.Connection(pre_u, wilson, synapse=h_tar, transform=h_tar.tau, solver=LstsqL2(reg=reg), seed=0, label='pre_u_wilson')
        pre_u_durstewitz = nengo.Connection(pre_u, durstewitz, synapse=h_tar, transform=h_tar.tau, solver=LstsqL2(reg=reg), seed=0, label='pre_u_durstewitz')
        
        # Feedback Connections
        if supv:
            pre_x_lif = nengo.Connection(pre_x, lif, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_x_lif')
            pre_x_alif = nengo.Connection(pre_x, alif, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_x_alif')
            pre_x_wilson = nengo.Connection(pre_x, wilson, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_x_wilson')
            pre_x_durstewitz = nengo.Connection(pre_x, durstewitz, synapse=h_tar, solver=LstsqL2(reg=reg), seed=0, label='pre_x_durstewitz')
        else:
            lif_lif = nengo.Connection(lif, lif, synapse=h_lif, solver=solver_lif, seed=0, label='lif_lif')
            alif_alif = nengo.Connection(alif, alif, synapse=h_alif, solver=solver_alif, seed=0, label='alif_alif')
            wilson_wilson = nengo.Connection(wilson, wilson, synapse=h_wilson, solver=solver_wilson, seed=0, label='wilson_wilson')
            durstewitz_durstewitz = nengo.Connection(durstewitz, durstewitz, synapse=h_durstewitz, solver=solver_durstewitz, seed=0, label='durstewitz_durstewitz')

        # Probes
        p_u = nengo.Probe(u, synapse=None)
        p_tar = nengo.Probe(tar, synapse=None)
        p_nef = nengo.Probe(nef, synapse=h_tar)
        p_lif = nengo.Probe(lif.neurons, synapse=None)
        p_alif = nengo.Probe(alif.neurons, synapse=None)
        p_wilson = nengo.Probe(wilson.neurons, synapse=None)
        p_durstewitz = nengo.Probe(durstewitz.neurons, synapse=None)

    if gb_iter:
        gbd = gbopt([pre_u_durstewitz], gb_iter=gb_iter, pt=False)
    if any(gbd):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            pre_u_durstewitz.gain = gbd['pre_u_durstewitz']['gain']
            pre_u_durstewitz.bias = gbd['pre_u_durstewitz']['bias'] * h_tar.tau
            if supv:
                pre_x_durstewitz.gain = gbd['pre_u_durstewitz']['gain']
                pre_x_durstewitz.bias = gbd['pre_u_durstewitz']['bias']
            else:
                durstewitz_durstewitz.gain = gbd['pre_u_durstewitz']['gain']
                durstewitz_durstewitz.bias = gbd['pre_u_durstewitz']['bias']

    with nengo.Simulator(model, seed=0, dt=dt, progress_bar=False) as sim:
        neuron.h.init()
        sim.run(t, progress_bar=True)
        reset_neuron(sim) 
        
    return dict(
        times=sim.trange(),
        u=sim.data[p_u],
        tar=sim.data[p_tar],
        nef=sim.data[p_nef],
        lif=sim.data[p_lif],
        alif=sim.data[p_alif],
        wilson=sim.data[p_wilson],
        durstewitz=sim.data[p_durstewitz],
        gbd=gbd)

In [4]:
def trials(n_neurons=100, t=4*np.pi, h_tar=Lowpass(0.1), dt=0.001, n_trains=1, n_trials=1, gb_iter=0, h_iter=0, lstsq_iter=0, order=1):

    h_tar.tau = 0.1
    h_lif, h_alif, h_wilson, h_durstewitz = h_tar, h_tar, h_tar, h_tar
    d_lif, d_alif, d_wilson, d_durstewitz = np.zeros((n_neurons, 1)), np.zeros((n_neurons, 1)), np.zeros((n_neurons, 1)), np.zeros((n_neurons, 1))

    print('optimizing gain/bias')
    stim_func = nengo.processes.WhiteSignal(period=t/2, high=1, rms=1, seed=0)
    data = go(d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
            n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gb_iter=gb_iter, stim_func=stim_func, supv=1)
    
    print('gathering training data')
    lifs = np.zeros((n_trains, int(t/dt), n_neurons))
    alifs = np.zeros((n_trains, int(t/dt), n_neurons))
    wilsonss = np.zeros((n_trains, int(t/dt), n_neurons))
    durstewitzs = np.zeros((n_trains, int(t/dt), n_neurons))
    times = np.zeros((n_trains, int(t/dt)))
    nefs = np.zeros((n_trains, int(t/dt), 1))
    us = np.zeros((n_trains, int(t/dt), 1))
    tars = np.zeros((n_trains, int(t/dt), 1))

    for n in range(n_trains):
        stim_func = nengo.processes.WhiteSignal(period=t/2, high=1, rms=1, seed=n)
        data = go(d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
            n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gbd=data['gbd'], stim_func=stim_func, supv=1)
        lifs[n] = data['lif']
        alifs[n] = data['alif']
        wilsonss[n] = data['wilson']
        durstewitzs[n] = data['durstewitz']
        times[n] =  data['times']
        nefs[n] = data['nef']
        us[n] = data['u']
        tars[n] = data['tar']
    lifs = lifs.reshape((n_trains*int(t/dt), n_neurons))
    alifs = alifs.reshape((n_trains*int(t/dt), n_neurons))
    wilsonss = wilsonss.reshape((n_trains*int(t/dt), n_neurons))
    durstewitzs = durstewitzs.reshape((n_trains*int(t/dt), n_neurons))
    times = times.reshape((n_trains*int(t/dt), 1))
    nefs = nefs.reshape((n_trains*int(t/dt), 1))
    us = us.reshape((n_trains*int(t/dt), 1))
    tars = tars.reshape((n_trains*int(t/dt), 1))
        
    print('optimizing filters and decoders')
    if h_iter:
        d_lif, h_lif  = dh_hyperopt(h_tar.filt(tars, dt=dt), lifs, order=order, h_iter=h_iter, dt=dt, name='integrator_lif')
        d_alif, h_alif  = dh_hyperopt(h_tar.filt(tars, dt=dt), alifs, order=order, h_iter=h_iter, dt=dt, name='integrator_alif')
        d_wilson, h_wilson  = dh_hyperopt(h_tar.filt(tars, dt=dt), wilsonss, order=order, h_iter=h_iter, dt=dt, name='integrator_wilson')
        d_durstewitz, h_durstewitz  = dh_hyperopt(h_tar.filt(tars, dt=dt), durstewitzs, order=order, h_iter=h_iter, dt=dt, name='integrator_durstewitz')
    elif lstsq_iter:    
        d_lif, h_lif  = dh_lstsq(us, tars, lifs, order=order, h_tar=h_tar, dt=dt, lstsq_iter=lstsq_iter)
        d_alif, h_alif  = dh_lstsq(us, tars, alifs, order=order, h_tar=h_tar, dt=dt, lstsq_iter=lstsq_iter)
        d_wilson, h_wilson  = dh_lstsq(us, tars, wilsonss, order=order, h_tar=h_tar, dt=dt, lstsq_iter=lstsq_iter)
        d_durstewitz, h_durstewitz  = dh_lstsq(us, tars, durstewitzs, order=order, h_tar=h_tar, dt=dt, lstsq_iter=lstsq_iter)
    else:
        h_lif, h_alif, h_wilson, h_durstewitz = h_tar, h_tar, h_tar, h_tar
        d_lif = d_lstsq(tars, lifs, h_lif, h_tar, reg=1e-1, dt=dt)
        d_alif = d_lstsq(tars, alifs, h_alif, h_tar, reg=1e-1, dt=dt)
        d_wilson = d_lstsq(tars, wilsonss, h_wilson, h_tar, reg=1e-1, dt=dt)
        d_durstewitz = d_lstsq(tars, durstewitzs, h_durstewitz, h_tar, reg=1e-1, dt=dt)

    times = np.arange(0, 1, 0.0001)
    fig, ax = plt.subplots(figsize=((12, 8)))
    ax.plot(times, h_tar.impulse(len(times), dt=0.0001), label="h_tar")
    ax.plot(times, h_lif.impulse(len(times), dt=0.0001), label="h_lif")
    ax.plot(times, h_alif.impulse(len(times), dt=0.0001), label="h_alif")
    ax.plot(times, h_wilson.impulse(len(times), dt=0.0001), label="h_wilson")
    ax.plot(times, h_durstewitz.impulse(len(times), dt=0.0001), label="h_durstewitz")
    ax.set(xlabel='time (seconds)', ylabel='impulse response', ylim=((0, 10)))
    ax.legend(loc='upper right')
    plt.tight_layout()
    plt.show()
        
        
    print('running experimental trials')
    nrmses = np.zeros((5, n_trials))
    for trial in range(n_trials):
        stim_func = nengo.processes.WhiteSignal(period=t/2, high=1, rms=1, seed=trial)
        data = go(d_lif, d_alif, d_wilson, d_durstewitz, h_lif, h_alif, h_wilson, h_durstewitz,
            n_neurons=n_neurons, t=t, h_tar=h_tar, dt=dt, gbd=data['gbd'], gb_iter=0, stim_func=stim_func, supv=0)

        a_lif = h_lif.filt(data['lif'], dt=dt)
        a_alif = h_alif.filt(data['alif'], dt=dt)
        a_wilson = h_wilson.filt(data['wilson'], dt=dt)
        a_durstewitz = h_durstewitz.filt(data['durstewitz'], dt=dt)
        target = h_tar.filt(data['tar'], dt=dt)
        xhat_nef = data['nef']
        xhat_lif = np.dot(a_lif, d_lif)
        xhat_alif = np.dot(a_alif, d_alif)
        xhat_wilson = np.dot(a_wilson, d_wilson)
        xhat_durstewitz = np.dot(a_durstewitz, d_durstewitz)
        nrmses[0, trial] = nrmse(xhat_nef, target=target)
        nrmses[1, trial] = nrmse(xhat_lif, target=target)
        nrmses[2, trial] = nrmse(xhat_alif, target=target)
        nrmses[3, trial] = nrmse(xhat_wilson, target=target)
        nrmses[4, trial] = nrmse(xhat_durstewitz, target=target)
        
        fig, ax = plt.subplots(figsize=((12, 8)))
        ax.plot(data['times'], target, linestyle="--", label='target')
        ax.plot(data['times'], xhat_nef, label='NEF, nrmse=%.3f' %nrmses[0, trial])
        ax.plot(data['times'], xhat_lif, label='LIF, nrmse=%.3f' %nrmses[1, trial])
        ax.plot(data['times'], xhat_alif, label='ALIF, nrmse=%.3f' %nrmses[2, trial])
        ax.plot(data['times'], xhat_wilson, label='Wilson, nrmse=%.3f' %nrmses[3, trial])
        ax.plot(data['times'], xhat_durstewitz, label='Durstewitz, nrmse=%.3f' %nrmses[4, trial])
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="test %s"%trial)
        plt.legend(loc='upper right')
        plt.show()
            
    if n_trials > 1:
        nt_names =  ['LIF\n(static)', 'LIF\n(temporal)', 'ALIF', 'Wilson', 'Durstewitz']
        fig, ax = plt.subplots(1, 1, figsize=(14, 6))
        sns.barplot(data=nrmses.T)
        ax.set(ylabel='NRMSE')
        plt.xticks(np.arange(len(nt_names)), tuple(nt_names), rotation=0)
        plt.show()
 
    return nrmses