In [4]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from pyvbmc import VBMC
import corner
from tqdm.notebook import tqdm
import pickle
import random
from scipy.integrate import cumulative_trapezoid as cumtrapz

from time_vary_norm_utils import (
    up_or_down_RTs_fit_fn, cum_pro_and_reactive_time_vary_fn,
    rho_A_t_VEC_fn, up_or_down_RTs_fit_wrt_stim_fn, rho_A_t_fn, cum_A_t_fn,
    CDF_E_minus_small_t_NORM_rate_norm_l_time_varying_fn, rho_E_minus_small_t_NORM_rate_norm_time_varying_fn)
from types import SimpleNamespace
from time_vary_and_norm_simulators import psiam_tied_data_gen_wrapper_rate_norm_fn

import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from sklearn.model_selection import GroupKFold


In [3]:
exp_df = pd.read_csv('../outExp.csv')

# remove wrong rows 
count = ((exp_df['RTwrtStim'].isna()) & (exp_df['abort_event'] == 3)).sum()
print("Number of rows where RTwrtStim is NaN and abort_event == 3:", count)
exp_df = exp_df[~((exp_df['RTwrtStim'].isna()) & (exp_df['abort_event'] == 3))].copy()

# comparable batch
all_df = exp_df[
    (exp_df['batch_name'] == 'Comparable') &
    (exp_df['LED_trial'].isin([np.nan, 0]))
]

# aborts and valid
# df_valid_and_aborts = exp_df_batch[
#     (exp_df_batch['success'].isin([1,-1])) |
#     (exp_df_batch['abort_event'] == 3)
# ].copy()

# ## choice and acc columns
# # 1 is right , -1 is left
# df_valid_and_aborts['choice'] = df_valid_and_aborts['response_poke'].apply(lambda x: 1 if x == 3 else (-1 if x == 2 else random.choice([1, -1])))
# # 1 or 0 if the choice was correct or not
# df_valid_and_aborts['accuracy'] = (df_valid_and_aborts['ILD'] * df_valid_and_aborts['choice']).apply(lambda x: 1 if x > 0 else 0)

# ## df used for fitting - valid trials < 1s + stim
# df_valid_less_than_1 = df_valid_and_aborts[
#     (df_valid_and_aborts['success'].isin([1,-1])) & 
#     (df_valid_and_aborts['RTwrtStim'] < 1) &
#     (df_valid_and_aborts['RTwrtStim'] > 0)
# ]

# find ABL and ILD
ABL_arr = all_df['ABL'].unique()
ILD_arr = all_df['ILD'].unique()


# sort ILD arr in ascending order
ILD_arr = np.sort(ILD_arr)
ABL_arr = np.sort(ABL_arr)

print('ABL:', ABL_arr)
print('ILD:', ILD_arr)

Number of rows where RTwrtStim is NaN and abort_event == 3: 16
ABL: [10 25 40 50 55 70]
ILD: [-8.   -4.   -2.25 -1.25 -0.5   0.    0.5   1.25  2.25  4.    8.  ]


In [5]:
df = all_df.copy()

# ----- outcome flags -----
df['is_abort']   = (df['abort_event'] == 3).astype(int)
df['short_poke'] = ((df['is_abort'] == 1) & (df['TotalFixTime'] < 300)).astype(int)

# ----- quick sanity checks -----
print(df['short_poke'].value_counts(dropna=False))
print(df.groupby(['animal', 'session'])['short_poke'].mean().describe())


short_poke
0    108269
1     10598
Name: count, dtype: int64
count    164.000000
mean       0.087456
std        0.036282
min        0.022654
25%        0.056332
50%        0.089805
75%        0.112854
max        0.178474
Name: short_poke, dtype: float64


# get the dataframe into a “time-series per session” order

In [6]:
df = (df
      .sort_values(['animal', 'session', 'trial'])   # critical!
      .reset_index(drop=True))


# build history-based predictors (fixed effects)


In [7]:
MAX_LAG = 3       # 3 previous trials

for lag in range(1, MAX_LAG + 1):
    df[f'success_{lag}']   = df.groupby(['animal', 'session'])['success'].shift(lag)
    df[f'abort_{lag}']     = df.groupby(['animal', 'session'])['is_abort'].shift(lag)
    df[f'shortpoke_{lag}'] = df.groupby(['animal', 'session'])['short_poke'].shift(lag)

# session-time variable
df['trial_in_session'] = df.groupby(['animal', 'session']).cumcount() + 1

# drop first few rows of every session that now contain NaNs
df = df.dropna().reset_index(drop=True)


# specify a hierarchical logistic GLM

In [9]:
import pandas as pd
import numpy as np
import statsmodels.api as sm          # ← NEW
import statsmodels.formula.api as smf

print("statsmodels version →", sm.__version__)  # should print 0.15.x or 0.16.x

MAX_LAG = 3
formula = (
    "short_poke ~ "
    + " + ".join([f"success_{k} + abort_{k}" for k in range(1, MAX_LAG+1)])
    + " + trial_in_session"
)

md = smf.mixedlm(
        formula=formula,
        data=df,
        groups="animal",
        vc_formula={"session": "0 + C(session)"},
        family=sm.families.Binomial()   # ← use sm. not smf.
)

fit = md.fit(method="lbfgs")
print(fit.summary())


statsmodels version → 0.14.4


ValueError: argument family not permitted for MixedLM initialization