In [4]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.behavioral_utils as behavioral_utils

import os
import pandas as pd
import matplotlib

from sklearn.linear_model import LinearRegression

EVENT = "FeedbackOnset"  # event in behavior to align on
PRE_INTERVAL = 1300   # time in ms before event
POST_INTERVAL = 1500  # time in ms after event
INTERVAL_SIZE = 100  # size of interval in ms

SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
SESS_SPIKES_PATH = "/data/patrick_res/firing_rates/{sess_name}_firing_rates_{pre_interval}_{event}_{post_interval}_{interval_size}_bins_1_smooth.pickle"
SESSIONS_PATH = "/data/patrick_res/sessions/valid_sessions_rpe.pickle"

feature_dims = ["Color", "Shape", "Pattern"]

rule_to_dim = {
    'CIRCLE': 'Shape', 
    'SQUARE': 'Shape', 
    'STAR': 'Shape', 
    'TRIANGLE': 'Shape', 
    'CYAN': 'Color', 
    'GREEN': 'Color', 
    'MAGENTA': 'Color', 
    'YELLOW': 'Color', 
    'ESCHER': 'Pattern', 
    'POLKADOT': 'Pattern', 
    'RIPPLE': 'Pattern', 
    'SWIRL': 'Pattern'
}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
session = "20180709"
model_path = f"/data/082023_Feat_RLDE_HV/sess-{session}_hv.csv"
model_vals = pd.read_csv(model_path)
feat_names = np.array([
    'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', 
    'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', 
    'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL'
])
renames = {}
for i, feat_name in enumerate(feat_names):
    renames[f"feat_{i}"] = feat_name
model_vals = model_vals.rename(columns=renames)

In [6]:
behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=session)
beh = pd.read_csv(behavior_path)
# filter trials 
valid_beh = behavioral_utils.get_valid_trials(beh)
spikes_path = SESS_SPIKES_PATH.format(
    sess_name=session, 
    pre_interval=PRE_INTERVAL, 
    event=EVENT, 
    post_interval=POST_INTERVAL, 
    interval_size=INTERVAL_SIZE
)
frs = pd.read_pickle(spikes_path)

# grab the features of the selected card
feature_selections = behavioral_utils.get_selection_features(valid_beh)
valid_beh = pd.merge(valid_beh, feature_selections, on="TrialNumber", how="inner")
valid_beh_vals = pd.merge(valid_beh, model_vals, left_on="TrialNumber", right_on="trial", how="inner")
assert(len(valid_beh_vals) == len(valid_beh))

In [7]:
def get_highest_val_feat(row):
    color = row["Color"]
    shape = row["Shape"]
    pattern = row["Pattern"]
    vals = {color: row[color], shape: row[shape], pattern: row[pattern]}
    max_feat = max(zip(vals.values(), vals.keys()))[1]
    row["MaxFeat"] = max_feat
    row["MaxFeatDim"] = rule_to_dim[max_feat]
    return row
valid_beh_max = valid_beh_vals.apply(get_highest_val_feat, axis=1)


In [8]:
def check_session(row):
    session = row.session_name
    model_path = f"/data/082023_Feat_RLDE_HV/sess-{session}_hv.csv"
    model_vals = pd.read_csv(model_path)
    feat_names = np.array([
        'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', 
        'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', 
        'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL'
    ])
    renames = {}
    for i, feat_name in enumerate(feat_names):
        renames[f"feat_{i}"] = feat_name
    model_vals = model_vals.rename(columns=renames)
    behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=session)
    beh = pd.read_csv(behavior_path)
    # filter trials 
    valid_beh = behavioral_utils.get_valid_trials(beh)

    # grab the features of the selected card
    feature_selections = behavioral_utils.get_selection_features(valid_beh)
    valid_beh = pd.merge(valid_beh, feature_selections, on="TrialNumber", how="inner")
    valid_beh_vals = pd.merge(valid_beh, model_vals, left_on="TrialNumber", right_on="trial", how="inner")
    assert(len(valid_beh_vals) == len(valid_beh))
    def get_highest_val_feat(row):
        color = row["Color"]
        shape = row["Shape"]
        pattern = row["Pattern"]
        vals = {color: row[color], shape: row[shape], pattern: row[pattern]}
        max_feat = max(zip(vals.values(), vals.keys()))[1]
        row["MaxFeat"] = max_feat
        row["MaxFeatDim"] = rule_to_dim[max_feat]
        return row
    valid_beh_max = valid_beh_vals.apply(get_highest_val_feat, axis=1)
    has_12 = len(valid_beh_max.MaxFeat.unique()) == 12
    enough_per_feat = np.all(valid_beh_max.groupby("MaxFeat").count().TrialNumber >= 15)    
    row["enough_data"] = has_12 and enough_per_feat
    return row

In [9]:
sessions = pd.read_pickle(SESSIONS_PATH)
valid_session = sessions.apply(check_session, axis=1)

In [10]:
len(valid_session)

27