In [None]:
import os

import numpy as np

import nengo
from nengo.solvers import LstsqL2, NoSolver
from nengo.utils.matplotlib import rasterplot
from nengo.utils.ensemble import tuning_curves
from nengo.dists import Uniform
from nengo.params import Default
from nengo.utils.numpy import rmse

from nengolib.signal import s, nrmse, LinearSystem
from nengolib.synapses import Lowpass

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='poster', style='whitegrid')
%matplotlib inline

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials

from scipy.signal import convolve

from nengo_bioneurons import BahlNeuron

# Signals

In [None]:
def make_stimulus(signal, freq, amp, seed):       
    if signal == 'cos':
        return nengo.Node(output=lambda t: np.cos(freq*t))
    elif signal == 'sin':
        return nengo.Node(output=lambda t: np.sin(freq*t))
    elif signal == 'white_noise':
        return nengo.Node(nengo.processes.WhiteSignal(
            period=100,
            high=freq,
            rms=amp,
            seed=seed))

In [None]:
def norms(signal, freq, amp, ss, tau, t, plot=False):
    # first find the norm of the filtered signal
    lpf=Lowpass(tau)
    with nengo.Network() as model:
        stim = make_stimulus(signal, freq, amp, ss)
        p_stim = nengo.Probe(stim, synapse=None)
        p_integral = nengo.Probe(stim, synapse=1/s)
    with nengo.Simulator(model, progress_bar=False) as sim:
        sim.run(t, progress_bar=False)
    stimulus = sim.data[p_stim]
    target = sim.data[p_integral]
    target_f = lpf.filt(sim.data[p_integral])
    norm_s = np.max(np.abs(stimulus))
    norm = np.max(np.abs(target))
    norm_f = np.max(np.abs(target_f))
    
    if plot:
        plt.plot(sim.trange(), stimulus)
        plt.plot(sim.trange(), target)
        plt.plot(sim.trange(), lpf.filt(sim.data[p_integral]/norm_f))
        plt.show()
    return norm, norm_s, norm_f

In [None]:
def bin_activities_values_1d(
    xhat_pre,
    act_bio,
    x_min=-1,
    x_max=1,
    n_neurons=10,
    n_eval_points=20):

    def find_nearest(array,value):
        idx = (np.abs(array-value)).argmin()
        return idx

    x_bins = np.linspace(x_min, x_max, num=n_eval_points)
    hz_means = np.empty((n_neurons, n_eval_points))
    hz_stds = np.empty((n_neurons, n_eval_points))
    for i in range(n_neurons):
        bin_act = [[] for _ in range(x_bins.shape[0])]
        for t in range(act_bio.shape[0]):
            idx = find_nearest(x_bins, xhat_pre[t])
            bin_act[idx].append(act_bio[t, i])
        for x in range(len(bin_act)):
            hz_means[i, x] = np.average(bin_act[x])
            hz_stds[i, x] = np.std(bin_act[x])

    return x_bins, hz_means, hz_stds

# Encoder Learning

In [None]:
class EncoderNode(nengo.Node):
    def __init__(
            self,
            conn,
            n_bio,
            n_syn,
            dim,
            d_pre,
            eta,  # learning rate
            seed, # learning seed
            syn_encoders_init):
        
        self.conn = conn
        self.n_syn = n_syn
        self.dim = dim
        self.d_pre = d_pre
        self.eta = eta
        self.rng = np.random.RandomState(seed=seed)
        self.n_pre = self.d_pre.shape[0]
        self.n_bio = n_bio
        self.syn_encoders = syn_encoders_init
        self.syn_weights = np.zeros_like(self.syn_encoders)
        self.a_target = np.array([])
        self.a_bio = np.array([])

        super(EncoderNode, self).__init__(self.update, 
            size_in=2*self.n_bio,  # [a_bio, a_lif]
            size_out=self.n_bio)

    def update(self, t, x):
        self.a_bio = x[0:self.n_bio]
        self.a_target = x[self.n_bio:2*self.n_bio]
        return self.a_bio - self.a_target
    
    def update_encoders(self, bio, pre, syn):
        a_error = self.a_bio[bio] - self.a_target[bio]
        d_syn = self.d_pre[pre]
        e_old = self.syn_encoders[bio, pre, syn]
        delta = self.rng.uniform(0, 2 * self.eta * np.abs(a_error))
        if a_error > 0 and d_syn > 0:  # overactive, positive dec => reduce enc reduce weight
            self.syn_encoders[bio, pre, syn] += -delta
        if a_error > 0 and d_syn < 0:  # overactive, negative dec => increase enc reduce weight
            self.syn_encoders[bio, pre, syn] += +delta
        if a_error < 0 and d_syn > 0:  # underactive, positive dec => increase enc increase weight
            self.syn_encoders[bio, pre, syn] += +delta
        if a_error < 0 and d_syn < 0:  # underactive, negative dec => reduce enc increase weight
            self.syn_encoders[bio, pre, syn] += -delta
        w_new = np.dot(d_syn, self.syn_encoders[bio, pre, syn])
        return w_new

    def update_weights(self):
        for bio in range(self.n_bio):
            for pre in range(self.n_pre):
                for syn in range(self.n_syn):
                    self.syn_weights[bio, pre, syn] = np.dot(self.d_pre[pre], self.syn_encoders[bio, pre, syn])
        return self.syn_weights

# Readout Filter and Decoder Optimization

