### Want to see how many pairs of rules have at least N sessions where both rules appear at least 2 times

In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
from constants.glm_constants import *
from constants.behavioral_constants import *

import seaborn as sns
import scipy.stats


In [2]:
num_bins = 2

def get_sess_beh(row, beh_path):
    session = row.session_name
    behavior_path = beh_path.format(sess_name=session)
    beh = pd.read_csv(behavior_path)

    # # filter trials 
    beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(beh)
    beh = pd.merge(beh, feature_selections, on="TrialNumber", how="inner")
    beh = behavioral_utils.get_beliefs_per_session(beh, session, "SA")
    beh = behavioral_utils.get_belief_value_labels(beh)
    beh["session"] = session
    return beh

In [3]:
# Monkey S
SESSIONS_PATH = "/data/patrick_res/sessions/SA/valid_sessions.pickle"
sessions = pd.read_pickle(SESSIONS_PATH)
beh_path = SESS_BEHAVIOR_PATH
all_beh = pd.concat(sessions.apply(lambda x: get_sess_beh(x, beh_path), axis=1).values).reset_index()

In [5]:
len(sessions)

45

In [5]:
behavioral_utils.get_good_pairs_across_sessions(all_beh, 3)

Unnamed: 0,pair,sessions,num_sessions,dim_type
0,"[CIRCLE, SQUARE]","[20180615, 20180709, 20180802, 20180803, 20180...",14,within dim
1,"[CIRCLE, STAR]","[20180709, 20180803, 20180821, 20180911, 20180...",9,within dim
2,"[CIRCLE, TRIANGLE]","[20180615, 20180802, 20180803, 20180806, 20180...",8,within dim
3,"[CIRCLE, CYAN]","[20180802, 20180803, 20180911, 20180918, 20180...",7,across dim
4,"[CIRCLE, GREEN]","[20180802, 20180806, 20180808, 20180813, 20180...",11,across dim
...,...,...,...,...
61,"[ESCHER, RIPPLE]","[20180803, 20180810, 20180821, 20180907, 20180...",6,within dim
62,"[ESCHER, SWIRL]","[20180803, 20180808, 20180810, 20180813, 20180...",5,within dim
63,"[POLKADOT, RIPPLE]","[20180705, 20180802, 20180803, 20180810, 20180...",7,within dim
64,"[POLKADOT, SWIRL]","[20180709, 20180802, 20180803, 20180808, 20180...",9,within dim


In [6]:
num_blocks = all_beh.groupby(["session", "CurrentRule"]).apply(lambda x: len(x.BlockNumber.unique())).reset_index()

In [7]:
num_blocks

Unnamed: 0,session,CurrentRule,0
0,20180615,CIRCLE,3
1,20180615,ESCHER,1
2,20180615,MAGENTA,1
3,20180615,POLKADOT,1
4,20180615,RIPPLE,3
...,...,...,...
493,20181010,RIPPLE,4
494,20181010,SQUARE,1
495,20181010,STAR,2
496,20181010,SWIRL,2


In [8]:
pairs = []
for i in range(12):
    for j in range(i + 1, 12):
        feat1 = FEATURES[i]
        feat2 = FEATURES[j]
        sess_1 = num_blocks[(num_blocks.CurrentRule == feat1) & (num_blocks[0] >= 3)].session
        sess_2 = num_blocks[(num_blocks.CurrentRule == feat2) & (num_blocks[0] >= 3)].session
        joints = sess_1[sess_1.isin(sess_2)].values
        if FEATURE_TO_DIM[feat1] == FEATURE_TO_DIM[feat2]:
            dim_type = "within dim"
        else: 
            dim_type = "across dim"
        pairs.append({"pair": [feat1, feat2], "sessions": joints, "num_sessions": len(joints), "dim_type": dim_type})
pairs = pd.DataFrame(pairs)

In [9]:
pairs.sort_values(by="num_sessions", ascending=False)[:20]

Unnamed: 0,pair,sessions,num_sessions,dim_type
0,"[CIRCLE, SQUARE]","[20180615, 20180709, 20180802, 20180803, 20180...",14,within dim
40,"[CYAN, YELLOW]","[20180705, 20180801, 20180802, 20180820, 20180...",14,within dim
6,"[CIRCLE, YELLOW]","[20180802, 20180810, 20180813, 20180821, 20180...",13,across dim
46,"[GREEN, YELLOW]","[20180802, 20180813, 20180829, 20180904, 20180...",12,within dim
16,"[SQUARE, YELLOW]","[20180801, 20180802, 20180810, 20180820, 20180...",12,across dim
13,"[SQUARE, CYAN]","[20180801, 20180802, 20180803, 20180820, 20180...",11,across dim
51,"[MAGENTA, YELLOW]","[20180705, 20180801, 20180802, 20180820, 20180...",11,within dim
56,"[YELLOW, ESCHER]","[20180810, 20180813, 20180820, 20180821, 20180...",11,across dim
58,"[YELLOW, RIPPLE]","[20180705, 20180802, 20180810, 20180821, 20180...",11,across dim
15,"[SQUARE, MAGENTA]","[20180801, 20180802, 20180803, 20180806, 20180...",11,across dim


