In [1]:
import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils

from constants.behavioral_constants import *
from constants.decoding_constants import *

import argparse
from scripts.anova_analysis.anova_configs import add_defaults_to_parser, AnovaConfigs
import utils.io_utils as io_utils
import utils.anova_utils as anova_utils
from tqdm import tqdm

In [11]:
tqdm.pandas()

In [23]:
args = argparse.Namespace(
    **AnovaConfigs()._asdict()
)
args.conditions = ["BeliefConf", "BeliefPartition"]
args.beh_filters = {"Response": "Correct", "Choice": "Chose"}
args.window_size = 500
args.subject = "SA"



FEATS_PATH = "/data/patrick_res/sessions/{sub}/feats_at_least_3blocks.pickle"
feat_sessions = pd.read_pickle(FEATS_PATH.format(sub=args.subject))
feat_sess_pairs = feat_sessions.explode("sessions")
feat_sess_pairs = feat_sess_pairs[:10]


Unnamed: 0,feat,sessions,num_essions
0,CIRCLE,"[20180615, 20180625, 20180709, 20180802, 20180...",22
1,SQUARE,"[20180615, 20180709, 20180801, 20180802, 20180...",22
2,STAR,"[20180709, 20180803, 20180817, 20180821, 20180...",15
3,TRIANGLE,"[20180615, 20180705, 20180801, 20180802, 20180...",11
4,CYAN,"[20180705, 20180801, 20180802, 20180803, 20180...",19
5,GREEN,"[20180802, 20180806, 20180808, 20180813, 20180...",19
6,MAGENTA,"[20180622, 20180705, 20180801, 20180802, 20180...",21
7,YELLOW,"[20180705, 201807250001, 20180801, 20180802, 2...",25
8,ESCHER,"[20180803, 20180808, 20180810, 20180813, 20180...",17
9,POLKADOT,"[20180705, 20180709, 20180802, 20180803, 20180...",19


In [24]:
def split_by_condition(group):
    rng = np.random.default_rng()
    trials = group.TrialNumber.unique()
    rng.shuffle(trials)
    split_point = len(trials) // 2
    return pd.Series({"split_0": trials[:split_point], "split_1": trials[split_point:]})

def find_trial_splits(args, row):
    args.feat = row.feat
    feat = args.feat
    beh = behavioral_utils.load_behavior_from_args(row.sessions, args)
    beh = behavioral_utils.get_belief_partitions(beh, feat)
    beh["Choice"] = beh.apply(lambda x: "Chose" if x[FEATURE_TO_DIM[feat]] == feat else "Not Chose", axis=1)
    beh["FeatPreferred"] = beh["PreferredBelief"].apply(lambda x: "Preferred" if x == feat else "Not Preferred")
    beh = behavioral_utils.filter_behavior(beh, args.beh_filters)
    cond_splits = beh.groupby("BeliefPartition").apply(split_by_condition).reset_index()
    return pd.Series({
        "split_0": np.concatenate(cond_splits.split_0.values), 
        "split_1": np.concatenate(cond_splits.split_1.values)
    })

feat_sess_pairs[["split_0", "split_1"]] = feat_sess_pairs.progress_apply(lambda x: find_trial_splits(args, x), axis=1)


100%|██████████| 10/10 [00:02<00:00,  3.70it/s]


In [32]:
len(feat_sess_pairs.iloc[0].split_0)

54

In [28]:
filt_str = "_".join([f"{k}_{v}"for k, v in args.beh_filters.items()])
components = [args.subject, filt_str, "belief_partition_splits"]
run_name = "_".join(s for s in components if s)

In [29]:
run_name

'SA_Response_Correct_Choice_Chose_belief_partition_splits'

In [33]:
feat_sess_pairs

Unnamed: 0,feat,sessions,num_essions,split_0,split_1
0,CIRCLE,20180615,22,"[453, 287, 655, 654, 659, 447, 657, 661, 451, ...","[452, 443, 442, 660, 450, 651, 448, 656, 658, ..."
0,CIRCLE,20180625,22,"[492, 635, 443, 446, 430, 442, 626, 637, 503, ...","[433, 486, 494, 625, 504, 440, 629, 445, 496, ..."
0,CIRCLE,20180709,22,"[763, 822, 764, 706, 472, 770, 699, 466, 754, ...","[708, 762, 766, 825, 705, 826, 823, 704, 829, ..."
0,CIRCLE,20180802,22,"[729, 815, 55, 821, 823, 225, 217, 223, 256, 7...","[226, 53, 219, 730, 54, 899, 816, 733, 253, 82..."
0,CIRCLE,20180803,22,"[840, 482, 844, 478, 487, 481, 477, 483, 845, ...","[841, 476, 843, 485, 490, 931, 489, 475, 484, ..."
0,CIRCLE,20180806,22,"[520, 1221, 523, 524, 314, 1248, 1223, 519, 53...","[1240, 1237, 1245, 304, 1242, 1246, 315, 1220,..."
0,CIRCLE,20180808,22,"[714, 710, 246, 250, 845, 252, 713, 249, 251, ...","[712, 839, 247, 244, 848, 849, 840, 253, 844, ..."
0,CIRCLE,20180810,22,"[1065, 1071, 778, 772, 1070, 1118, 781, 1115, ...","[449, 1061, 1066, 1069, 1119, 1064, 780, 771, ..."
0,CIRCLE,20180813,22,"[751, 1234, 1252, 1248, 747, 1251, 956, 965, 9...","[1240, 955, 1247, 752, 963, 1245, 1241, 959, 1..."
0,CIRCLE,20180821,22,"[157, 127, 143, 304, 119, 289, 290, 274, 161, ...","[300, 158, 147, 306, 153, 134, 299, 286, 283, ..."
