In [1]:
# %% [markdown]
# # Didn't converge, better to fit aborts seperately first - SciPy Optimize Version

# %%
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
# Comment out or remove pyvbmc import if not used elsewhere
# from pyvbmc import VBMC 
# import corner # Comment out or remove if not used for plotting results here
# from tqdm.notebook import tqdm # Use standard tqdm if not in notebook
import pickle
import random
from scipy.integrate import cumulative_trapezoid as cumtrapz
import scipy.optimize as opt
from scipy.optimize import Bounds
from time import time

# Assume time_vary_norm_utils is in the python path or same directory
try:
    from time_vary_norm_utils import (
        # up_or_down_RTs_fit_fn, # Not used in this specific script
        # cum_pro_and_reactive_time_vary_fn, # Not used
        # rho_A_t_VEC_fn, # Not used
        # up_or_down_RTs_fit_wrt_stim_fn, # Not used
        rho_A_t_fn, 
        cum_A_t_fn
    )
    print("Successfully imported helper functions from time_vary_norm_utils")
except ImportError:
    print("ERROR: Could not import helper functions from time_vary_norm_utils.")
    print("Please ensure 'time_vary_norm_utils.py' is in the Python path or the same directory.")
    # Define dummy functions if import fails, to allow script to run partially
    # THIS IS FOR DEMONSTRATION ONLY - REPLACE WITH ACTUAL FUNCTIONS
    def rho_A_t_fn(t, V, theta): return np.exp(-t) if t > 0 else 0
    def cum_A_t_fn(t, V, theta): return 1 - np.exp(-t) if t > 0 else 0
    print("WARNING: Using dummy placeholder functions for rho_A_t_fn and cum_A_t_fn.")
    

# Not used in this specific script, comment out or remove
# from types import SimpleNamespace
# from time_vary_and_norm_simulators import psiam_tied_data_gen_wrapper_rate_norm_fn

# %% [markdown]
# ## Data Loading and Preprocessing

# %%
try:
    exp_df = pd.read_csv('../outExp.csv')
except FileNotFoundError:
    raise FileNotFoundError("Error: '../outExp.csv' not found. Please check the path.")
    
    

count = ((exp_df['RTwrtStim'].isna()) & (exp_df['abort_event'] == 3)).sum()
print(f"Number of rows where RTwrtStim is NaN and abort_event == 3: {count}")

# Original filtering condition
exp_df = exp_df[~((exp_df['RTwrtStim'].isna()) & (exp_df['abort_event'] == 3))].copy()
print(f"Shape after removing NaNs in RTwrtStim for abort_event==3: {exp_df.shape}")

exp_df_batch = exp_df[
    (exp_df['batch_name'] == 'Comparable') &
    (exp_df['LED_trial'].isin([np.nan, 0]))
].copy()
print(f"Shape after filtering for batch and LED trial: {exp_df_batch.shape}")


df_valid_and_aborts = exp_df_batch[
    (exp_df_batch['success'].isin([1,-1])) |
    (exp_df_batch['abort_event'] == 3)
].copy()
print(f"Shape after selecting valid trials and aborts: {df_valid_and_aborts.shape}")


df_aborts = df_valid_and_aborts[df_valid_and_aborts['abort_event'] == 3].copy()

# Ensure 'response_poke' and 'ILD' exist before applying functions
if 'response_poke' in df_valid_and_aborts.columns:
    # 1 is right , -1 is left
    df_valid_and_aborts['choice'] = df_valid_and_aborts['response_poke'].apply(lambda x: 1 if x == 3 else (-1 if x == 2 else random.choice([1, -1])))
else:
    print("Warning: 'response_poke' column not found. Skipping 'choice' calculation.")
    df_valid_and_aborts['choice'] = 1 # Assign dummy value

if 'ILD' in df_valid_and_aborts.columns and 'choice' in df_valid_and_aborts.columns:
    # 1 or 0 if the choice was correct or not
    df_valid_and_aborts['accuracy'] = (df_valid_and_aborts['ILD'] * df_valid_and_aborts['choice']).apply(lambda x: 1 if x > 0 else 0)
else:
     print("Warning: 'ILD' or 'choice' column not found. Skipping 'accuracy' calculation.")
     df_valid_and_aborts['accuracy'] = 1 # Assign dummy value

df_aborts = df_valid_and_aborts[df_valid_and_aborts['abort_event'] == 3].copy()
print(f"Number of abort trials selected: {df_aborts.shape[0]}")
print(f"Total trials for fitting (valid + aborts): {df_valid_and_aborts.shape[0]}")


# find ABL and ILD
ABL_arr = df_valid_and_aborts['ABL'].unique()
ILD_arr = df_valid_and_aborts['ILD'].unique()

# sort ILD arr in ascending order
ILD_arr = np.sort(ILD_arr)
ABL_arr = np.sort(ABL_arr)

print('ABL:', ABL_arr)
print('ILD:', ILD_arr)

if df_valid_and_aborts.empty:
    raise SystemExit("ERROR: No data left after filtering. Check data loading and filtering steps.")

# %% [markdown]
# ## Model Parameters and Constants

# %%
proactive_trunc_time = 0.3
N_JOBS = 30 # Number of parallel jobs for likelihood calculation

# %% [markdown]
# ## Log-Likelihood Function

# %%
def compute_loglike(row, V_A, theta_A, t_A_aff):
    """Calculates the log-likelihood for a single trial row."""
    
    # Use RTwrtStim if it's a valid trial (success==1 or -1), otherwise use TotalFixTime for aborts
    # This assumes RTwrtStim is the relevant time for the likelihood calculation of successful trials,
    # and TotalFixTime is relevant for aborts. Adjust if this logic is incorrect.
    
    # IMPORTANT: Check if TotalFixTime or intended_fix makes sense for your model.
    # The original code used TotalFixTime and intended_fix. Reverting to that.
    timed_fix = row['TotalFixTime']
    intended_fix = row['intended_fix']
    
    rt = timed_fix       # Use TotalFixTime as the observed time
    t_stim = intended_fix # Use intended_fix as the stimulus/censoring time

    if pd.isna(rt) or pd.isna(t_stim):
        # Handle potential missing values if necessary
        # print(f"Warning: NaN found in row: {row.name}, rt={rt}, t_stim={t_stim}. Skipping row.")
        return 0 # Or handle appropriately, e.g., raise error or impute
        
    try:
        # Calculate cumulative probability up to proactive_trunc_time (used for truncation)
        cum_prob_trunc = cum_A_t_fn(proactive_trunc_time - t_A_aff, V_A, theta_A)
        trunc_factor = 1.0 - cum_prob_trunc
        
        # Add small epsilon to avoid division by zero if trunc_factor is exactly 0
        trunc_factor = max(trunc_factor, 1e-50) 

        if row['abort_event'] == 3: # Abort trial
            # Density only non-zero if abort time is >= proactive_trunc_time
            if rt < proactive_trunc_time:
                pdf = 0.0 
            else:
                pdf = rho_A_t_fn(rt - t_A_aff, V_A, theta_A)
        else: # Valid trial (success == 1 or -1)
            # Probability of *not* aborting before stimulus time t_stim
            pdf = 1.0 - cum_A_t_fn(t_stim - t_A_aff, V_A, theta_A)

        # Normalize by the truncation factor (probability of not aborting before proactive_trunc_time)
        pdf /= trunc_factor
        
        # Floor the pdf to avoid log(0)
        pdf = max(pdf, 1e-50) 

    except (ValueError, TypeError) as e:
        print(f"Error during PDF/CDF calculation for row {row.name}: {e}")
        print(f"Params: V_A={V_A}, theta_A={theta_A}, t_A_aff={t_A_aff}")
        print(f"Inputs: rt={rt}, t_stim={t_stim}, proactive_trunc_time={proactive_trunc_time}")
        # Decide how to handle: return large negative loglike, 0, or raise error
        return -np.inf 

    log_pdf = np.log(pdf)

    if np.isnan(log_pdf):
        # print(f'Warning: NaN log_pdf encountered.')
        # print(f'row["abort_event"] = {row["abort_event"]}')
        # print(f'row["TotalFixTime"] / rt = {rt}')
        # print(f'row["intended_fix"] / t_stim = {t_stim}')
        # print(f'Calculated pdf = {pdf}')
        # print(f'Params: V_A={V_A}, theta_A={theta_A}, t_A_aff={t_A_aff}')
        # print(f'Trunc Factor = {trunc_factor}')
        return -np.inf # Return large negative value for invalid log-likelihood

    return log_pdf


