In [1]:
from dynesty import NestedSampler
from dynesty import plotting as dyplot
import numpy as np
import matplotlib.pyplot as plt
import pickle
import multiprocessing
from joblib import Parallel, delayed
from psiam_utils import rho_A_t_fn, rho_E_t_fn, cum_A_t_fn, rho_E_minus_t_fn, P_large_t_btn_1_2
from scipy.integrate import quad
import os

# Read data

In [1]:
with open('psiam_data_5k_1.pkl', 'rb') as f:
    psiam_data = pickle.load(f)


choices = psiam_data['choices']
RTs = psiam_data['RTs']
is_act_resp = psiam_data['is_act_resp']
V_A = psiam_data['V_A']
theta_A = psiam_data['theta_A']
V_E = psiam_data['V_E']
theta_E = psiam_data['theta_E']
Z_E = psiam_data['Z_E']
t_stim = psiam_data['t_stim']


indices_evid = np.where(is_act_resp == 0)[0]
RTs_evid = RTs[indices_evid].flatten()

indices_act = np.where(is_act_resp == 1)[0]
RTs_act = RTs[indices_act].flatten()


RTs = RTs.flatten()

correct_idx = np.where(choices == 1)[0]
wrong_idx = np.where(choices == -1)[0]


correct_RT = RTs[correct_idx]
wrong_RT = RTs[wrong_idx]

abort_idx = np.where(RTs < t_stim)[0]
abort_RT = RTs[abort_idx]

print(f"V_A: {V_A}")
print(f"theta_A: {theta_A}")
print(f"V_E: {V_E}")
print(f"theta_E: {theta_E}")
print(f"Num of AI process: {is_act_resp.sum()}/{len(is_act_resp)}")
print(f"t start is {t_stim}")

NameError: name 'pickle' is not defined

In [3]:
# Bounds used for BADs
V_A_bounds = [0.1, 3]; V_A_plausible_bounds = [0.5, 1.5]
theta_A_bounds = [1, 3]; theta_A_plausible_bounds = [1.5, 2.5]
V_E_bounds = [-5, 5]; V_E_plausible_bounds = [-2, 2]
theta_E_bounds = [0.1, 5]; theta_E_plausible_bounds = [0.5,1.5]
Z_bounds = [-0.5, 0.5]; Z_plausible_bounds = [-0.2, 0.2]

def transform_random_number(u, a, b):
    return (b-a)*u + a 

def psiam_prior_fn(u):
    priors = np.zeros_like(u)
    
    priors[0] = transform_random_number(u[0], V_A_bounds[0], V_A_bounds[1])
    priors[1] = transform_random_number(u[1], theta_A_bounds[0], theta_A_bounds[1])
    priors[2] = transform_random_number(u[2], V_E_bounds[0], V_E_bounds[1])
    priors[3] = transform_random_number(u[3], theta_E_bounds[0], theta_E_bounds[1])
    priors[4] = transform_random_number(u[4], Z_bounds[0], Z_bounds[1])

    return priors

In [4]:
def calculate_abort_loglike(t, V_A, theta_A, t_a, V_E, theta_E, K_max, t_stim, Z, t_E, abort_norm_term):
    P_A = rho_A_t_fn(t, V_A, theta_A, t_a)
    C_E = quad(rho_E_t_fn, 0, t, args=(V_E, theta_E, K_max, t_stim))[0]
    P_E = rho_E_t_fn(t, V_E, theta_E, K_max, t_stim, Z, t_E)
    C_A = cum_A_t_fn(t, V_A, theta_A, t_a)
    p_abort = P_A * (1 - C_E) + P_E * (1 - C_A)
    if p_abort <= 0:
        p_abort = 1e-6
    return np.log(p_abort / abort_norm_term)

def calculate_correct_loglike(t, V_A, theta_A, t_a, V_E, theta_E, Z, K_max, t_stim, t_E, correct_norm_term):
    P_A = rho_A_t_fn(t, V_A, theta_A, t_a)
    P_E_btn_1_2 = P_large_t_btn_1_2(1, 2, t, V_E, theta_E, Z, K_max, t_stim)
    P_E_plus = rho_E_minus_t_fn(t, -V_E, theta_E, K_max, t_stim, -Z, t_E)
    C_A = cum_A_t_fn(t, V_A, theta_A, t_a)
    p_correct = P_A * P_E_btn_1_2 + P_E_plus * (1 - C_A)
    if p_correct <= 0:
        p_correct = 1e-6
    return np.log(p_correct / correct_norm_term)

