In [10]:
import sys
sys.path.append('../../pybeh')
import warnings
warnings.simplefilter('ignore')
import pandas as pd
import numpy as np
import cmlreaders as cml
from SimulatedSubjectData import *
from pandas_to_pybeh import pd_crp, get_all_matrices
import matplotlib.pyplot as plt
import pickle
pd.set_option("display.max_columns", None)

In [11]:
list_len = 15
num_lists = 1000

exp = 'CourierReinstate1'
subjects = ['LTP564', 'LTP565', 'LTP566', 'LTP567', 'LTP568', 'LTP569', 'LTP571', 'LTP572', 'LTP573',
            'LTP574', 'LTP575', 'LTP576', 'LTP577', 'LTP578', 'LTP579', 'LTP580', 'LTP581', 'LTP583',
            'LTP584', 'LTP585', 'LTP586', 'LTP587', 'LTP588', 'LTP589', 'LTP590', 'LTP591', 'LTP592', 
            'LTP593', 'LTP594', 'LTP595', 'LTP596', 'LTP597', 'LTP598', 'LTP599', 'LTP600', 'LTP601', 
            'LTP602', 'LTP603', 'LTP604', 'LTP605']

In [12]:
df = cml.get_data_index('ltp', rootdir='/').query("experiment == @exp")

full_evs = None
for i, row in df.iterrows():
    reader = cml.CMLReader(subject=row['subject'], experiment=row['experiment'], session=row['session'])
    evs = reader.load('task_events')
    full_evs = evs if full_evs is None else pd.concat([full_evs, evs], ignore_index=True)
full_evs = full_evs.query("subject in @subjects")
full_evs = full_evs[(full_evs['item'] != "AMPLIFIER") & (full_evs['item'] != "APPLE") & (full_evs['item'] != "AXE") & 
                    (full_evs['item'] != "BASKETBALL_HOOP") & (full_evs['item'] != "DOOR") & (full_evs['item'] != "IRONING_BOARD") & 
                    (full_evs['item'] != "SHOVEL") & (full_evs['item'] != "STOVE")]

In [13]:
with open("words.pkl", "rb") as f:
    wordpool = pickle.load(f)

In [14]:
def compute_recall_rate(data):
    word_evs = data[data['type'] == 'WORD']
    return word_evs['recalled'].sum() / len(word_evs)

def compute_first_recall(data, list_len):
    rec_evs = data[data['type'] == 'REC_WORD']
    rec_evs['pos'] = rec_evs.groupby(['session', 'trial']).cumcount()
    first_recall_df = rec_evs.query('pos == 0 and serialpos >= 0')
    first_recall_df = first_recall_df.groupby(
        ['serialpos']).agg(
        {'recalled': 'count'}).reindex(range(1, list_len+1), fill_value=0)
    n_lists = first_recall_df['recalled'].sum()
    return first_recall_df['recalled'].to_numpy(dtype=float) / n_lists

def compute_lag_crp_single_subject_array(
    data, 
    list_len
):
    center = list_len - 1
    min_lag = -center
    max_lag = center + 1
    actual = {lag: 0 for lag in range(min_lag, max_lag)}
    possible = {lag: 0 for lag in range(min_lag, max_lag)}
    for session_id, session_data in data.groupby('session'):
        recalls = session_data[session_data.type == 'REC_WORD']
        # print(recalls)
        words = session_data[session_data.type == 'WORD']
        if recalls.empty or words.empty:
            print(f"session {session_id} has no events")
            continue
        # print(recalls.intruded)
        recalls = recalls[(recalls['trial'] != -999)]
        word_to_pos = dict(zip(words['item'], words['serialpos']))
        # print(word_to_pos)
        # print(recalls)
        for trial in recalls['trial'].unique():
            trial_words = words[words['trial'] == trial]['item'].tolist()
            trial_recalls = (recalls[recalls['trial'] == trial]
                             .sort_values('rectime')
                             .drop_duplicates('item'))
            
            if len(trial_recalls) < 2:
                print(f"session {session_id}, trial {trial} doesn't have enough events")
                continue
            trial_recalls = trial_recalls[trial_recalls['item'].isin(trial_words)]
            recall_pos = [word_to_pos[w] for w in trial_recalls['item']]
            # print(recall_pos)
            for i, cur in enumerate(recall_pos[:-1]):
                lag = recall_pos[i+1] - cur
                if min_lag <= lag <= max_lag and lag != 0:
                    actual[lag] += 1
                for pos in set(range(1, list_len+1)) - set(recall_pos[:i+1]):
                    pl = pos - cur
                    if min_lag <= pl <= max_lag and pl != 0:
                        possible[pl] += 1

    # build CRP array
    full_len = 2*list_len - 1
    crp = np.full(full_len, np.nan)
    center = list_len - 1
    for lag in range(min_lag, max_lag):
        idx = center + lag
        if 0 <= idx < full_len:
            crp[idx] = (actual[lag] / possible[lag]) if possible[lag] > 0 else np.nan
    crp[center] = 0.0
    return crp

