In [1]:
import numpy as np
import pickle
from psiam_utils import rho_A_t_fn, rho_E_t_fn, cum_A_t_fn, rho_E_minus_t_fn, P_large_t_btn_1_2
from scipy.integrate import quad
from joblib import Parallel, delayed
from pybads import BADS
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from scipy.stats import median_abs_deviation
import matplotlib.pyplot as plt

# Read data

In [2]:
with open('psiam_data_5k_1.pkl', 'rb') as f:
    psiam_data = pickle.load(f)


choices = psiam_data['choices']
RTs = psiam_data['RTs']
is_act_resp = psiam_data['is_act_resp']
V_A = psiam_data['V_A']
theta_A = psiam_data['theta_A']
V_E = psiam_data['V_E']
theta_E = psiam_data['theta_E']
Z_E = psiam_data['Z_E']
t_stim = psiam_data['t_stim']


indices_evid = np.where(is_act_resp == 0)[0]
RTs_evid = RTs[indices_evid].flatten()

indices_act = np.where(is_act_resp == 1)[0]
RTs_act = RTs[indices_act].flatten()


RTs = RTs.flatten()

correct_idx = np.where(choices == 1)[0]
wrong_idx = np.where(choices == -1)[0]


correct_RT = RTs[correct_idx]
wrong_RT = RTs[wrong_idx]

abort_idx = np.where(RTs < t_stim)[0]
abort_RT = RTs[abort_idx]

print(f"V_A: {V_A}")
print(f"theta_A: {theta_A}")
print(f"V_E: {V_E}")
print(f"theta_E: {theta_E}")
print(f"Num of AI process: {is_act_resp.sum()}/{len(is_act_resp)}")
print(f"t start is {t_stim}")

V_A: 1
theta_A: 2
V_E: 0.5
theta_E: 1
Num of AI process: 1975/10000
t start is 0


In [2]:
def simulate_AI_for_T(v, a, dt):
    dB = dt**0.5 

    DV = 0
    t = 0
    for t in range(1,t_max):
        DV += v*dt + np.random.normal(0, dB)
        t += 1
        if DV >= a/2:
            break
        
    
    if DV < a/2:
        print(f'Re-running')
        simulate_AI_for_T(v, a, dt)
    
    return  t*dt

v = 1.2
a = 2
dt = 1e-4
t_max = int(10/dt)
N_sim = 10000
sim_results = np.zeros((N_sim))

from tqdm import tqdm

for i in tqdm(range(N_sim)):
    result = simulate_AI_for_T(v, a, dt)
    sim_results[i] = result

  0%|          | 0/10000 [00:00<?, ?it/s]

 32%|███▏      | 3157/10000 [00:23<00:52, 129.77it/s]

Re-running


100%|██████████| 10000/10000 [01:14<00:00, 134.48it/s]


In [3]:
def bads_negloglike_fn(params):
    v,a = params
    p_a_test = np.array([rho_A_t_fn(t, v, a/2, 0) for t in sim_results])
    return -np.sum(np.log(p_a_test))

In [6]:
v_bounds = [0.1, 1.2]; v_plausible_bounds = [0.5, 1.1]
a_bounds = [1,3]; a_plausible_bounds = [1.5, 2.5]

lb = np.array([v_bounds[0], a_bounds[0]])
ub = np.array([v_bounds[1], a_bounds[1]])
plb = np.array([v_plausible_bounds[0], a_plausible_bounds[0]])
pub = np.array([v_plausible_bounds[1], a_plausible_bounds[1]])

v_0 = np.random.uniform(plb[0], pub[0])
a_0 = np.random.uniform(plb[1], pub[1])

x0 = np.array([v_0, a_0]);


options = {'display': 'off'}
bads = BADS(bads_negloglike_fn, x0, lb, ub, plb, pub, options=options)
optimize_result = bads.optimize()
x_min = optimize_result['x']
print(f'v = {x_min[0]:.3f}, a = {x_min[1]:.3f}')

v = 1.200, a = 2.001


In [8]:
def simulate_ddm(v, a, dt):
    t_max = int(7/dt)
    
    dB = dt**0.5 
    DV = 0
    c = 0
    for t in range(1,t_max):
        DV += v*dt + np.random.normal(0, dB)
        if DV >= a/2:
            c = 1
            break
        elif DV <= -a/2:
            c = -1
            break
    
    if DV < a/2 and DV > -a/2:
        print(f'DV value is {DV}')
        simulate_ddm(v, a, dt)
    
    return  t*dt, c