def aggregate_loglike_fn(params):
    """Aggregates log-likelihood over all trials using joblib."""
    V_A, theta_A, t_A_aff = params
    
    # Check for obviously bad parameter values early
    if V_A <= 0 or theta_A <= 0:
        return -np.inf

    try:
        all_loglike = Parallel(n_jobs=N_JOBS)(delayed(compute_loglike)(row, V_A, theta_A, t_A_aff)
                                           for _, row in df_aborts.iterrows())
        
        # Check if any results are None or non-numeric before summing
        valid_loglikes = [ll for ll in all_loglike if isinstance(ll, (int, float)) and np.isfinite(ll)]
        
        if not valid_loglikes:
             print("Warning: No valid log-likelihoods computed.")
             return -np.inf

        total_loglike = np.sum(valid_loglikes)

    except Exception as e:
        print(f"Error during parallel computation with params {params}: {e}")
        total_loglike = -np.inf # Assign large negative value on error

    # Add a check for the final sum being non-finite
    if not np.isfinite(total_loglike):
        # print(f"Warning: Non-finite total log-likelihood ({total_loglike}) for params {params}")
        return -np.inf
        
    return total_loglike


# %% [markdown]
# ## Prior Function

# %%
def trapezoidal_logpdf(x, a, b, c, d):
    """Computes the log PDF of a trapezoidal distribution."""
    if x < a or x > d or a > b or b > c or c > d:
        return -np.inf  # Outside support or invalid bounds

    # Avoid division by zero if boundaries coincide
    width_bottom = d - a
    width_top = c - b
    if width_bottom <= 0: return -np.inf # Invalid base
    
    area = (width_bottom + width_top) / 2.0
    if area <= 0: return -np.inf # Should not happen with valid a,b,c,d

    h_max = 1.0 / area  # Height of the normalized trapezoid

    if a <= x < b:
        # Handle vertical slope case b=a
        slope_ab = h_max / (b - a) if (b - a) > 1e-12 else np.inf
        pdf_value = (x - a) * slope_ab
    elif b <= x <= c:
        pdf_value = h_max
    elif c < x <= d:
        # Handle vertical slope case d=c
        slope_cd = h_max / (d - c) if (d - c) > 1e-12 else np.inf
        pdf_value = (d - x) * slope_cd
    else: 
         # This case included for completeness, already handled by initial check
        pdf_value = 0.0

    if pdf_value <= 1e-100: # Check against small positive number for log
        return -np.inf
    else:
        return np.log(pdf_value)


# Define bounds (copied from the original script)
V_A_bounds = [0.01, 10]
theta_A_bounds = [0.1, 6]
t_A_aff_bounds = [-1, 0.1]

V_A_plausible_bounds = [0.1, 3]
theta_A_plausible_bounds = [0.5, 4]
t_A_aff_plausible_bounds = [-0.25, 0.05]

plb = np.array([
    V_A_plausible_bounds[0],
    theta_A_plausible_bounds[0],
    t_A_aff_plausible_bounds[0]
])

pub = np.array([
    V_A_plausible_bounds[1],
    theta_A_plausible_bounds[1],
    t_A_aff_plausible_bounds[1]
])

def prior_fn(params):
    """Computes the joint log-prior for the parameters."""
    V_A, theta_A, t_A_aff = params

    V_A_logpdf = trapezoidal_logpdf(
        V_A,
        V_A_bounds[0], V_A_plausible_bounds[0], V_A_plausible_bounds[1], V_A_bounds[1]
    )
    
    theta_A_logpdf = trapezoidal_logpdf(
        theta_A,
        theta_A_bounds[0], theta_A_plausible_bounds[0], theta_A_plausible_bounds[1], theta_A_bounds[1]
    )
    
    t_A_aff_logpdf = trapezoidal_logpdf(
        t_A_aff,
        t_A_aff_bounds[0], t_A_aff_plausible_bounds[0], t_A_aff_plausible_bounds[1], t_A_aff_bounds[1]
    )
    
    # Check for -inf which indicates parameters are outside prior support
    if np.isneginf(V_A_logpdf) or np.isneginf(theta_A_logpdf) or np.isneginf(t_A_aff_logpdf):
        return -np.inf

    return V_A_logpdf + theta_A_logpdf + t_A_aff_logpdf

# %% [markdown]
# ## Joint Log-Posterior Function

# %%
def joint_log_posterior_fn(params):
    """Computes the joint log-posterior (log-prior + log-likelihood)."""
    log_prior = prior_fn(params)
    
    # If prior is -inf, parameters are invalid, no need to calculate likelihood
    if np.isneginf(log_prior):
        return -np.inf
        
    log_likelihood = aggregate_loglike_fn(params)
    
    # If likelihood calculation failed or resulted in -inf
    if np.isneginf(log_likelihood):
         return -np.inf

    return log_prior + log_likelihood

# %% [markdown]
# ## Objective Functions for Minimization (Negative Log-Posterior/Likelihood)

# %%
# Define hard bounds (lower and upper)
lb = np.array([V_A_bounds[0], theta_A_bounds[0], t_A_aff_bounds[0]])
ub = np.array([V_A_bounds[1], theta_A_bounds[1], t_A_aff_bounds[1]])
bounds_obj = Bounds(lb, ub)

# Bounds for differential_evolution (list of tuples)
bounds_list = list(zip(lb, ub))

# Objective function for MAP estimation (minimize negative log-posterior)
def neg_log_joint(params):
    # First check hard bounds - essential for some optimizers
    if not np.all((params >= lb) & (params <= ub)):
        return np.inf # Use np.inf for minimization objective outside bounds
        
    log_posterior = joint_log_posterior_fn(params)
    
    # Handle -inf return (e.g., params outside prior support or likelihood error)
    if np.isneginf(log_posterior):
        return np.inf # Return a large positive number for minimization
    
    return -log_posterior

# Objective function for MLE estimation (minimize negative log-likelihood)
# Use this if you want to ignore the prior
def neg_log_likelihood(params):
    # First check hard bounds
    if not np.all((params >= lb) & (params <= ub)):
        return np.inf 
        
    log_like = aggregate_loglike_fn(params)
    
    # Handle -inf return
    if np.isneginf(log_like):
        return np.inf
        
    return -log_like


# %% [markdown]
# ## Setup for Optimization

# %%
# Bounds for L-BFGS-B (using scipy.optimize.Bounds)
bounds_obj = Bounds(lb, ub)

# Bounds for differential_evolution (list of tuples)
bounds_list = list(zip(lb, ub))

