### Look at which region the units are in for Blanche and Sam

In [11]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.classifier_utils as classifier_utils

import utils.io_utils as io_utils

import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import matplotlib
import utils.spike_utils as spike_utils
import utils.subspace_utils as subspace_utils
from trial_splitters.condition_trial_splitter import ConditionTrialSplitter 
from utils.session_data import SessionData
from constants.behavioral_constants import *
from constants.decoding_constants import *
import seaborn as sns
from scripts.pseudo_decoding.belief_partitions.belief_partition_configs import *
import scripts.pseudo_decoding.belief_partitions.belief_partitions_io as belief_partitions_io

import scipy
import argparse
import copy
import plotly.express as px
from scripts.anova_analysis.anova_configs import *



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
subject = "BL"
bl_pos = pd.read_pickle(UNITS_PATH.format(sub=subject))
bl_pos["subject"] = subject
bl_pos = bl_pos.groupby("structure_level2").PseudoUnitID.nunique().reset_index(name="BL_num_units")

subject = "SA"
sa_pos = pd.read_pickle(UNITS_PATH.format(sub=subject))
sa_pos["subject"] = subject
sa_pos = sa_pos.groupby("structure_level2").PseudoUnitID.nunique().reset_index(name="SA_num_units")

In [3]:
all_pos = pd.merge(bl_pos, sa_pos, on="structure_level2")
all_pos["min_shared"] = np.min(np.vstack((all_pos.BL_num_units, all_pos.SA_num_units)), axis=0)
all_pos["total"] = all_pos.BL_num_units + all_pos.SA_num_units

In [4]:
all_pos.style.hide(axis="index")

structure_level2,BL_num_units,SA_num_units,min_shared,total
amygdala (Amy),69,83,69,152
anterior_cingulate_gyrus (ACgG),11,106,11,117
basal_ganglia (BG),62,95,62,157
cerebellum (Cb),4,2,2,6
extrastriate_visual_areas_2-4 (V2-V4),19,8,8,27
floor_of_the_lateral_sulcus (floor_of_ls),1,7,1,8
inferior_parietal_lobule (IPL),16,44,16,60
inferior_temporal_cortex (ITC),49,206,49,255
lateral_and_ventral_pallium (LVPal),6,93,6,99
medial_pallium (MPal),210,53,53,263


In [5]:
all_pos.sort_values(by="total", ascending=False)[:10].style.hide(axis="index")

structure_level2,BL_num_units,SA_num_units,min_shared,total
medial_pallium (MPal),210,53,53,263
inferior_temporal_cortex (ITC),49,206,49,255
basal_ganglia (BG),62,95,62,157
amygdala (Amy),69,83,69,152
anterior_cingulate_gyrus (ACgG),11,106,11,117
lateral_and_ventral_pallium (LVPal),6,93,6,99
posterior_medial_cortex (PMC),37,47,37,84
superior_parietal_lobule (SPL),10,62,10,72
inferior_parietal_lobule (IPL),16,44,16,60
thalamus (Thal),17,37,17,54


In [6]:
all_pos.sort_values(by="min_shared", ascending=False)[:10].style.hide(axis="index")

structure_level2,BL_num_units,SA_num_units,min_shared,total
amygdala (Amy),69,83,69,152
basal_ganglia (BG),62,95,62,157
medial_pallium (MPal),210,53,53,263
inferior_temporal_cortex (ITC),49,206,49,255
posterior_medial_cortex (PMC),37,47,37,84
thalamus (Thal),17,37,17,54
inferior_parietal_lobule (IPL),16,44,16,60
anterior_cingulate_gyrus (ACgG),11,106,11,117
superior_parietal_lobule (SPL),10,62,10,72
extrastriate_visual_areas_2-4 (V2-V4),19,8,8,27


### Now look at only sessions with some feat as 3 blocks

In [7]:
FEATS_PATH = "/data/patrick_res/sessions/{sub}/feats_at_least_3blocks.pickle"
bl_feats = pd.read_pickle(FEATS_PATH.format(sub="BL"))
bl_sessions = bl_feats.sessions.explode().unique()

