### Following idea of single feature binary decoding, should look at how many sessions have more than N blocks of a rule


In [9]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
from constants.glm_constants import *
from constants.behavioral_constants import *

import seaborn as sns
import scipy.stats


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
num_bins = 2

def get_sess_beh(row, beh_path, sub):
    session = row.session_name
    behavior_path = beh_path.format(sess_name=session)
    beh = pd.read_csv(behavior_path)

    # # filter trials 
    beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(beh)
    beh = pd.merge(beh, feature_selections, on="TrialNumber", how="inner")
    beh = behavioral_utils.get_beliefs_per_session(beh, session, sub)
    beh = behavioral_utils.get_belief_value_labels(beh)
    beh["session"] = session
    return beh

In [11]:
sub = "SA"

if sub == "SA":
    # Monkey S
    SESSIONS_PATH = "/data/patrick_res/sessions/SA/valid_sessions.pickle"
    sessions = pd.read_pickle(SESSIONS_PATH)
    beh_path = SESS_BEHAVIOR_PATH
    all_beh = pd.concat(sessions.apply(lambda x: get_sess_beh(x, beh_path, "SA"), axis=1).values).reset_index()
elif sub == "BL":
    # Monkey B
    # all_units = pd.read_pickle("/data/patrick_res/firing_rates/BL/all_units.pickle")
    # sessions = pd.DataFrame({"session_name": all_units.session.unique()})
    sessions = pd.read_pickle("/data/patrick_res/sessions/BL/valid_sessions.pickle")
    all_beh = pd.concat(sessions.apply(lambda x: get_sess_beh(x, BL_SESS_BEHAVIOR_PATH, "BL"), axis=1).values).reset_index()
else: 
    raise ValueError()

In [12]:
block_thresh = 3
good_sess = behavioral_utils.get_good_sessions_per_rule(all_beh, block_thresh)


### For each session, for each rule, what are the min number of trials that match: 
- high preferred features
- high conf trials where features are chosen but not preferred. 

In [13]:
def min_trials_per_session(row):
    res = []
    for sess in row.sessions:
        beh = all_beh[all_beh.session == sess]

        pref = behavioral_utils.get_chosen_preferred_single(row.feat, beh)
        min_pref = np.min(pref.groupby("Choice").count().TrialNumber)

        not_pref = behavioral_utils.get_chosen_not_preferred_single(row.feat, beh)
        min_not_pref = np.min(not_pref.groupby("Choice").count().TrialNumber)

        min_all = np.min((min_pref, min_not_pref))

        res.append({
            "feat": row.feat, 
            "session": sess, 
            "min_pref": min_pref, 
            "min_not_pref": min_not_pref,
            "min_all": min_all
        })
    return pd.DataFrame(res)
        
min_trials = pd.concat(good_sess.apply(min_trials_per_session, axis=1).values)



In [20]:
min_trials.sort_values(by="min_all")

Unnamed: 0,feat,session,min_pref,min_not_pref,min_all
4,MAGENTA,20180803,13,140,13
16,MAGENTA,20180928,16,72,16
9,RIPPLE,20180907,18,88,18
17,GREEN,20181005,20,90,20
18,GREEN,20181009,22,85,22
...,...,...,...,...,...
4,CYAN,20180820,90,91,90
3,YELLOW,20180802,94,132,94
15,YELLOW,20180912,101,99,99
20,MAGENTA,20181005,100,110,100


In [22]:
all_beh[(all_beh.session == "20180803") & (all_beh.CurrentRule == "MAGENTA")]

Unnamed: 0,index,TrialNumber,BlockNumber,TrialAfterRuleChange,TaskInterrupt,ConditionNumber,Response,ItemChosen,TrialType,CurrentRule,...,YELLOWProb,ESCHERProb,POLKADOTProb,RIPPLEProb,SWIRLProb,BeliefStateValue,BeliefStateValueBin,PreferredBelief,BeliefStateValueLabel,session
39448,261,293,11,0,,3012,Incorrect,1.0,5,MAGENTA,...,0.012898,0.012959,0.021213,0.042448,0.014327,0.533522,1,STAR,High STAR,20180803
39449,262,294,11,1,,3403,Incorrect,3.0,5,MAGENTA,...,0.045439,0.045527,0.023958,0.087733,0.047484,0.394699,0,STAR,Low,20180803
39450,263,295,11,2,,2992,Incorrect,2.0,5,MAGENTA,...,0.072828,0.072928,0.020262,0.12074,0.075145,0.341186,0,STAR,Low,20180803
39451,264,296,11,3,,3122,Incorrect,2.0,5,MAGENTA,...,0.038075,0.091225,0.016244,0.138749,0.093429,0.333839,0,RIPPLE,Low,20180803
39452,265,297,11,4,,3117,Correct,0.0,5,MAGENTA,...,0.021454,0.099288,0.013226,0.142155,0.101276,0.335998,0,RIPPLE,Low,20180803
39453,266,298,11,5,,3419,Incorrect,3.0,5,MAGENTA,...,0.023347,0.163786,0.018581,0.093268,0.069587,0.34232,0,MAGENTA,Low,20180803
39454,267,299,11,6,,3403,Correct,0.0,5,MAGENTA,...,0.039487,0.170907,0.035027,0.043838,0.082757,0.343911,0,MAGENTA,Low,20180803
39455,268,300,11,7,,3293,Correct,0.0,5,MAGENTA,...,0.031372,0.244228,0.028973,0.033712,0.054642,0.3583,0,MAGENTA,Low,20180803
39456,269,301,11,8,,3312,Correct,0.0,5,MAGENTA,...,0.02832,0.148349,0.026967,0.029639,0.099182,0.370403,0,MAGENTA,Low,20180803
39457,270,302,11,9,,3322,Correct,0.0,5,MAGENTA,...,0.023083,0.081822,0.022421,0.023729,0.13824,0.395328,0,MAGENTA,Low,20180803


### Save files