In [83]:
import json
import numpy as np
import pandas as pd
import scipy.io as sio
import statsmodels.api as sm
from tqdm.notebook import tqdm
from glob import glob
from matplotlib import pyplot as plt

In [40]:
ll = 24
nsess = 24
nlist = 24*nsess
outputs = np.arange(ll)
lists = np.arange(nlist)
irt_lags = 4
w2v_path = '/home1/shai.goldman/IRT_git/Scripts/Resources/w2v.mat'
w2v = sio.loadmat(w2v_path)['w2v']

In [17]:
def load_data(file_name):
    
    x = pd.read_json(file_name)

    # convert data stats to pandas dfs with columns for
    # output position and rows for list number
    data = {}
    for key in x.keys():
        data[key] = np.array([i for i in x[key].values])
        if len(data[key].shape) > 1:
            data[key] = np.pad(data[key], [(0,0), (0,ll-data[key].shape[1])], mode='edge')
            data[key] = pd.DataFrame(data[key], columns=outputs, index=lists)
        else:
            data[key] = pd.Series(data[key], index=lists)
    # I prefer the old naming convention for 'recalls' matrix
    data['recalls'] = data.pop('serialpos')
    
    return data

In [18]:
def get_lags(data):
    # calc serial lags
    prev_rec = data['recalls'].loc[:, :ll-2]
    prev_rec.columns = range(1,ll)
    return data['recalls'] - prev_rec

In [19]:
def strech_intrus(data):
    
    data = data.copy()
    
    # remove all recalls after the first intrusion in a list
    first_intrus = pd.Series([list(i).index(-1) if -1 in i
                                 else len(i) for i in data['recalls'].values],
                                index = lists
                               )

    for ls in data['recalls'].index:
        fi = first_intrus.loc[ls]
        data['recalls'].loc[ls, fi:] = 0
        data['times'].loc[ls, fi:] = 0
    
    return data

In [20]:
def get_irts(data):
    # calc irts
    prev_times = data['times'].loc[:,:ll-2].astype(float)
    prev_times.columns = range(1, ll)

    irts = data['times'] - prev_times
    irts[irts<=0] = np.nan
    return irts

In [62]:
def add_prev_irts(data, lags=irt_lags):
    # include prev irts into the data df
    
    for lag in range(1,lags+1):
        # shift all the irts by the lag
        prev_irts = data['irt'].loc[:, :ll-lag-1].copy()
        prev_irts.columns = (outputs+lag)[:-lag]
        # insert nans for outputs before the first lag
        for output in range(lag):
            prev_irts[output] = np.nan
        # resort columns since we added the first outputs to the end
        prev_irts = prev_irts[sorted(prev_irts.columns)]
        # input to data array
        data[f'irt-{lag}'] = prev_irts

In [63]:
def sem_sim(a, b):
    """ helper func for finding semantic sims. """
    if a <= 0 or b <= 0:
        return np.nan
    # the -1 is very important because of 0 indexing in python vs matlab
    # they originally started the rec_nos from index 1 when the lab was
    # matlab and to have a w2v matrix its going to start from index 0 in python
    return w2v[a-1, b-1]

In [64]:
def get_sems(data):
    # calc semantic similarities
    sims = [[sem_sim(row.loc[i-1], no) 
              if i>0 else np.nan
              for i, no in row.iteritems()]
             for r, row in data['rec_nos'].iterrows()]
    return pd.DataFrame(sims, index=lists, columns=outputs)

In [65]:
def detailed_data(data, irt_lags=irt_lags):
    """ Add many important details to the data. """
    
    data = data.copy()
    
    data['lag'] = get_lags(data)
    
    # set all repeats as intrusions
    data['recalls'][data['lag']==0] = -1
    
    # remove all recs after an intrus
    data = strech_intrus(data)
    
    data['irt'] = get_irts(data)
    add_prev_irts(data, lags=irt_lags)
    
    # calc total recalls per list
    data['total_recalls'] = pd.Series(
        [r[r>0].size for i, r in data['recalls'].iterrows()],
        index=lists
    )
    
    data['sem'] = get_sems(data)

    return data

In [66]:
def prep_data_for_ols(data):
    # prepare data for OLS modeling by flattening it
    
    #-----flatten data----#
    flat_data = {}
    for key in data:
        if len(data[key].shape) > 1:
            flat_data[key] = data[key].values.flatten()
        else:
            flat_data[key] = np.repeat(data[key], ll).values.flatten()

    flat_data = pd.DataFrame(flat_data)
    
    # include output pos as a variable
    flat_data['output_pos'] = np.repeat([outputs], nlist, axis=0).flatten()

    #-----filter data----#
    # remove keys that wont go into the model
    for key in ['pres_words', 'pres_nos', 'rec_words',
                'rec_nos', 'recalled', 'times', 'intrusions',
                'subject', 'good_trial', 'recalls'
               ]:
        flat_data.pop(key)

    # remove all nans in prep for modeling
    for key in flat_data:
        flat_data = flat_data[~np.isnan(flat_data[key])]
        
    #----adjust some variables for the model----#
    # output position is inverted
    flat_data['output_pos'] = flat_data['output_pos'].astype(float)
    flat_data['output_pos'] = (ll-flat_data['output_pos']) ** -1
    # total recalls is normalized
    flat_data['total_recalls'] /= ll
    # lag is taken as ln(|lag|)
    flat_data['lag'] = np.log(np.abs(flat_data['lag']))
    
    # include lag sim interaction
    flat_data['lag_sem'] = flat_data['lag'] * flat_data['sem']
    
    return flat_data

In [67]:
def fit_model(flat_data):
    X = sm.add_constant(flat_data)
    y = X.pop('irt')

    model = sm.OLS(y, X)
    return model.fit()

In [68]:
def get_model(file_name, irt_lags=irt_lags):
    data = load_data(file_name)
    data = detailed_data(data, irt_lags=irt_lags)
    flat_data = prep_data_for_ols(data)
    return fit_model(flat_data)

In [3]:
# get behavioral data for ltpFR2 all subjects
path = '/data/eeg/scalp/ltp/ltpFR2/behavioral/data/'
files = [f for f in glob(path+'beh_data_LTP*.json') if 'incomplete' not in f]

In [None]:
lags = range(1,10)
aics = []
bics = []
for file in tqdm(files):
    subj = file.split('LTP')[1].replace('.json', '')
    models = [get_model(file, lag) for lag in lags]
    aics.append(pd.Series([m.aic for m in models], index=lags, name=subj))
    bics.append(pd.Series([m.bic for m in models], index=lags, name=subj))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

In [82]:
aics = pd.concat(aics)
bics = pd.concat(bics)