In [None]:
def optimize_elephys(
    save_data,
    save_ensemble,
    n_neurons,
    tau_lpf,
    n_zeros=1,
    n_poles=3,
    z_min=1e2,
    z_max=1e3,
    p_min=-1e2,
    p_max=-1e0,
    reg=0.1,
    max_evals=200,
    normalize=True,
    seed=5):
    
    print "running optimization ..."
    h_eps = []
    hyps = {}  # hyperparameters
    hyps['save_data'] = save_data
    hyps['save_ensemble'] = save_ensemble
    for bio in range(n_neurons):  # put all bios in one hyperparams so decoder act on all 
        for z in range(n_zeros):
            hyps['%s_bio_%s_zero'%(bio, z)] = hp.uniform('%s_bio_%s_zero'%(bio, z), z_min, z_max)
        for p in range(n_poles):
            hyps['%s_bio_%s_pole'%(bio, p)] = hp.uniform('%s_bio_%s_pole'%(bio, p), p_min, p_max)


    def objective(hyps):
        from nengolib.signal import nrmse
        from nengo.utils.numpy import rmse

        for bio in range(n_neurons):
            zeros = np.array([hyps['%s_bio_%s_zero'%(bio, z)] for z in range(n_zeros)])
            poles = np.array([hyps['%s_bio_%s_pole'%(bio, p)] for p in range(n_poles)])
            h_ep = LinearSystem((zeros, poles, 1.0))
            if normalize:
                h_ep/= h_ep.dcgain
            h_eps.append(h_ep)
                
        spikes = np.load(hyps['save_data']+hyps['save_ensemble'])['spikes']
        act_lpf = Lowpass(tau_lpf).filt(spikes)
        act_eps = np.zeros_like(act_lpf)
        for n in range(n_neurons):
            act_eps[:,n] = h_eps[n].filt(act_lpf[:,n])
        target = np.load(hyps['save_data']+"target.npz")['target']
        if np.sum(act_eps != 0):
            d_eps = nengo.solvers.LstsqL2(reg=reg)(act_eps, target)[0]
        else:
            d_eps = np.zeros((n_neurons, 1))
        xhat_eps = np.dot(act_eps, d_eps)
        if np.sum(target) != 0: nrmse = nrmse(xhat_eps, target=target)
        else: nrmse = rmse(xhat_eps, target)

        return {'loss': nrmse,
            'h_eps': h_eps,
            'd_eps': d_eps,
            'status': STATUS_OK }

    trials = Trials()

    best = fmin(objective,
                rstate=np.random.RandomState(seed=seed),
                space=hyps,
                algo=tpe.suggest,
                max_evals=max_evals,
                trials=trials)

    best_idx = np.array(trials.losses()).argmin()
    best_h_eps = trials.trials[best_idx]['result']['h_eps']
    best_d_eps = trials.trials[best_idx]['result']['d_eps']

    return best_h_eps, best_d_eps

## Helper

In [None]:
def get_kwargs(
    n_neurons=10,
    n_pre=100,
    n_syn=1,
    sec='tuft',
    taus={'network': 0.05,
          'readout': 0.05},
    seeds={'ns': 1, 'ss':2, 'es': 3, 'cs': 4, 'ls': 5},
    neuron_type=BahlNeuron(bias_method='weights_fixed')):
    
    pre_kwargs = dict(
        n_neurons=n_pre,
        dimensions=1,
        max_rates=Uniform(20, 40),
        seed=seeds['es'])
    lif_kwargs = dict(
        n_neurons=n_neurons,
        dimensions=1,
        max_rates=Uniform(20, 40),
        neuron_type=nengo.LIFRate(),  # adaptiveLIF?
        seed=seeds['es'],
        label='lif')
    conn_kwargs = dict(
        sec=sec,
        n_syn=n_syn,
        syn_type='ExpSyn',
        tau_list=[taus['network']],
        synapse=taus['network'],
        seed=seeds['cs'])
    bio_kwargs = dict(
        n_neurons=n_neurons,
        dimensions=1,
        encoders=Uniform(-1, 1),
        gain=Uniform(0, 0),
        bias=Uniform(0, 0),
        neuron_type=neuron_type,
        seed=seeds['es'],
        label='bio')
    
    return pre_kwargs, lif_kwargs, conn_kwargs, bio_kwargs

In [None]:
def get_syn_encoders_init(
    n_neurons=10,
    n_pre=100,
    n_syn=1,
    signal='cos',
    freq=1,
    amp=1,
    t=1.0,
    sec='tuft',
    taus={'network': 0.05,
          'readout': 0.05},
    seeds={'ns': 1, 'ss':2, 'es': 3, 'cs': 4, 'ls': 5},
    T_u=1,
    T_x=1):

    norm, norm_s, norm_f = norms(signal, freq, amp, seeds['ss'], taus['network'], t)

    pre_kwargs, lif_kwargs, conn_kwargs, bio_kwargs = get_kwargs(n_neurons, n_pre, n_syn, sec, taus, seeds)

    # Build a network to collect encoders, gains, biases from target LIF
    with nengo.Network(seed=seeds['ns']) as pre_model:
        pre_u = nengo.Ensemble(radius=norm_s, **pre_kwargs)
        pre_x = nengo.Ensemble(radius=norm, **pre_kwargs)
        lif = nengo.Ensemble(**lif_kwargs)
        pre_u_lif = nengo.Connection(pre_u, lif, transform=T_u, **conn_kwargs)
        pre_x_lif = nengo.Connection(pre_x, lif, transform=T_x, **conn_kwargs)
    sim = nengo.Simulator(pre_model, seed=seeds['ss'])
    d_pre_u = sim.data[pre_u_lif].weights.T
    d_pre_x = sim.data[pre_x_lif].weights.T
    e_target = sim.data[lif].encoders
    syn_encoders_pre_bio = np.zeros((n_neurons, n_pre, n_syn))
    syn_encoders_bio_bio = np.zeros((n_neurons, n_neurons, n_syn))
    for bio in range(n_neurons):
        syn_encoders_pre_bio[bio] = e_target[bio] * np.ones((n_pre, n_syn))
        syn_encoders_bio_bio[bio] = e_target[bio] * np.ones((n_neurons, n_syn))
        
    return syn_encoders_pre_bio, syn_encoders_bio_bio, d_pre_u, d_pre_x

