In [2]:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS


In [3]:
from ddm_utils import simulate_ddm, parse_sim_results
from joblib import Parallel, delayed
import pickle
import numpy as np

# simuldate data and make it torch tensor

In [4]:
N_sim = 100; true_v = 0.2; true_a = 2; true_w = 0.5; true_params = [true_v, true_a, true_w]

sim_results = Parallel(n_jobs=-1)(delayed(simulate_ddm)(true_v, true_a) for _ in range(N_sim))
choices, RTs = parse_sim_results(sim_results)
with open('sample_rt.pkl', 'wb') as f:
    pickle.dump(RTs, f)
with open('sample_choice.pkl', 'wb') as f:
    pickle.dump(choices, f)

In [13]:
def rtd_density_a_NUTS_pyro(t, v, a, w, K_max=10):
    if t > 0.25:
        non_sum_term = (torch.pi / a**2) * torch.exp(-v * a * w - (v**2 * t / 2))
        k_vals = torch.linspace(1, K_max, K_max, device=t.device)
        sum_sine_term = torch.sin(k_vals * torch.pi * w)
        sum_exp_term = torch.exp(-(k_vals**2 * torch.pi**2 * t) / (2 * a**2))
        sum_result = torch.sum(k_vals * sum_sine_term * sum_exp_term)
    else:
        non_sum_term = (1 / a**2) * (a**3 / torch.sqrt(2 * torch.pi * t**3)) * torch.exp(-v * a * w - (v**2 * t) / 2)
        K_max = int(K_max / 2)
        k_vals = torch.linspace(-K_max, K_max, 2 * K_max + 1, device=t.device)
        sum_w_term = w + 2 * k_vals
        sum_exp_term = torch.exp(-(a**2 * (w + 2 * k_vals)**2) / (2 * t))
        sum_result = torch.sum(sum_w_term * sum_exp_term)

    density = non_sum_term * sum_result
    return density

In [16]:


def potential_fn(params):
    v = params['v']
    a = params['a']
    w = params['w']


    with open('sample_choice.pkl', 'rb') as f:
        choices = pickle.load(f)
    with open('sample_rt.pkl', 'rb') as f:
        RTs = pickle.load(f)
    
    choices = torch.tensor(choices, dtype=torch.float32)
    RTs = torch.tensor(RTs, dtype=torch.float32)

    choices_pos = torch.where(choices == 1)[0]
    choices_neg = torch.where(choices == -1)[0]

    RTs_pos = RTs[choices_pos]
    RTs_neg = RTs[choices_neg]

    prob_pos = torch.zeros_like(RTs_pos)
    prob_neg = torch.zeros_like(RTs_neg)

    for idx,t in enumerate(RTs_pos):
        prob_pos[idx] = rtd_density_a_NUTS_pyro(t, -v, a, 1 - w)

    for idx,t in enumerate(RTs_neg):
        prob_neg[idx] = rtd_density_a_NUTS_pyro(t, v, a, w)

    # prob_pos = torch.stack([rtd_density_a_NUTS_pyro(t, -v, a, 1 - w) for t in RTs_pos])
    # prob_neg = torch.stack([rtd_density_a_NUTS_pyro(t, v, a, w) for t in RTs_neg])

    log_pos = torch.log(prob_pos)
    log_neg = torch.log(prob_neg)

    sum_loglike = torch.sum(log_pos) + torch.sum(log_neg)

    # priors
    v_prior = dist.Uniform(-5., 5.)
    a_prior = dist.Uniform(1., 3.)
    w_prior = dist.Uniform(0.3, 0.7)

    log_prior = v_prior.log_prob(v) + a_prior.log_prob(a) + w_prior.log_prob(w)

    total_log_prob = -(sum_loglike + log_prior)

    return total_log_prob

init_v = torch.tensor(np.random.uniform(-5, 5))
init_a = torch.tensor(np.random.uniform(1, 3))
init_w = torch.tensor(np.random.uniform(0.3, 0.7))

initial_params = {
    "v": init_v,
    "a": init_a,
    "w": init_w,
}


# Create NUTS kernel
nuts_kernel = NUTS(potential_fn=potential_fn)

# Create MCMC object
mcmc = MCMC(
    nuts_kernel,
    num_samples=10000,
    warmup_steps=500,
    num_chains=1,
    initial_params=initial_params
)

# Run MCMC
mcmc.run()

# Get samples
samples = mcmc.get_samples()


Sample: 100%|██████████| 10500/10500 [1:09:42,  2.51it/s, step size=6.82e-01, acc. prob=0.914]