# Initial Guess (using plausible centers or provided values)
# x_0 = np.array([
#     (V_A_plausible_bounds[0] + V_A_plausible_bounds[1]) / 2,
#     (theta_A_plausible_bounds[0] + theta_A_plausible_bounds[1]) / 2,
#     (t_A_aff_plausible_bounds[0] + t_A_aff_plausible_bounds[1]) / 2,
# ])
# Use the user's initial guess
x_0 = np.clip(np.random.uniform(low=plb, high=pub, size=len(plb)), lb, ub)

# Ensure x_0 is within hard bounds
# x_0 = np.clip(x_0, lb, ub) 
print(f"Using Initial Guess (clipped to bounds): {x_0}")


# %% [markdown]
# ## Run Optimization Methods

# %%

# --- Method 1: L-BFGS-B (Local Optimization for MAP) ---
print("\n--- Running L-BFGS-B ---")
objective_fn_lbfgsb = neg_log_joint # Use neg_log_joint for MAP
print(f"Optimizing: MAP (Negative Log Joint)")

start_time_lbfgsb = time()
result_lbfgsb = None # Initialize result variable
try:
    result_lbfgsb = opt.minimize(
        objective_fn_lbfgsb,
        x_0,
        method='L-BFGS-B',
        bounds=bounds_obj,
        options={'disp': False, 'maxiter': 200, 'ftol': 1e-7, 'gtol': 1e-5} # Standard options
    )
except Exception as e:
    print(f"L-BFGS-B optimization failed with error: {e}")
    
end_time_lbfgsb = time()

print("\n--- L-BFGS-B Results ---")
params_lbfgsb = None
if result_lbfgsb is not None and result_lbfgsb.success:
    print(f"Optimization successful!")
    params_lbfgsb = result_lbfgsb.x
    max_log_posterior_lbfgsb = -result_lbfgsb.fun 
    print(f"Found MAP parameters (V_A, theta_A, t_A_aff): {params_lbfgsb}")
    print(f"Maximum log posterior value: {max_log_posterior_lbfgsb}")
else:
    status_msg = result_lbfgsb.message if result_lbfgsb is not None else "Optimization did not run or failed."
    print(f"Optimization failed: {status_msg}")
    if result_lbfgsb is not None:
         print(f"Last parameters: {result_lbfgsb.x}")
         print(f"Last function value: {-result_lbfgsb.fun}")

print(f"Time taken: {end_time_lbfgsb - start_time_lbfgsb:.2f} seconds")


# --- Method 2: Differential Evolution (Global Optimization for MAP) ---
print("\n\n--- Running Differential Evolution ---")
objective_fn_de = neg_log_joint # Use neg_log_joint for MAP
print(f"Optimizing: MAP (Negative Log Joint)")

# Adjust workers based on your joblib N_JOBS and system cores
# Using N_JOBS directly assumes DE can efficiently manage these workers
n_workers_de = N_JOBS # Use the same number of workers as likelihood calculation

start_time_de = time()
result_de = None # Initialize result variable
try:
    result_de = opt.differential_evolution(
        objective_fn_de,
        bounds=bounds_list,
        strategy='best1bin',    # Common strategy
        maxiter=100,            # Iterations limit (adjust as needed)
        popsize=15,             # Population size (rule of thumb: >= 5*Dim)
        tol=0.01,               # Convergence tolerance
        mutation=(0.5, 1),      # Mutation factor range
        recombination=0.7,      # Recombination probability
        disp=False,              # Display progress
        workers=n_workers_de,   # Use parallel workers
        updating='deferred'     # Recommended when using workers
    )
except Exception as e:
     print(f"Differential Evolution optimization failed with error: {e}")
     
end_time_de = time()

print("\n--- Differential Evolution Results ---")
params_de = None
if result_de is not None and result_de.success:
    print(f"Optimization successful!")
    params_de = result_de.x
    max_log_posterior_de = -result_de.fun 
    print(f"Found MAP parameters (V_A, theta_A, t_A_aff): {params_de}")
    print(f"Maximum log posterior value: {max_log_posterior_de}")
else:
    # DE might terminate without success flag but still have a result
    if result_de is not None and hasattr(result_de, 'x'):
         status_msg = result_de.message if hasattr(result_de, 'message') else "Terminated without success flag."
         print(f"Optimization terminated: {status_msg}")
         params_de = result_de.x # Still store the best found params
         print(f"Best parameters found: {params_de}")
         print(f"Best log posterior value found: {-result_de.fun}")
    else:
         print(f"Optimization did not run or failed catastrophically.")


print(f"Time taken: {end_time_de - start_time_de:.2f} seconds")


# %% [markdown]
# ## Final Parameter Output

# %%
print("\n\n--- Final Optimization Results ---")

print("Parameters Format: [V_A, theta_A, t_A_aff]")

if params_lbfgsb is not None:
    print(f"\nL-BFGS-B MAP Parameters: {params_lbfgsb}")
    # You could re-evaluate the likelihood/posterior at this point if needed
    # final_log_post_lbfgsb = joint_log_posterior_fn(params_lbfgsb)
    # print(f"L-BFGS-B Final Log Posterior: {final_log_post_lbfgsb}")
else:
    print("\nL-BFGS-B did not converge successfully or did not run.")

if params_de is not None:
     print(f"\nDifferential Evolution MAP Parameters: {params_de}")
     # final_log_post_de = joint_log_posterior_fn(params_de)
     # print(f"Differential Evolution Final Log Posterior: {final_log_post_de}")
else:
    print("\nDifferential Evolution did not converge successfully or did not run.")

# You can choose which parameters to use going forward, e.g., from DE if it found a better value
best_params = None
if params_de is not None:
    best_params = params_de
    print(f"\nUsing Differential Evolution parameters as final result.")
elif params_lbfgsb is not None:
    best_params = params_lbfgsb
    print(f"\nUsing L-BFGS-B parameters as final result (DE failed).")
else:
    print("\nNo successful optimization result obtained.")

if best_params is not None:
    print(f"\n---> Final Selected Parameters: {best_params}")

Successfully imported helper functions from time_vary_norm_utils
Number of rows where RTwrtStim is NaN and abort_event == 3: 16
Shape after removing NaNs in RTwrtStim for abort_event==3: (792588, 62)
Shape after filtering for batch and LED trial: (118867, 62)
Shape after selecting valid trials and aborts: (90296, 62)
Number of abort trials selected: 10598
Total trials for fitting (valid + aborts): 90296
ABL: [10 25 40 50 55 70]
ILD: [-8.   -4.   -2.25 -1.25 -0.5   0.    0.5   1.25  2.25  4.    8.  ]
Using Initial Guess (clipped to bounds): [ 0.60544472  2.85673582 -0.16070492]

--- Running L-BFGS-B ---
Optimizing: MAP (Negative Log Joint)


  df = fun(x) - f0



--- L-BFGS-B Results ---
Optimization successful!
Found MAP parameters (V_A, theta_A, t_A_aff): [ 0.60544472  2.85673582 -0.16070492]
Maximum log posterior value: -520180.585435247
Time taken: 7.62 seconds