In [None]:
def save_decoders_filters(
    save_dir,
    d_pre_u,
    d_pre_x,
    d_eps_dict,
    syn_weights_dict,
    syn_encoders_dict,
    h_eps_dict):
    
    print "saving data ..."
    np.savez(save_dir+'decoders.npz',
        d_pre_u=d_pre_u,
        d_pre_x=d_pre_x,    
        d_eps_bio=d_eps_dict['d_eps_bio'],
        d_eps_supv=d_eps_dict['d_eps_supv'])

    np.savez(save_dir+'syn_weights.npz',
        syn_weights_pre_u_supv=syn_weights_dict['syn_weights_pre_u_supv'],
        syn_weights_pre_x_supv=syn_weights_dict['syn_weights_pre_x_supv'],
        syn_weights_pre_u_bio=syn_weights_dict['syn_weights_pre_u_bio'],
        syn_weights_supv_bio=syn_weights_dict['syn_weights_supv_bio'],
        syn_weights_bio_bio=syn_weights_dict['syn_weights_bio_bio'])

    np.savez(save_dir+'syn_encoders.npz',
        syn_encoders_pre_u_supv=syn_encoders_dict['syn_encoders_pre_u_supv'],
        syn_encoders_pre_x_supv=syn_encoders_dict['syn_encoders_pre_x_supv'],
        syn_encoders_pre_u_bio=syn_encoders_dict['syn_encoders_pre_u_bio'],
        syn_encoders_supv_bio=syn_encoders_dict['syn_encoders_supv_bio'],
        syn_encoders_bio_bio=syn_encoders_dict['syn_encoders_bio_bio'])

    h_eps_bio_num = [np.array(filt.num) for filt in h_eps_dict['h_eps_bio']]
    h_eps_bio_den = [np.array(filt.den) for filt in h_eps_dict['h_eps_bio']]
    h_eps_supv_num = [np.array(filt.num) for filt in h_eps_dict['h_eps_supv']]
    h_eps_supv_den = [np.array(filt.den) for filt in h_eps_dict['h_eps_supv']]

    np.savez(save_dir+'filters.npz',
        h_eps_bio_num=h_eps_bio_num,
        h_eps_bio_den=h_eps_bio_den,
        h_eps_supv_num=h_eps_supv_num,
        h_eps_supv_den=h_eps_supv_den)

In [None]:
def load_decoders_filters(save_dir):
    
    d_pre_u = np.load(save_dir+'decoders.npz')['d_pre_u']
    d_pre_x = np.load(save_dir+'decoders.npz')['d_pre_x']

    d_eps_bio = np.load(save_dir+'decoders.npz')['d_eps_bio']
    d_eps_supv = np.load(save_dir+'decoders.npz')['d_eps_supv']
    d_eps_dict_fb = {'d_eps_bio': d_eps_bio, 'd_eps_supv': d_eps_supv}

    syn_weights_pre_u_supv = np.load(save_dir+'syn_weights.npz')['syn_weights_pre_u_supv']
    syn_weights_pre_x_supv = np.load(save_dir+'syn_weights.npz')['syn_weights_pre_x_supv']
    syn_weights_pre_u_bio = np.load(save_dir+'syn_weights.npz')['syn_weights_pre_u_bio']
    syn_weights_supv_bio = np.load(save_dir+'syn_weights.npz')['syn_weights_supv_bio']
    syn_weights_bio_bio = np.load(save_dir+'syn_weights.npz')['syn_weights_bio_bio']
    syn_weights_dict_fb = {
        'syn_weights_pre_u_supv': syn_weights_pre_u_supv,
        'syn_weights_pre_x_supv': syn_weights_pre_x_supv,
        'syn_weights_pre_u_bio': syn_weights_pre_u_bio,
        'syn_weights_supv_bio': syn_weights_supv_bio,
        'syn_weights_bio_bio': syn_weights_bio_bio,
    }

    syn_encoders_pre_u_supv = np.load(save_dir+'syn_encoders.npz')['syn_encoders_pre_u_supv']
    syn_encoders_pre_x_supv = np.load(save_dir+'syn_encoders.npz')['syn_encoders_pre_x_supv']
    syn_encoders_pre_u_bio = np.load(save_dir+'syn_encoders.npz')['syn_encoders_pre_u_bio']
    syn_encoders_supv_bio = np.load(save_dir+'syn_encoders.npz')['syn_encoders_supv_bio']
    syn_encoders_bio_bio = np.load(save_dir+'syn_encoders.npz')['syn_encoders_bio_bio']
    syn_encoders_dict_fb = {
        'syn_encoders_pre_u_supv': syn_encoders_pre_u_supv,
        'syn_encoders_pre_x_supv': syn_encoders_pre_x_supv,
        'syn_encoders_pre_u_bio': syn_encoders_pre_u_bio,
        'syn_encoders_supv_bio': syn_encoders_supv_bio,
        'syn_encoders_bio_bio': syn_encoders_bio_bio,
    }
    
    h_eps_bio_num = np.load(save_dir+'filters.npz')['h_eps_bio_num']
    h_eps_bio_den = np.load(save_dir+'filters.npz')['h_eps_bio_den']
    h_eps_supv_num = np.load(save_dir+'filters.npz')['h_eps_supv_num']
    h_eps_supv_den = np.load(save_dir+'filters.npz')['h_eps_supv_den']
    h_eps_bio = [LinearSystem((h_eps_bio_num[i], h_eps_bio_den[i])) for i in range(len(h_eps_bio_num))]
    h_eps_supv = [LinearSystem((h_eps_supv_num[i], h_eps_supv_den[i])) for i in range(len(h_eps_supv_num))]
    h_eps_dict_fb = {'h_eps_bio': h_eps_bio, 'h_eps_supv': h_eps_supv}

    
    return d_pre_u, d_pre_x, d_eps_dict, syn_weights_dict_fb, syn_encoders_dict_fb, h_eps_dict_fb

