### Script for generating subtrials to use in decoding on Hyak

In [4]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.behavioral_utils as behavioral_utils

import os
import pandas as pd
import matplotlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
FEATURE_DIMS = ["Color", "Shape", "Pattern"]
OUTPUT_DIR = "/data/patrick_res/information"

SESSIONS_PATH = "/data/patrick_res/multi_sess/valid_sessions_rpe.pickle"
SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"

PRE_INTERVAL = 500
POST_INTERVAL = 500
INTERVAL_SIZE = 100
SMOOTH = 1
EVENT = "StimOnset"

### Generate balance correct vs incorrect subtrials

In [13]:
SEED = 42
def process_session(sess_name):
    behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=sess_name)
    beh = pd.read_csv(behavior_path)

    # filter trials 
    valid_beh = behavioral_utils.get_valid_trials(beh)
    valid_beh = behavioral_utils.get_rpes_per_session(sess_name, valid_beh)
    cor_beh = valid_beh[valid_beh.Response == "Correct"]
    inc_beh = valid_beh[valid_beh.Response == "Incorrect"]

    min_num_trials = np.min((len(cor_beh), len(inc_beh)))
    cor_beh = cor_beh.sample(min_num_trials, random_state=SEED)
    pos_med = cor_beh.RPE_FE.median()

    inc_beh = inc_beh.sample(min_num_trials, random_state=SEED)
    neg_med = inc_beh.RPE_FE.median()

    valid_beh = pd.concat([cor_beh, inc_beh])

    def add_group(row):
        rpe = row.RPE_FE
        group = None
        if rpe < neg_med:
            group = "more neg"
        elif rpe >= neg_med and rpe < 0:
            group = "less neg"
        elif rpe >= 0 and rpe < pos_med:
            group = "less pos"
        elif rpe > pos_med:
            group = "more pos"
        row["RPEGroup"] = group
        return row
    valid_beh = valid_beh.apply(add_group, axis=1)
    valid_beh["session"] = sess_name
    return valid_beh

In [14]:
valid_sess = pd.read_pickle(SESSIONS_PATH)
all_trials = pd.concat(valid_sess.apply(lambda x: process_session(x.session_name), axis=1).values)

In [26]:
all_trials[all_trials.RPEGroup == "more pos"].to_pickle("/data/patrick_res/more_pos_trials.pickle")

In [9]:
cor_bal_trials = all_trials[all_trials.Response == "Correct"]
inc_bal_trials = all_trials[all_trials.Response == "Incorrect"]

In [10]:
cor_bal_trials.to_pickle("/data/patrick_res/cor_bal_trials.pickle")
inc_bal_trials.to_pickle("/data/patrick_res/inc_bal_trials.pickle")