# Why we needed more data

Our previous neural data from CO and CST with Earl and Ford suggested something that looked like at least one dimension of neural activity that specified the upcoming trial type (CO or CST), despite behavior being roughly similar in the center hold period across trial types. Because the monkeys knew which type of trial was coming up, this theoretically would allow them to prepare their behavior before movement in whatever way they needed to optimize their success.

Unfortunately, those data had a few problems for this kind of analysis:

1. Monkeys knew what kind of trial was coming up as they were reaching to the center target. This makes it difficult to study the evolution of this contextual preparation signal because it's confounded with a reach.
2. CO was a much shorter task than CST, so it was possible that the signal we saw was merely anticipation of a quicker reward, rather than contextual preparation.
3. CO behavior was much simpler ahead of the reward--only one moment of visual input necessary to specify the behavior for the entire trial in CO, unlike CST, where the monkey has to pay attention to the cursor throughout the trial.
4. The split between CO and CST was not quite even during the main part of the recording session, so there may have been some bias in the monkeys' expectations of upcoming trial.
5. CO involved movements in both horizontal and vertical axes, whereas CST involved movements in only the horizontal axis.

For this reason, we collected new data from another monkey, this time with a few modifications:

1. Instead of CO, we introduced the monkey to a new horizontal random target task (RTT). In the RTT, monkeys would reach to a visually presented target. Once the monkey reached the target, a new reach target would appear for the monkey to reach to. This would continue until the monkey reached to 8 targets sequentially, at which point the trial would end with a reward. Importantly, each of the 8 targets would be selected uniformly from a set of 17 targets lined up on the horizontal axis, so movements would only be horizontal. For this tasks, 8 targets seemed to be a good number to match the 6 second trial time of CST.
2. At the start of the trial, monkeys would be presented with an ambiguous center target to reach and hold in, without knowing which type of trial was coming up. After a short delay (0.3-0.5s), the hold target would change shape to indicate which of the two tasks (CST or RTT) was coming up. After another short delay (0.5-0.75s), the trial would start.

Ideally, these changes would allow for a more careful examination of the putative contextual preparation signal--this notebook serves as an investigation of that.

In [None]:
import src
import pyaldata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import k3d
import yaml
from sklearn.decomposition import PCA
from ipywidgets import interact

with open('../params.yaml','r') as f:
    params = yaml.safe_load(f)
    inspection_params = params['inspection']

sns.set_context('talk')

%load_ext autoreload
%autoreload 2

# Data inspection

To start, let's import a dataset (2022/07/20) and preprocess it. This takes a number of preprocessing steps, and they should be fairly self-evident in the pipeline below.

In [None]:
td = (
    pyaldata.mat2dataframe('../data/trial_data/Prez_20220720_RTTCSTCO_TD.mat', shift_idx_fields=True, td_name='trial_data')
    .assign(
        date_time=lambda x: pd.to_datetime(x['date_time']),
        session_date=lambda x: pd.DatetimeIndex(x['date_time']).normalize()
    )
    .query('task=="RTT" | task=="CST"')
    .pipe(src.data.remove_aborts, verbose=inspection_params['verbose'])
    .pipe(src.data.remove_artifact_trials, verbose=inspection_params['verbose'])
    .pipe(src.data.filter_unit_guides, filter_func=lambda guide: guide[:,1] > (0 if inspection_params['keep_unsorted'] else 1))
    .pipe(src.data.remove_correlated_units)
    .pipe(pyaldata.remove_low_firing_neurons, 'M1_spikes', threshold=0.1, divide_by_bin_size=True, verbose=inspection_params['verbose'])
    .pipe(pyaldata.remove_low_firing_neurons, 'PMd_spikes', threshold=0.1, divide_by_bin_size=True, verbose=inspection_params['verbose'])
    .pipe(pyaldata.remove_low_firing_neurons, 'MC_spikes', threshold=0.1, divide_by_bin_size=True, verbose=inspection_params['verbose'])
    .pipe(pyaldata.add_firing_rates,method='smooth', std=0.05, backend='convolve')
    .pipe(src.data.trim_nans, ref_signals=['rel_hand_pos'])
    .pipe(src.data.fill_kinematic_signals)
    .pipe(src.data.rebin_data,new_bin_size=inspection_params['bin_size'])
    .pipe(pyaldata.soft_normalize_signal,signals=['M1_rates','PMd_rates','MC_rates'])
    .pipe(pyaldata.dim_reduce,PCA(n_components=3),'M1_rates','M1_pca')
    .pipe(pyaldata.dim_reduce,PCA(n_components=3),'PMd_rates','PMd_pca')
    .pipe(pyaldata.dim_reduce,PCA(n_components=3),'MC_rates','MC_pca')
    .assign(
        **{
            'idx_ctHoldTime': lambda x: x['idx_ctHoldTime'].map(lambda y: y[-1] if y.size>1 else y),
            'Ambiguous Hold Period': lambda x: x['bin_size']*(x['idx_pretaskHoldTime'] - x['idx_ctHoldTime']),
            'Cued Hold Period': lambda x: x['bin_size']*(x['idx_goCueTime'] - x['idx_pretaskHoldTime']),
            'Movement Period': lambda x: x['bin_size']*(x['idx_endTime'] - x['idx_goCueTime']),
        }
    )
)

