### Draft up code to do time balancing, see if running it is feasible, how good it is. 

In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.classifier_utils as classifier_utils

import utils.io_utils as io_utils

import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import matplotlib
import utils.spike_utils as spike_utils
import utils.subspace_utils as subspace_utils
from trial_splitters.condition_trial_splitter import ConditionTrialSplitter 
from utils.session_data import SessionData
from constants.behavioral_constants import *
from constants.decoding_constants import *
import seaborn as sns
from scripts.pseudo_decoding.belief_partitions.belief_partition_configs import *
import scripts.pseudo_decoding.belief_partitions.belief_partitions_io as belief_partitions_io

import scipy
import argparse
import copy
from tqdm import tqdm

import numpy as np
from scipy.spatial.distance import pdist, cdist
from scipy.stats import skew
import matplotlib.pyplot as plt

In [2]:
def compute_pairwise_stats(beh):
    trial_dists = pdist(beh.TrialNumber.values.reshape(-1, 1), metric='euclidean')
    return np.mean(trial_dists), np.std(trial_dists)

def time_balance_conds(beh, sample_size, num_iters, target_mean, target_std, seed=42):
    """
    Sample from behavior, such that distribution of subsample trial number distances 
    in subsample is as close to target mean/std as possible
    """
    # n_size = np.min((len(cond1_beh, len(cond2_beh))))
    lowest_err = np.inf   
    best_sub_trials = None
    rng = np.random.default_rng(seed=seed)
    for i in tqdm(range(num_iters)):
        idx = rng.choice(len(beh), size=sample_size, replace=False)
        sub_beh = beh.iloc[idx]
        mean, std = compute_pairwise_stats(sub_beh)
        # just weight mean and std equally at this point, 
        # err = 2 * (mean - target_mean) ** 2 + (std - target_std) **2
        err = (mean - target_mean) ** 2
        if err < lowest_err:
            print(f"new err: {err}")
            best_sub_trials = sub_beh.TrialNumber.values
            lowest_err = err
    return best_sub_trials, lowest_err


In [3]:
subject = "SA"
session = "20180802"
feat = "TRIANGLE"
beh = behavioral_utils.get_valid_belief_beh_for_sub_sess(subject, session)
beh = behavioral_utils.get_belief_partitions(beh, feat, use_x=True)

In [4]:
beh[beh.BeliefPartition == "High X"]

Unnamed: 0,TrialNumber,BlockNumber,TrialAfterRuleChange,TaskInterrupt,ConditionNumber,Response,ItemChosen,TrialType,CurrentRule,LastRule,...,PrevColor,PrevShape,PrevPattern,session,BeliefConf,BeliefPolicy,BeliefPartition,NextBeliefConf,NextBeliefPolicy,NextBeliefPartition
39,74,3,17,,545,Incorrect,3.0,9,TRIANGLE,CIRCLE,...,GREEN,TRIANGLE,RIPPLE,20180802,High,X,High X,Low,Not X,Low
41,76,3,19,,177,Incorrect,2.0,9,TRIANGLE,CIRCLE,...,YELLOW,TRIANGLE,ESCHER,20180802,High,X,High X,High,X,High X
42,77,3,20,,510,Correct,0.0,9,TRIANGLE,CIRCLE,...,YELLOW,STAR,SWIRL,20180802,High,X,High X,High,X,High X
43,78,3,21,,406,Correct,0.0,9,TRIANGLE,CIRCLE,...,GREEN,TRIANGLE,ESCHER,20180802,High,X,High X,High,X,High X
44,79,3,22,,432,Correct,0.0,9,TRIANGLE,CIRCLE,...,YELLOW,TRIANGLE,SWIRL,20180802,High,X,High X,High,X,High X
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1286,1321,44,22,,81,Correct,0.0,9,TRIANGLE,SQUARE,...,CYAN,TRIANGLE,ESCHER,20180802,High,X,High X,High,X,High X
1287,1322,44,23,,107,Correct,0.0,9,TRIANGLE,SQUARE,...,CYAN,TRIANGLE,ESCHER,20180802,High,X,High X,High,X,High X
1288,1323,45,0,,4081,Incorrect,2.0,7,GREEN,TRIANGLE,...,CYAN,TRIANGLE,ESCHER,20180802,High,X,High X,High,X,High X
1289,1324,45,1,,4336,Incorrect,1.0,7,GREEN,TRIANGLE,...,MAGENTA,SQUARE,ESCHER,20180802,High,X,High X,High,X,High X


