In [22]:
%load_ext autoreload
import os, sys, glob
import json
import re
import numpy as np
import pandas as pd
from natsort import natsorted

sys.path.append('/dartfs/rc/lab/F/FinnLab/tommy/isc_asynchrony_behavior/code/utils/')
sys.path.append('/dartfs/rc/lab/F/FinnLab/tommy/utils/gentle')

import gentle
from config import *
import preproc_utils as preproc
# from preproc_utils import create_balanced_orders, get_consecutive_list_idxs, sort_consecutive_constraint, check_consecutive_spacing

# from text_utils import get_pos_tags, get_lemma

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Set directories 

In [2]:
base_dir = '/dartfs/rc/lab/F/FinnLab/tommy/isc_asynchrony_behavior/'
stim_dir = os.path.join(base_dir, 'stimuli')
cache_dir = os.path.join('/dartfs/rc/lab/F/FinnLab/tommy/', 'models')

gentle_dir = os.path.join(stim_dir, 'gentle')

In [36]:
def load_model_data(model_dir, model_name, task, window_size, top_n):
    '''
    Loads model data from directory
    '''

    model_dir = os.path.join(model_dir, task, model_name, f'window-size-{window_size}')
    results_fn = natsorted(glob.glob(os.path.join(model_dir, f'*top-{top_n}*')))[0]

    # load the data, remove nans
    model_results = pd.read_csv(results_fn)
#     model_results['glove_continuous_accuracy'] = model_results['glove_continuous_accuracy'].apply(np.nan_to_num)
#     model_results['word2vec_continuous_accuracy'] = model_results['word2vec_continuous_accuracy'].apply(np.nan_to_num)

    model_results['glove_avg_accuracy'] = model_results['glove_avg_accuracy'] #.apply(np.nan_to_num)
    model_results['word2vec_avg_accuracy'] = model_results['word2vec_avg_accuracy'] #.apply(np.nan_to_num)
    model_results['fasttext_avg_accuracy'] = model_results['fasttext_avg_accuracy'] #.apply(np.nan_to_num)

    model_results['glove_max_accuracy'] = model_results['glove_max_accuracy'] #.apply(np.nan_to_num)
    model_results['word2vec_max_accuracy'] = model_results['word2vec_max_accuracy'] #.apply(np.nan_to_num)
    model_results['fasttext_max_accuracy'] = model_results['fasttext_max_accuracy'] #.apply(np.nan_to_num)

    return model_results

def get_stim_candidate_idxs(task):
    '''
    Find the NWP candidate indices of a preprocessed transcript
    '''

    preproc_fn = os.path.join(STIM_DIR, 'preprocessed', task, f'{task}_transcript-preprocessed.csv')
    df_preproc = pd.read_csv(preproc_fn)
    nwp_idxs = np.where(df_preproc['NWP_Candidate'])[0]

    return df_preproc, nwp_idxs

def divide_nwp_dataframe(df, accuracy_type, percentile):

    df_divide = df.copy()

    # first find the lowest and highest percentile for entropy
    low_entropy_idxs = df['entropy'] < np.nanpercentile(df['entropy'], percentile)
    high_entropy_idxs = df['entropy'] >= np.nanpercentile(df['entropy'], 100-percentile)

    ## set names for entropy group
    df_divide.loc[low_entropy_idxs, 'entropy_group'] = 'low'
    df_divide.loc[high_entropy_idxs, 'entropy_group'] = 'high'

    # repeat for continuous accuracy
    low_accuracy_idxs = df[accuracy_type] < np.nanpercentile(df[accuracy_type], percentile)
    high_accuracy_idxs = df[accuracy_type] >= np.nanpercentile(df[accuracy_type], 100-percentile)

    ## set names for accuracy group
    df_divide.loc[low_accuracy_idxs, 'accuracy_group'] = 'low'
    df_divide.loc[high_accuracy_idxs, 'accuracy_group'] = 'high'

    return df_divide #.dropna()

# def get_quadrant_distributions(df_divide, indices):
    