# Network

In [None]:
def double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=1,
    n_neurons=100,
    n_neurons_plot=10,
    n_pre=100,
    n_syn=1,
    signal='cos',
    freq=1,
    amp=1,
    sec='tuft',
    taus={'network': 0.05,
          'readout': 0.05},
    regs={'pre-bio': 0.1,
          'bio-out': 0.1},
    seeds={'ns': 1, 'ss':2, 'es': 3, 'cs': 4, 'ls': 5},
    neuron_type=BahlNeuron(bias_method='weights_fixed'),
    save_dir='/home/pduggins/nengo_bioneurons/nengo_bioneurons/tests/data/double_optimize_alternate/',
    save_suffix='default/',
    eta=0,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=False,
    plot_supv=False,
    plot_bio=False,
    sim_supv=False,
    sim_bio=False,
    save_df=True):
    
    if not os.path.exists(save_dir+save_suffix):
        os.makedirs(save_dir+save_suffix)
    
    # transform input signal u so that the integral x is normalized to np.max(x)==1
    norm, norm_s, norm_f = norms(signal, freq, amp, seeds['ss'], taus['network'], t)
    # keyword arguments for ensembles and connections
    pre_kwargs, lif_kwargs, conn_kwargs, bio_kwargs = get_kwargs(
        n_neurons, n_pre, n_syn, sec, taus, seeds, neuron_type)
    
    # Simulate the full network with encoder learning rules. Skip simulating supv or bio when possible.
    with nengo.Network(seed=seeds['ns']) as model:
        u = make_stimulus(signal, freq, amp, seed=seeds['ss'])
        pre_u = nengo.Ensemble(radius=norm_s, **pre_kwargs)
        pre_x = nengo.Ensemble(radius=norm, **pre_kwargs)
        if sim_supv: supv = nengo.Ensemble(**bio_kwargs)
        if sim_bio: bio = nengo.Ensemble(**bio_kwargs)
        lif = nengo.Ensemble(**lif_kwargs)
        tar = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())

        # normal connections
        nengo.Connection(u, pre_u, synapse=None, seed=seeds['cs'])
        nengo.Connection(u, pre_x, synapse=1/s, seed=seeds['cs'])
        nengo.Connection(u, tar, synapse=1/s, transform=1.0/norm_f)
        nengo.Connection(pre_u, lif, synapse=taus['network'], transform=taus['network']/norm_f)
        nengo.Connection(pre_x, lif, synapse=taus['network'], transform=1.0/norm_f)  # proxy for accuracy