subject = "BL"
bl_pos = pd.read_pickle(UNITS_PATH.format(sub=subject))
bl_pos["subject"] = subject
bl_feats = pd.read_pickle(FEATS_PATH.format(sub="BL"))
bl_sessions = bl_feats.sessions.explode().unique()
bl_pos = bl_pos[bl_pos.session.isin(bl_sessions)]
bl_pos = bl_pos.groupby("structure_level2").PseudoUnitID.nunique().reset_index(name="BL_num_units")

subject = "SA"
sa_pos = pd.read_pickle(UNITS_PATH.format(sub=subject))
sa_pos["subject"] = subject
sa_feats = pd.read_pickle(FEATS_PATH.format(sub="SA"))
sa_sessions = sa_feats.sessions.explode().unique()
sa_pos = sa_pos[sa_pos.session.isin(sa_sessions)]
sa_pos = sa_pos.groupby("structure_level2").PseudoUnitID.nunique().reset_index(name="SA_num_units")

all_pos = pd.merge(bl_pos, sa_pos, on="structure_level2")
all_pos["min_shared"] = np.min(np.vstack((all_pos.BL_num_units, all_pos.SA_num_units)), axis=0)
all_pos["total"] = all_pos.BL_num_units + all_pos.SA_num_units

In [8]:
all_pos.sort_values(by="total", ascending=False)[:10].style.hide(axis="index")

structure_level2,BL_num_units,SA_num_units,min_shared,total
inferior_temporal_cortex (ITC),36,205,36,241
medial_pallium (MPal),161,53,53,214
amygdala (Amy),63,79,63,142
basal_ganglia (BG),46,90,46,136
anterior_cingulate_gyrus (ACgG),8,99,8,107
lateral_and_ventral_pallium (LVPal),6,86,6,92
superior_parietal_lobule (SPL),7,62,7,69
posterior_medial_cortex (PMC),20,46,20,66
inferior_parietal_lobule (IPL),9,44,9,53
thalamus (Thal),15,36,15,51


In [9]:
all_pos.sort_values(by="min_shared", ascending=False)[:10].style.hide(axis="index")

structure_level2,BL_num_units,SA_num_units,min_shared,total
amygdala (Amy),63,79,63,142
medial_pallium (MPal),161,53,53,214
basal_ganglia (BG),46,90,46,136
inferior_temporal_cortex (ITC),36,205,36,241
posterior_medial_cortex (PMC),20,46,20,66
thalamus (Thal),15,36,15,51
inferior_parietal_lobule (IPL),9,44,9,53
anterior_cingulate_gyrus (ACgG),8,99,8,107
extrastriate_visual_areas_2-4 (V2-V4),17,8,8,25
superior_parietal_lobule (SPL),7,62,7,69


### For each region, what is the number of "belief partition selective" units?
- current working definition is the unit has to be selective above shuffle 95th percentile in either StimOnset or Fb onset periods

In [33]:
sig_level = "95th"