--- Running Differential Evolution ---
Optimizing: MAP (Negative Log Joint)


  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._backend.configure(n_jobs=self.n_jobs, parallel=self,
  n_jobs = self._bac


--- Differential Evolution Results ---
Optimization successful!
Found MAP parameters (V_A, theta_A, t_A_aff): [ 4.33935744  2.85768139 -0.12724264]
Maximum log posterior value: -494352.53566732287
Time taken: 11.48 seconds


--- Final Optimization Results ---
Parameters Format: [V_A, theta_A, t_A_aff]

L-BFGS-B MAP Parameters: [ 0.60544472  2.85673582 -0.16070492]

Differential Evolution MAP Parameters: [ 4.33935744  2.85768139 -0.12724264]

Using Differential Evolution parameters as final result.

---> Final Selected Parameters: [ 4.33935744  2.85768139 -0.12724264]


# Didn't converge, better to fit aborts seperately first

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from pyvbmc import VBMC
import corner
from tqdm.notebook import tqdm
import pickle
import random
from scipy.integrate import cumulative_trapezoid as cumtrapz

from time_vary_norm_utils import (
    up_or_down_RTs_fit_fn, cum_pro_and_reactive_time_vary_fn,
    rho_A_t_VEC_fn, up_or_down_RTs_fit_wrt_stim_fn, rho_A_t_fn, cum_A_t_fn)
from types import SimpleNamespace
from time_vary_and_norm_simulators import psiam_tied_data_gen_wrapper_rate_norm_fn

In [None]:
exp_df = pd.read_csv('../outExp.csv')

count = ((exp_df['RTwrtStim'].isna()) & (exp_df['abort_event'] == 3)).sum()
print("Number of rows where RTwrtStim is NaN and abort_event == 3:", count)


exp_df = exp_df[~((exp_df['RTwrtStim'].isna()) & (exp_df['abort_event'] == 3))].copy()

exp_df_batch = exp_df[
    (exp_df['batch_name'] == 'Comparable') &
    (exp_df['LED_trial'].isin([np.nan, 0]))
]

df_valid_and_aborts = exp_df_batch[
    (exp_df_batch['success'].isin([1,-1])) |
    (exp_df_batch['abort_event'] == 3)
].copy()

# 1 is right , -1 is left
df_valid_and_aborts['choice'] = df_valid_and_aborts['response_poke'].apply(lambda x: 1 if x == 3 else (-1 if x == 2 else random.choice([1, -1])))

# 1 or 0 if the choice was correct or not
df_valid_and_aborts['accuracy'] = (df_valid_and_aborts['ILD'] * df_valid_and_aborts['choice']).apply(lambda x: 1 if x > 0 else 0)

df_aborts = df_valid_and_aborts[df_valid_and_aborts['abort_event'] == 3]
# find ABL and ILD
ABL_arr = df_valid_and_aborts['ABL'].unique()
ILD_arr = df_valid_and_aborts['ILD'].unique()


# sort ILD arr in ascending order
ILD_arr = np.sort(ILD_arr)
ABL_arr = np.sort(ABL_arr)

print('ABL:', ABL_arr)
print('ILD:', ILD_arr)

# vbmc

In [None]:
is_norm = False
is_time_vary = False
phi_params_obj = np.nan
rate_norm_l = np.nan

In [None]:
proactive_trunc_time = 0.3
K_max = 10

## loglike fn

In [None]:
def compute_loglike(row, V_A, theta_A, t_A_aff):
    
    timed_fix = row['TotalFixTime']
    intended_fix = row['intended_fix']
    
    
    rt = timed_fix
    t_stim = intended_fix
    trunc_factor = 1 - cum_A_t_fn(proactive_trunc_time - t_A_aff, V_A, theta_A)

    if row['abort_event'] == 3:
        if rt < proactive_trunc_time:
            pdf = 0
        else:
            pdf = rho_A_t_fn(rt - t_A_aff, V_A, theta_A)
    else:
        pdf = 1 - cum_A_t_fn(t_stim - t_A_aff, V_A, theta_A)

    pdf /= (trunc_factor + 1e-50)
    pdf = max(pdf, 1e-50)

    if np.isnan(pdf):
        print(f'row["abort_event"] = {row["abort_event"]}')
        print(f'row["RTwrtStim"] = {row["RTwrtStim"]}')
        raise ValueError(f'nan pdf rt = {rt}, t_stim = {t_stim}')
    
    return np.log(pdf)
    
    


def vbmc_loglike_fn(params):
    V_A, theta_A, t_A_aff = params
    all_loglike = Parallel(n_jobs=30)(delayed(compute_loglike)(row, V_A, theta_A, t_A_aff)\
                                    #    for _, row in df_aborts.iterrows() )
                                       for _, row in df_valid_and_aborts.iterrows() )

    return np.sum(all_loglike)


## priors

In [None]:
V_A_bounds = [0.01, 10]
theta_A_bounds = [0.1, 6]
t_A_aff_bounds = [-1, 0.1]


V_A_plausible_bounds = [0.1, 3]
theta_A_plausible_bounds = [0.5, 4]
t_A_aff_plausible_bounds = [-0.25, 0.05]


In [None]:
def trapezoidal_logpdf(x, a, b, c, d):
    if x < a or x > d:
        return -np.inf  # Logarithm of zero
    area = ((b - a) + (d - c)) / 2 + (c - b)
    h_max = 1.0 / area  # Height of the trapezoid to normalize the area to 1
    
    if a <= x <= b:
        pdf_value = ((x - a) / (b - a)) * h_max
    elif b < x < c:
        pdf_value = h_max
    elif c <= x <= d:
        pdf_value = ((d - x) / (d - c)) * h_max
    else:
        pdf_value = 0.0  # This case is redundant due to the initial check

    if pdf_value <= 0.0:
        return -np.inf
    else:
        return np.log(pdf_value)
    

def vbmc_prior_fn(params):
    V_A, theta_A, t_A_aff = params

    V_A_logpdf = trapezoidal_logpdf(
        V_A,
        V_A_bounds[0],
        V_A_plausible_bounds[0],
        V_A_plausible_bounds[1],
        V_A_bounds[1]
    )
    
    theta_A_logpdf = trapezoidal_logpdf(
        theta_A,
        theta_A_bounds[0],
        theta_A_plausible_bounds[0],
        theta_A_plausible_bounds[1],
        theta_A_bounds[1]
    )
    
    t_A_aff_logpdf = trapezoidal_logpdf(
        t_A_aff,
        t_A_aff_bounds[0],
        t_A_aff_plausible_bounds[0],
        t_A_aff_plausible_bounds[1],
        t_A_aff_bounds[1]
    )
    
    
    
    return (
        V_A_logpdf +
        theta_A_logpdf +
        t_A_aff_logpdf
    )

## prior + loglike

In [None]:
def vbmc_joint_fn(params):
    priors = vbmc_prior_fn(params)
    loglike = vbmc_loglike_fn(params)

    return priors + loglike

# run vbmc

In [None]:
# Add bounds for all parameters (order: V_A, theta_A, t_A_aff, rate_lambda, T_0, theta_E, w, t_E_aff, del_go)
lb = np.array([
    V_A_bounds[0],
    theta_A_bounds[0],
    t_A_aff_bounds[0],
])

ub = np.array([
    V_A_bounds[1],
    theta_A_bounds[1],
    t_A_aff_bounds[1]
])

plb = np.array([
    V_A_plausible_bounds[0],
    theta_A_plausible_bounds[0],
    t_A_aff_plausible_bounds[0]
])

pub = np.array([
    V_A_plausible_bounds[1],
    theta_A_plausible_bounds[1],
    t_A_aff_plausible_bounds[1]
])

# Initialize with random values within plausible bounds
np.random.seed(42)
# V_A_0 = np.random.uniform(*V_A_plausible_bounds)
# theta_A_0 = np.random.uniform(*theta_A_plausible_bounds)
# t_A_aff_0 = np.random.uniform(*t_A_aff_plausible_bounds)

V_A_0 = 2.8
theta_A_0 = 3.2
t_A_aff_0 = -0.22

x_0 = np.array([
    V_A_0,
    theta_A_0,
    t_A_aff_0
])

# Run VBMC
vbmc = VBMC(vbmc_joint_fn, x_0, lb, ub, plb, pub, options={'display': 'on'})
vp, results = vbmc.optimize()

In [None]:

# vbmc.save('ONLY_norm_vbmc_fit.pkl')

In [None]:
# Sample from the VBMC posterior (returns tuple: samples, log weights)
vp_samples = vp.sample(int(1e5))[0]

# Convert T_0 to ms (T_0 is at index 4)

# Parameter labels (order matters!)
param_labels = [
    r'$V_A$',           # 0
    r'$\theta_A$',      # 1
    r'$t_A^{aff}$',     # 2
]

# Compute 1st and 99th percentiles for each param to restrict range
percentiles = np.percentile(vp_samples, [1, 99], axis=0)
_ranges = [(percentiles[0, i], percentiles[1, i]) for i in range(vp_samples.shape[1])]

# Create the corner plot
fig = corner.corner(
    vp_samples,
    labels=param_labels,
    show_titles=True,
    quantiles=[0.025, 0.5, 0.975],
    range=_ranges,
    title_fmt=".3f"
)

plt.show()

# compare with data

In [None]:
V_A = vp_samples[:,0].mean()
theta_A = vp_samples[:,1].mean()
t_A_aff = vp_samples[:,2].mean()


print(f'V_A: {V_A}')
print(f'theta_A: {theta_A}')
print(f't_A_aff: {t_A_aff}')

In [None]:


N_theory = int(1e3)
t_pts = np.arange(0,1.25, 0.001)
t_stim_samples = df_valid_and_aborts.sample(N_theory)['intended_fix']
pdf_samples = np.zeros((N_theory, len(t_pts)))

for i, t_stim in enumerate(t_stim_samples):
    t_stim_idx = np.searchsorted(t_pts, t_stim)
    proactive_trunc_idx = np.searchsorted(t_pts, proactive_trunc_time)
    pdf_samples[i, :proactive_trunc_idx] = 0
    pdf_samples[i, t_stim_idx:] = 0
    t_btn = t_pts[proactive_trunc_idx:t_stim_idx-1]
    
    pdf_samples[i, proactive_trunc_idx:t_stim_idx-1] = rho_A_t_VEC_fn(t_btn - t_A_aff, V_A, theta_A) / (1 - cum_A_t_fn(proactive_trunc_time - t_A_aff, V_A, theta_A))



plt.figure(figsize=(10,5))

bins = np.arange(0,2,0.02)
df_aborts = df_valid_and_aborts[df_valid_and_aborts['abort_event'] == 3]
df_aborts_RT = df_aborts['TotalFixTime'].dropna().values
df_aborts_RT_trunc = df_aborts_RT[df_aborts_RT > proactive_trunc_time]

frac_aborts = len(df_aborts_RT_trunc) / len(df_valid_and_aborts)
aborts_hist, _ = np.histogram(df_aborts_RT_trunc, bins=bins, density=True)

plt.plot(bins[:-1], aborts_hist * frac_aborts, label='data')
plt.xlabel('abort rt')
plt.ylabel('density')

plt.plot(t_pts, np.mean(pdf_samples, axis=0), label='theory')
plt.legend()

In [None]:
# vbmc.save('non_linear_only_norm_decent_non_convg.pkl')

# Simulation diagnostics

## params

In [None]:
rate_lambda = vp_samples[:, 0].mean()
T_0 = vp_samples[:, 1].mean()
theta_E = vp_samples[:, 2].mean()
w = vp_samples[:, 3].mean()
Z_E = (w - 0.5) * 2 * theta_E
t_E_aff = vp_samples[:, 4].mean()
rate_norm_l = vp_samples[:, 5].mean()

# Print them out
print("Posterior Means:")
print(f"rate_lambda  = {rate_lambda:.5f}")
print(f"T_0 (ms)      = {1e3*T_0:.5f}")
print(f"theta_E       = {theta_E:.5f}")
print(f"Z_E           = {Z_E:.5f}")
print(f"t_E_aff       = {1e3*t_E_aff:.5f} ms")
print(f"rate_norm_l   = {rate_norm_l:.5f}")


## simulate

In [None]:
# sample t-stim
N_sim = int(1e6)

t_stim_samples = df_led_off['intended_fix'].sample(N_sim, replace=True).values
ABL_samples = df_led_off['ABL'].sample(N_sim, replace=True).values
ILD_samples = df_led_off['ILD'].sample(N_sim, replace=True).values

N_print = int(N_sim / 5)
dt  = 1e-4

sim_results = Parallel(n_jobs=30)(
    delayed(psiam_tied_data_gen_wrapper_rate_norm_fn)(
        V_A, theta_A, ABL_samples[iter_num], ILD_samples[iter_num], rate_lambda, T_0, theta_E, Z_E, t_A_aff, t_E_aff, del_go, 
        t_stim_samples[iter_num], rate_norm_l, iter_num, N_print, dt
    ) for iter_num in tqdm(range(N_sim))
)

# ******* TEMP add back ILD 16 *****************

In [None]:
# LED off rows
df_led_off = df[ df['LED_trial'] == 0 ]
print(f'len df_led_off = {len(df_led_off)}')

# > 0 and < 1s valid rt 
df_led_off_valid = df_led_off[
    (df_led_off['timed_fix'] - df_led_off['intended_fix'] > 0) &
    (df_led_off['timed_fix'] - df_led_off['intended_fix'] < 1)
]

df_led_off_valid = df_led_off_valid[df_led_off_valid['response_poke'].isin([2,3])]

## prepare valid sim df, data df

In [None]:
sim_results_df = pd.DataFrame(sim_results)
sim_results_df_valid = sim_results_df[
    (sim_results_df['rt'] > sim_results_df['t_stim']) &
    (sim_results_df['rt'] - sim_results_df['t_stim'] < 1)
].copy()
sim_results_df_valid.loc[:, 'correct'] = (sim_results_df_valid['ILD'] * sim_results_df_valid['choice']).apply(lambda x: 1 if x > 0 else 0)


df_led_off_valid_renamed = df_led_off_valid.rename(columns = {
    'timed_fix': 'rt',
    'intended_fix': 't_stim'
}).copy()

sim_df_1 = sim_results_df_valid.copy()
data_df_1 = df_led_off_valid_renamed.copy()

## plots

In [None]:
bw = 0.02
bins = np.arange(0, 1, bw)
bin_centers = bins[:-1] + (0.5 * bw)

n_rows = len(ILD_arr)
n_cols = len(ABL_arr)

fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows), sharey='row')