#         lif_lif = nengo.Connection(lif, lif, **conn_kwargs)  # true recurrence on training spikes
        
        # bioneuron connections (learned)
        if sim_supv:
            pre_u_supv = nengo.Connection(pre_u, supv, syn_weights=syn_weights_pre_u_supv, **conn_kwargs)
            pre_x_supv = nengo.Connection(pre_x, supv, syn_weights=syn_weights_pre_x_supv, **conn_kwargs)
        if sim_bio:
            pre_u_bio = nengo.Connection(pre_u, bio, syn_weights=syn_weights_pre_u_bio, **conn_kwargs)
            bio_bio = nengo.Connection(bio, bio, syn_weights=syn_weights_bio_bio, **conn_kwargs)
        if sim_supv and sim_bio:
            supv_bio = nengo.Connection(supv, bio, syn_weights=syn_weights_supv_bio, **conn_kwargs)
  
        # associate encoder learning nodes with each learned connection
        if learn_pre_supv and sim_supv:
            enc_node_pre_u_supv = EncoderNode(
                pre_u_supv, n_neurons, n_syn, 1, d_pre_u, eta, seeds['ls'], syn_encoders_pre_u_supv)
            pre_u_supv.learning_node = enc_node_pre_u_supv
            nengo.Connection(supv.neurons, enc_node_pre_u_supv[0:n_neurons], synapse=taus['readout'])
            nengo.Connection(lif.neurons, enc_node_pre_u_supv[n_neurons:2*n_neurons], synapse=taus['readout'])

            enc_node_pre_x_supv = EncoderNode(
                pre_x_supv, n_neurons, n_syn, 1, d_pre_x, eta, seeds['ls'], syn_encoders_pre_x_supv)
            pre_x_supv.learning_node = enc_node_pre_x_supv
            nengo.Connection(supv.neurons, enc_node_pre_x_supv[0:n_neurons], synapse=taus['readout'])
            nengo.Connection(lif.neurons, enc_node_pre_x_supv[n_neurons:2*n_neurons], synapse=taus['readout'])
            
        if learn_supv_bio and sim_supv and sim_bio:
            enc_node_pre_u_bio = EncoderNode(
                pre_u_bio, n_neurons, n_syn, 1, d_pre_u, eta, seeds['ls'], syn_encoders_pre_u_bio)
            pre_u_bio.learning_node = enc_node_pre_u_bio
            nengo.Connection(bio.neurons, enc_node_pre_u_bio[0:n_neurons], synapse=taus['readout'])
            nengo.Connection(lif.neurons, enc_node_pre_u_bio[n_neurons:2*n_neurons], synapse=taus['readout'])

            enc_node_supv_bio = EncoderNode(
                supv_bio, n_neurons, n_syn, 1, d_eps_supv, eta, seeds['ls'], syn_encoders_supv_bio)
            supv_bio.learning_node = enc_node_supv_bio
            nengo.Connection(bio.neurons, enc_node_supv_bio[0:n_neurons], synapse=taus['readout'])
            nengo.Connection(lif.neurons, enc_node_supv_bio[n_neurons:2*n_neurons], synapse=taus['readout'])
            
        if learn_bio_bio and sim_bio:
            enc_node_bio_bio = EncoderNode(
                bio_bio, n_neurons, n_syn, 1, d_eps_bio, eta, seeds['ls'], syn_encoders_bio_bio)
            bio_bio.learning_node = enc_node_bio_bio
            nengo.Connection(bio.neurons, enc_node_bio_bio[0:n_neurons], synapse=taus['readout'])
            nengo.Connection(lif.neurons, enc_node_bio_bio[n_neurons:2*n_neurons], synapse=taus['readout'])

        # probes
        p_stim = nengo.Probe(u, synapse=None)
        p_target = nengo.Probe(tar, synapse=None)
        if sim_supv:
            p_spk_supv = nengo.Probe(supv.neurons, synapse=None)
            p_act_supv = nengo.Probe(supv.neurons, synapse=taus['readout'])
        if sim_bio:
            p_spk_bio = nengo.Probe(bio.neurons, synapse=None)
            p_act_bio = nengo.Probe(bio.neurons, synapse=taus['readout'])
        p_act_lif = nengo.Probe(lif.neurons, synapse=taus['readout'])
        p_lif = nengo.Probe(lif, synapse=taus['readout'])

    # RUN the simulation
    with nengo.Simulator(model, seed=seeds['ss']) as sim:
        sim.run(t)

    # collect spikes, lowpass activities, and lif decodes
    lpf = Lowpass(taus['readout'])
    stim = lpf.filt(sim.data[p_stim])
    target = lpf.filt(sim.data[p_target])
    if sim_supv:
        spikes_supv = sim.data[p_spk_supv]
        act_lpf_supv = sim.data[p_act_supv]
        np.savez(save_dir+save_suffix+"spikes_supv.npz", spikes=spikes_supv)
    if sim_bio:
        spikes_bio = sim.data[p_spk_bio]
        act_lpf_bio = sim.data[p_act_bio]
        np.savez(save_dir+save_suffix+"spikes_bio.npz", spikes=spikes_bio)
    act_lif = sim.data[p_act_lif]
    xhat_lif = sim.data[p_lif]
    nrmse_lif = nrmse(xhat_lif, target=target)
    np.savez(save_dir+save_suffix+"target.npz", target=target)
    np.savez(save_dir+save_suffix+"lif.npz", act=act_lif, xhat=xhat_lif)
    
    # optimize d_eps and h_eps given spikes and a target
    if optimize_supv:
        h_eps_supv, d_eps_supv = optimize_elephys(
            save_dir+save_suffix,
            "spikes_supv.npz",
            n_neurons,
            taus['readout'],
            normalize=True,
            seed=seeds['ls'])
        
    if optimize_bio:
        h_eps_bio, d_eps_bio = optimize_elephys(
            save_dir+save_suffix,
            "spikes_bio.npz",
            n_neurons,
            taus['readout'],
            normalize=True,
            seed=seeds['ls'])

    # Compute activities and xhat from h_eps and d_eps, then bin data for tuning curve estimation
    act_eps_supv = np.zeros_like(act_lif)
    act_eps_bio = np.zeros_like(act_lif)
    x_bins_lif, hz_means_lif, hz_stds_lif = bin_activities_values_1d(
        target, act_lif, n_neurons=n_neurons)
    if sim_supv:
        for n in range(n_neurons):
            act_eps_supv[:,n] = h_eps_supv[n].filt(act_lpf_supv[:,n])
        xhat_eps_supv = np.dot(act_eps_supv, d_eps_supv)
        nrmse_eps_supv = nrmse(xhat_eps_supv, target=target)
        x_bins_eps_supv, hz_means_eps_supv, hz_stds_eps_supv = bin_activities_values_1d(
            target, act_eps_supv, n_neurons=n_neurons)
        np.savez(save_dir+save_suffix+"supv.npz", act=act_eps_supv, xhat=xhat_eps_supv)
        
    if sim_bio:
        for n in range(n_neurons):
            act_eps_bio[:,n] = h_eps_bio[n].filt(act_lpf_bio[:,n])
        xhat_eps_bio = np.dot(act_eps_bio, d_eps_bio)
        nrmse_eps_bio = nrmse(xhat_eps_bio, target=target)
        x_bins_eps_bio, hz_means_eps_bio, hz_stds_eps_bio = bin_activities_values_1d(
            target, act_eps_bio, n_neurons=n_neurons)
        np.savez(save_dir+save_suffix+"bio.npz", act=act_eps_bio, xhat=xhat_eps_bio)

    # Update synaptic encoders and weights for any connection that has been learned
    if learn_pre_supv:
        syn_encoders_pre_u_supv_new = enc_node_pre_u_supv.syn_encoders
        syn_encoders_pre_x_supv_new = enc_node_pre_x_supv.syn_encoders
        syn_weights_pre_u_supv_new = sim.data[pre_u_supv].weights
        syn_weights_pre_x_supv_new = sim.data[pre_x_supv].weights
    else:
        syn_encoders_pre_u_supv_new = syn_encoders_pre_u_supv
        syn_encoders_pre_x_supv_new = syn_encoders_pre_x_supv
        syn_weights_pre_u_supv_new = syn_weights_pre_u_supv
        syn_weights_pre_x_supv_new = syn_weights_pre_x_supv
    if learn_supv_bio:
        syn_encoders_pre_u_bio_new = enc_node_pre_u_bio.syn_encoders
        syn_encoders_supv_bio_new = enc_node_supv_bio.syn_encoders
        syn_weights_pre_u_bio_new = sim.data[pre_u_bio].weights
        syn_weights_supv_bio_new = sim.data[supv_bio].weights
    else:
        syn_encoders_pre_u_bio_new = syn_encoders_pre_u_bio
        syn_encoders_supv_bio_new = syn_encoders_supv_bio
        syn_weights_pre_u_bio_new = syn_weights_pre_u_bio
        syn_weights_supv_bio_new = syn_weights_supv_bio
    if learn_bio_bio:
        syn_encoders_bio_bio_new = enc_node_bio_bio.syn_encoders
        syn_weights_bio_bio_new = sim.data[bio_bio].weights
    else:
        syn_encoders_bio_bio_new = syn_encoders_bio_bio
        syn_weights_bio_bio_new = syn_weights_bio_bio
    
    # Save dictionary of filters, decoders, encoders, and weights
    h_eps_new = {'h_eps_supv': h_eps_supv, 'h_eps_bio': h_eps_bio}
    d_eps_new = {'d_eps_supv': d_eps_supv, 'd_eps_bio': d_eps_bio}
    syn_encoders_new = {
        'syn_encoders_pre_u_supv': syn_encoders_pre_u_supv_new,
        'syn_encoders_pre_x_supv': syn_encoders_pre_x_supv_new,
        'syn_encoders_pre_u_bio': syn_encoders_pre_u_bio_new,
        'syn_encoders_supv_bio': syn_encoders_supv_bio_new,
        'syn_encoders_bio_bio': syn_encoders_bio_bio_new,
        }
    syn_weights_new = {
        'syn_weights_pre_u_supv': syn_weights_pre_u_supv_new,
        'syn_weights_pre_x_supv': syn_weights_pre_x_supv_new,
        'syn_weights_pre_u_bio': syn_weights_pre_u_bio_new,
        'syn_weights_supv_bio': syn_weights_supv_bio_new,
        'syn_weights_bio_bio': syn_weights_bio_bio_new,
        }
    if save_df:
        save_decoders_filters(save_dir+save_suffix,
            d_pre_u, d_pre_x, d_eps_new, syn_weights_new, syn_encoders_new, h_eps_new)
    
    # Plots
    if plot_supv and sim_supv:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 8))
        rasterplot(sim.trange(), spikes_supv, ax=ax1)
        ax1.set(xlabel='time', ylabel='neuron')
        sns.distplot(np.ravel(act_lpf_supv), ax=ax2)
        ax2.set(xlim=((1, 50)), ylim=((0, 0.05)), xlabel='activity', ylabel='frequency', title='supv')
        plt.tight_layout()
        
        times = np.arange(0, 1, 0.001)
        fig, (ax, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
        ax.plot(times, Lowpass(taus['readout']).impulse(len(times)))
        for n in range(n_neurons):
            ax2.plot(times, h_eps_supv[n].impulse(len(times)))
            ax3.plot(times, h_eps_supv[n].filt(Lowpass(taus['readout']).impulse(len(times))))
        ax.set(xlabel='time', ylabel='amplitude', title='lowpass')
        ax2.set(xlabel='time', title='h_eps_supv')
        ax3.set(xlabel='time', title='h_eps_supv.filt(lowpass)')
        plt.tight_layout()
        
        cmap = sns.color_palette('hls', n_neurons)
        fig, (ax4, ax5) = plt.subplots(1, 2, figsize=(8, 8), sharey=True)
        for n in range(n_neurons):
            ax4.plot(x_bins_lif, hz_means_lif[n], c=cmap[n])
            ax4.fill_between(x_bins_lif,
                hz_means_lif[n]+hz_stds_lif[n],
                hz_means_lif[n]-hz_stds_lif[n],
                alpha=0.5, facecolor=cmap[n])
            ax5.plot(x_bins_eps_supv, hz_means_eps_supv[n], c=cmap[n])
            ax5.fill_between(x_bins_eps_supv,
                hz_means_eps_supv[n]+hz_stds_eps_supv[n],
                hz_means_eps_supv[n]-hz_stds_eps_supv[n],
                alpha=0.5, facecolor=cmap[n])
        ax4.set(xlim=((-1,1)), ylim=((0, 50)), xlabel='$\mathbf{x}$', ylabel='activity (Hz)', title='lif')
        ax5.set(xlim=((-1,1)), ylim=((0, 50)), xlabel='$\mathbf{x}$', title='elephys supv')
        plt.tight_layout()
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 4))
        ax.plot(sim.trange(), target, label='target', linestyle='--')
        ax.plot(sim.trange(), xhat_lif, alpha=0.5, label='lif, nrmse=%.3f' %nrmse_lif)
        ax.plot(sim.trange(), xhat_eps_supv, alpha=0.5, label='elephys, nrmse=%.3f' %nrmse_eps_supv)
        ax.set(xlabel='time', ylabel='$\mathbf{x}$', title='supv')
        ax.legend(loc='lower left')
        plt.tight_layout()
        
        plt.show()

    if plot_bio and sim_bio:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 8))
        rasterplot(sim.trange(), spikes_bio, ax=ax1)
        ax1.set(xlabel='time', ylabel='neuron')
        sns.distplot(np.ravel(act_lpf_bio), ax=ax2)
        ax2.set(xlim=((1, 50)), ylim=((0, 0.05)), xlabel='activity', ylabel='frequency', title='bio')
        
        times = np.arange(0, 1, 0.001)
        fig, (ax, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
        ax.plot(times, Lowpass(taus['readout']).impulse(len(times)))
        for n in range(n_neurons):
            ax2.plot(times, h_eps_bio[n].impulse(len(times)))
            ax3.plot(times, h_eps_bio[n].filt(Lowpass(taus['readout']).impulse(len(times))))
        ax.set(xlabel='time', ylabel='amplitude', title='lowpass')
        ax2.set(xlabel='time', title='h_eps_bio')
        ax3.set(xlabel='time', title='h_eps_bio.filt(lowpass)')
        plt.tight_layout()
        
        cmap = sns.color_palette('hls', n_neurons_plot)
        fig, (ax4, ax5) = plt.subplots(1, 2, figsize=(8, 8), sharey=True)
        for n in range(n_neurons_plot):
            ax4.plot(x_bins_lif, hz_means_lif[n], c=cmap[n])
            ax4.fill_between(x_bins_lif,
                hz_means_lif[n]+hz_stds_lif[n],
                hz_means_lif[n]-hz_stds_lif[n],
                alpha=0.5, facecolor=cmap[n])
            ax5.plot(x_bins_eps_bio, hz_means_eps_bio[n], c=cmap[n])
            ax5.fill_between(x_bins_eps_bio,
                hz_means_eps_bio[n]+hz_stds_eps_bio[n],
                hz_means_eps_bio[n]-hz_stds_eps_bio[n],
                alpha=0.5, facecolor=cmap[n])
        ax4.set(xlim=((-1,1)), ylim=((0, 50)), xlabel='$\mathbf{x}$', ylabel='activity (Hz)', title='lif')
        ax5.set(xlim=((-1,1)), ylim=((0, 50)), xlabel='$\mathbf{x}$', title='elephys bio')
        plt.tight_layout()
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 4))
        ax.plot(sim.trange(), target, label='target', linestyle='--')
        ax.plot(sim.trange(), xhat_lif, alpha=0.5, label='lif, nrmse=%.3f' %nrmse_lif)
        ax.plot(sim.trange(), xhat_eps_bio, alpha=0.5, label='elephys, nrmse=%.3f' %nrmse_eps_bio)
        ax.set(xlabel='time', ylabel='$\mathbf{x}$', title='bio')
        ax.legend(loc='lower left')
        plt.tight_layout()
        
        plt.show()

    return d_eps_new, h_eps_new, syn_encoders_new, syn_weights_new