In [15]:
df595 = full_evs[full_evs['subject'] == 'LTP595']
# compute_recall_rate(df595)
# compute_first_recall(df595, list_len)
crp595 = compute_lag_crp_single_subject_array(df595, list_len)
crp595

array([0.06382979, 0.05882353, 0.04424779, 0.02898551, 0.06470588,
       0.04102564, 0.05479452, 0.04979253, 0.07806691, 0.04895105,
       0.06644518, 0.08196721, 0.14195584, 0.3250774 , 0.        ,
       0.45721271, 0.20478723, 0.10670732, 0.06713781, 0.06477733,
       0.06161137, 0.0483871 , 0.03870968, 0.02459016, 0.04081633,
       0.05263158, 0.04081633, 0.        , 0.        ])

In [16]:
# CourierReinstate1 data
sub_parameters = {}

for sub in subjects:
    df_sub = full_evs[full_evs['subject'] == sub]
    recall_rate = compute_recall_rate(df_sub)
    first_recall = compute_first_recall(df_sub, list_len)
    lag_crp = compute_lag_crp_single_subject_array(df_sub, list_len)
    sub_parameters[sub] = {
        'recall_rate': recall_rate,
        'first_recall': first_recall,
        'lag_crp': lag_crp
    }

session 2, trial 7 doesn't have enough events


In [17]:
import optuna
import numpy as np

def crp_loss(simulated_crp, target_crp):
    mask = ~np.isnan(simulated_crp) & ~np.isnan(target_crp)
    return np.mean((simulated_crp[mask] - target_crp[mask])**2)

def make_objective(sub_id, sub_parameters, target_crp, list_len, num_lists, n_avg=3):
    def objective(trial):
        # Suggest lag_crp distribution for this subject
        lag_crp = np.array([
            trial.suggest_float(f"lag_{i}", 1e-6, 1.0)
            for i in range(len(target_crp))
        ])
        lag_crp /= lag_crp.sum()  # normalize
        print(lag_crp)
        # Pull subject parameters
        recall_rate = sub_parameters[sub_id]['recall_rate']
        first_recall = sub_parameters[sub_id]['first_recall']

        losses = []
        for _ in range(n_avg):
            sim = SimulatedSubjectData(
                subject=sub_id,
                first_recall=first_recall,
                lag_crp=lag_crp,
                recall_rate=recall_rate,
                value_acc=0.6,     # if you use this, add to sub_parameters
                complex_params=None,     # same here
                seed=None
            )
            df = sim.generateData(list_len, num_lists, gen_pos=True)
            sim_crp = compute_lag_crp_single_subject_array(df, list_len)
            losses.append(crp_loss(sim_crp, target_crp))

        return np.mean(losses)
    return objective


In [18]:
# Example: pick a subject
sub_id = subjects[0]
target_crp = sub_parameters[sub_id]['lag_crp']  # empirical CRP

# Create Optuna study
study = optuna.create_study(direction="minimize")
objective = make_objective(sub_id, sub_parameters, target_crp, list_len=15, num_lists=30, n_avg=5)
study.optimize(objective, n_trials=100)

print("Best loss:", study.best_value)