for i_idx, ILD in enumerate(ILD_arr):
    for a_idx, ABL in enumerate(ABL_arr):
        ax = axs[i_idx, a_idx] if n_rows > 1 else axs[a_idx]

        sim_df_1_ABL_ILD = sim_df_1[(sim_df_1['ABL'] == ABL) & (sim_df_1['ILD'] == ILD)]
        data_df_1_ABL_ILD = data_df_1[(data_df_1['ABL'] == ABL) & (data_df_1['ILD'] == ILD)]

        sim_up = sim_df_1_ABL_ILD[sim_df_1_ABL_ILD['choice'] == 1]
        sim_down = sim_df_1_ABL_ILD[sim_df_1_ABL_ILD['choice'] == -1]
        data_up = data_df_1_ABL_ILD[data_df_1_ABL_ILD['choice'] == 1]
        data_down = data_df_1_ABL_ILD[data_df_1_ABL_ILD['choice'] == -1]

        sim_up_rt = sim_up['rt'] - sim_up['t_stim']
        sim_down_rt = sim_down['rt'] - sim_down['t_stim']
        data_up_rt = data_up['rt'] - data_up['t_stim']
        data_down_rt = data_down['rt'] - data_down['t_stim']

        sim_up_hist, _ = np.histogram(sim_up_rt, bins=bins, density=True)
        sim_down_hist, _ = np.histogram(sim_down_rt, bins=bins, density=True)
        data_up_hist, _ = np.histogram(data_up_rt, bins=bins, density=True)
        data_down_hist, _ = np.histogram(data_down_rt, bins=bins, density=True)

        # Normalize histograms by proportion of trials
        sim_up_hist *= len(sim_up) / len(sim_df_1_ABL_ILD)
        sim_down_hist *= len(sim_down) / len(sim_df_1_ABL_ILD)
        data_up_hist *= len(data_up) / len(data_df_1_ABL_ILD)
        data_down_hist *= len(data_down) / len(data_df_1_ABL_ILD)

        # Plot
        ax.plot(bin_centers, data_up_hist, color='b', label='Data' if (i_idx == 0 and a_idx == 0) else "")
        ax.plot(bin_centers, -data_down_hist, color='b')
        ax.plot(bin_centers, sim_up_hist, color='r', label='Sim' if (i_idx == 0 and a_idx == 0) else "")
        ax.plot(bin_centers, -sim_down_hist, color='r')

        # Compute fractions
        data_total = len(data_df_1_ABL_ILD)
        sim_total = len(sim_df_1_ABL_ILD)
        data_up_frac = len(data_up) / data_total if data_total else 0
        data_down_frac = len(data_down) / data_total if data_total else 0
        sim_up_frac = len(sim_up) / sim_total if sim_total else 0
        sim_down_frac = len(sim_down) / sim_total if sim_total else 0

        ax.set_title(
            f"ABL: {ABL}, ILD: {ILD}\n"
            f"Data,Sim: (+{data_up_frac:.2f},+{sim_up_frac:.2f}), "
            f"(-{data_down_frac:.2f},-{sim_down_frac:.2f})"
        )
        
        ax.axhline(0, color='k', linewidth=0.5)
        ax.set_xlim([0, 0.7])
        if a_idx == 0:
            ax.set_ylabel("Density (Up / Down flipped)")
        if i_idx == n_rows - 1:
            ax.set_xlabel("RT (s)")