In [5]:
mean, std = compute_pairwise_stats(beh)
print(mean)
print(std)

568.8819021656823
401.90658925054845


In [8]:
best_sub_trials, lowest_err = time_balance_conds(beh, 500, 20000, target_mean=600, target_std=450)

  0%|          | 0/20000 [00:00<?, ?it/s]

  2%|▏         | 319/20000 [00:00<00:12, 1598.92it/s]

new err: 1367.440364297335
new err: 960.8782403299567
new err: 250.7918442392767
new err: 127.01497784416989
new err: 6.89926548253192


  5%|▍         | 980/20000 [00:00<00:11, 1642.50it/s]

new err: 0.5113596982501665


  9%|▉         | 1808/20000 [00:01<00:11, 1647.71it/s]

new err: 0.1887980867547167


 23%|██▎       | 4614/20000 [00:02<00:09, 1632.97it/s]

new err: 0.031289994112478625


 82%|████████▏ | 16385/20000 [00:10<00:02, 1623.41it/s]

new err: 0.0003569752089355063


100%|██████████| 20000/20000 [00:12<00:00, 1627.42it/s]


In [9]:
sub_mean, sub_std = compute_pairwise_stats(beh[beh.TrialNumber.isin(best_sub_trials)])
print(sub_mean)
print(sub_std)

599.9811062124248
424.2868070551108


In [10]:
lowest_err

0.0003569752089355063

### If only care about distance between pairs of conditions: 

In [6]:
def compute_pairwise_stats_by_pair(beh_a, beh_b, mean=True):
    trial_dists = cdist(
        beh_a.TrialNumber.values.reshape(-1, 1), 
        beh_b.TrialNumber.values.reshape(-1, 1), 
        metric='euclidean')
    if mean:
        return np.mean(trial_dists)
    else: 
        trial_dists
    return np.mean(trial_dists)

def time_balance_conds_by_pair(beh_a, beh_b, num_iters, target_mean, sample_size=None, seed=42):
    """
    Sample from behavior, such that distribution of subsample trial number distances 
    in subsample is as close to target mean/std as possible
    """

    lowest_err = np.inf   
    best_sub_a_trials = None
    best_sub_b_trials = None
    rng = np.random.default_rng(seed=seed)
    sample_size = np.min((len(beh_a), len(beh_b))) if sample_size is None else sample_size
    for i in tqdm(range(num_iters)):
        idx_a = rng.choice(len(beh_a), size=sample_size, replace=False)
        sub_beh_a = beh_a.iloc[idx_a]

        idx_b = rng.choice(len(beh_b), size=sample_size, replace=False)
        sub_beh_b = beh_b.iloc[idx_b]
        mean = compute_pairwise_stats_by_pair(sub_beh_a, sub_beh_b)
        # just weight mean and std equally at this point, 
        # err = 2 * (mean - target_mean) ** 2 + (std - target_std) **2
        err = (mean - target_mean) ** 2
        if err < lowest_err:
            print(f"new err: {err}, new mean {mean}")
            best_sub_a_trials = sub_beh_a.TrialNumber.values
            best_sub_b_trials = sub_beh_b.TrialNumber.values

            lowest_err = err
    return best_sub_a_trials, best_sub_b_trials, lowest_err


In [7]:
subject = "SA"
session = "20180802"
beh = behavioral_utils.get_valid_belief_beh_for_sub_sess(subject, session)

In [8]:
beh.groupby("CurrentRule").BlockNumber.nunique()

CurrentRule
CIRCLE      6
CYAN        6
ESCHER      1
GREEN       6
MAGENTA     3
POLKADOT    4
RIPPLE      4
SQUARE      5
STAR        2
SWIRL       4
TRIANGLE    7
YELLOW      5
Name: BlockNumber, dtype: int64