def calculate_wrong_loglike(t, V_A, theta_A, t_a, V_E, theta_E, Z, K_max, t_stim, t_E, wrong_norm_term):
    P_A = rho_A_t_fn(t, V_A, theta_A, t_a)
    P_E_btn_0_1 = P_large_t_btn_1_2(0, 1, t, V_E, theta_E, Z, K_max, t_stim)
    P_E_minus = rho_E_minus_t_fn(t, V_E, theta_E, K_max, t_stim, Z, t_E)
    C_A = cum_A_t_fn(t, V_A, theta_A, t_a)
    p_wrong = P_A * P_E_btn_0_1 + P_E_minus * (1 - C_A)
    if p_wrong <= 0:
        p_wrong = 1e-6
    return np.log(p_wrong / wrong_norm_term)



def psiam_loglike_fn(params):
    V_A, theta_A, V_E, theta_E, Z = params
    # hyperparams
    t_a = 0; t_E = 0; 
    K_max = 10
    
    # norm terms
    N = len(RTs)
    N_abort = len(abort_RT)
    N_correct = len(correct_RT)
    N_wrong = len(wrong_RT)

    abort_norm_term = N_abort/N
    correct_norm_term = N_correct/N
    wrong_norm_term = N_wrong/N

    # abort_loglike = sum(Parallel(n_jobs=n_jobs)(
    #     delayed(calculate_abort_loglike)(t, V_A, theta_A, t_a, V_E, theta_E, K_max, t_stim, Z, t_E, abort_norm_term)
    #     for t in abort_RT
    # ))
    abort_loglike = 0
    for t in abort_RT:
        abort_loglike += calculate_abort_loglike(t, V_A, theta_A, t_a, V_E, theta_E, K_max, t_stim, Z, t_E, abort_norm_term)

    # correct_loglike = sum(Parallel(n_jobs=n_jobs)(
    #     delayed(calculate_correct_loglike)(t, V_A, theta_A, t_a, V_E, theta_E, Z, K_max, t_stim, t_E, correct_norm_term)
    #     for t in correct_RT
    # ))
    correct_loglike = 0
    for t in correct_RT:
        correct_loglike += calculate_correct_loglike(t, V_A, theta_A, t_a, V_E, theta_E, Z, K_max, t_stim, t_E, correct_norm_term)

    # wrong_loglike = sum(Parallel(n_jobs=n_jobs)(
    #     delayed(calculate_wrong_loglike)(t, V_A, theta_A, t_a, V_E, theta_E, Z, K_max, t_stim, t_E, wrong_norm_term)
    #     for t in wrong_RT
    # ))
    wrong_loglike = 0
    for t in wrong_RT:
        wrong_loglike += calculate_wrong_loglike(t, V_A, theta_A, t_a, V_E, theta_E, Z, K_max, t_stim, t_E, wrong_norm_term)

    total_loglike = abort_loglike + correct_loglike + wrong_loglike


    if np.isnan(total_loglike):
        raise ValueError("Log-likelihood is NaN or infinite.")
    if np.isinf(total_loglike):
        raise ValueError("Log-likelihood is infinite.")

    return total_loglike

    

In [5]:

import time

start_time = time.time()
result = psiam_loglike_fn([V_A, theta_A, V_E, theta_E, Z_E])
end_time = time.time()

time_taken = end_time - start_time

print(f'loglike = {result}')
print(f'Time taken: {time_taken:.6f} seconds')


loglike = -8337.297015813281
Time taken: 1.505448 seconds


In [6]:
num_process = os.cpu_count()
pool = multiprocessing.Pool(processes=num_process)
ndim = 5
sampler = NestedSampler(psiam_loglike_fn, psiam_prior_fn, ndim, pool=pool, queue_size=num_process)
sampler.run_nested()
pool.close()
pool.join()

1807it [2:05:41,  4.17s/it, bound: 0 | nc: 18 | ncall: 17281 | eff(%): 10.457 | loglstar:   -inf < 52621.451 <    inf | logz: 52610.933 +/-  0.145 | dlogz: 83733.580 >  0.509]  


KeyboardInterrupt: 