In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import nengo
import scipy
import pandas as pd
import scipy.io
import h5py
from model_1 import *

palette = sns.color_palette('tab10')
sns.set_palette(palette)
sns.set(context='paper', style='whitegrid', font="cmr10", font_scale=1.0)
plt.rcParams['axes.formatter.use_mathtext'] = True

In [10]:
monkey = 'V'
session = 0
seed_network = session + 10 if monkey=='V' else session
bid = 1
trials = 8
filter_width = 10    # 10ms filter
box_filter = np.ones(filter_width)
env = Environment(monkey=monkey, session=session)
net = build_network(env, seed_network=seed_network)
sim = nengo.Simulator(net, dt=net.env.dt, progress_bar=False)
probes = [net.s_v, net.s_w, net.s_a, net.s_vwa, net.s_evc, net.s_drel]
labels = ['value', 'omega', 'action', 'mixed', 'error', 'reliability']
arrays = [[], [], [], [], [], []]

with sim:
    for trial in env.empirical.query("monkey==@monkey & session==@session & bid==@bid")['trial'].unique()[:trials]:
        print(f"running monkey {env.monkey}, session {session}, block {bid}, trial {trial}")
        t_start = sim.trange().shape[0]
        net.env.set_cue(bid, trial)
        sim.run(net.env.t_cue)
        t_end = sim.trange().shape[0]
        for p in range(len(probes)):
            spikes = sim.data[probes[p]][t_start:t_end] / 1000
            binned = scipy.ndimage.convolve1d(spikes, box_filter, mode='nearest')[::filter_width]
            arrays[p].append(binned)
        env.set_action(sim, net)
        env.set_reward(bid, trial)
        sim.run(net.env.t_reward)

spike_dict = {}
for p in range(len(probes)):
    label = labels[p]
    data = np.stack(arrays[p], axis=2)
    spike_dict[label] = data
scipy.io.savemat(f"data/monkey{monkey}_session{session}_block{bid}_trial{trial}.mat", spike_dict)

running monkey V, session 0, block 1, trial 1
running monkey V, session 0, block 1, trial 2
running monkey V, session 0, block 1, trial 3
running monkey V, session 0, block 1, trial 4
running monkey V, session 0, block 1, trial 5
running monkey V, session 0, block 1, trial 6
running monkey V, session 0, block 1, trial 7
running monkey V, session 0, block 1, trial 8


In [8]:
spike_dict

{'value': array([[[3., 1., 2., ..., 2., 0., 0.],
         [4., 1., 2., ..., 2., 0., 0.],
         [4., 2., 2., ..., 2., 0., 0.],
         ...,
         [1., 2., 1., ..., 0., 4., 0.],
         [1., 1., 0., ..., 0., 5., 0.],
         [1., 1., 0., ..., 0., 6., 0.]],
 
        [[6., 1., 0., ..., 5., 1., 1.],
         [5., 1., 0., ..., 4., 1., 1.],
         [4., 1., 0., ..., 3., 1., 1.],
         ...,
         [1., 1., 4., ..., 0., 0., 0.],
         [1., 1., 5., ..., 0., 0., 0.],
         [1., 1., 6., ..., 0., 0., 0.]],
 
        [[1., 2., 0., ..., 6., 5., 0.],
         [1., 2., 0., ..., 5., 4., 0.],
         [1., 2., 0., ..., 4., 3., 0.],
         ...,
         [4., 5., 0., ..., 5., 0., 5.],
         [5., 5., 0., ..., 5., 0., 5.],
         [6., 6., 0., ..., 6., 0., 6.]],
 
        ...,
 
        [[0., 0., 1., ..., 1., 2., 2.],
         [0., 0., 1., ..., 1., 3., 2.],
         [0., 0., 1., ..., 1., 3., 2.],
         ...,
         [1., 0., 1., ..., 0., 1., 1.],
         [1., 0., 1., ..., 0., 