In [23]:
import sys
sys.path.append('../../pybeh')
import warnings
warnings.simplefilter('ignore')
import pandas as pd
import numpy as np
import cmlreaders as cml
from SimulatedSubjectData import *
from pandas_to_pybeh import pd_crp, get_all_matrices
import matplotlib.pyplot as plt
import pickle
pd.set_option("display.max_columns", None)

In [24]:
value_acc = 0.6
list_len = 15
num_lists = 1000

exp = 'CourierReinstate1'
subjects = ['LTP564', 'LTP565', 'LTP566', 'LTP567', 'LTP568', 'LTP569', 'LTP571', 'LTP572', 'LTP573',
            'LTP574', 'LTP575', 'LTP576', 'LTP577', 'LTP578', 'LTP579', 'LTP580', 'LTP581', 'LTP583',
            'LTP584', 'LTP585', 'LTP586', 'LTP587', 'LTP588', 'LTP589', 'LTP590', 'LTP591', 'LTP592', 
            'LTP593', 'LTP594', 'LTP595', 'LTP596', 'LTP597', 'LTP598', 'LTP599', 'LTP600', 'LTP601', 
            'LTP602', 'LTP603', 'LTP604', 'LTP605']

In [25]:
df = cml.get_data_index('ltp', rootdir='/').query("experiment == @exp")

full_evs = None
for i, row in df.iterrows():
    reader = cml.CMLReader(subject=row['subject'], experiment=row['experiment'], session=row['session'])
    evs = reader.load('task_events')
    full_evs = evs if full_evs is None else pd.concat([full_evs, evs], ignore_index=True)
full_evs = full_evs.query("subject in @subjects")
full_evs = full_evs[(full_evs['item'] != "AMPLIFIER") & (full_evs['item'] != "APPLE") & (full_evs['item'] != "AXE") & 
                    (full_evs['item'] != "BASKETBALL_HOOP") & (full_evs['item'] != "DOOR") & (full_evs['item'] != "IRONING_BOARD") & 
                    (full_evs['item'] != "SHOVEL") & (full_evs['item'] != "STOVE")]

In [26]:
with open("words.pkl", "rb") as f:
    wordpool = pickle.load(f)

In [27]:
# # Remove/replace rows with NaN as itemno

# # 1) Define which recall-of-word types to keep, then keep only what exists
# rec_word_types_all = ['REC_WORD', 'REC_WORD_VV']
# present_types = set(full_evs['type'].unique())
# present_rec_types = [t for t in rec_word_types_all if t in present_types]

# need_types = ['WORD'] + present_rec_types
# clean_evs = full_evs[full_evs['type'].isin(need_types)].copy()

# # 2) Normalize recall variants to 'REC_WORD' so your function can use a single rec_type
# for t in present_rec_types:
#     if t != 'REC_WORD':
#         clean_evs.loc[clean_evs['type'] == t, 'type'] = 'REC_WORD'

# # 3) Drop bad WORD rows; fill recall NaNs with sentinel 0
# is_word = clean_evs['type'].eq('WORD')
# is_rec  = clean_evs['type'].eq('REC_WORD')

# dropped_word = (is_word & clean_evs['itemno'].isna()).sum()
# temporal_evs = clean_evs[~(is_word & clean_evs['itemno'].isna())].copy()
# filled_rec = temporal_evs.loc[is_rec, 'itemno'].isna().sum()
# temporal_evs.loc[is_rec, 'itemno'] = temporal_evs.loc[is_rec, 'itemno'].fillna(0)

# # 4) Safe to cast now
# temporal_evs['itemno'] = temporal_evs['itemno'].astype('int64')
# temporal_evs

In [28]:
def compute_recall_rate(data):
    word_evs = data[data['type'] == 'WORD']
    return word_evs['recalled'].sum() / len(word_evs)

def compute_first_recall(data, list_length):
    rec_evs = data[data['type'] == 'REC_WORD']
    rec_evs['pos'] = rec_evs.groupby(['session', 'trial']).cumcount()
    first_recall_df = rec_evs.query('pos == 0 and serialpos >= 0')
    first_recall_df = first_recall_df.groupby(
        ['serialpos']).agg(
        {'recalled': 'count'}).reindex(range(1, list_length+1), fill_value=0)
    n_lists = first_recall_df['recalled'].sum()
    return first_recall_df['recalled'].to_numpy(dtype=float) / n_lists