fig.tight_layout()
fig.legend(loc='upper right')
plt.show()


In [None]:
bw = 0.02
bins = np.arange(0, 1, bw)
bin_centers = bins[:-1] + (0.5 * bw)

def plot_tacho(df_1):
    df_1 = df_1.copy()
    df_1['RT_bin'] = pd.cut(df_1['rt'] - df_1['t_stim'], bins=bins, include_lowest=True)
    grouped = df_1.groupby('RT_bin', observed=False)['correct'].agg(['mean', 'count'])
    grouped['bin_mid'] = grouped.index.map(lambda x: x.mid)
    return grouped['bin_mid'], grouped['mean']

n_rows = len(ILD_arr)
n_cols = len(ABL_arr)

# === Define fig2 ===
fig2, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows), sharey='row')

for i_idx, ILD in enumerate(ILD_arr):
    for a_idx, ABL in enumerate(ABL_arr):
        ax = axs[i_idx, a_idx] if n_rows > 1 else axs[a_idx]

        sim_df_1_ABL_ILD = sim_df_1[(sim_df_1['ABL'] == ABL) & (sim_df_1['ILD'] == ILD)]
        data_df_1_ABL_ILD = data_df_1[(data_df_1['ABL'] == ABL) & (data_df_1['ILD'] == ILD)]

        sim_tacho_x, sim_tacho_y = plot_tacho(sim_df_1_ABL_ILD)
        data_tacho_x, data_tacho_y = plot_tacho(data_df_1_ABL_ILD)

        # Plotting
        ax.plot(data_tacho_x, data_tacho_y, color='b', label='Data' if (i_idx == 0 and a_idx == 0) else "")
        ax.plot(sim_tacho_x, sim_tacho_y, color='r', label='Sim' if (i_idx == 0 and a_idx == 0) else "")

        ax.set_ylim([0.5, 1.05])
        ax.set_xlim([0, 0.7])
        ax.set_title(f"ABL: {ABL}, ILD: {ILD}")
        if a_idx == 0:
            ax.set_ylabel("P(correct)")
        if i_idx == n_rows - 1:
            ax.set_xlabel("RT (s)")

fig2.tight_layout()
fig2.legend(loc='upper right')
plt.show()


# Grand RTDs and tacho

In [None]:
def grand_rtd(df_1):
    df_1_rt = df_1['rt'] - df_1['t_stim']
    rt_hist, _ = np.histogram(df_1_rt, bins=bins, density=True)
    return rt_hist

def plot_psycho(df_1):
    prob_choice_dict = {}

    all_ABL = np.sort(df_1['ABL'].unique())
    all_ILD = np.sort(df_1['ILD'].unique())

    for abl in all_ABL:
        filtered_df = df_1[df_1['ABL'] == abl]
        prob_choice_dict[abl] = [
            sum(filtered_df[filtered_df['ILD'] == ild]['choice'] == 1) / len(filtered_df[filtered_df['ILD'] == ild])
            for ild in all_ILD
        ]

    return prob_choice_dict

# === Define fig3 ===
fig3, axes = plt.subplots(1, 3, figsize=(15, 4))

# === Grand RTD ===
axes[0].plot(bin_centers, grand_rtd(data_df_1), color='b', label='data')
axes[0].plot(bin_centers, grand_rtd(sim_df_1), color='r', label='sim')
axes[0].legend()
axes[0].set_xlabel('rt wrt stim')
axes[0].set_ylabel('density')
axes[0].set_title('Grand RTD')

# === Grand Psychometric ===
data_psycho = plot_psycho(data_df_1)
sim_psycho = plot_psycho(sim_df_1)

colors = ['r', 'b', 'g']  # Adjust colors for your ABLs
for i, ABL in enumerate(ABL_arr):
    axes[1].plot(ILD_arr, data_psycho[ABL], color=colors[i], label=f'data ABL={ABL}', marker='o', linestyle='None')
    axes[1].plot(ILD_arr, sim_psycho[ABL], color=colors[i], linestyle='-')

axes[1].legend()
axes[1].set_xlabel('ILD')
axes[1].set_ylabel('P(right)')
axes[1].set_title('Grand Psychometric')

# === Grand Tacho ===
data_tacho_x, data_tacho_y = plot_tacho(data_df_1)
sim_tacho_x, sim_tacho_y = plot_tacho(sim_df_1)

axes[2].plot(data_tacho_x, data_tacho_y, color='b', label='data')
axes[2].plot(sim_tacho_x, sim_tacho_y, color='r', label='sim')
axes[2].legend()
axes[2].set_xlabel('rt wrt stim')
axes[2].set_ylabel('acc')
axes[2].set_title('Grand Tacho')
axes[2].set_ylim(0.5, 1)

fig3.tight_layout()
plt.show()


# all in a single PDF?

In [None]:
import os
from matplotlib.backends.backend_pdf import PdfPages
from docx import Document
from docx.shared import Inches

# Set your filename prefix
output_filename = 'no_ILD_16_V4_NON_LINEAR_ONLY_Norm_report'

# Ensure output directory
os.makedirs('outputs', exist_ok=True)

# === Save individual figures as PNGs ===
fig1_path = f'outputs/{output_filename}_updown_hist.png'
fig2_path = f'outputs/{output_filename}_tacho.png'
fig3_path = f'outputs/{output_filename}_grand_summary.png'

fig.savefig(fig1_path)
fig2.savefig(fig2_path)
fig3.savefig(fig3_path)

# === Create PDF with all three figures ===
pdf_path = f'outputs/{output_filename}.pdf'
with PdfPages(pdf_path) as pdf:
    for fig_item in [fig, fig2, fig3]:
        pdf.savefig(fig_item)