## Individual trials

As a sanity check, let's check out the rasters from a couple random CST trials and RTT trials:

In [None]:
_,axs = plt.subplots(4,1,figsize=(10,15),sharex=True)
for ax,trial_id in zip(axs.flatten(),td.groupby('task').sample(n=2)['trial_id']):
    trial = td.loc[td['trial_id']==trial_id,:].squeeze()
    src.plot.make_trial_raster(
        trial,
        ax,
        sig='MC_spikes',
        events=[
            'idx_ctHoldTime',
            'idx_pretaskHoldTime',
            'idx_goCueTime',
            'idx_cstEndTime',
            'idx_rtHoldTimes',
        ],
        ref_event_idx=trial['idx_goCueTime'])
    ax.set_xlabel('')
    ax.set_ylabel(trial['task'])
axs[0].set_title('Neuron Rasters')
axs[-1].set_xlabel('Time (s)')

Let's double-check the timing of each trial, just to make sure the task is happening the way we think it is.

In [None]:
g = sns.pairplot(
    data=td.query('task=="CST" | task=="RTT"'),
    vars=['Ambiguous Hold Period','Cued Hold Period','Movement Period'],
    hue='task',
    diag_kind='hist',
    corner=True,
    height=3,
)

## Neural state space in RTT and CST

First, let's take a look at the individual

In [None]:
cst_trace_plot = k3d.plot(name='CST smoothed neural traces')
max_abs_hand_vel = np.percentile(np.abs(np.row_stack(td.query('task=="CST"')['hand_vel'])[:,0]),95)
# plot traces
for _,trial in td.query('task=="CST"').sample(n=10).iterrows():
    neural_trace = trial['M1_pca']
    cst_trace_plot+=k3d.line(
        neural_trace[:,0:3].astype(np.float32),
        shader='mesh',
        width=3e-3,
        attribute=trial['hand_vel'][:,0],
        color_map=k3d.paraview_color_maps.Erdc_divHi_purpleGreen,
        color_range=[-max_abs_hand_vel,max_abs_hand_vel],
    )

cst_trace_plot.display()

rtt_trace_plot = k3d.plot(name='RTT smoothed neural traces')
for _,trial in td.query('task=="RTT"').sample(n=10).iterrows():
    neural_trace = trial['M1_pca']
    rtt_trace_plot+=k3d.line(
        neural_trace[:,0:3].astype(np.float32),
        shader='mesh',
        width=3e-3,
        attribute=trial['hand_vel'][:,0],
        #attribute=150*np.ones(neural_trace.shape[0]),
        color_map=k3d.paraview_color_maps.Erdc_divHi_purpleGreen,
        color_range=[-max_abs_hand_vel,max_abs_hand_vel],
    )
rtt_trace_plot.display()

In [None]:
def extract_td_epochs(td):
    '''
    Prepare data for hold-time PCA and LDA, as well as data for smooth hold/move M1 activity
    
    Arguments:
        args (Namespace): Namespace of command-line arguments
        
    Returns:
        td_binned (DataFrame): PyalData formatted structure of neural/behavioral data
        td_smooth (DataFrame): PyalData formatted structure of neural/behavioral data
    '''
    binned_epoch_dict = {
        'ambig_hold': src.util.generate_realtime_epoch_fun(
            'idx_pretaskHoldTime',
            rel_start_time=-0.3,
        ),
        'hold': src.util.generate_realtime_epoch_fun(
            'idx_goCueTime',
            rel_start_time=-0.3,
        ),
        'move': src.util.generate_realtime_epoch_fun(
            'idx_goCueTime',
            rel_start_time=0,
            rel_end_time=0.3,
        ),
    }

    td_binned = (
        td.copy()
        .pipe(src.util.split_trials_by_epoch,binned_epoch_dict)
        .pipe(src.data.rebin_data,new_bin_size=0.3)
        .pipe(pyaldata.add_firing_rates,method='bin')
    )

    spike_fields = [name for name in td.columns.values if name.endswith("_spikes")]
    for field in spike_fields:
        assert td_binned[field].values[0].ndim==1, "Binning didn't work"

    smooth_epoch_dict = {
        'hold_move': src.util.generate_realtime_epoch_fun(
            'idx_goCueTime',
            rel_start_time=-0.8,
            rel_end_time=0.5,
        ),
        'hold_move_ref_cue': src.util.generate_realtime_epoch_fun(
            'idx_pretaskHoldTime',
            rel_start_time=-0.3,
            rel_end_time=1.0,
        ),
        'full': lambda trial : slice(0,trial['hand_pos'].shape[0]),
    }
    td_smooth = (
        td.copy()
        .pipe(pyaldata.add_firing_rates,method='smooth',std=0.05,backend='convolve')
        .pipe(src.util.split_trials_by_epoch,smooth_epoch_dict)
        .pipe(src.data.rebin_data,new_bin_size=0.05)
    )

    td_epochs = pd.concat([td_binned,td_smooth]).reset_index()

    return td_epochs