In [9]:
beh.groupby("CurrentRule").BlockNumber.nunique()

circle_beh = behavioral_utils.get_belief_partitions(beh, "CIRCLE", use_x=True)
circle_beh = circle_beh[circle_beh.BeliefPartition == "High X"]

triangle_beh = behavioral_utils.get_belief_partitions(beh, "TRIANGLE", use_x=True)
triangle_beh = triangle_beh[triangle_beh.BeliefPartition == "High X"]

green_beh = behavioral_utils.get_belief_partitions(beh, "GREEN", use_x=True)
green_beh = green_beh[green_beh.BeliefPartition == "High X"]


In [10]:
compute_pairwise_stats_by_pair(circle_beh, triangle_beh)

425.7869162087912

In [11]:
compute_pairwise_stats_by_pair(circle_beh, green_beh)

635.0475470430108

In [12]:
circle_sub, green_sub, _, = time_balance_conds_by_pair(circle_beh, green_beh, 20000, 510, sample_size=30)

  4%|▍         | 835/20000 [00:00<00:04, 4171.87it/s]

new err: 12515.642711111112, new mean 621.8733333333333
new err: 4396.574044444437, new mean 576.3066666666666
new err: 79.68537777777796, new mean 501.0733333333333
new err: 27.995856790123863, new mean 515.2911111111111
new err: 6.656400000000211, new mean 512.58
new err: 4.391353086419786, new mean 507.90444444444444
new err: 3.4348444444445176, new mean 508.14666666666665


 12%|█▏        | 2470/20000 [00:00<00:04, 3991.72it/s]

new err: 0.8877827160493389, new mean 509.0577777777778
new err: 0.1067111111111027, new mean 510.32666666666665


 18%|█▊        | 3664/20000 [00:00<00:04, 3957.73it/s]

new err: 0.027041975308636407, new mean 510.1644444444444
new err: 0.0019753086419733996, new mean 509.9555555555556


100%|██████████| 20000/20000 [00:04<00:00, 4091.40it/s]


In [14]:
len(green_sub)

30

In [16]:
time_balance_conds_by_pair(circle_beh, triangle_beh, 20000, 510, sample_size=30)

  0%|          | 0/20000 [00:00<?, ?it/s]

  4%|▍         | 802/20000 [00:00<00:04, 4008.94it/s]

new err: 18950.275600000008, new mean 372.34
new err: 2454.872177777779, new mean 460.4533333333333
new err: 488.9995111111116, new mean 487.88666666666666
new err: 79.68537777777796, new mean 501.0733333333333
new err: 41.38777777777783, new mean 503.56666666666666


  8%|▊         | 1612/20000 [00:00<00:04, 4033.23it/s]

new err: 9.81777777777773, new mean 506.8666666666667
new err: 7.556390123457004, new mean 512.7488888888889
new err: 2.7115111111110464, new mean 511.64666666666665


 12%|█▏        | 2420/20000 [00:00<00:04, 4027.07it/s]

new err: 1.1520444444444233, new mean 511.0733333333333
new err: 0.3211111111111068, new mean 509.43333333333334


 38%|███▊      | 7673/20000 [00:01<00:03, 4026.85it/s]

new err: 0.31859753086421694, new mean 510.56444444444446


 57%|█████▋    | 11315/20000 [00:02<00:02, 4028.60it/s]

new err: 0.008711111111105877, new mean 510.0933333333333


 77%|███████▋  | 15358/20000 [00:03<00:01, 4043.82it/s]

new err: 0.0012641975308645928, new mean 510.03555555555556


100%|██████████| 20000/20000 [00:04<00:00, 4031.55it/s]


(array([ 57, 824, 251, 223, 258, 216,  46, 250,  51, 259,  56, 721, 257,
        826, 732, 254, 252, 730,  52, 719, 823, 720, 256, 897,  54, 896,
        899,  53, 815, 734]),
 array([1321,   89,  669,  279,  749,  280, 1325,   86,  683, 1317,  523,
         997, 1324,   94,  753,   76,   77, 1318, 1323,   90,  516,  520,
         679, 1313, 1319, 1314,  525,   81,  750,  281]),
 0.0012641975308645928)