# === Create DOCX with all three figures ===
doc = Document()
doc.add_heading('RTD and Tacho Analysis Results', 0)

for img_path in [fig1_path, fig2_path, fig3_path]:
    doc.add_page_break()
    doc.add_picture(img_path, width=Inches(6.5))

docx_path = f'outputs/{output_filename}.docx'
doc.save(docx_path)

print(f"✅ Saved PDF to: {pdf_path}")
print(f"✅ Saved DOCX to: {docx_path}")


# up and down RTD

In [None]:
N_theory = int(1e3)
t_stim_and_led_tuple = [(row['intended_fix'], row['intended_fix'] - row['LED_onset_time']) for _, row in df.iterrows()]
random_indices = np.random.randint(0, len(t_stim_and_led_tuple), N_theory)
t_pts = np.arange(-1, 2, 0.001)

P_A_samples = np.zeros((N_theory, len(t_pts)))
for idx in range(N_theory):
    t_stim, t_LED = t_stim_and_led_tuple[random_indices[idx]]
    pdf = rho_A_t_VEC_fn(t_pts - t_A_aff + t_stim, V_A, theta_A)
    P_A_samples[idx, :] = pdf

P_A_samples_mean = np.mean(P_A_samples, axis=0)
C_A_mean = cumtrapz(P_A_samples_mean, t_pts, initial=0)

In [None]:
# Create figure and axes row by row to enable row-wise shared Y axes
fig = plt.figure(figsize=(18, 24))
axes = []

for i in range(10):  # 10 rows
    row_axes = []
    for j in range(3):  # 3 columns
        ax = fig.add_subplot(10, 3, i * 3 + j + 1, sharey=row_axes[0] if row_axes else None)
        row_axes.append(ax)
    axes.append(row_axes)

fig.subplots_adjust(hspace=0.8, wspace=0.4)

bin_size = 0.02
bins = np.arange(-1, 2, bin_size)
bin_centers = bins[:-1] + (bin_size / 2)
t_pts = np.arange(-1, 2, 0.001)

phi_params_obj = np.nan

for a_idx, ABL in enumerate(ABL_arr):
    for i_idx, ILD in enumerate(ILD_arr):
        ax = axes[i_idx][a_idx]  # axes[row][col] = (ILD, ABL)

        # data
        df_led_off_abort_and_valid = df_led_off[(df_led_off['abort_event'] == 3) | (df_led_off['response_poke'].isin([2,3]))]
        mask_invalid = ~df_led_off_abort_and_valid['response_poke'].isin([2, 3])
        # Step 2: Assign random values (2 or 3 with 50% chance) to those rows
        df_led_off_abort_and_valid.loc[mask_invalid, 'response_poke'] = np.random.choice([2, 3], size=mask_invalid.sum())

        df_ABL_ILD = df_led_off_abort_and_valid[
            (df_led_off_abort_and_valid['ABL'] == ABL) & (df_led_off_abort_and_valid['ILD'] == ILD)]
        # df_ABL_ILD = df_led_off_valid[
        #     (df_led_off_valid['ABL'] == ABL) & (df_led_off_valid['ILD'] == ILD)]
        df_ABL_ILD_up = df_ABL_ILD[df_ABL_ILD['response_poke'] == 3]
        df_ABL_ILD_down = df_ABL_ILD[df_ABL_ILD['response_poke'] == 2]

        df_ABL_ILD_up_rt = df_ABL_ILD_up['timed_fix'] - df_ABL_ILD_up['intended_fix']
        df_ABL_ILD_down_rt = df_ABL_ILD_down['timed_fix'] - df_ABL_ILD_down['intended_fix']

        data_up_rt_hist, _ = np.histogram(df_ABL_ILD_up_rt, bins=bins, density=True)
        data_down_rt_hist, _ = np.histogram(df_ABL_ILD_down_rt, bins=bins, density=True)

        data_frac_up = len(df_ABL_ILD_up) / len(df_ABL_ILD)
        data_frac_down = len(df_ABL_ILD_down) / len(df_ABL_ILD)

        # theory
        theory_ABL_ILD_up = np.zeros_like(t_pts)
        theory_ABL_ILD_down = np.zeros_like(t_pts)

        for idx, t in enumerate(t_pts):
            P_A = P_A_samples_mean[idx]
            C_A = C_A_mean[idx]
            theory_ABL_ILD_up[idx] = up_or_down_RTs_fit_wrt_stim_fn(
                t, 1,
                P_A, C_A,
                t_stim, ABL, ILD, rate_lambda, T_0, theta_E, Z_E, t_E_aff, del_go,
                phi_params_obj, rate_norm_l,
                is_norm, is_time_vary, K_max)

            theory_ABL_ILD_down[idx] = up_or_down_RTs_fit_wrt_stim_fn(
                t, -1,
                P_A, C_A,
                t_stim, ABL, ILD, rate_lambda, T_0, theta_E, Z_E, t_E_aff, del_go,
                phi_params_obj, rate_norm_l,
                is_norm, is_time_vary, K_max)

        ax.plot(bin_centers, data_up_rt_hist * data_frac_up, 'b--', label='Data Up' if i_idx == 0 and a_idx == 0 else "")
        ax.plot(bin_centers, -data_down_rt_hist * data_frac_down, 'b--', label='Data Down' if i_idx == 0 and a_idx == 0 else "")
        ax.plot(t_pts, theory_ABL_ILD_up, 'r-', label='Theory Up' if i_idx == 0 and a_idx == 0 else "")
        ax.plot(t_pts, -theory_ABL_ILD_down, 'r-', label='Theory Down' if i_idx == 0 and a_idx == 0 else "")

        ax.set_title(f'ABL={ABL}, ILD={ILD}', fontsize=9)
        ax.axhline(0, color='black', linewidth=0.5)
        ax.set_xlim(-0.2, 0.7)