#     df_idx = df_divide.loc[indices]
    
#     # get the items as a dictionary for passing out to aggregate
#     quadrant_dist = {f'{labels[0]}-entropy_{labels[1]}-accuracy': round(len(df)/len(df_idx), 2) 
#                  for labels, df in df_idx.groupby(['entropy_group', 'accuracy_group'])}

#     df_quadrants = pd.DataFrame.from_dict(quadrant_dist, orient='index').T
    
#     return df_quadrants

def select_prediction_words(df_divide, remove_perc, select_perc, min_spacing_thresh=3):
    '''
    
    df_divide: candidate words divided into quartiles based on entropy and accuracy
    
    remove_perc: percentage of words to remove based on proximity to other words
        helps ensure decent spacing between presented words
        
    select_perc: percentage of words to select for presentation    
    
    '''
    
    df_divide['spacing'] = np.hstack([np.nan, np.diff(df_divide.index)])
    
    quadrant_distributions = get_quadrant_distributions(df_divide, df_divide.index).to_numpy()
    
    updated = []

    for i, df in df_divide.groupby(['entropy_group', 'accuracy_group']):
        # find how many words to remove in the quadrant based on the percent
        n_words = round(remove_perc * len(df))
        df = df.sort_values(by='spacing').iloc[n_words:]
        updated.append(df.sort_index())

    updated = pd.concat(updated).sort_index()
    updated_distributions = get_quadrant_distributions(updated, updated.index).to_numpy()
    
    print (quadrant_distributions)
    print (updated_distributions)
    
    assert (np.isclose(quadrant_distributions, updated_distributions, atol=0.01).all())
    
    # make sure it is scaled to the original dataframe
    select_perc = select_perc/(1-remove_perc)
    min_spacing = 0
    RANDOM_STATE = 0
    
    print (f'Selecting {select_perc*100:.2f}% of remaining items')
    
    while (min_spacing < min_spacing_thresh):
        # now sample the words from each quadrant
        sampled = []

        for i, df in updated.groupby(['entropy_group', 'accuracy_group']):

            df_sampled = df.sample(frac=select_perc, random_state=RANDOM_STATE).sort_index()
            sampled.append((len(df_sampled), df_sampled))

        n_sampled, sampled = zip(*sampled)
        sampled = pd.concat(sampled).sort_index()

        min_spacing = np.diff(sampled.index).min()
        
        RANDOM_STATE += 1
    
    print (f'Min spacing of {min_spacing}')
    print (f'{len(sampled)} total words')

    return sampled

import random

def random_chunks(lst, n, shuffle=False):
    """Created randomized n-sized chunks from lst."""
    
    tmp_lst = lst.copy()
    n_total = len(lst)
    
    if shuffle:
        random.shuffle(tmp_lst)
    
    all_chunks = []
    
    for i in range(0, len(tmp_lst), n):
        all_chunks.append(tmp_lst[i:i + n])
    
    # distribute remaining items across orders
    if len(all_chunks) != n_total//n:
        remainder = all_chunks.pop()
        
        for i, item in enumerate(remainder):      
            all_chunks[i%n].append(item)
    
    # lastly sort for ordered indices
    all_chunks = [sorted(chunk) for chunk in all_chunks]
    
    return all_chunks
    

In [None]:
adventuresinsayingyes
black
bronx
example_trial
eyespy
milkywayoriginal
milkywaysynonyms
milkywayvodka
nwp_practice_trial
piemanpni
prettymouth
shame
tunnel
wheretheressmoke

In [38]:
all_tasks_quadrants = []

models_dir = os.path.join(DERIVATIVES_DIR, 'model-predictions')
model_name = 'gpt2-xl'
task = 'life'

df_preproc, candidate_idxs = get_stim_candidate_idxs(task)

model_results = load_model_data(models_dir, model_name=model_name, task=task, top_n=5, window_size=100)
model_results.loc[:, 'binary_accuracy'] = model_results['binary_accuracy'].astype(bool)
model_results = model_results.iloc[candidate_idxs]

