In [1]:
# # Training tempotrons for the simulation
import os
from os.path import join
import numpy as np
import pandas as pd
from library.Tempotron import Tempotron
from library.script_wrappers import datagen_jitter
from library.utils import save_pickle, load_pickle

# ====================================== Global params and paths ==================================
jitter_times = 5
jitter_ms = 2
project_tag = 'Jit10_2ms_gau'
sim_tag = 'fig6_TrainStand_Icompen2a6'
data_dir = 'sim_results/%s' % sim_tag
save_dir = 'sim_results/%s/%s' % (sim_tag, project_tag)
os.makedirs(save_dir, exist_ok=True)
legendsize = 8
# ======================== Construct training and testing set =================
exintags = ['in']

for exintag in exintags:
    print(exintag)
    if exintag == 'in':
        center_x, center_y = 0, 20
    elif exintag == 'ex':
        center_x, center_y = 0, -20
    else:
        raise ValueError

    simdata = load_pickle(join(data_dir, 'fig6_%s.pkl' % exintag))

    BehDF = simdata['BehDF']
    SpikeDF = simdata['SpikeDF']
    NeuronDF = simdata['NeuronDF']
    MetaData = simdata['MetaData']
    config_dict = simdata['Config']

    theta_phase_plot = BehDF['theta_phase_plot']
    traj_x = BehDF['traj_x'].to_numpy()
    traj_y = BehDF['traj_y'].to_numpy()
    traj_a = BehDF['traj_a'].to_numpy()
    t = BehDF['t'].to_numpy()
    theta_phase = BehDF['theta_phase'].to_numpy()

    nn_ca3 = MetaData['nn_ca3']

    xxtun1d = NeuronDF['neuronx'].to_numpy()
    yytun1d = NeuronDF['neurony'].to_numpy()
    aatun1d = NeuronDF['neurona'].to_numpy()

    xxtun1d_ca3 = xxtun1d[:nn_ca3]
    yytun1d_ca3 = yytun1d[:nn_ca3]
    aatun1d_ca3 = aatun1d[:nn_ca3]
    nx_ca3, ny_ca3 = config_dict['nx_ca3'], config_dict['ny_ca3']
    xxtun2d_ca3 = xxtun1d_ca3.reshape(nx_ca3, nx_ca3)  # Assuming nx = ny
    yytun2d_ca3 = yytun1d_ca3.reshape(nx_ca3, nx_ca3)  # Assuming nx = ny
    aatun2d_ca3 = aatun1d_ca3.reshape(nx_ca3, nx_ca3)  # Assuming nx = ny

    Ipos_max_compen = config_dict['Ipos_max_compen']
    Iangle_diff = config_dict['Iangle_diff']
    Iangle_kappa = config_dict['Iangle_kappa']
    xmin, xmax, ymin, ymax = config_dict['xmin'], config_dict['xmax'], config_dict['ymin'], config_dict['ymax']
    theta_f = config_dict['theta_f']  # in Hz
    theta_T = 1 / theta_f * 1e3  # in ms
    dt = config_dict['dt']
    traj_d = np.append(0, np.cumsum(np.sqrt(np.diff(traj_x) ** 2 + np.diff(traj_y) ** 2)))

    # Find all the neurons in the input space
    all_nidx = np.where((np.abs(xxtun1d_ca3 - center_x) < 10) & (np.abs(yytun1d_ca3 - center_y) < 10))[0]
    all_nidx = np.sort(all_nidx)

    # Trim down SpikeDF
    SpikeDF['tsp'] = SpikeDF['tidxsp'].apply(lambda x: t[x])
    spdftmplist = []
    for nidx in all_nidx:
        spdftmplist.append(SpikeDF[SpikeDF['neuronid'] == nidx])
    SpikeDF_subset = pd.concat(spdftmplist, ignore_index=True)

    # Loop for theta cycles (patterns)
    data_M = []
    label_M = []
    trajtype = []
    theta_bounds = []
    tmax = t.max()
    overlap_r = 2
    cycle_i = 0
    while (cycle_i * theta_T) < tmax:
        print('\rCurrent cycle %d' % cycle_i, end='', flush=True)

        # Create input data - spikes
        theta_tstart, theta_tend = cycle_i * theta_T, (cycle_i + 1) * theta_T
        spdf_M = SpikeDF_subset[(SpikeDF_subset['tsp'] > theta_tstart) & (SpikeDF_subset['tsp'] <= theta_tend)]
        if spdf_M.shape[0] < 1:
            cycle_i += 1
            continue
        data_MN = []
        for nidx in all_nidx:
            spdf_MN = spdf_M[spdf_M['neuronid'] == nidx]
            tsp = spdf_MN['tsp'].to_numpy() - theta_tstart
            data_MN.append(tsp)
        t_intheta = (t > theta_tstart) & (t <= theta_tend)
        data_M.append(data_MN)

        # Create Labels
        traj_r = np.sqrt((traj_x[t_intheta] - center_x) ** 2 + (traj_y[t_intheta] - center_y) ** 2)
        r05 = np.median(traj_r)
        if r05 < overlap_r:
            label = True
        else:
            label = False
        label_M.append(label)

        # Traj type
        behdf_M = BehDF[(BehDF['t'] > theta_tstart) & (BehDF['t'] <= theta_tend)]
        traj_type = int(behdf_M['traj_type'].median())
        trajtype.append(traj_type)
        theta_bounds.append(np.array([theta_tstart, theta_tend]))

        cycle_i += 1
    print()
    theta_bounds = np.stack(theta_bounds)
    data_M = np.array(data_M, dtype=object)
    trajtype = np.array(trajtype)
    labels = np.array(label_M)

    # # Separate train/test set
    train_idx = np.where(trajtype == -1)[0].astype(int)
    test_idx = np.setdiff1d(np.arange(trajtype.shape[0]), train_idx)
    X_train_ori = data_M[train_idx]
    X_test_ori = data_M[test_idx]
    Y_train_ori = labels[train_idx]
    Y_test_ori = labels[test_idx]
    trajtype_train_ori = trajtype[train_idx]
    trajtype_test_ori = trajtype[test_idx]

    # # Jittering
    X_train, Y_train, trajtype_train, Marr_train, jitbatch_train = datagen_jitter(X_train_ori, Y_train_ori,
                                                                                  trajtype_train_ori, jitter_times,
                                                                                  jitter_ms, 0)

in
Current cycle 34

KeyboardInterrupt: 

In [7]:
X_train[0]

array([array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64

In [5]:
X_train[0]

array([array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64),
       array([], dtype=float64), array([], dtype=float64