# not finalized; there is an error somewhere since the resulting array is filled with nan values
def compute_lag_crp(data, list_length, itemno_column='itemno', list_index=['subject', 'session', 'trial'], 
                    pres_type='WORD', rec_type='REC_WORD', type_column='type', max_lag=None):
    if max_lag is None:
        lag_num = list_length - 1
    else:
        lag_num = min(list_length - 1, int(max_lag))
    
    crp_df = pd_crp(
        data,
        lag_num=lag_num,
        itemno_column=itemno_column,
        list_index=list_index,
        pres_type=pres_type,
        rec_type=rec_type,
        type_column=type_column
    )
    
    full_length = 2 * list_length - 1
    crp_arr = np.zeros(full_length, dtype=float)
    center = list_length - 1
    print(center)
    
    for lag, p in zip(crp_df['lag'].to_numpy(), crp_df['prob'].to_numpy()):
        idx = center + int(lag)
        if 0 <= idx < full_length:
            crp_arr[idx] = float(p)

    crp_arr[center] = 0.0
    return crp_arr

In [80]:
def compute_lag_crp_single_subject_array(
    data, 
    list_len
):
    center = list_len - 1
    min_lag = -center
    max_lag = center + 1
    actual = {lag: 0 for lag in range(min_lag, max_lag)}
    possible = {lag: 0 for lag in range(min_lag, max_lag)}
    for session_id, session_data in data.groupby('session'):
        recalls = session_data[session_data.type == 'REC_WORD']
        # print(recalls)
        words = session_data[session_data.type == 'WORD']
        if recalls.empty or words.empty:
            print(f"session {session_id} has no events")
            continue
        # print(recalls.intruded)
        recalls = recalls[(recalls['trial'] != -999)]
        word_to_pos = dict(zip(words['item'], words['serialpos']))
        # print(word_to_pos)
        # print(recalls)
        for trial in recalls['trial'].unique():
            trial_words = words[words['trial'] == trial]['item'].tolist()
            trial_recalls = (recalls[recalls['trial'] == trial]
                             .sort_values('rectime')
                             .drop_duplicates('item'))
            
            if len(trial_recalls) < 2:
                print(f"session {session_id}, trial {trial} doesn't have enough events")
                continue
            trial_recalls = trial_recalls[trial_recalls['item'].isin(trial_words)]
            recall_pos = [word_to_pos[w] for w in trial_recalls['item']]
            # print(recall_pos)
            for i, cur in enumerate(recall_pos[:-1]):
                lag = recall_pos[i+1] - cur
                if min_lag <= lag <= max_lag and lag != 0:
                    actual[lag] += 1
                for pos in set(range(1, list_len+1)) - set(recall_pos[:i+1]):
                    pl = pos - cur
                    if min_lag <= pl <= max_lag and pl != 0:
                        possible[pl] += 1

    # build CRP array
    full_len = 2*list_len - 1
    crp = np.full(full_len, np.nan)
    center = list_len - 1
    for lag in range(min_lag, max_lag):
        idx = center + lag
        if 0 <= idx < full_len:
            crp[idx] = (actual[lag] / possible[lag]) if possible[lag] > 0 else np.nan
    crp[center] = 0.0
    return crp


In [82]:
# testing functions
df595 = full_evs[full_evs['subject'] == 'LTP595']
# compute_recall_rate(df595)
# compute_first_recall(df595, list_len)
crp = compute_lag_crp_single_subject_array(df595, list_len)

In [85]:
crp

array([0.06382979, 0.05882353, 0.04424779, 0.02898551, 0.06470588,
       0.04102564, 0.05479452, 0.04979253, 0.07806691, 0.04895105,
       0.06644518, 0.08196721, 0.14195584, 0.3250774 , 0.        ,
       0.45721271, 0.20478723, 0.10670732, 0.06713781, 0.06477733,
       0.06161137, 0.0483871 , 0.03870968, 0.02459016, 0.04081633,
       0.05263158, 0.04081633, 0.        , 0.        ])

In [53]:
df595.type == 'REC_WORD'

152417    False
152418    False
152419    False
152420    False
152421    False
          ...  
158457    False
158458     True
158459     True
158460    False
158461    False
Name: type, Length: 5721, dtype: bool

In [None]:
sub_parameters = {}

for sub in subjects:
    df_sub = full_evs[full_evs['subject'] == sub]
    recall_rate = compute_recall_rate(df_sub)
    first_recall = compute_first_recall(df_sub, list_len)
    lag_crp = compute_lag_crp(df_sub, list_len)
    sub_parameters[sub] = {
        'recall_rate': recall_rate,
        'first_recall': first_recall,
        'lag_crp': lag_crp
    }
    
sub_parameters