### Want to look at subsampling to ensure we are correctly evaluating within dimension vs. across dimension cross decoding accuracies. 
- coarse data point: on average, there's 7.1 sessions per pair for across dim, 7.7 for within dim

How should we balance: want to balance number of trials (data points) and number of units (features)
- across all pairs, look at all sessions, find min number of units for pair. This is the number of units that will be subselected. 
- across all pairs, for each session, across each condition, look at minimum number of trials present. This should be number of trials that is sub-sampled. 

In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.classifier_utils as classifier_utils

import utils.io_utils as io_utils

import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import matplotlib
import utils.spike_utils as spike_utils
import utils.subspace_utils as subspace_utils
from trial_splitters.condition_trial_splitter import ConditionTrialSplitter 
from utils.session_data import SessionData
from constants.behavioral_constants import *


import scipy


### Load pairs, data

In [2]:
pairs = pd.read_pickle("/data/patrick_res/sessions/pairs_at_least_3blocks_7sess.pickle")
pairs = pairs.reset_index(drop=True)
SESSIONS_PATH = "/data/patrick_res/sessions/valid_sessions_rpe.pickle"
sessions = pd.read_pickle(SESSIONS_PATH)
all_units = spike_utils.get_unit_positions(sessions)

In [3]:
all_units.to_pickle("/data/patrick_res/firing_rates/all_units.pickle")

In [4]:
all_units

Unnamed: 0,Channel,Unit,SpikeTimesFile,UnitID,electrode_id,x,y,z,distance,in_brain,...,structure_level1,structure_level2,structure_level3,structure_level4,structure_level5,structure_level6,structure_potential,session,PseudoUnitID,manual_structure
0,100,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,0,100,13.535342,-55.049086,19.90947,27.5,True,...,telencephalon (tel),lateral_and_ventral_pallium (LVPal),lateral_pallium (LPal),claustrum (Cl),claustrum (Cl),claustrum (Cl),[],20180709,2018070900,Claustrum
1,108,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,1,108,13.699304,-54.08206,21.82937,28.75,True,...,telencephalon (tel),lateral_and_ventral_pallium (LVPal),lateral_pallium (LPal),claustrum (Cl),claustrum (Cl),claustrum (Cl),[],20180709,2018070901,Claustrum
2,109,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,2,109,15.204069,-54.488859,21.577754,27.5,True,...,telencephalon (tel),lateral_and_ventral_pallium (LVPal),lateral_pallium (LPal),claustrum (Cl),claustrum (Cl),claustrum (Cl),[],20180709,2018070902,Claustrum
3,10a,1,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,3,10a,0.789428,-75.75489,38.106774,9,True,...,Frontal_Lobe (Frontal),lateral_prefrontal_cortex (lat_PFC),dorsolateral_prefrontal_cortex (dlPFC),area_8B (area_8B),area_8B (area_8B),medial_area_8B (area_8Bm),dlPFC,20180709,2018070903,Prefrontal Cortex
4,10a,2,/data/rawdata/sub-SA/sess-20180709/spikes/sub-...,4,10a,0.789428,-75.75489,38.106774,9,True,...,Frontal_Lobe (Frontal),lateral_prefrontal_cortex (lat_PFC),dorsolateral_prefrontal_cortex (dlPFC),area_8B (area_8B),area_8B (area_8B),medial_area_8B (area_8Bm),dlPFC,20180709,2018070904,Prefrontal Cortex
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25,90,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,25,90,10.452756,-52.703145,19.318001,31.125,True,...,telencephalon (tel),amygdala (Amy),pallial_amygdala (pAmy),lateropallial_amygdala (lpAmy),lateral_amygdaloid_nucleus (La),lateral_dorsal_amygdaloid_nucleus (LaD),Hippocampal,20180910,2018091025,Amygdala
26,92,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,26,92,13.520363,-51.91999,19.341959,30.625,True,...,telencephalon (tel),basal_ganglia (BG),striatum (Str),dorsal_striatum (DStr),caudate (Cd),caudate_tail (CdT),Hippocampal,20180910,2018091026,Basal Ganglia
27,95,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,27,95,5.433664,-68.902483,15.922112,16,True,...,telencephalon (tel),basal_ganglia (BG),striatum (Str),dorsal_striatum (DStr),caudate (Cd),caudate_head (CdH),[],20180910,2018091027,Basal Ganglia
28,99,1,/data/rawdata/sub-SA/sess-20180910/spikes/sub-...,28,99,12.140106,-51.630938,21.155323,31.75,True,...,telencephalon (tel),amygdala (Amy),pallial_amygdala (pAmy),lateropallial_amygdala (lpAmy),lateral_amygdaloid_nucleus (La),lateral_dorsal_amygdaloid_nucleus (LaD),Hippocampal,20180910,2018091028,Amygdala


### Figure out how many units per pair, what is min number of units across pairs

In [3]:
def find_num_units(row, all_units):
    sub_units = all_units[all_units.session.isin(row.sessions)]
    return len(sub_units)
pairs["num_units_available"] = pairs.apply(lambda row: find_num_units(row, all_units), axis=1)

In [4]:
pairs.num_units_available.min()

179

### What is min number of trials across pair/session/condition

In [7]:
exploded = pairs.explode("sessions")
def find_num_trials_per_pair_sess(row):
    feat1, feat2 = row.pair
    sess_name = row.sessions
    behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=sess_name)
    beh = pd.read_csv(behavior_path)
    beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(beh)
    beh = pd.merge(beh, feature_selections, on="TrialNumber", how="inner")
    beh = behavioral_utils.get_beliefs_per_session(beh, sess_name)
    beh = behavioral_utils.get_belief_value_labels(beh)
    beh = beh[
        ((beh[FEATURE_TO_DIM[feat1]] == feat1) & (beh.BeliefStateValueLabel == f"High {feat1}")) |
        ((beh[FEATURE_TO_DIM[feat2]] == feat2) & (beh.BeliefStateValueLabel == f"High {feat2}")) |
        (beh.BeliefStateValueLabel == "Low")
    ]
    beh = behavioral_utils.balance_trials_by_condition(beh, ["BeliefStateValueLabel"])

    return len(beh)
    
exploded["balanced_num_trials_per_pair_sess"] = exploded.apply(find_num_trials_per_pair_sess, axis=1)

In [15]:
exploded["pair_str"] = exploded.apply(lambda x: "_".join(x.pair), axis=1)

In [17]:
exploded[["pair_str", "balanced_num_trials_per_pair_sess"]].groupby("pair_str").sum()["balanced_num_trials_per_pair_sess"].min()

693

### Current thinking: 
- For each pair, want equal number of trials and units. 
- The issue is, what's the best way to attribute each trials across sessions?
### One way to distribute: 
- One way: take the min trials per pair, X trials, 
- take session with min number of trials, Y trials
- X_AVG = X/ num sessions for pair
- if X_AVG > Y: 
  - take Y trials from all sessions first, then order sessions by number of trials ascending, grabbing all trials from each session until X is hit. ,  
- if X_AVG < Y:
  - take X_AVG from all sessions, done