### Script to find good pairs for each subject
Criteria: 
- For each pair of features, find sessions where the pairs of features each show up least N blocks as rules per session 
- Find pairs of features which have at least M sessions that satisfy this condition. 

In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
from constants.glm_constants import *
from constants.behavioral_constants import *

import seaborn as sns
import scipy.stats


In [2]:
num_bins = 2

def get_sess_beh(row, beh_path, sub):
    session = row.session_name
    behavior_path = beh_path.format(sess_name=session)
    beh = pd.read_csv(behavior_path)

    # # filter trials 
    beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(beh)
    beh = pd.merge(beh, feature_selections, on="TrialNumber", how="inner")
    beh = behavioral_utils.get_beliefs_per_session(beh, session, sub)
    beh = behavioral_utils.get_belief_value_labels(beh)
    beh["session"] = session
    return beh

### Load behavior based on subject

In [5]:
sub = "BL"

if sub == "SA":
    # Monkey S
    SESSIONS_PATH = "/data/patrick_res/sessions/SA/valid_sessions.pickle"
    sessions = pd.read_pickle(SESSIONS_PATH)
    beh_path = SESS_BEHAVIOR_PATH
    all_beh = pd.concat(sessions.apply(lambda x: get_sess_beh(x, beh_path, "SA"), axis=1).values).reset_index()
elif sub == "BL":
    # Monkey B
    # all_units = pd.read_pickle("/data/patrick_res/firing_rates/BL/all_units.pickle")
    # sessions = pd.DataFrame({"session_name": all_units.session.unique()})
    sessions = pd.read_pickle("/data/patrick_res/sessions/BL/valid_sessions.pickle")
    all_beh = pd.concat(sessions.apply(lambda x: get_sess_beh(x, BL_SESS_BEHAVIOR_PATH, "BL"), axis=1).values).reset_index()
else: 
    raise ValueError()

  beh["PreferredBelief"] = beh[[f"{feat}Prob" for feat in FEATURES]].idxmax(axis=1).apply(lambda x: x[:-4])
  beh["PreferredBelief"] = beh[[f"{feat}Prob" for feat in FEATURES]].idxmax(axis=1).apply(lambda x: x[:-4])


In [6]:
all_beh

Unnamed: 0,index,TrialNumber,BlockNumber,TrialAfterRuleChange,TaskInterrupt,ConditionNumber,Response,ItemChosen,TrialType,CurrentRule,...,YELLOWProb,ESCHERProb,POLKADOTProb,RIPPLEProb,SWIRLProb,BeliefStateValue,BeliefStateValueBin,PreferredBelief,BeliefStateValueLabel,session
0,0,27,2,0,,3698,Incorrect,2.0,6,YELLOW,...,0.016559,0.047355,0.016991,0.026388,0.023243,0.544836,1.0,CYAN,High CYAN,20190128
1,1,28,2,1,,3871,Incorrect,2.0,6,YELLOW,...,0.032866,0.059105,0.014452,0.041241,0.038562,0.507766,1.0,CYAN,High CYAN,20190128
2,2,29,2,2,,3975,Incorrect,1.0,6,YELLOW,...,0.046895,0.069315,0.031161,0.023503,0.051762,0.481420,1.0,CYAN,High CYAN,20190128
3,3,30,2,3,,3852,Incorrect,2.0,6,YELLOW,...,0.078819,0.104462,0.060822,0.022639,0.084385,0.418154,0.0,CYAN,Low,20190128
4,4,31,2,4,,3891,Incorrect,1.0,6,YELLOW,...,0.086070,0.107958,0.070709,0.016575,0.090821,0.417044,0.0,CYAN,Low,20190128
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
39352,88,150,3,81,,3829,Correct,0.0,6,YELLOW,...,0.696677,0.043605,0.020507,0.030682,0.025258,0.537880,1.0,YELLOW,High YELLOW,20190123
39353,89,151,3,82,,3679,Correct,0.0,6,YELLOW,...,0.697238,0.063663,0.017939,0.022231,0.019943,0.539393,1.0,YELLOW,High YELLOW,20190123
39354,90,152,3,83,,3504,Correct,0.0,6,YELLOW,...,0.700474,0.083443,0.016921,0.018738,0.017769,0.541134,1.0,YELLOW,High YELLOW,20190123
39355,91,153,3,84,,3984,Correct,0.0,6,YELLOW,...,0.718449,0.045601,0.016837,0.040527,0.017204,0.547438,1.0,YELLOW,High YELLOW,20190123


In [7]:
len(sessions)

61

In [12]:
block_thresh = 3

In [13]:
pairs = behavioral_utils.get_good_pairs_across_sessions(all_beh, block_thresh)
pairs.sort_values(by="num_sessions", ascending=False)[:20]