v = 0.7
a = 2
dt = 1e-4
t_max = int(7/dt)
N_sim = 25000
choices = np.zeros((N_sim))
rts = np.zeros((N_sim))

from tqdm import tqdm

for i in tqdm(range(N_sim)):
    result = simulate_ddm(v, a, dt)
    rts[i] = result[0]
    choices[i] = result[1]


 42%|████▏     | 10533/25000 [01:22<01:51, 130.08it/s]

DV value is 0.5275891614928279


100%|██████████| 25000/25000 [03:17<00:00, 126.68it/s]


In [9]:
choices_np = np.array(choices); rts_np = np.array(rts);
upp_bound_idx = np.where(choices_np == 1)[0]
upp_bound_rt = rts_np[upp_bound_idx]

low_bound_idx = np.where(choices_np == -1)[0]
low_bound_rt = rts_np[low_bound_idx]

In [14]:
def bads_ea_neg_loglike_fn(params):
    v,a,z = params
    

    t_pts = upp_bound_rt
    theory_up = np.array([rho_E_minus_t_fn(t, -v, a/2, 50, 0, -z, 0) for t in t_pts])/(len(upp_bound_idx)/N_sim)

    t_pts = low_bound_rt
    theory_down = np.array([rho_E_minus_t_fn(t, v, a/2, 50, 0, z, 0) for t in t_pts])/(len(low_bound_idx)/N_sim)

    return -(np.sum(np.log(theory_up)) + np.sum(np.log(theory_down)))

In [17]:
v_bounds = [0.1, 3]; v_plausible_bounds = [0.2, 1]
a_bounds = [1,3]; a_plausible_bounds = [1.5, 2.5]
z_bounds = [-0.9, 0.9]; z_plausible_bounds = [-0.5, 0.5]

lb = np.array([v_bounds[0], a_bounds[0], z_bounds[0]])
ub = np.array([v_bounds[1], a_bounds[1], z_bounds[1]])
plb = np.array([v_plausible_bounds[0], a_plausible_bounds[0], z_plausible_bounds[0]])
pub = np.array([v_plausible_bounds[1], a_plausible_bounds[1], z_plausible_bounds[1]])

v_0 = np.random.uniform(plb[0], pub[0])
a_0 = np.random.uniform(plb[1], pub[1])
z_0 = np.random.uniform(plb[2], pub[2])

x0 = np.array([v_0, a_0, z_0]);


options = {'display': 'off'}
bads = BADS(bads_ea_neg_loglike_fn, x0, lb, ub, plb, pub, options=options)
optimize_result = bads.optimize()
x_min = optimize_result['x']
print(f'v = {x_min[0]:.3f}, a = {x_min[1]:.3f}, z = {x_min[2]:.3f}')

v = 0.692, a = 2.009, z = 0.001


In [19]:
print(f'True params likelihood = {bads_ea_neg_loglike_fn([v,a,0])}')
print(f'BADS est likelihood = {bads_ea_neg_loglike_fn(x_min)}')

True params likelihood = 17857.428586494272
BADS est likelihood = 17854.96368788252


# DV Btn x1 and x2

In [2]:
def simulate_ddm(v, a, dt):
    t_max = int(7/dt)
    
    dB = dt**0.5 
    DV_arr = np.full((t_max), np.nan)

    DV = 0
    DV_arr[0] = DV
    # DV_arr = []
    # DV_arr.append(DV)
    for t in range(1,t_max):
        DV += v*dt + np.random.normal(0, dB)
        DV_arr[t] = DV
        # DV_arr.append(DV)
        if DV >= a/2:
            break
        elif DV <= -a/2:
            break
    
    if DV < a/2 and DV > -a/2:
        print(f'DV value is {DV}')
        raise ValueError('Simulation failed')
    
    return  DV_arr

v = 1.2
a = 2
dt = 1e-4
t_max = int(7/dt)
N_sim = 5000
sim_results = np.zeros((N_sim, t_max))

from tqdm import tqdm

for i in tqdm(range(N_sim)):
    result = simulate_ddm(v, a, dt)
    sim_results[i,:] = result


100%|██████████| 5000/5000 [00:43<00:00, 114.45it/s]