def get_sub_stats(sub, region):
    args = argparse.Namespace(
        **AnovaConfigs()._asdict()
    )
    args.conditions = ["BeliefConf", "BeliefPartition"]
    args.beh_filters = {"Response": "Correct", "Choice": "Chose"}
    args.subject = sub

    FEATS_PATH = f"/data/patrick_res/sessions/{sub}/feats_at_least_3blocks.pickle"
    feats = pd.read_pickle(FEATS_PATH)
    unit_pos = pd.read_pickle(UNITS_PATH.format(sub=sub))
    unit_pos = unit_pos[unit_pos.structure_level2 == region]

    per_feat_totals = feats.groupby("feat").apply(lambda x: unit_pos[unit_pos.session.isin(x.sessions.iloc[0])].PseudoUnitID.nunique()).reset_index(name="num_units_recorded")

    args.trial_event = "StimOnset"
    stim_res = io_utils.read_anova_good_units(args, sig_level, "BeliefPartition", return_pos=True)
    stim_res["trial_event"] = "StimOnset"

    args.trial_event = "FeedbackOnsetLong"
    fb_res = io_utils.read_anova_good_units(args, sig_level, "BeliefPartition", return_pos=True)
    fb_res["trial_event"] = "FeedbackOnsetLong"

    all_res = pd.concat((stim_res, fb_res))
    all_res["session"] = (all_res.PseudoUnitID / 100).astype(int)
    all_res = all_res[all_res.structure_level2 == region]

    per_feat_sig = all_res.groupby("feat").PseudoUnitID.nunique().reset_index(name="num_units_sig")
    per_feat_sess = all_res.groupby("feat").session.nunique().reset_index(name="num_sessions_sig")

    per_feat_units = pd.merge(per_feat_sig, per_feat_totals, on="feat", how="outer")
    per_feat_units = pd.merge(per_feat_units, per_feat_sess, on="feat", how="outer")
    per_feat_units = per_feat_units.fillna(0)
    per_feat_units[["num_units_sig", "num_units_recorded", "num_sessions_sig"]] = per_feat_units[["num_units_sig", "num_units_recorded", "num_sessions_sig"]].astype(int)
    return per_feat_units


common_regions = ["inferior_temporal_cortex (ITC)", "medial_pallium (MPal)", "basal_ganglia (BG)", "amygdala (Amy)"]
for region in common_regions:
    sa_stats = get_sub_stats("SA", region)
    bl_stats = get_sub_stats("BL", region)
    stats = pd.merge(sa_stats, bl_stats, on="feat", how="outer", suffixes=["_sa", "_bl"])
    stats["num_units_sig_total"] = stats["num_units_sig_sa"] + stats["num_units_sig_bl"]
    print(region)
    display(stats[["feat", "num_units_sig_sa", "num_units_sig_bl", "num_units_sig_total"]].style.hide(axis="index"))


inferior_temporal_cortex (ITC)


feat,num_units_sig_sa,num_units_sig_bl,num_units_sig_total
CIRCLE,24,1,25
CYAN,26,1,27
ESCHER,16,4,20
GREEN,23,0,23
MAGENTA,18,1,19
POLKADOT,19,0,19
RIPPLE,10,2,12
SQUARE,23,1,24
STAR,21,2,23
SWIRL,8,4,12


medial_pallium (MPal)


feat,num_units_sig_sa,num_units_sig_bl,num_units_sig_total
CIRCLE,11,2,13
CYAN,11,7,18
ESCHER,6,25,31
GREEN,7,5,12
MAGENTA,5,7,12
POLKADOT,8,5,13
RIPPLE,5,15,20
SQUARE,12,5,17
STAR,4,8,12
SWIRL,3,9,12


basal_ganglia (BG)


feat,num_units_sig_sa,num_units_sig_bl,num_units_sig_total
CIRCLE,17,3,20
CYAN,12,4,16
ESCHER,7,9,16
GREEN,14,0,14
MAGENTA,12,0,12
POLKADOT,11,3,14
RIPPLE,7,0,7
SQUARE,9,0,9
STAR,11,6,17
SWIRL,15,1,16


amygdala (Amy)


feat,num_units_sig_sa,num_units_sig_bl,num_units_sig_total
CIRCLE,11,0,11
CYAN,17,2,19
ESCHER,9,7,16
GREEN,7,0,7
MAGENTA,11,0,11
POLKADOT,7,1,8
RIPPLE,12,4,16
SQUARE,12,1,13
STAR,6,5,11
SWIRL,6,1,7


In [29]:
bl_stats

Unnamed: 0,feat,num_units_sig,num_units_recorded,num_sessions_sig
0,CYAN,2.0,17,2.0
1,ESCHER,7.0,31,6.0
2,POLKADOT,1.0,10,1.0
3,RIPPLE,4.0,20,3.0
4,SQUARE,1.0,10,1.0
5,STAR,5.0,21,2.0
6,SWIRL,1.0,19,1.0
7,TRIANGLE,4.0,20,3.0
8,YELLOW,2.0,13,2.0
9,CIRCLE,,7,