df_divide = divide_nwp_dataframe(model_results, accuracy_type='word2vec_avg_accuracy', percentile=45)


df_selected = preproc.select_prediction_words(df_divide, remove_perc=0.5, select_perc=0.4, min_spacing_thresh=2)


[[0.12 0.29 0.29 0.11]]
[[0.14 0.35 0.36 0.14]]


  df_divide.loc[low_entropy_idxs, 'entropy_group'] = 'low'
  df_divide.loc[low_accuracy_idxs, 'accuracy_group'] = 'low'


AssertionError: 

In [20]:
all_tasks_quadrants = []

models_dir = os.path.join(DERIVATIVES_DIR, 'model-predictions')
model_name = 'gpt2-xl'
task = 'life'

df_preproc, candidate_idxs = get_stim_candidate_idxs(task)

model_results = load_model_data(models_dir, model_name=model_name, task=task, top_n=5, window_size=100)
model_results.loc[:, 'binary_accuracy'] = model_results['binary_accuracy'].astype(bool)
model_results = model_results.iloc[candidate_idxs]

df_divide = divide_nwp_dataframe(model_results, accuracy_type='word2vec_avg_accuracy', percentile=45)


df_selected = select_prediction_words(df_divide, remove_perc=0.5, select_perc=0.4, min_spacing_thresh=2)

# # percent_sampled = 0.
# n_orders = 3
# n_participants_per_item = 50
# consecutive_spacing = 8

# # find distribution of selected words from the divided quadrants
# quadrant_distribution = get_quadrant_distributions(df_divide, df_selected.index).to_numpy()
# deviation_threshold = 0.05
# order_distributions = np.zeros((n_orders, 4))

# # find indices for presentation and set number of items each subject sees
# nwp_indices = sorted(df_selected.index)

# # # Find lists with consecutive items violating our constraint

# while not (np.allclose(quadrant_distribution, order_distributions, atol=deviation_threshold)):
    
#     subject_experiment_orders = random_chunks(nwp_indices, len(nwp_indices)//n_orders, shuffle=True)
    
#     print ('starting')
# #     test_orders = subject_experiment_orders.copy()
#     idxs = get_consecutive_list_idxs(subject_experiment_orders, consecutive_spacing=consecutive_spacing)
#     subject_experiment_orders = sort_consecutive_constraint(subject_experiment_orders, consecutive_spacing=consecutive_spacing)
    
#     order_distributions = [get_quadrant_distributions(df_divide, order).to_numpy() for order in subject_experiment_orders]
    
#     # sometimes the randomized order makes a quadrant be dropped --> reset and try again
#     if not all([order.shape[-1] == 4 for order in order_distributions]):
#         order_distributions = np.zeros((n_orders, 4))
# # # Test again once we have completed resorting
# # idxs = get_consecutive_list_idxs(subject_experiment_orders, consecutive_spacing=p.consecutive_spacing)
# # print (f'Lists violating consecutive index constraint: {100*(len(idxs))/len(subject_experiment_orders)}%')

# # uniq, counts = np.unique(subject_experiment_orders, return_counts=True)
# # print (f'All counts per word: {np.sum(counts >= p.n_participants_per_item) / len(counts)*100}%')

# # counts = Counter(tuple(o) for o in subject_experiment_orders)
# # unique_orders = np.sum([v for k, v in counts.items()]) / len(counts)

# # print (f'Unique orders: {unique_orders*100}%')

# # orders_meeting_consecutive = np.sum([check_consecutive_spacing(order, consecutive_spacing=p.consecutive_spacing) for order in subject_experiment_orders]) / len(subject_experiment_orders)
# # print (f'Consecutive constraint: {orders_meeting_consecutive*100}%'