# Add single legend outside the plot
handles, labels = axes[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', fontsize=10)

plt.tight_layout(rect=[0, 0, 0.98, 1])
plt.show()


# tachometric

In [None]:
import matplotlib.pyplot as plt

# Prepare the figure and axes
fig, axes = plt.subplots(nrows=10, ncols=3, figsize=(18, 24), sharex=True, sharey=True)
fig.subplots_adjust(hspace=0.5, wspace=0.3)
bin_size = 0.02
bins = np.arange(-1, 2, bin_size)
bin_centers = bins[:-1] + (bin_size / 2)
t_pts = np.arange(-1, 2, 0.001)

phi_params_obj = np.nan

for a_idx, ABL in enumerate(ABL_arr):
    for i_idx, ILD in enumerate(ILD_arr):
        ax = axes[i_idx, a_idx]  # (row=ILD, col=ABL)

        # data
        df_led_off_abort_and_valid = df_led_off[(df_led_off['abort_event'] == 3) | (df_led_off['response_poke'].isin([2,3]))]
        mask_invalid = ~df_led_off_abort_and_valid['response_poke'].isin([2, 3])
        df_led_off_abort_and_valid.loc[mask_invalid, 'response_poke'] = np.random.choice([2, 3], size=mask_invalid.sum())

        df_ABL_ILD = df_led_off_abort_and_valid[
            (df_led_off_abort_and_valid['ABL'] == ABL) & (df_led_off_abort_and_valid['ILD'] == ILD)].copy()
        # df_ABL_ILD = df_led_off_valid[
        #     (df_led_off_valid['ABL'] == ABL) & (df_led_off_valid['ILD'] == ILD)].copy()
        
        df_ABL_ILD.loc[:, 'RT'] = df_ABL_ILD['timed_fix'] - df_ABL_ILD['intended_fix']
        df_ABL_ILD.loc[:, 'is_correct'] = (
                    df_ABL_ILD['ILD'] * (2 * df_ABL_ILD['response_poke'] - 5)
                ) > 0
        df_ABL_ILD.loc[:, 'rt_bin'] = pd.cut(
                df_ABL_ILD['RT'], bins=bins, include_lowest=True
            )
        tachometric_curve = df_ABL_ILD.groupby('rt_bin', observed=False)['is_correct'].mean()


        # theory
        theory_ABL_ILD_up = np.zeros_like(t_pts)
        theory_ABL_ILD_down = np.zeros_like(t_pts)
        theory_tacho = np.zeros_like(t_pts)
        for idx, t in enumerate(t_pts):
            P_A = P_A_samples_mean[idx]
            C_A = C_A_mean[idx]
            theory_ABL_ILD_up[idx] = up_or_down_RTs_fit_wrt_stim_fn(
                t, 1,
                P_A, C_A,
                t_stim, ABL, ILD, rate_lambda, T_0, theta_E, Z_E, t_E_aff, del_go,
                phi_params_obj, rate_norm_l,
                is_norm, is_time_vary, K_max)

            theory_ABL_ILD_down[idx] = up_or_down_RTs_fit_wrt_stim_fn(
                t, -1,
                P_A, C_A,
                t_stim, ABL, ILD, rate_lambda, T_0, theta_E, Z_E, t_E_aff, del_go,
                phi_params_obj, rate_norm_l,
                is_norm, is_time_vary, K_max)

            if ILD > 0:
                theory_tacho[idx] = theory_ABL_ILD_up[idx] / (theory_ABL_ILD_up[idx] + theory_ABL_ILD_down[idx] + 1e-10)
            else:
                theory_tacho[idx] = theory_ABL_ILD_down[idx] / (theory_ABL_ILD_up[idx] + theory_ABL_ILD_down[idx] + 1e-10)
        

        ax.plot(bin_centers, tachometric_curve, 'b--')
        ax.plot(t_pts, theory_tacho, 'r-')
        ax.set_title(f'ABL={ABL}, ILD={ILD}', fontsize=9)
        ax.axhline(0, color='black', linewidth=0.5)
        ax.set_xlim(0, 0.7)
        ax.set_ylim(0.5, 1.05)

# Add shared axis labels

# Add a single legend outside the plot
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', fontsize=10)

plt.tight_layout(rect=[0, 0, 0.98, 1])  # leave space for legend
plt.show()


# grand rtd, psycho, tacho

In [None]:
# sample t-stim
N_sim = int(1e6)

t_stim_samples = df['intended_fix'].sample(N_sim, replace=True).values
ABL_samples = df['ABL'].sample(N_sim, replace=True).values
ILD_samples = df['ILD'].sample(N_sim, replace=True).values

N_print = int(N_sim / 5)
dt  = 1e-4

In [None]:
sim_results = Parallel(n_jobs=30)(
    delayed(psiam_tied_data_gen_wrapper_rate_norm_fn)(
        V_A, theta_A, ABL_samples[iter_num], ILD_samples[iter_num], rate_lambda, T_0, theta_E, Z_E, t_A_aff, t_E_aff, del_go, 
        t_stim_samples[iter_num], rate_norm_l, iter_num, N_print, dt
    ) for iter_num in tqdm(range(N_sim))
)

In [None]:
# sim df
sim_results_df = pd.DataFrame(sim_results)
sim_results_df_valid = sim_results_df[sim_results_df['rt'] > sim_results_df['t_stim']]
sim_results_df_valid_less_than_1 = sim_results_df_valid[sim_results_df_valid['rt'] - sim_results_df_valid['t_stim'] < 1].copy()

# rename data df cols
df_led_off_valid.loc[:,'choice'] = 2*df_led_off_valid['response_poke'] - 5
df_led_off_valid_renamed = df_led_off_valid.rename(columns={
    'timed_fix': 'rt',
    'intended_fix': 't_stim'
}).copy()

# add corr cols in both
df_led_off_valid_renamed.loc[:,'correct'] = (df_led_off_valid_renamed['choice'] * df_led_off_valid_renamed['ILD'] > 0).astype(int)
sim_results_df_valid_less_than_1.loc[:,'correct'] = (sim_results_df_valid_less_than_1['choice'] * sim_results_df_valid_less_than_1['ILD'] > 0).astype(int)

In [None]:
bw = 0.02
bins = np.arange(0, 1, bw)
bin_centers = bins[:-1] + 0.5*bw
def grand_rtd(df_1):
    df_1_rt = df_1['rt'] - df_1['t_stim']
    rt_hist, _ = np.histogram(df_1_rt, bins=bins, density=True)
    return rt_hist

def plot_psycho(df_1):
    prob_choice_dict = {}

    all_ABL = np.sort(df_1['ABL'].unique())
    all_ILD = np.sort(df_1['ILD'].unique())

    for abl in all_ABL:
        filtered_df = df_1[df_1['ABL'] == abl]
        prob_choice_dict[abl] = [sum(filtered_df[filtered_df['ILD'] == ild]['choice'] == 1) / len(filtered_df[filtered_df['ILD'] == ild]) for ild in all_ILD]

    return prob_choice_dict

def plot_tacho(df_1):
    # prob of correct vs RT
    df_1.loc[:, 'RT_bin'] = pd.cut(df_1['rt'] - df_1['t_stim'], bins=bins, include_lowest=True)
    grouped_by_rt_bin = df_1.groupby('RT_bin', observed=False)['correct'].agg(['mean', 'count'])
    grouped_by_rt_bin['bin_mid'] = grouped_by_rt_bin.index.map(lambda x: x.mid)
    return grouped_by_rt_bin['bin_mid'], grouped_by_rt_bin['mean']

In [None]:
import matplotlib.pyplot as plt

# Create a figure with 3 subplots in a single row
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# === grand RTD ===
axes[0].plot(bin_centers, grand_rtd(df_led_off_valid_renamed), color='b', label='data')
axes[0].plot(bin_centers, grand_rtd(sim_results_df_valid_less_than_1), color='r', label='sim')
axes[0].legend()
axes[0].set_xlabel('rt wrt stim')
axes[0].set_ylabel('density')
axes[0].set_title('Grand RTD')

# === grand psycho ===
data_psycho = plot_psycho(df_led_off_valid_renamed)
sim_psycho = plot_psycho(sim_results_df_valid_less_than_1)

colors = ['r', 'b', 'g']  # Define colors for each ABL
for i, ABL in enumerate(ABL_arr):
    axes[1].plot(ILD_arr, data_psycho[ABL], color=colors[i], label=f'data ABL={ABL}', marker='o', linestyle='None')
    axes[1].plot(ILD_arr, sim_psycho[ABL], color=colors[i], label=f'sim ABL={ABL}', linestyle='-')

axes[1].legend()
axes[1].set_xlabel('ILD')
axes[1].set_ylabel('P(right)')
axes[1].set_title('Grand Psychometric')

# === grand tacho ===
data_tacho_x, data_tacho_y = plot_tacho(df_led_off_valid_renamed)
sim_tacho_x, sim_tacho_y = plot_tacho(sim_results_df_valid_less_than_1)

axes[2].plot(data_tacho_x, data_tacho_y, color='b', label='data')
axes[2].plot(sim_tacho_x, sim_tacho_y, color='r', label='sim')
axes[2].legend()
axes[2].set_xlabel('rt wrt stim')
axes[2].set_ylabel('acc')
axes[2].set_title('Grand Tacho')
axes[2].set_ylim(0.5, 1);