In [None]:
td_epoch = extract_td_epochs(td.rename(columns={'M1_rates':'M1_orig_rates','MC_rates':'M1_rates'}))

In [None]:
td_train,td_test = src.context_analysis.apply_models(td_epoch,train_epochs=['ambig_hold'],test_epochs=['hold_move','hold_move_ref_cue','full'])

fig_gen_dict = {
    'task_M1_pca':src.context_analysis.plot_hold_pca(td_train,array_name='M1',hue_order=['RTT','CST']),
    'task_M1_lda':src.context_analysis.plot_M1_lda(td_train,hue_order=['RTT','CST']),
    'task_beh':src.context_analysis.plot_hold_behavior(td_train,hue_order=['RTT','CST']),
    'task_beh_lda':src.context_analysis.plot_beh_lda(td_train,hue_order=['RTT','CST']),
    # 'task_M1_potent': src.plot_M1_hold_potent(td_train,hue_order=['RTT','CST']),
    # 'task_M1_potent_lda': src.plot_M1_potent_lda(td_train,hue_order=['RTT','CST']),
    # 'task_M1_null_lda': src.plot_M1_null_lda(td_train,hue_order=['RTT','CST']),
    # LDA traces
    'task_lda_trace':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="hold_move"'),ref_event='idx_goCueTime',label_colors={'RTT':'r','CST':'b'}),
    'task_lda_trace_pretask':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="hold_move_ref_cue"'),ref_event='idx_pretaskHoldTime',label_colors={'RTT':'r','CST':'b'}),
    'task_lda_trace':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="full"'),ref_event='idx_goCueTime',label_colors={'RTT':'r','CST':'b'}),
}

In [None]:
td_train,td_test = src.context_analysis.apply_models(td_epoch,train_epochs=['hold'],test_epochs=['hold_move','hold_move_ref_cue','full'])

fig_gen_dict = {
    'task_M1_pca':src.context_analysis.plot_hold_pca(td_train,array_name='M1',hue_order=['RTT','CST']),
    'task_M1_lda':src.context_analysis.plot_M1_lda(td_train,hue_order=['RTT','CST']),
    'task_beh':src.context_analysis.plot_hold_behavior(td_train,hue_order=['RTT','CST']),
    'task_beh_lda':src.context_analysis.plot_beh_lda(td_train,hue_order=['RTT','CST']),
    # 'task_M1_potent': src.plot_M1_hold_potent(td_train,hue_order=['RTT','CST']),
    # 'task_M1_potent_lda': src.plot_M1_potent_lda(td_train,hue_order=['RTT','CST']),
    # 'task_M1_null_lda': src.plot_M1_null_lda(td_train,hue_order=['RTT','CST']),
    # LDA traces
    'task_lda_trace':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="hold_move"'),ref_event='idx_goCueTime',label_colors={'RTT':'r','CST':'b'}),
    'task_lda_trace_pretask':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="hold_move_ref_cue"'),ref_event='idx_pretaskHoldTime',label_colors={'RTT':'r','CST':'b'}),
    'task_lda_trace':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="full"'),ref_event='idx_goCueTime',label_colors={'RTT':'r','CST':'b'}),
}

In [None]:
td_train,td_test = src.context_analysis.apply_models(td_epoch,train_epochs=['move'],test_epochs=['hold_move','hold_move_ref_cue','full'])

fig_gen_dict = {
    'task_M1_pca':src.context_analysis.plot_hold_pca(td_train,array_name='M1',hue_order=['RTT','CST']),
    'task_M1_lda':src.context_analysis.plot_M1_lda(td_train,hue_order=['RTT','CST']),
    'task_beh':src.context_analysis.plot_hold_behavior(td_train,hue_order=['RTT','CST']),
    'task_beh_lda':src.context_analysis.plot_beh_lda(td_train,hue_order=['RTT','CST']),
    # 'task_M1_potent': src.plot_M1_hold_potent(td_train,hue_order=['RTT','CST']),
    # 'task_M1_potent_lda': src.plot_M1_potent_lda(td_train,hue_order=['RTT','CST']),
    # 'task_M1_null_lda': src.plot_M1_null_lda(td_train,hue_order=['RTT','CST']),
    # LDA traces
    'task_lda_trace':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="hold_move"'),ref_event='idx_goCueTime',label_colors={'RTT':'r','CST':'b'}),
    'task_lda_trace_pretask':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="hold_move_ref_cue"'),ref_event='idx_pretaskHoldTime',label_colors={'RTT':'r','CST':'b'}),
    'task_lda_trace':src.context_analysis.plot_M1_lda_traces(td_test.query('epoch=="full"'),ref_event='idx_goCueTime',label_colors={'RTT':'r','CST':'b'}),
}