[   2    4    5    8   11   15   19   22   23   24   27   31   35   37
   38   39   41   43   44   45   48   51   53   57   58   62   67   68
   72   75   77   78   82   86   91   94   98   99  103  106  108  110
  117  119  121  122  125  128  129  133  138  141  142  145  148  151
  154  160  161  162  169  173  178  179  181  183  184  185  186  189
  190  191  193  195  196  197  199  201  203  205  210  214  216  218
  220  221  223  224  230  233  239  240  242  246  247  248  250  251
  254  255  257  262  264  265  268  269  270  273  274  278  281  283
  285  286  290  294  298  299  304  306  308  309  311  312  315  320
  321  325  327  334  339  340  342  345  347  351  355  359  362  365
  368  370  371  374  378  380  382  383  385  386  387  390  391  394
  398  401  402  405  409  410  414  417  420  421  425  426  431  432
  434  435  442  447  451  457  460  462  464  465  467  470  473  476
  478  480  482  486  488  491  496  497  498  499  504  506  508  511
  512 

  df_divide.loc[low_entropy_idxs, 'entropy_group'] = 'low'
  df_divide.loc[low_accuracy_idxs, 'accuracy_group'] = 'low'


In [718]:
N_ORDERS = 4
CONSECUTIVE_SPACING = 10 

# get baseline distribution of quadrants --> deviation threshold is the amount of error
# we tolerate from the distribution in each order
quadrant_distribution = get_quadrant_distributions(df_divide, df_selected.index).to_numpy()
deviation_threshold = 0.05
order_distributions = np.zeros((n_orders, 4))

# find indices for presentation and set number of items each subject sees
nwp_indices = sorted(df_selected.index)

# Find lists with consecutive items violating our constraint
while not (np.allclose(quadrant_distribution, order_distributions, atol=deviation_threshold)):

    # randomly chunk all indices into N_ORDERS
    subject_experiment_orders = random_chunks(nwp_indices, len(nwp_indices)//N_ORDERS, shuffle=True)
    subject_experiment_orders = sort_consecutive_constraint(subject_experiment_orders, consecutive_spacing=CONSECUTIVE_SPACING)
    
    # now find distribution of each order
    order_distributions = [get_quadrant_distributions(df_divide, order).to_numpy() for order in subject_experiment_orders]

    # sometimes the randomized order makes a quadrant be dropped --> reset and try again
    if not all([order.shape[-1] == 4 for order in order_distributions]):
        order_distributions = np.zeros((N_ORDERS, 4))

idxs = get_consecutive_list_idxs(subject_experiment_orders, consecutive_spacing=CONSECUTIVE_SPACING)
# now update df_preproc with our selected indices --> write out
selected_idxs = df_selected.index
df_final = df_preproc.copy()
df_final.loc[selected_idxs, ['entropy_group', 'accuracy_group']] = df_selected[['entropy_group', 'accuracy_group']]

df_final


# for each of the experiment orders
# for i, order in enumerate(subject_experiment_orders):

Starting pass #1
Number of lists w/ violation: 3
Starting pass #2
Number of lists w/ violation: 2
Starting pass #3
Number of lists w/ violation: 2
Starting pass #4
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 3
Starting pass #2
Number of lists w/ violation: 2
Starting pass #3
Number of lists w/ violation: 1
Starting pass #4
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 3
Starting pass #2
Number of lists w/ violation: 2
Starting pass #3
Number of lists w/ violation: 1
Starting pass #4
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 2
Starting pass #2
Number of lists w/ violation: 3
Starting pass #3
Number of lists w/ violation: 2
Starting pass #4
Number of lists w/ violation: 2
Starting pass #5
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 3
Starting pass #2
Number of lists w/ violation: 2
Starting pass #3
Num

Number of lists w/ violation: 2
Starting pass #2
Number of lists w/ violation: 2
Starting pass #3
Number of lists w/ violation: 2
Starting pass #4
Number of lists w/ violation: 2
Starting pass #5
Number of lists w/ violation: 1
Starting pass #6
Number of lists w/ violation: 1
Starting pass #7
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 3
Starting pass #2
Number of lists w/ violation: 3
Starting pass #3
Number of lists w/ violation: 3
Starting pass #4
Number of lists w/ violation: 2
Starting pass #5
Number of lists w/ violation: 1
Starting pass #6
Number of lists w/ violation: 1
Starting pass #7
Number of lists w/ violation: 2
Starting pass #8
Number of lists w/ violation: 1
Starting pass #9
Number of lists w/ violation: 1
Starting pass #10
Number of lists w/ violation: 1
Starting pass #11
Number of lists w/ violation: 0
Starting pass #1
Number of lists w/ violation: 3
Starting pass #2
Number of lists w/ violation: 2
Starting pass #3
Number of lists w/

Unnamed: 0,Word_Written,Case,POS,POS_Definition,Punctuation,Stop_Word,Word_Vocab,Onset,Offset,Duration,Named_Entity,NWP_Candidate,entropy_group,accuracy_group
0,I,success,PRP,"pronoun, personal",,True,I,0.012472,0.127781,0.115309,False,False,,
1,reached,success,VBD,"verb, past tense",,False,reached,0.127781,0.493847,0.366067,False,False,,
2,over,success,RB,adverb,,True,over,0.493847,0.960317,0.466470,False,False,,
3,and,success,CC,"conjunction, coordinating",,True,and,1.539002,1.661162,0.122160,False,False,,
4,secretly,success,RB,adverb,,False,secretly,1.664915,2.377098,0.712183,False,True,high,low
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1822,I,success,PRP,"pronoun, personal",,True,I,590.470522,590.611418,0.140897,False,False,,
1823,still,success,RB,adverb,,False,still,590.611418,590.999320,0.387902,False,False,,
1824,miss,success,VBP,"verb, present tense, not 3rd person singular",,False,miss,590.999320,591.188889,0.189569,False,False,,
1825,the,success,DT,determiner,,True,the,591.188889,591.265763,0.076874,False,False,,


In [736]:
test = [[12, 64, 81, 100, 295, 317, 341, 368, 380, 404, 431, 470, 483, 504, 517, 572, 598, 609, 666, 681, 704, 716, 741, 774, 907, 981, 1020, 1035, 1068, 1094, 1141, 1174, 1191, 1211, 1227, 1239, 1269, 1281, 1318, 1336, 1376, 1388, 1408, 1430, 1445, 1462, 1473, 1485, 1525], [19, 75, 148, 190, 235, 255, 305, 331, 354, 390, 409, 451, 476, 496, 507, 540, 558, 588, 626, 648, 670, 687, 710, 731, 813, 842, 856, 874, 919, 948, 994, 1009, 1043, 1055, 1112, 1136, 1182, 1233, 1246, 1273, 1300, 1311, 1325, 1380, 1394, 1416, 1436, 1493, 1516], [9, 49, 85, 116, 154, 201, 221, 244, 272, 302, 383, 401, 421, 446, 462, 528, 576, 673, 693, 707, 724, 747, 779, 799, 833, 845, 903, 914, 935, 970, 986, 1004, 1125, 1147, 1178, 1202, 1230, 1251, 1264, 1284, 1369, 1421, 1433, 1454, 1476, 1488, 1507, 1539], [6, 46, 90, 111, 143, 163, 212, 239, 263, 299, 346, 361, 373, 395, 417, 441, 500, 547, 567, 580, 603, 645, 663, 677, 713, 727, 744, 828, 848, 868, 941, 1017, 1062, 1075, 1100, 1133, 1187, 1217, 1236, 1291, 1303, 1328, 1343, 1358, 1412, 1440, 1479, 1496]]

In [748]:
len(test)

4

In [735]:
subject_experiment_orders

[[4,
  30,
  51,
  70,
  83,
  106,
  130,
  144,
  165,
  202,
  244,
  257,
  276,
  344,
  367,
  406,
  435,
  453,
  490,
  539,
  552,
  574,
  605,
  639,
  661,
  699,
  711,
  748,
  835,
  848,
  863,
  881,
  907,
  950,
  966,
  1021,
  1065,
  1167,
  1179,
  1239,
  1299,
  1341,
  1397,
  1409,
  1462,
  1476,
  1513,
  1542,
  1565,
  1590,
  1678,
  1745,
  1758,
  1817],
 [64,
  114,
  136,
  169,
  208,
  239,
  261,
  288,
  341,
  355,
  372,
  389,
  413,
  432,
  484,
  531,
  548,
  565,
  589,
  621,
  650,
  664,
  675,
  691,
  826,
  852,
  889,
  924,
  959,
  978,
  994,
  1037,
  1062,
  1076,
  1099,
  1130,
  1151,
  1190,
  1203,
  1222,
  1259,
  1318,
  1389,
  1451,
  1481,
  1497,
  1551,
  1601,
  1634,
  1658,
  1673,
  1732,
  1772,
  1783],
 [23,
  38,
  78,
  99,
  121,
  151,
  198,
  251,
  272,
  319,
  337,
  382,
  425,
  445,
  469,
  504,
  525,
  544,
  585,
  601,
  632,
  670,
  696,
  736,
  756,
  769,
  813,
  830,
  842,
  896,
 

In [738]:
np.tile(test, (50, 1)).shape[0]

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (4,) + inhomogeneous part.

In [None]:
# get the items as a dictionary for passing out to aggregate
quadrant_dist = {f'{labels[0]}-entropy_{labels[1]}-accuracy': round(len(df)/len(df_divide), 2) 
             for labels, df in df_divide.groupby(['entropy_group', 'accuracy_group'])}

df_quadrants = pd.DataFrame.from_dict(quadrant_dist, orient='index').T

In [343]:
order_distributions

[array([[0.15, 0.3 , 0.4 , 0.14]]),
 array([[0.13, 0.4 , 0.33, 0.14]]),
 array([[0.15, 0.34, 0.38, 0.12]]),
 array([[0.13, 0.33, 0.36, 0.18]])]

In [None]:
def create_n_random_orders(n_orders, n_participants_per_item, consecutive_spacing)

In [125]:
get_quadrant_distributions(df_divide, order).to_numpy().min()

0.15

In [207]:
test_orders

[[12,
  89,
  95,
  158,
  169,
  202,
  282,
  355,
  390,
  408,
  425,
  463,
  496,
  504,
  535,
  582,
  597,
  608,
  691,
  696,
  742,
  822,
  904,
  997,
  1024,
  1032,
  1065,
  1081,
  1105,
  1110,
  1222,
  1227,
  1235,
  1245,
  1282,
  1302,
  1307,
  1345,
  1447,
  1486,
  1565,
  1617,
  1637,
  1722,
  1798],
 [17,
  73,
  99,
  114,
  204,
  253,
  261,
  323,
  329,
  351,
  531,
  605,
  610,
  646,
  661,
  701,
  738,
  762,
  795,
  802,
  813,
  839,
  896,
  902,
  955,
  1103,
  1136,
  1162,
  1259,
  1290,
  1371,
  1411,
  1443,
  1448,
  1459,
  1484,
  1513,
  1551,
  1607,
  1701,
  1710,
  1735,
  1747,
  1754],
 [144,
  151,
  187,
  198,
  217,
  251,
  292,
  324,
  376,
  475,
  536,
  548,
  578,
  606,
  681,
  713,
  749,
  786,
  806,
  873,
  933,
  940,
  980,
  987,
  1044,
  1087,
  1102,
  1154,
  1172,
  1231,
  1255,
  1278,
  1327,
  1354,
  1367,
  1409,
  1438,
  1451,
  1571,
  1579,
  1610,
  1629,
  1655,
  1765],
 [11,
  23,


In [245]:
order_distributions

[array([[0.13, 0.28, 0.38, 0.2 ]]), array([[0.2 , 0.35, 0.3 , 0.15]])]

In [257]:
order_distributions

[array([[0.17, 0.35, 0.29, 0.19]]), array([[0.17, 0.28, 0.39, 0.16]])]

In [295]:
np.diff(test_orders[0]).shape

(91,)