In [13]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.behavioral_utils as behavioral_utils

import os
import pandas as pd
import matplotlib

SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"

feature_dims = ["Color", "Shape", "Pattern"]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
session = "20180802"
model_path = f"/data/082023_Feat_RLDE_HV/sess-{session}_hv.csv"
model_vals = pd.read_csv(model_path)
feat_names = np.array([
    'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', 
    'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', 
    'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL'
])

In [4]:
renames = {}
for i, feat_name in enumerate(feat_names):
    renames[f"feat_{i}"] = feat_name

In [8]:
model_vals = model_vals.rename(columns=renames)

In [9]:
model_vals

Unnamed: 0,trial,CIRCLE,SQUARE,STAR,TRIANGLE,CYAN,GREEN,MAGENTA,YELLOW,ESCHER,POLKADOT,RIPPLE,SWIRL
0,35,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1,36,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2,37,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3,38,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
4,39,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1702,1738,0.130760,0.181514,0.007635,0.218069,-0.017239,0.540891,-0.000036,-0.131027,0.193760,0.032039,0.110023,0.196207
1703,1739,0.126029,0.174946,0.007359,0.245983,-0.016615,0.568805,-0.000035,-0.126286,0.186749,0.030880,0.137937,0.189108
1704,1740,0.121469,0.168616,0.057369,0.237082,-0.016014,0.618815,-0.000034,-0.121716,0.179992,0.029762,0.132946,0.239118
1705,1741,0.117074,0.162515,0.075414,0.228504,-0.015434,0.636860,-0.000033,-0.117312,0.173479,0.028685,0.128135,0.257163


In [11]:
behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=session)
beh = pd.read_csv(behavior_path)

# filter trials 
valid_beh = behavioral_utils.get_valid_trials(beh)

# grab the features of the selected card
feature_selections = behavioral_utils.get_selection_features(valid_beh)
valid_beh = pd.merge(valid_beh, feature_selections, on="TrialNumber", how="inner")
valid_beh_vals = pd.merge(valid_beh, model_vals, left_on="TrialNumber", right_on="trial", how="inner")
assert(len(valid_beh_vals) == len(valid_beh))

In [14]:
def get_highest_val_feature(row):
    highest_val_feat = None
    highest_val = -1
    for feature_dim in feature_dims:
        feature = row[feature_dim]
        val = row[feature]
        if val > highest_val:
            highest_val_feat = feature
    row["highest_val_feature"] = highest_val_feat
    return row
beh_vals = valid_beh_vals.apply(get_highest_val_feature, axis=1)

In [15]:
beh_vals

Unnamed: 0,TrialNumber,BlockNumber,TrialAfterRuleChange,TaskInterrupt,ConditionNumber,Response,ItemChosen,TrialType,CurrentRule,LastRule,...,TRIANGLE,CYAN,GREEN,MAGENTA,YELLOW,ESCHER,POLKADOT,RIPPLE,SWIRL,highest_val_feature
0,35,2,0,,641,Incorrect,3.0,8,CIRCLE,CYAN,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,ESCHER
1,36,2,1,,627,Incorrect,3.0,8,CIRCLE,CYAN,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,POLKADOT
2,37,2,2,,808,Incorrect,3.0,8,CIRCLE,CYAN,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,SWIRL
3,38,2,3,,783,Incorrect,1.0,8,CIRCLE,CYAN,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,ESCHER
4,39,2,4,,1116,Incorrect,1.0,8,CIRCLE,CYAN,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,POLKADOT
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1702,1738,54,31,,4072,Correct,0.0,7,GREEN,CYAN,...,0.218069,-0.017239,0.540891,-0.000036,-0.131027,0.193760,0.032039,0.110023,0.196207,RIPPLE
1703,1739,54,32,,4429,Correct,0.0,7,GREEN,CYAN,...,0.245983,-0.016615,0.568805,-0.000035,-0.126286,0.186749,0.030880,0.137937,0.189108,SWIRL
1704,1740,54,33,,4457,Correct,0.0,7,GREEN,CYAN,...,0.237082,-0.016014,0.618815,-0.000034,-0.121716,0.179992,0.029762,0.132946,0.239118,SWIRL
1705,1741,54,34,,4380,Correct,0.0,7,GREEN,CYAN,...,0.228504,-0.015434,0.636860,-0.000033,-0.117312,0.173479,0.028685,0.128135,0.257163,RIPPLE
