# NUTS

In [2]:
import numpy as np
import pickle
from scipy import stats
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from nuts import NUTSSampler, nuts6
import sys
from numba import jit

In [3]:
@jit(nopython=True)
def rtd_density_a(t, v, a, w, K_max=10):
    if t > 0.25:
        non_sum_term = (np.pi/a**2)*np.exp(-v*a*w - (v**2 * t/2))
        k_vals = np.linspace(1, K_max, K_max)
        sum_sine_term = np.sin(k_vals*np.pi*w)
        sum_exp_term = np.exp(-(k_vals**2 * np.pi**2 * t)/(2*a**2))
        sum_result = np.sum(k_vals * sum_sine_term * sum_exp_term)
    else:
        non_sum_term = (1/a**2)*(a**3/np.sqrt(2*np.pi*t**3))*np.exp(-v*a*w - (v**2 * t)/2)
        K_max = int(K_max/2)
        k_vals = np.linspace(-K_max, K_max, 2*K_max + 1)
        sum_w_term = w + 2*k_vals
        sum_exp_term = np.exp(-(a**2 * (w + 2*k_vals)**2)/(2*t))
        sum_result = np.sum(sum_w_term*sum_exp_term)

    
    density =  non_sum_term * sum_result
    if density <= 0:
        density += 1e-6
    return density

In [4]:
def loglike_fn(params):
    v,a,w = params
    with open('sample_rt_nuts.pkl', 'rb') as f:
        RTs = np.array(pickle.load(f))
    with open('sample_choice_nuts.pkl', 'rb') as f:
        choices = np.array(pickle.load(f))

    choices_pos = np.where(choices == 1)[0]
    choices_neg = np.where(choices == -1)[0]

    RTs_pos = RTs[choices_pos]
    RTs_neg = RTs[choices_neg]

    prob_pos = Parallel(n_jobs=-1)(delayed(rtd_density_a)(t, -v, a, 1-w) for t in RTs_pos)
    prob_neg = Parallel(n_jobs=-1)(delayed(rtd_density_a)(t, v, a, w) for t in RTs_neg)

    prob_pos = np.array(prob_pos)
    prob_neg = np.array(prob_neg)

    prob_pos[prob_pos <= 0] = 1e-10
    prob_neg[prob_neg <= 0] = 1e-10

    log_pos = np.log(prob_pos)
    log_neg = np.log(prob_neg)
    
    if np.isnan(log_pos).any() or np.isnan(log_neg).any():
        print('log_neg',log_neg)
        print('prob_neg = ', prob_neg)
        raise ValueError("NaN values found in log_pos or log_neg")

    loglike = (np.sum(log_pos) + np.sum(log_neg))
    return loglike

In [None]:
def grad_fn(params):
    v,a,w = params
    delta = 1e-6
    grads = np.zeros_like(params)
    with open('sample_rt_nuts.pkl', 'rb') as f:
        RTs = np.array(pickle.load(f))
    with open('sample_choice_nuts.pkl', 'rb') as f:
        choices = np.array(pickle.load(f))

    choices_pos = np.where(choices == 1)[0]
    choices_neg = np.where(choices == -1)[0]

    RTs_pos = RTs[choices_pos]
    RTs_neg = RTs[choices_neg]

    grad_pos_v = Parallel(n_jobs=-1)(delayed(delta_rule)(t, [-v, a, 1-w], [delta,0,0], delta) for t in RTs_pos)
    grad_neg_v = Parallel(n_jobs=-1)(delayed(delta_rule)(t, [v, a, w], [delta,0,0], delta) for t in RTs_neg)
    grads[0] = np.sum(grad_pos_v) + np.sum(grad_neg_v)


    grad_pos_a = Parallel(n_jobs=-1)(delayed(delta_rule)(t, [-v, a, 1-w], [0,delta,0], delta) for t in RTs_pos)
    grad_neg_a = Parallel(n_jobs=-1)(delayed(delta_rule)(t, [v, a, w], [0,delta,0], delta) for t in RTs_neg)
    grads[1] = np.sum(grad_pos_a) + np.sum(grad_neg_a)


    grad_pos_w = Parallel(n_jobs=-1)(delayed(delta_rule)(t, [-v, a, 1-w], [0,0,delta], delta) for t in RTs_pos)
    grad_neg_w = Parallel(n_jobs=-1)(delayed(delta_rule)(t, [v, a, w], [0,0,delta], delta) for t in RTs_neg)
    grads[2] = np.sum(grad_pos_w) + np.sum(grad_neg_w)

    return grads

def delta_rule(t,params,delta_arr, delta):
    v,a,w = params
    rtd = rtd_density_a(t,v,a,w)

    params = np.array(params); delta_arr = np.array(delta_arr)
    params = params + delta_arr
    v,a,w = params
    rtd_delta = rtd_density_a(t, v, a, w)


    diff_val = (1.0/rtd)*(rtd_delta - rtd)*(1.0/delta)
    return diff_val


In [None]:
def log_like_and_grad(params):
    loglike = loglike_fn(params)
    grad = grad_fn(params)

    return loglike, grad