Unnamed: 0,pair,sessions,num_sessions,dim_type
61,"[ESCHER, RIPPLE]","[20190529, 20190617, 20190710, 20190816, 20190...",5,within dim
62,"[ESCHER, SWIRL]","[20190529, 20190627, 20190814, 20190823]",4,within dim
65,"[RIPPLE, SWIRL]","[20190524, 20190529, 20190823]",3,within dim
27,"[STAR, POLKADOT]","[20190605, 20190606, 20190607]",3,across dim
2,"[CIRCLE, TRIANGLE]","[20190603, 20190606, 20190812]",3,within dim
34,"[TRIANGLE, ESCHER]","[20190221, 20190710, 20190814]",3,across dim
1,"[CIRCLE, STAR]","[20190228, 20190606, 20190703]",3,within dim
60,"[ESCHER, POLKADOT]","[20190227, 20190531, 20190816]",3,within dim
25,"[STAR, YELLOW]","[20190524, 20190605, 20190703]",3,across dim
52,"[MAGENTA, ESCHER]","[20190130, 20190530, 20190823]",3,across dim


In [18]:
good_sess = pairs[pairs.num_sessions >=3].sessions.explode()

In [21]:
all_units = pd.read_pickle("/data/patrick_res/firing_rates/BL/all_units.pickle")

good_sess[good_sess.isin(all_units.session.unique())]

19

In [14]:
session_thresh = 6

In [15]:
good_pairs = pairs[pairs.num_sessions >= session_thresh]

In [16]:
len(good_pairs)

8

In [17]:
len(good_pairs[good_pairs.dim_type == "within dim"])

4

In [18]:
good_pairs

Unnamed: 0,pair,sessions,num_sessions,dim_type
26,"[STAR, ESCHER]","[20190531, 20190611, 20190617, 20190625, 20190...",7,across dim
28,"[STAR, RIPPLE]","[20190531, 20190611, 20190617, 20190625, 20190...",6,across dim
34,"[TRIANGLE, ESCHER]","[20190531, 20190611, 20190627, 20190710, 20190...",6,across dim
41,"[CYAN, ESCHER]","[20190207, 20190220, 20190529, 20190531, 20190...",6,across dim
60,"[ESCHER, POLKADOT]","[20190220, 20190529, 20190531, 20190611, 20190...",9,within dim
61,"[ESCHER, RIPPLE]","[20190529, 20190531, 20190611, 20190617, 20190...",9,within dim
62,"[ESCHER, SWIRL]","[20190529, 20190531, 20190625, 20190627, 20190...",6,within dim
63,"[POLKADOT, RIPPLE]","[20190529, 20190531, 20190611, 20190617, 20190...",8,within dim


In [20]:
# good_pairs.to_pickle("/data/patrick_res/sessions/pairs_at_least_3blocks_7sess.pickle")
good_pairs.to_pickle(f"/data/patrick_res/sessions/{sub}/pairs_at_least_{block_thresh}blocks_{session_thresh}sess.pickle")

### For each session, for each pair, what are the min number of trials that match: 
- high preferred features
- high conf trials where features are chosen but not preferred. 

In [17]:
def min_trials_per_session(row):
    feat1, feat2 = row.pair
    res = []
    for sess in row.sessions:
        beh = all_beh[all_beh.session == sess]

        # find minimum number of trials, when either features are preferred
        pref_beh = beh[beh.BeliefStateValueLabel.isin([f"High {feat1}", f"High {feat2}"])]
        min_pref = np.min(pref_beh.groupby("BeliefStateValueLabel").count().TrialNumber)

        pref_chose = behavioral_utils.get_chosen_preferred_trials(row.pair, pref_beh)
        min_pref_chose = np.min(pref_chose.groupby("BeliefStateValueLabel").count().TrialNumber)


        not_pref_chose = behavioral_utils.get_chosen_not_preferred_trials(row.pair, beh)

        min_not_pref = np.min(not_pref_chose.groupby("Choice").count().TrialNumber)

        res.append({
            "pair": row.pair, 
            "session": sess, 
            "min_pref": min_pref, 
            "min_pref_chose": min_pref_chose,
            "min_not_pref_chose": min_not_pref
        })
    return pd.DataFrame(res)
        
min_trials = pd.concat(good_pairs.apply(min_trials_per_session, axis=1).values)



In [18]:
min_trials

Unnamed: 0,pair,session,min_pref,min_pref_chose,min_not_pref_chose
0,"[CIRCLE, SQUARE]",20180615,27,25,64
1,"[CIRCLE, SQUARE]",20180709,51,42,70
2,"[CIRCLE, SQUARE]",20180802,54,46,134
3,"[CIRCLE, SQUARE]",20180803,25,23,122
4,"[CIRCLE, SQUARE]",20180806,87,68,69
...,...,...,...,...,...
6,"[YELLOW, RIPPLE]",20180912,62,50,65
7,"[YELLOW, RIPPLE]",20180921,45,39,45
8,"[YELLOW, RIPPLE]",20181005,59,46,78
9,"[YELLOW, RIPPLE]",20181009,44,38,59


In [20]:
min_trials["min_all"] = min_trials[["min_pref", "min_pref_chose", "min_not_pref_chose"]].min(axis=1)

In [18]:
# min_trials.to_pickle("/data/patrick_res/sessions/SA/pairs_at_least_3blocks_7sess_min_trials.pickle")
min_trials.to_pickle("/data/patrick_res/sessions/SA/pairs_at_least_3blocks_10sess_more_sess.pickle")


In [18]:
min_trials_original = pd.read_pickle("/data/patrick_res/sessions/pairs_at_least_3blocks_7sess_min_trials.pickle")

In [23]:
(min_trials_original.min_all - min_trials.min_all).sum()

78