best_lag_crp = np.array([study.best_params[f"lag_{i}"] for i in range(len(target_crp))])
best_lag_crp /= best_lag_crp.sum()
print("Best lag_crp distribution:", best_lag_crp)


[I 2025-09-29 10:41:00,969] A new study created in memory with name: no-name-4d7746b0-d748-49a5-b0d9-d6401695a590
[W 2025-09-29 10:41:01,058] Trial 0 failed with parameters: {'lag_0': 0.3351049284032237, 'lag_1': 0.22981252179981246, 'lag_2': 0.701047065765901, 'lag_3': 0.5662477264264821, 'lag_4': 0.5591305159531477, 'lag_5': 0.5015079872251285, 'lag_6': 0.31125945124443755, 'lag_7': 0.8495277749740726, 'lag_8': 0.908906499797524, 'lag_9': 0.3798539856602812, 'lag_10': 0.6167349196796628, 'lag_11': 0.7259976763940551, 'lag_12': 0.9901911058562812, 'lag_13': 0.18507471084688357, 'lag_14': 0.16200857657090956, 'lag_15': 0.3182561379330872, 'lag_16': 0.2735111685873487, 'lag_17': 0.49285727185755474, 'lag_18': 0.19127264650496045, 'lag_19': 0.08975256815227611, 'lag_20': 0.07868189750682626, 'lag_21': 0.8744467435440847, 'lag_22': 0.1308410775010976, 'lag_23': 0.8689895121468862, 'lag_24': 0.7874585916103851, 'lag_25': 0.3525571672667155, 'lag_26': 0.07490362160105764, 'lag_27': 0.730412

[0.02376518 0.01629799 0.0497173  0.04015752 0.03965277 0.03556626
 0.0220741  0.06024735 0.06445841 0.02693873 0.043738   0.05148676
 0.070223   0.01312525 0.01148943 0.02257029 0.01939704 0.03495277
 0.0135648  0.00636513 0.00558001 0.06201457 0.00927907 0.06162755
 0.05584549 0.02500287 0.00531206 0.05179983 0.05775047]


KeyError: 'session'

In [19]:
results = {}

for sub_id in subjects:
    target_crp = sub_parameters[sub_id]['lag_crp']
    study = optuna.create_study(direction="minimize")
    objective = make_objective(sub_id, sub_parameters, target_crp, list_len=24, num_lists=30, n_avg=5)
    study.optimize(objective, n_trials=100)

    best_lag_crp = np.array([study.best_params[f"lag_{i}"] for i in range(len(target_crp))])
    best_lag_crp /= best_lag_crp.sum()

    results[sub_id] = {
        "best_loss": study.best_value,
        "best_lag_crp": best_lag_crp
    }


[I 2025-09-29 10:41:27,814] A new study created in memory with name: no-name-0ab37670-6c6f-4ecd-aa09-72333732ec62
[W 2025-09-29 10:41:28,017] Trial 0 failed with parameters: {'lag_0': 0.1644524028541934, 'lag_1': 0.8615258254382522, 'lag_2': 0.8264433203766354, 'lag_3': 0.42388402985046103, 'lag_4': 0.07128425926827085, 'lag_5': 0.7485794252419353, 'lag_6': 0.8436651517293269, 'lag_7': 0.5124572744425689, 'lag_8': 0.8155438165575044, 'lag_9': 0.9256983587242497, 'lag_10': 0.8841974526517897, 'lag_11': 0.31208045287930597, 'lag_12': 0.09064592302457865, 'lag_13': 0.5979484713356218, 'lag_14': 0.22914739377017262, 'lag_15': 0.8469405564428352, 'lag_16': 0.2724902498926321, 'lag_17': 0.07207430042728331, 'lag_18': 0.1608217282423759, 'lag_19': 0.7178968953897817, 'lag_20': 0.9456500309751311, 'lag_21': 0.9889190759484273, 'lag_22': 0.738074153509126, 'lag_23': 0.32534458115361936, 'lag_24': 0.6977095011638981, 'lag_25': 0.9788227246305353, 'lag_26': 0.6498142212547211, 'lag_27': 0.2134056

[0.01026299 0.05376531 0.05157591 0.02645336 0.00444864 0.04671665
 0.05265067 0.03198096 0.0508957  0.05777013 0.05518018 0.01947603
 0.00565695 0.03731621 0.01430042 0.05285508 0.01700532 0.00449795
 0.01003641 0.04480185 0.05901525 0.06171555 0.04606105 0.0203038
 0.04354201 0.06108546 0.040553   0.01331802 0.00675913]


KeyError: 'session'

In [None]:
# Simulated data

sim_dfs = {}

for sub, parameters in sub_parameters.items():
    sim = SimulatedSubjectData(
        subject=sub,
        first_recall=parameters['first_recall'],
        lag_crp=parameters['lag_crp'],
        recall_rate=parameters['recall_rate'],
        value_acc=0.6
    )
    df_sim = sim.generateData(list_len, num_lists, gen_pos=True)
    sim_dfs[sub] = df_sim

In [None]:
sim_crps = {}

for sub, df_sim in sim_dfs.items():
    sim_crps[sub] = compute_lag_crp_single_subject_array(df_sub, list_len)

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=8, figsize=(24, 12.5), sharex=True, sharey=True)
axes = axes.flatten()

lags = np.arange(-(list_len-1), list_len)
mask = lags != 0
lags = lags[mask]

for ax, sub in zip(axes, sub_parameters.keys()):
    cr1_sub_crp = sub_parameters[sub]['lag_crp'][mask]
    sim_sub_crp = sim_crps.get(sub)
    if sim_sub_crp is not None:
        sim_sub_crp = sim_sub_crp[mask]
        ax.plot(lags, cr1_sub_crp, label="CourierReinstate1")
        ax.plot(lags, sim_sub_crp, linestyle="--", label="Simulated")
    ax.set_title(str(sub), fontsize=10)
    ax.tick_params(labelsize=8)

fig.suptitle("CR1 vs Simulated Lag CRPs", fontsize=16)
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper right")
plt.tight_layout()
plt.show()

In [None]:
cr1_all_crp = None
sim_all_crp = None
cr1_stack = np.vstack([sub_parameters[sub]['lag_crp'] for sub in sub_parameters])
sim_stack = np.vstack([sim_crps.get(sub) for sub in sub_parameters if sub in sim_crps])
cr1_all_crp = cr1_stack.mean(axis=0)
sim_all_crp = sim_stack.mean(axis=0)

lags = np.arange(-(list_len-1), list_len)
mask = lags != 0
lags = lags[mask]

plt.figure(figsize=(8, 6))
plt.plot(lags, cr1_all_crp[mask], label="CourierReinstate1")
plt.plot(lags, sim_all_crp[mask], linestyle="--", label="Simulated")
plt.xlabel('Lag')
plt.ylabel('Conditional Response Probability')
plt.title('Across Subjects Average Lag CRP')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
errors = {}
lags = np.arange(-(list_len-1), list_len)
mask = lags != 0
lags = lags[mask]
        
for sub in sub_parameters.keys():
    if sub in sim_crps:
        cr1_sub_crp = sub_parameters[sub]['lag_crp'][mask]
        sim_sub_crp = sim_crps.get(sub)
        if sim_sub_crp is not None:
            sim_sub_crp = sim_sub_crp[mask]
        errors[sub] = np.sqrt(((cr1_sub_crp - sim_sub_crp) ** 2).mean())
    
errors

In [None]:
avg_error = np.nanmean(list(errors.values()))
avg_error

In [None]:
results = {}

for sub_id in subjects:
    target_crp = sub_parameters[sub_id]['lag_crp']
    study = optuna.create_study(direction="minimize")
    objective = make_objective(sub_id, sub_parameters, target_crp, list_len=24, num_lists=30, n_avg=5)
    study.optimize(objective, n_trials=100)

    best_lag_crp = np.array([study.best_params[f"lag_{i}"] for i in range(len(target_crp))])
    best_lag_crp /= best_lag_crp.sum()

    results[sub_id] = {
        "best_loss": study.best_value,
        "best_lag_crp": best_lag_crp
    }