# Simulations

In [None]:
n_neurons = 100
n_pre = 100
max_evals = 100
n_syn = 1
freq = 1
taus = {'network': 0.05, 'readout': 0.05}
eta = 5e-5

t_pre_supv = 16*np.pi
t_supv_bio = 16*np.pi
t_bio_bio = 16*np.pi
t_cos = 4*np.pi
t_white_noise = 4*np.pi
signal = 'cos'

save_dir='/home/pduggins/nengo_bioneurons/nengo_bioneurons/tests/data/double_optimize_alternate/%s_neurons_%s_evals_%s_nsyn_%.3f_freq_%s_eta_%s_signal/' %(n_neurons, max_evals, n_syn, freq, eta, signal)
     
d_eps_bio = np.zeros((n_neurons, 1))
d_eps_supv = np.zeros((n_neurons, 1))

h_eps_bio = [Lowpass(taus['readout']) for _ in range(n_neurons)]
h_eps_supv = [Lowpass(taus['readout']) for _ in range(n_neurons)]

syn_encoders_pre_bio, syn_encoders_bio_bio, d_pre_u, d_pre_x = get_syn_encoders_init(
    n_neurons, n_pre, n_syn, t=t_pre_supv, freq=freq, T_u=0.1, T_x=1.0)
syn_encoders_pre_u_supv = syn_encoders_pre_bio
syn_encoders_pre_x_supv = syn_encoders_pre_bio
syn_encoders_pre_u_bio = syn_encoders_pre_bio
syn_encoders_supv_bio = syn_encoders_bio_bio
syn_encoders_bio_bio = syn_encoders_bio_bio