In [3]:
t_pts = np.linspace(0.0001, 7, t_max)

prob_data = np.zeros((t_max))
for t in range(1, t_max):
    # remove the nan. if at "t" upper bound is reached, then from t+1 to end of array "nans" are filled
    cleaned_data = sim_results[:,t][~np.isnan(sim_results[:,t])]
    # In how many simulations, is the DV between starting point and upper threshold
    count_between_0_and_1 = np.sum((cleaned_data >= 0) & (cleaned_data <= 1))
    # count_between_0_and_1 = np.sum((cleaned_data >= -1) & (cleaned_data <= 0))

    # Prob that at time 't' DV = Num of simulations in which DV is between starting pt and upper threshold / Total num of simulations 
    prob_data[t] = count_between_0_and_1 / N_sim


In [11]:
from psiam_utils import P_small_t_btn_1_2

def bads_btn_loglike_fn(params):
    V_E,Z = params
    theta_E = 1
    K_max = 20
    P_EA_btn = np.zeros((len(t_pts)))
    for i in range(len(t_pts)):
        t = t_pts[i]
        p = P_large_t_btn_1_2(1, 2, t, V_E, theta_E, Z, K_max, 0)
        # p = P_small_t_btn_1_2(t, V_E, theta_E, Z, K_max, 0)
        if p <= 0:
            p = 1e-9
        if np.isnan(p):
            raise ValueError(f'p is nan. Params={[V_E,theta_E,Z,t]}')
        
            
        P_EA_btn[i] = p

    P_EA_btn = np.array(P_EA_btn)
    

    return np.sum((prob_data - P_EA_btn)**2)*(1/len(P_EA_btn))

In [14]:
v_bounds = [0.1, 2]; v_plausible_bounds = [0.2, 1.6]
a_bounds = [0.1, 2]; a_plausible_bounds = [0.5, 1.5] # theta here
z_bounds = [-0.5, 0.5]; z_plausible_bounds = [-0.3, 0.3]

# lb = np.array([v_bounds[0], a_bounds[0], z_bounds[0]])
# ub = np.array([v_bounds[1], a_bounds[1], z_bounds[1]])
# plb = np.array([v_plausible_bounds[0], a_plausible_bounds[0], z_plausible_bounds[0]])
# pub = np.array([v_plausible_bounds[1], a_plausible_bounds[1], z_plausible_bounds[1]])

lb = np.array([v_bounds[0], z_bounds[0]])
ub = np.array([v_bounds[1], z_bounds[1]])
plb = np.array([v_plausible_bounds[0], z_plausible_bounds[0]])
pub = np.array([v_plausible_bounds[1], z_plausible_bounds[1]])

v_0 = np.random.uniform(plb[0], pub[0])
# a_0 = np.random.uniform(plb[1], pub[1])
z_0 = np.random.uniform(plb[1], pub[1])

x0 = np.array([v_0,  z_0]);


options = {'display': 'off'}
bads = BADS(bads_btn_loglike_fn, x0, lb, ub, plb, pub, options=options)
optimize_result = bads.optimize()
x_min = optimize_result['x']
# print(f'BADS: v = {x_min[0]:.3f}, theta = {x_min[1]:.3f}, z = {x_min[2]:.3f}')
print(f'True: v={v}, theta = {a/2} z = {0}')

True: v=1.2, theta = 1.0 z = 0


In [15]:
x_min

array([1.12526567, 0.00139539])

In [9]:
print(f'Err with true params = {bads_btn_loglike_fn([v,a/2,0])}')
print(f'Err with BADS params = {bads_btn_loglike_fn(x_min)}')

Err with true params = 1.2490716207029107e-05
Err with BADS params = 1.0499221336388876e-05


In [None]:
# large, K_max = 100

