In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.pseudo_utils as pseudo_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.behavioral_utils as behavioral_utils
from utils.session_data import SessionData
import utils.io_utils as io_utils
from utils.constants import *
import json

from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)

import matplotlib.pyplot as plt
import matplotlib

In [2]:
# the output directory to store the data
OUTPUT_DIR = "/data/patrick_res/pseudo"
# path to a dataframe of sessions to analyze
# SESSIONS_PATH = "/data/patrick_scratch/multi_sess/valid_sessions.pickle"
SESSIONS_PATH = "/data/patrick_res/sessions/valid_sessions_rpe.pickle"
# path for each session, specifying behavior
SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
# path for each session, for spikes that have been pre-aligned to event time and binned. 
SESS_SPIKES_PATH = "/data/patrick_res/firing_rates/{sess_name}_firing_rates_{pre_interval}_{event}_{post_interval}_{interval_size}_bins_1_smooth.pickle"

FEATURE_DIMS = ["Color", "Shape", "Pattern"]

In [4]:
def get_labels_for_session(session):
    behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=session)

    beh = pd.read_csv(behavior_path)
    valid_beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(valid_beh)
    valid_beh_merged = pd.merge(valid_beh, feature_selections, on="TrialNumber", how="inner")
    valid_beh_vals_conf = behavioral_utils.get_rpes_per_session(session, valid_beh_merged)
    med_conf = np.median(valid_beh_vals_conf["Prob_FE"].to_numpy())
    def assign_conf(row, med):
        row["Conf"] = "high" if row["Prob_FE"] > med else "low"
        return row
    valid_beh_vals_conf = valid_beh_vals_conf.apply(lambda row: assign_conf(row, med_conf), axis=1)

    # feat_dim = FEATURE_TO_DIM[feat]
    # selected_feat_beh = valid_beh_vals_conf[valid_beh_vals_conf[feat_dim] == feat]
    valid_beh_vals_conf["session"] = session
    return valid_beh_vals_conf

In [5]:
valid_sessions = pd.read_pickle(SESSIONS_PATH)
res = pd.concat(valid_sessions.apply(lambda row: get_labels_for_session(row.session_name), axis=1).values)

In [7]:
res[res.Conf == "low"].to_pickle("/data/patrick_res/low_conf")

Unnamed: 0,TrialNumber,BlockNumber,TrialAfterRuleChange,TaskInterrupt,ConditionNumber,Response,ItemChosen,TrialType,CurrentRule,LastRule,...,trial,fb,Prob_FE,Prob_FD,Prob_FRL,RPE_FE,RPE_FD,RPE_FRL,Conf,session
0,49,2,0,,2258,Incorrect,2.0,11,SQUARE,SWIRL,...,49,0,0.250000,0.250000,0.250000,-0.250000,-0.250000,-0.250000,low,20180709
1,50,2,1,,1881,Incorrect,2.0,11,SQUARE,SWIRL,...,50,0,0.250000,0.250000,0.250000,-0.250000,-0.250000,-0.250000,low,20180709
2,51,2,2,,2206,Incorrect,3.0,11,SQUARE,SWIRL,...,51,0,0.250000,0.250000,0.250000,-0.250000,-0.250000,-0.250000,low,20180709
3,52,2,3,,2005,Correct,0.0,11,SQUARE,SWIRL,...,52,1,0.250000,0.250000,0.250000,0.750000,0.750000,0.750000,low,20180709
4,53,2,4,,1834,Correct,0.0,11,SQUARE,SWIRL,...,53,1,0.286681,0.287395,0.280660,0.713319,0.712605,0.719340,low,20180709
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
842,865,26,2,,2467,Correct,0.0,4,CYAN,GREEN,...,865,1,0.116135,0.158716,0.123846,0.883865,0.841284,0.876154,low,20180910
844,867,26,4,,2489,Correct,0.0,4,CYAN,GREEN,...,867,1,0.417026,0.500836,0.340898,0.582974,0.499164,0.659102,low,20180910
845,868,26,5,,2775,Correct,0.0,4,CYAN,GREEN,...,868,1,0.503678,0.471231,0.338906,0.496322,0.528769,0.661094,low,20180910
848,871,26,8,,2816,Correct,0.0,4,CYAN,GREEN,...,871,1,0.296669,0.597349,0.250715,0.703331,0.402651,0.749285,low,20180910