syn_weights_pre_u_supv = np.zeros((n_neurons, n_pre, n_syn))
syn_weights_pre_x_supv = np.zeros((n_neurons, n_pre, n_syn))
syn_weights_pre_u_bio = np.zeros((n_neurons, n_pre, n_syn))
syn_weights_supv_bio = np.zeros((n_neurons, n_neurons, n_syn))
syn_weights_bio_bio = np.zeros((n_neurons, n_neurons, n_syn))

In [None]:
print "Training syn_encoders for pre_u_supv and pre_x_supv ..."
_, _, syn_encoders_dict_new, syn_weights_dict_new = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_pre_supv,
    signal=signal,
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="syn_enc_supv/",
    eta=eta,
    learn_pre_supv=True,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=False,
    plot_supv=True,
    plot_bio=False,
    sim_supv=True,
    sim_bio=False)
syn_weights_pre_u_supv = syn_weights_dict_new['syn_weights_pre_u_supv']
syn_weights_pre_x_supv = syn_weights_dict_new['syn_weights_pre_x_supv']
syn_encoders_pre_u_supv = syn_encoders_dict_new['syn_encoders_pre_u_supv']
syn_encoders_pre_x_supv = syn_encoders_dict_new['syn_encoders_pre_x_supv']

print "Training d_eps/h_eps for supv ..."
d_eps_dict, h_eps_dict, _, _ = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_pre_supv,
    signal=signal,
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="d_eps_supv/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=True,
    optimize_bio=False,
    plot_supv=True,
    plot_bio=False,
    sim_supv=True,
    sim_bio=False)