In [25]:
def P_small_t_btn_1_2_CUSTOM(t, V_E, theta_E, Z, n_max, t_stim):
    """
    Integration of P_small(x,t) with x from 1,2
    """
    v = V_E
    a = 2*theta_E
    mu = v*theta_E
    z = a * (Z + theta_E)/(2*theta_E)
    
    if t <= t_stim:
        return 0
    else:
        t = t - t_stim

    result = 0
    
    sqrt_t = np.sqrt(t)
    
    for n in range(-n_max, n_max + 1):
        term1 = np.exp(4 * mu * n) * (
            Phi((2 - (z + 4 * n + mu * t)) / sqrt_t) -
            Phi((1 - (z + 4 * n + mu * t)) / sqrt_t)
        )
        
        term2 = np.exp(2 * mu * (2 * (1 - n) - z)) * (
            Phi((2 - (-z + 4 * (1 - n) + mu * t)) / sqrt_t) -
            Phi((1 - (-z + 4 * (1 - n) + mu * t)) / sqrt_t)
        )

        if np.isnan(term1):
            print("term1 is NaN")
        
        if np.isnan(term2):
            print("term2 is NaN")
        
        result += term1 - term2
    
    return result


In [39]:
# P_small_t_btn_1_2(0.0001, 1.81, 2.16, 0.38, 50, 0)
P_small_t_btn_1_2_CUSTOM(0.001, 1.81, 2.16, 0.38, 30, 0)
# P_small_t_btn_1_2(0.0001, 0.5, 1, 0, 50, 0)


-0.014663879292970366

In [4]:
# V_A: 1 # theta_A: 2 # V_E: 0.5 # theta_E: 1 # Z = 0
# V_A_bounds = [0.1, 3]; V_A_plausible_bounds = [0.5, 1.5]
# theta_A_bounds = [1, 3]; theta_A_plausible_bounds = [1.5, 2.5]
# V_E_bounds = [-5, 5]; V_E_plausible_bounds = [-2, 2]
# theta_E_bounds = [0.1, 5]; theta_E_plausible_bounds = [0.5,1.5]
# Z_bounds = [-0.5, 0.5]; Z_plausible_bounds = [-0.2, 0.2]

V_A_bounds = [0.5, 1.5]; V_A_plausible_bounds = [0.7, 1.3]
theta_A_bounds = [1.5, 2.5]; theta_A_plausible_bounds = [1.7, 2.2]
V_E_bounds = [0, 1]; V_E_plausible_bounds = [0.2, 0.7]
theta_E_bounds = [0.1, 2]; theta_E_plausible_bounds = [0.5,1.5]
Z_bounds = [-0.2, 0.2]; Z_plausible_bounds = [-0.1, 0.1]

lb = np.array([V_A_bounds[0], theta_A_bounds[0], V_E_bounds[0], theta_E_bounds[0], Z_bounds[0]])
ub = np.array([V_A_bounds[1], theta_A_bounds[1], V_E_bounds[1], theta_E_bounds[1], Z_bounds[1]])
plb = np.array([V_A_plausible_bounds[0], theta_A_plausible_bounds[0], V_E_plausible_bounds[0], theta_E_plausible_bounds[0], Z_plausible_bounds[0]])
pub = np.array([V_A_plausible_bounds[1], theta_A_plausible_bounds[1], V_E_plausible_bounds[1], theta_E_plausible_bounds[1], Z_plausible_bounds[1]])


V_A_0 = np.random.uniform(plb[0], pub[0])
theta_A_0 =  np.random.uniform(plb[1], pub[1])
V_E_0 = np.random.uniform(plb[2], pub[2])
theta_E_0 = np.random.uniform(plb[3], pub[3])
Z_0 = np.random.uniform(plb[4], pub[4])

x0 = np.array([V_A_0, theta_A_0, V_E_0, theta_E_0, Z_0]);

options = {'display': 'off'}
bads = BADS(psiam_ai_loglike, x0, lb, ub, plb, pub, options=options)
optimize_result = bads.optimize()
x_min = optimize_result['x']


# V_A: 1 # theta_A: 2 # V_E: 0.5 # theta_E: 1 # Z = 0
print(f'Est. V_A = {x_min[0]}, True V_A = {V_A}')
print(f'Est. theta_A = {x_min[1]}, True theta_A = {theta_A}')
print(f'Est. V_E = {x_min[2]}, True V_E = {V_E}')
print(f'Est. theta_E = {x_min[3]}, True theta_E = {theta_E}')
print(f'Est. Z = {x_min[4]}, True Z = {Z_E}')

Est. V_A = 1.4999999999883586, True V_A = 1
Est. theta_A = 1.710153133160202, True theta_A = 2
Est. V_E = 0.00017448158469052233, True V_E = 0.5
Est. theta_E = 2.0, True theta_E = 1
Est. Z = -0.00027932946104556325, True Z = 0
