In [12]:
# tensor flow
import tensorflow_probability as tfp
import tensorflow as tf

# rest
from numba import njit, jit
import matplotlib.pyplot as plt
import numpy as np
import pickle
from joblib import Parallel, delayed

# parse
from ddm_utils import simulate_ddm, parse_sim_results


In [13]:
N_sim = 100; v = 0.2; a = 2
sim_results = Parallel(n_jobs=-1)(delayed(simulate_ddm)(v, a) for _ in range(N_sim))


choices, RTs = parse_sim_results(sim_results)
    
with open('sample_rt.pkl', 'wb') as f:
    pickle.dump(RTs, f)
with open('sample_choice.pkl', 'wb') as f:
    pickle.dump(choices, f)

In [14]:
def rtd_density_a(t, v, a, w, K_max=10):
    # Use tf operations instead of numpy
    if tf.greater(t, 0.25):
        non_sum_term = (np.pi/a**2)*tf.exp(-v*a*w - (v**2 * t/2))
        k_vals = tf.linspace(1., tf.cast(K_max, tf.float32), K_max)
        sum_sine_term = tf.sin(k_vals*np.pi*w)
        sum_exp_term = tf.exp(-(k_vals**2 * np.pi**2 * t)/(2*a**2))
        sum_result = tf.reduce_sum(k_vals * sum_sine_term * sum_exp_term)
    else:
        non_sum_term = (1/a**2)*(a**3/tf.sqrt(2*np.pi*t**3))*tf.exp(-v*a*w - (v**2 * t)/2)
        K_max = int(K_max/2)
        k_vals = tf.linspace(tf.cast(-K_max, tf.float32), tf.cast(K_max, tf.float32), 2*K_max + 1)
        sum_w_term = w + 2*k_vals
        sum_exp_term = tf.exp(-(a**2 * (w + 2*k_vals)**2)/(2*t))
        sum_result = tf.reduce_sum(sum_w_term*sum_exp_term)


    density =  non_sum_term * sum_result
    density = tf.where(density <= 0, 1e-10, density) # Use tf.where for conditional assignment
    return density


def loglike_fn(v,a,w):
    # No need to convert to numpy arrays, keep them as tensors

     # Load RTs and choices from the saved files
    print(f"v,a,w: {v}, {a}, {w}")
    with open('sample_rt.pkl', 'rb') as f:
        RTs = pickle.load(f)
    with open('sample_choice.pkl', 'rb') as f:
        choices = pickle.load(f)

    # Convert to tensors
    RTs = tf.constant(RTs, dtype=tf.float32)
    choices = tf.constant(choices, dtype=tf.int32)

    choices_pos = tf.where(choices == 1)[:, 0] # Use tf.where for indexing
    choices_neg = tf.where(choices == -1)[:, 0]

    RTs_pos = tf.gather(RTs, choices_pos)
    RTs_neg = tf.gather(RTs, choices_neg)

    # Use tf.map_fn for parallelization within TensorFlow graph
    prob_pos = tf.map_fn(lambda t: rtd_density_a(t, -v, a, 1-w), RTs_pos)
    prob_neg = tf.map_fn(lambda t: rtd_density_a(t, v, a, w), RTs_neg)

    # Use tf.where for conditional assignment
    prob_pos = tf.where(prob_pos <= 0, 1e-10, prob_pos)
    prob_neg = tf.where(prob_neg <= 0, 1e-10, prob_neg)

    log_pos = tf.math.log(prob_pos)
    log_neg = tf.math.log(prob_neg)

    # Calculate sum of log-likelihoods using TensorFlow operations
    sum_loglike = tf.reduce_sum(log_pos) + tf.reduce_sum(log_neg)
    print('loglike = ', sum_loglike)
    return sum_loglike



In [15]:
# NUTS init
nuts_kernel = tfp.mcmc.NoUTurnSampler(
    loglike_fn,
    step_size=0.1
)

num_burnin_steps = 100
nuts_adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    nuts_kernel,
    num_adaptation_steps=int(num_burnin_steps * 0.8),
)

In [16]:
import random
init_state = [tf.constant(random.uniform(-2,2), dtype=tf.float32),
              tf.constant(random.uniform(1,3), dtype=tf.float32),
              tf.constant(random.uniform(0.3,0.7), dtype=tf.float32)
              ]
chain_output = tfp.mcmc.sample_chain(
    num_results=int(1000),
    num_burnin_steps=num_burnin_steps,
    current_state=init_state,
    kernel=nuts_adaptive_kernel,
    trace_fn=None
)

v,a,w: -0.07520861178636551, 2.2097630500793457, 0.6328628063201904
loglike =  tf.Tensor(-161.54095, shape=(), dtype=float32)
v,a,w: -0.07520861178636551, 2.2097630500793457, 0.6328628063201904
loglike =  tf.Tensor(-161.54095, shape=(), dtype=float32)
v,a,w: 0.19510459899902344, 2.158651113510132, 0.2182157039642334
loglike =  tf.Tensor(-223.37079, shape=(), dtype=float32)
v,a,w: 0.19510459899902344, 2.158651113510132, 0.2182157039642334
loglike =  tf.Tensor(-223.37079, shape=(), dtype=float32)
v,a,w: 1.2428256273269653, 1.4181015491485596, 4.7622528076171875
loglike =  tf.Tensor(-1041.2998, shape=(), dtype=float32)
v,a,w: 1.2428256273269653, 1.4181015491485596, 4.7622528076171875
loglike =  tf.Tensor(-1041.2998, shape=(), dtype=float32)
v,a,w: 1.024801254272461, 1.803241491317749, -2.2295587062835693
loglike =  tf.Tensor(-2302.5852, shape=(), dtype=float32)
v,a,w: 1.024801254272461, 1.803241491317749, -2.2295587062835693
loglike =  tf.Tensor(-2302.5852, shape=(), dtype=float32)
v,a,w:

KeyboardInterrupt: 