d_eps_supv = d_eps_dict['d_eps_supv']
h_eps_supv = h_eps_dict['h_eps_supv']

In [None]:
print "Training syn_encoders for pre_u_bio and supv_bio ..."
_, _, syn_encoders_dict_new, syn_weights_dict_new = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_supv_bio,
    signal=signal,
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="syn_enc_bio/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=True,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=False,
    plot_supv=False,
    plot_bio=True,
    sim_supv=True,
    sim_bio=True)
syn_weights_pre_u_bio = syn_weights_dict_new['syn_weights_pre_u_bio']
syn_weights_supv_bio = syn_weights_dict_new['syn_weights_supv_bio']
syn_encoders_pre_u_bio = syn_encoders_dict_new['syn_encoders_pre_u_bio']
syn_encoders_supv_bio = syn_encoders_dict_new['syn_encoders_supv_bio']

print "Training d_eps/h_eps for bio ..."
d_eps_dict, h_eps_dict, _, _ = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_supv_bio,
    signal=signal,
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="d_eps_bio/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=True,
    plot_supv=False,
    plot_bio=True,
    sim_supv=True,
    sim_bio=True)
d_eps_bio = d_eps_dict['d_eps_bio']
h_eps_bio = h_eps_dict['h_eps_bio']

In [None]:
print "Training syn_encoders for bio_bio ..."
_, _, syn_encoders_dict_new, syn_weights_dict_new = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_supv_bio,
    signal=signal,
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="syn_enc_bio_bio/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=True,
    optimize_supv=False,
    optimize_bio=False,
    plot_supv=False,
    plot_bio=True,
    sim_supv=False,
    sim_bio=True)
syn_weights_bio_bio = syn_weights_dict_new['syn_weights_bio_bio']
syn_encoders_bio_bio = syn_encoders_dict_new['syn_encoders_bio_bio']

print "Training d_eps/h_eps for bio again ..."
d_eps_dict, h_eps_dict, _, _ = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_supv_bio,
    signal=signal,
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="d_eps_bio/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=True,
    plot_supv=False,
    plot_bio=True,
    sim_supv=False,
    sim_bio=True)
d_eps_bio = d_eps_dict['d_eps_bio']
h_eps_bio = h_eps_dict['h_eps_bio']

In [None]:
print "Testing with sinusoid ..."
_, _, _, _ = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_cos,
    signal='cos',
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="test_cos/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=False,
    plot_supv=False,
    plot_bio=True,
    sim_supv=False,
    sim_bio=True)

print "Testing with white noise ..."
_, _, _, _ = double_optimize_alternate(
    d_pre_u,
    d_pre_x,    
    d_eps_bio,
    d_eps_supv,
    h_eps_bio,
    h_eps_supv,
    syn_encoders_pre_u_supv,
    syn_encoders_pre_x_supv,
    syn_encoders_pre_u_bio,
    syn_encoders_supv_bio,
    syn_encoders_bio_bio,
    syn_weights_pre_u_supv,
    syn_weights_pre_x_supv,
    syn_weights_pre_u_bio,
    syn_weights_supv_bio,
    syn_weights_bio_bio,
    t=t_white_noise,
    signal='white_noise',
    freq=freq,
    n_neurons=n_neurons,
    n_syn=n_syn,
    save_dir=save_dir,
    save_suffix="test_white_noise/",
    eta=eta,
    learn_pre_supv=False,
    learn_supv_bio=False,
    learn_bio_bio=False,
    optimize_supv=False,
    optimize_bio=False,
    plot_supv=False,
    plot_bio=True,
    sim_supv=False,
    sim_bio=True)