In [14]:
good_pairs = pairs[pairs.num_sessions >= 10]

In [15]:
len(good_pairs)

17

In [13]:
len(good_pairs[good_pairs.dim_type == "within dim"])

8

In [14]:
good_pairs

Unnamed: 0,pair,sessions,num_sessions,dim_type
0,"[CIRCLE, SQUARE]","[20180709, 20180802, 20180803, 20180806, 20180...",8,within dim
4,"[CIRCLE, GREEN]","[20180802, 20180806, 20180808, 20180921, 20180...",7,across dim
6,"[CIRCLE, YELLOW]","[20180802, 20180918, 20180921, 20180924, 20181...",7,across dim
12,"[SQUARE, TRIANGLE]","[20180801, 20180802, 20180803, 20180806, 20180...",7,within dim
14,"[SQUARE, GREEN]","[20180802, 20180806, 20180808, 20180924, 20180...",7,across dim
15,"[SQUARE, MAGENTA]","[20180801, 20180802, 20180803, 20180806, 20180...",7,across dim
16,"[SQUARE, YELLOW]","[20180801, 20180802, 20180924, 20180925, 20180...",7,across dim
18,"[SQUARE, POLKADOT]","[20180709, 20180802, 20180803, 20180808, 20180...",7,across dim
27,"[STAR, POLKADOT]","[20180709, 20180803, 20180917, 20180920, 20180...",7,across dim
38,"[CYAN, GREEN]","[20180802, 20180910, 20180912, 20180921, 20180...",7,within dim


In [22]:
# good_pairs.to_pickle("/data/patrick_res/sessions/pairs_at_least_3blocks_7sess.pickle")
good_pairs.to_pickle("/data/patrick_res/sessions/SA/pairs_at_least_3blocks_10sess_more_sess.pickle")

### For each session, for each pair, what are the min number of trials that match: 
- high preferred features
- high conf trials where features are chosen but not preferred. 

In [17]:
def min_trials_per_session(row):
    feat1, feat2 = row.pair
    res = []
    for sess in row.sessions:
        beh = all_beh[all_beh.session == sess]

        # find minimum number of trials, when either features are preferred
        pref_beh = beh[beh.BeliefStateValueLabel.isin([f"High {feat1}", f"High {feat2}"])]
        min_pref = np.min(pref_beh.groupby("BeliefStateValueLabel").count().TrialNumber)

        pref_chose = behavioral_utils.get_chosen_preferred_trials(row.pair, pref_beh)
        min_pref_chose = np.min(pref_chose.groupby("BeliefStateValueLabel").count().TrialNumber)


        not_pref_chose = behavioral_utils.get_chosen_not_preferred_trials(row.pair, beh)

        min_not_pref = np.min(not_pref_chose.groupby("Choice").count().TrialNumber)

        res.append({
            "pair": row.pair, 
            "session": sess, 
            "min_pref": min_pref, 
            "min_pref_chose": min_pref_chose,
            "min_not_pref_chose": min_not_pref
        })
    return pd.DataFrame(res)
        
min_trials = pd.concat(good_pairs.apply(min_trials_per_session, axis=1).values)



In [18]:
min_trials

Unnamed: 0,pair,session,min_pref,min_pref_chose,min_not_pref_chose
0,"[CIRCLE, SQUARE]",20180615,27,25,64
1,"[CIRCLE, SQUARE]",20180709,51,42,70
2,"[CIRCLE, SQUARE]",20180802,54,46,134
3,"[CIRCLE, SQUARE]",20180803,25,23,122
4,"[CIRCLE, SQUARE]",20180806,87,68,69
...,...,...,...,...,...
6,"[YELLOW, RIPPLE]",20180912,62,50,65
7,"[YELLOW, RIPPLE]",20180921,45,39,45
8,"[YELLOW, RIPPLE]",20181005,59,46,78
9,"[YELLOW, RIPPLE]",20181009,44,38,59


In [20]:
min_trials["min_all"] = min_trials[["min_pref", "min_pref_chose", "min_not_pref_chose"]].min(axis=1)

In [18]:
# min_trials.to_pickle("/data/patrick_res/sessions/SA/pairs_at_least_3blocks_7sess_min_trials.pickle")
min_trials.to_pickle("/data/patrick_res/sessions/SA/pairs_at_least_3blocks_10sess_more_sess.pickle")


In [18]:
min_trials_original = pd.read_pickle("/data/patrick_res/sessions/pairs_at_least_3blocks_7sess_min_trials.pickle")

In [23]:
(min_trials_original.min_all - min_trials.min_all).sum()

78