In [1]:
import numpy as np
import pickle
from scipy import stats
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from nuts import NUTSSampler, nuts6
import sys
from numba import jit

# NUTS on some random normal data

In [2]:
data = np.random.normal(0, 1, 1000)

In [29]:

def loglike_fn(params):
    mu, sigma = params
    n = len(data)
    print('mu, sigma = ', mu,sigma)
    # Calculate log-likelihood
    log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
    return log_likelihood


In [30]:
def numerical_grad(params):
    epsilon=1e-6
    mu, sigma = params
    dL_mu = (loglike_fn([mu+epsilon, sigma]) - loglike_fn([mu, sigma]))/(epsilon)
    dL_sigma = (loglike_fn([mu, sigma+epsilon]) - loglike_fn([mu, sigma]))/(epsilon)

    return dL_mu, dL_sigma

In [32]:
def grad_fn(params):
    mu, sigma = params
    n = len(data)
    
    # Calculate the gradient
    dL_dmu = np.sum(data - mu) / (sigma**2)
    dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
    
    return np.array([dL_dmu, dL_dsigma])


In [33]:
def loglike_and_grad_fns(params):
    return loglike_fn(params), grad_fn(params)

In [9]:
# both agree
grad_fn([0,1]), numerical_grad([0,1])

(array([-18.22231857,   7.41381434]), (-18.222818653157447, 7.412803370243637))

In [34]:
sampler = NUTSSampler(2, loglike_fn, grad_fn)
x0 = np.array([0.5, 2])  # Initial guess
M, Madapt = 5000, 5000  # Number of samples and adaptation steps
delta = 1e-9
samples = sampler.run_mcmc(x0, M, Madapt, delta )

Running HMC with dual averaging and trajectory length 0.00...
mu, sigma =  0.5 2.0
mu, sigma =  1.0458138059777236 1.6241929288541423
mu, sigma =  1.0458138059777236 1.6241929288541423
mu, sigma =  1.4620720323134992 0.9078403743218819
find_reasonable_epsilon= 2.0
mu, sigma =  -1.2501941232318623 2.9397183018226647
mu, sigma =  1.7319718046640717 -0.3019002353682749
mu, sigma =  1359.822332621312 -1812.3194689857132
mu, sigma =  356110.00684015633 -599303.1462749633
mu, sigma =  250352495.70650688 -422962104.4704965
mu, sigma =  240849644959.8068 -406807941427.29987
mu, sigma =  253034056167273.4 -427384458458982.44
mu, sigma =  2.58236423789849e+17 -4.361713316538111e+17
mu, sigma =  2.4005174959057348e+20 -4.054567145658757e+20
mu, sigma =  1.9620650561919687e+23 -3.3140039707843843e+23
mu, sigma =  1.3843055582872454e+26 -2.338145772721079e+26
mu, sigma =  8.360149918366833e+28 -1.4120617427214582e+29
mu, sigma =  4.311793551512329e+31 -7.282786524231427e+31
mu, sigma =  1.901820265

  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
  dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
  dL_dmu = np.sum(data - mu) / (sigma**2)


mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma =  inf -inf
mu, sigma 

In [28]:
samples, lnprob, epsilon = nuts6(loglike_and_grad_fns, M, Madapt, x0, delta, progress=True)

  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)


mu, sigma =  0.5 2.0
mu, sigma =  -65.4400859525262 -167.64472993690956
find_reasonable_epsilon= 0.5


  dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
  dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  log_likelihood = -n/2 * np.log(2 * np.pi) - n * np.log(sigma) - (1/(2 * sigma**2)) * np.sum((data - mu)**2)
  dL_dsigma = -n/sigma + (1/(sigma**3)) * np.sum((data - mu)**2)
  dL_dmu = np.sum(data - mu) / (sigma**2)
  8%|▊         | 829/9999 [00:00<00:01, 8286.35it/s]

mu, sigma =  -15.151588798699368 -41.1191185800115
mu, sigma =  -61477.51533154101 -161585.99264841937
mu, sigma =  -20133240.547809515 -52922722.87829188
mu, sigma =  -14221223049.21792 -37381405424.21077
mu, sigma =  -13680348395071.162 -35959706827465.05
mu, sigma =  -1.4372287815217544e+16 -3.777851728138636e+16
mu, sigma =  -1.4667777057585304e+19 -3.855523044485568e+19
mu, sigma =  -1.3634891255617793e+22 -3.584022121163022e+22
mu, sigma =  -1.1144490185827083e+25 -2.9294035871834626e+25
mu, sigma =  -7.862827819996304e+27 -2.0667967432548378e+28
mu, sigma =  -4.7485483940990374e+30 -1.248187621653462e+31
mu, sigma =  -2.4490900934370136e+33 -6.43759668268518e+33
mu, sigma =  -1.0802300980190917e+36 -2.839456871831587e+36
mu, sigma =  -4.08947438585428e+38 -1.0749456220843024e+39
mu, sigma =  -1.3353437434811073e+41 -3.510040106860394e+41
mu, sigma =  -3.782101611769701e+43 -9.941506380166607e+43
mu, sigma =  -9.346928607348577e+45 -2.4569025352398676e+46
mu, sigma =  -2.02776210

 17%|█▋        | 1734/9999 [00:00<00:00, 8718.96it/s]

mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf


 26%|██▌       | 2606/9999 [00:00<00:00, 8188.89it/s]

mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =

 43%|████▎     | 4344/9999 [00:00<00:00, 7954.13it/s]

mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =

 61%|██████▏   | 6143/9999 [00:00<00:00, 8524.84it/s]

mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =

 80%|███████▉  | 7994/9999 [00:00<00:00, 8777.92it/s]

mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =

100%|██████████| 9999/9999 [00:01<00:00, 8665.77it/s]

mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =  -inf -inf
mu, sigma =




In [2]:
@jit(nopython=True)
def rtd_density_a(t, v, a, w, K_max=10):
    if t > 0.25:
        non_sum_term = (np.pi/a**2)*np.exp(-v*a*w - (v**2 * t/2))
        k_vals = np.linspace(1, K_max, K_max)
        sum_sine_term = np.sin(k_vals*np.pi*w)
        sum_exp_term = np.exp(-(k_vals**2 * np.pi**2 * t)/(2*a**2))
        sum_result = np.sum(k_vals * sum_sine_term * sum_exp_term)
    else:
        non_sum_term = (1/a**2)*(a**3/np.sqrt(2*np.pi*t**3))*np.exp(-v*a*w - (v**2 * t)/2)
        K_max = int(K_max/2)
        k_vals = np.linspace(-K_max, K_max, 2*K_max + 1)
        sum_w_term = w + 2*k_vals
        sum_exp_term = np.exp(-(a**2 * (w + 2*k_vals)**2)/(2*t))
        sum_result = np.sum(sum_w_term*sum_exp_term)

    
    density =  non_sum_term * sum_result
    if density <= 0:
        density = 1e-10
    return density

# Gradient checks

In [3]:
def grad_analytic(t,mu):
    K_max = 4
    k_vals = np.linspace(0, K_max, K_max + 1)
    sum_neg_1_term = (-1)**k_vals
    sum_two_k_term = (2 * k_vals) + 1
    sum_exp_term = np.exp(-((2 * k_vals + 1)**2) / (2 * t))
    sum_term = np.sum(sum_neg_1_term * sum_two_k_term * sum_exp_term)
    d_rtd_d_mu = (1 / np.sqrt(2 * np.pi * (t**3)))*(2*np.exp(-(mu**2)*t/2)*(np.sinh(mu) - mu*t*np.cosh(mu)))*sum_term
    return d_rtd_d_mu


def grad_comp(t,mu):
    dmu = 1e-6
    d_rtd_d_mu = (rtd_mu_small_t_parallel(t, mu+dmu) - rtd_mu_small_t_parallel(t,mu))/dmu
    return d_rtd_d_mu

In [4]:
t = 1; a = 2; v = 2; mu = v*(a/2)
print(f"comp={grad_comp(t, mu)}, analytic={grad_analytic(t, mu)}")

comp=-0.24124800054869944, analytic=-0.2412480172864229


# TODO: change gradient calculation
1. directly putting v,a instead of mu
2. do numerically and see

In [4]:


def log_like_and_grad(params):
    v,a = params
    print(f"v={v}, a={a}")
    mu = v * (a / 2)
    with open('sample_rt.pkl', 'rb') as f:
        t_arr = pickle.load(f)
    probs = Parallel(n_jobs=-1)(delayed(rtd_mu_small_t_parallel)(t, mu) for t in t_arr)
    probs = [p + 1e-10 for p in probs]
    # loglike = np.sum(np.log(probs))*(1/len(probs))
    loglike = np.sum(np.log(probs))
    

    
    grads = np.zeros_like(params)
    for t in t_arr:
        rtd = rtd_mu_small_t_parallel(t, mu) + 1e-10
        
        # Partial derivative of rtd with respect to mu
        try:
            K_max = 4
            k_vals = np.linspace(0, K_max, K_max + 1)
            sum_neg_1_term = (-1)**k_vals
            sum_two_k_term = (2 * k_vals) + 1
            sum_exp_term = np.exp(-((2 * k_vals + 1)**2) / (2 * t))
            sum_term = np.sum(sum_neg_1_term * sum_two_k_term * sum_exp_term)
            d_rtd_d_mu = (1 / np.sqrt(2 * np.pi * (t**3)))*(2*np.exp(-(mu**2)*t/2)*(np.sinh(mu) - mu*t*np.cosh(mu)))*sum_term
        except Warning as w:
            print(f"Warning occurred: {w}")
            print(f"Value of mu: {mu}")
            sys.exit()

        # Partial derivatives of mu with respect to v and a
        d_mu_d_v = a / 2
        d_mu_d_a = v / 2
        
        # Apply chain rule to get the partial derivatives of log-likelihood
        grads[0] += (1/rtd)*(d_rtd_d_mu) * d_mu_d_v  # Gradient with respect to v
        grads[1] += (1/rtd)*(d_rtd_d_mu) * d_mu_d_a  # Gradient with respect to a

    # if condition to check grads are nan
    if np.isnan(grads).any():
        print(f"Warning: grads is nan")
        grads[np.isnan(grads)] = 0.01

    # grads /= len(t_arr)
    print('loglike=',loglike,' grads=' ,grads)
    return loglike, grads




In [8]:


def loglike_fn(params):
    v,a = params
    print(f"v={v}, a={a}")
    mu = v * (a / 2)
    with open('sample_rt.pkl', 'rb') as f:
        t_arr = pickle.load(f)
    probs = Parallel(n_jobs=-1)(delayed(rtd_mu_small_t_parallel)(t, mu) for t in t_arr)
    probs = [p + 1e-10 for p in probs]
    # loglike = np.sum(np.log(probs))*(1/len(probs))
    loglike = np.sum(np.log(probs))
    print(f"loglike={loglike}")
    return loglike

def grad_fn(params):
    v,a = params
    mu = v * (a / 2)
    grads = np.zeros_like(params)
    with open('sample_rt.pkl', 'rb') as f:
        t_arr = pickle.load(f)
    for t in t_arr:
        rtd = rtd_mu_small_t_parallel(t, mu) + 1e-10
        try:
            K_max = 4
            k_vals = np.linspace(0, K_max, K_max + 1)
            sum_neg_1_term = (-1)**k_vals
            sum_two_k_term = (2 * k_vals) + 1
            sum_exp_term = np.exp(-((2 * k_vals + 1)**2) / (2 * t))
            sum_term = np.sum(sum_neg_1_term * sum_two_k_term * sum_exp_term)
            d_rtd_d_mu = (1 / np.sqrt(2 * np.pi * (t**3)))*(2*np.exp(-(mu**2)*t/2)*(np.sinh(mu) - mu*t*np.cosh(mu)))*sum_term
        except Warning as w:
            print(f"Warning occurred: {w}")
            print(f"Value of mu: {mu}")
            sys.exit()

        # Partial derivatives of mu with respect to v and a
        d_mu_d_v = a / 2
        d_mu_d_a = v / 2
        
        # Apply chain rule to get the partial derivatives of log-likelihood
        grads[0] += (1/rtd)*(d_rtd_d_mu) * d_mu_d_v  # Gradient with respect to v
        grads[1] += (1/rtd)*(d_rtd_d_mu) * d_mu_d_a  # Gradient with respect to a

    # if condition to check grads are nan
    if np.isnan(grads).any():
        print(f"Warning: grads is nan")
        grads[np.isnan(grads)] = 0.01

    # grads /= len(t_arr)
    print(' grads=' ,grads)
    return grads


In [9]:
sampler = NUTSSampler(2, loglike_fn, grad_fn)

In [11]:
x0 = np.array([1.8, 1.9])  # Initial guess
M, Madapt = 500, 500  # Number of samples and adaptation steps
delta = 0.1
samples = sampler.run_mcmc(x0, M, Madapt, delta )

Running HMC with dual averaging and trajectory length 0.10...
v=1.8, a=1.9
loglike=-2086.5056751345255
 grads= [5102.89353914 4834.32019498]
v=1.8040501092728878, a=1.9024491061499125
loglike=-2054.304625149823
 grads= [5004.92447051 4746.05838796]
v=1.80000150121395, a=1.9000000343632124
loglike=-2086.4978484958788
 grads= [5102.8685609  4834.30047579]
v=1.8000030075307938, a=1.9000000735607452
loglike=-2086.4899724856778
 grads= [5102.84343735 4834.2806203 ]
v=1.8000060354731615, a=1.9000001664587711
loglike=-2086.474072353307
 grads= [5102.79275429 4834.24050042]
v=1.8000121525926198, a=1.9000004102666654
loglike=-2086.4416796598043
 grads= [5102.68964439 4834.15862501]
v=1.800024631770426, a=1.9000011299298236
loglike=-2086.374524714104
 grads= [5102.47644918 4833.98833131]
v=1.8000505698815983, a=1.9000034974456173
loglike=-2086.2307378948813
 grads= [5102.02215553 4833.62157015]
v=1.8001063651261804, a=1.9000119452351143
loglike=-2085.905267905172
 grads= [5101.00194061 4832.7833

  non_sum_term = 2 * np.cosh(mu) * np.exp(-(mu**2) * t / 2) * (1 / np.sqrt(2 * np.pi * (t**3)))
  non_sum_term = 2 * np.cosh(mu) * np.exp(-(mu**2) * t / 2) * (1 / np.sqrt(2 * np.pi * (t**3)))
  d_rtd_d_mu = (1 / np.sqrt(2 * np.pi * (t**3)))*(2*np.exp(-(mu**2)*t/2)*(np.sinh(mu) - mu*t*np.cosh(mu)))*sum_term
  d_rtd_d_mu = (1 / np.sqrt(2 * np.pi * (t**3)))*(2*np.exp(-(mu**2)*t/2)*(np.sinh(mu) - mu*t*np.cosh(mu)))*sum_term
  d_rtd_d_mu = (1 / np.sqrt(2 * np.pi * (t**3)))*(2*np.exp(-(mu**2)*t/2)*(np.sinh(mu) - mu*t*np.cosh(mu)))*sum_term


loglike=nan
 grads= [0.01 0.01]
v=2977.8094147497454, a=2821.4228520798633




loglike=nan
 grads= [0.01 0.01]
v=589255.0065612282, a=558288.2584482275




loglike=nan
 grads= [0.01 0.01]
v=194089437.22352612, a=183874105.65094987




loglike=nan
 grads= [0.01 0.01]
v=79311877860.28117, a=75137570513.95706




loglike=nan
NUTS: Exception while calling your likelihood function:
  params: [7.93118779e+10 7.51375705e+10]
  args: ()
  kwargs: {}
  exception:


Traceback (most recent call last):
  File "/home/rka/code/rough/NUTS/nuts/helpers.py", line 95, in __call__
    return self.f(x, *self.args, **self.kwargs)
  File "/tmp/ipykernel_176167/4013745345.py", line 21, in grad_fn
    rtd = rtd_mu_small_t_parallel(t, mu) + 1e-10
  File "/tmp/ipykernel_176167/3604288234.py", line 8, in rtd_mu_small_t_parallel
    sum_term = np.sum(sum_neg_1_term * sum_two_k_term * sum_exp_term)
  File "<__array_function__ internals>", line 180, in sum
  File "/home/rka/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py", line 2183, in _sum_dispatcher
    def _sum_dispatcher(a, axis=None, dtype=None, out=None, keepdims=None,
KeyboardInterrupt


KeyboardInterrupt: 

In [6]:
# Define initial parameters and sampler settings
x0 = np.array([1.8, 1.9])  # Initial guess
M, Madapt = 500, 500  # Number of samples and adaptation steps
delta = 0.7

# Initialize NUTS sampler
# sampler = NUTSSampler(len(x0), log_joint, grad_log_like)
samples, lnprob, epsilon = nuts6(log_like_and_grad, M, Madapt, x0, delta, progress=True)
# Run NUTS sampling
# print("Starting Sampling")
# samples = sampler.run_mcmc(x0, M, Madapt, delta)
# print("Sampling Completed")


v=1.8, a=1.9


loglike= -2086.5056751345255  grads= [5102.89353914 4834.32019498]
v=1.8012020719474742, a=1.9022992989062901
loglike= -2069.3421930019895  grads= [5053.66759986 4785.09168196]
v=1.7999986531766248, a=1.899999884555969
loglike= -2086.513105941151  grads= [5102.91703154 4834.33912744]
v=1.799997311456143, a=1.899999773946258
loglike= -2086.5204873710013  grads= [5102.94037863 4834.35792362]
v=1.79999464332386, a=1.8999995672297967
loglike= -2086.5351020978846  grads= [5102.98663686 4834.3951071 ]
v=1.7999893682940165, a=1.8999992118087166
loglike= -2086.5637390029574  grads= [5103.0774096 4834.4678386]
v=1.7999790631732195, a=1.8999987330139256
loglike= -2086.6186424828547  grads= [5103.25198031 4834.60676002]
v=1.799959432687185, a=1.8999987036138215
loglike= -2086.718967094326  grads= [5103.57322399 4834.8584383 ]
v=1.799924090737354, a=1.9000023575715226
loglike= -2086.8816796081237  grads= [5104.10412922 4835.25714966]
v=1.7998690829266444, a=1.9000245165185639
loglike= -2087.055313

  0%|          | 0/9999 [00:00<?, ?it/s]

v=2.4762084743570205, a=2.5630553692419817


  0%|          | 1/9999 [00:05<14:36:01,  5.26s/it]

loglike= -17070.826603633857  grads= [-34736.60824287 -33559.58858074]
v=7.148072881422716, a=6.994862470183241


  0%|          | 2/9999 [00:10<14:28:43,  5.21s/it]

loglike= -1127781.079013802  grads= [-19704.07348477 -20135.65726709]
v=1.8948090997435236, a=1.9922487194695149
loglike= -1428.377727126278  grads= [1986.07900489 1888.94114198]
v=1.8915965351827975, a=1.9843460925659464
loglike= -1450.7902750850776  grads= [2188.07843987 2085.80630722]
v=2.063122261073695, a=2.144885436283914


  0%|          | 3/9999 [00:25<27:18:01,  9.83s/it]

loglike= -1871.4374155570808  grads= [-5179.48289537 -4982.04066358]
v=1.9720573195718367, a=2.0492974615976536


  0%|          | 4/9999 [00:30<22:12:40,  8.00s/it]

loglike= -1342.1075222589493  grads= [-738.17485699 -710.35228274]
v=1.8804390292416984, a=1.9378327060031606


  0%|          | 5/9999 [00:35<19:02:50,  6.86s/it]

loglike= -1601.0998983186842  grads= [3173.80480245 3079.80477532]
v=1.9681364158568582, a=2.047850990986747
loglike= -1338.4962214225952  grads= [-621.58554049 -597.38972376]
v=1.9733926836973406, a=2.0482558442818366
loglike= -1342.3540471694985  grads= [-745.07176497 -717.83960677]
v=1.9721183510456113, a=2.044699913853604
loglike= -1339.0949001075498  grads= [-641.27612548 -618.51248028]
v=1.9685978773955564, a=2.0389775746277663
loglike= -1334.1906854253773  grads= [-443.8001544  -428.48143737]
v=1.963522944313356, a=2.0317544314810334
loglike= -1330.4109275453916  grads= [-184.06744671 -177.88599317]
v=1.9578032945138322, a=2.023908222845956
loglike= -1329.845373431308  grads= [98.29968001 95.08901402]
v=1.9524279502834945, a=2.0163950740651067


  0%|          | 6/9999 [01:10<45:43:08, 16.47s/it]

loglike= -1332.763116163208  grads= [362.43810286 350.94029507]
v=1.9523882833116883, a=2.0173569042437287


  0%|          | 7/9999 [01:15<35:09:18, 12.67s/it]

loglike= -1332.448235566878  grads= [344.02690671 332.94758126]
v=1.9661200512074521, a=2.047285296922198


  0%|          | 8/9999 [01:20<28:12:31, 10.16s/it]

loglike= -1336.9767481654023  grads= [-565.59490547 -543.17172413]
v=1.951597691022659, a=2.0160114492676473
loglike= -1333.213639957247  grads= [387.38690883 375.0094758 ]
v=1.951281297895831, a=2.015125042274837
loglike= -1333.6829127718859  grads= [411.64191981 398.60016759]
v=1.9514687944179654, a=2.014726560560903


  0%|          | 9/9999 [01:35<32:03:30, 11.55s/it]

loglike= -1333.764990553177  grads= [415.6833916  402.63189206]
v=1.9497028364267497, a=2.0175929845687945
loglike= -1333.3565558367134  grads= [395.29994748 381.99846785]
v=1.9509036019669856, a=2.0227466354395482
loglike= -1331.3108263429986  grads= [266.94595018 257.46468025]
v=1.953981178772906, a=2.029710437811377
loglike= -1329.703821771687  grads= [60.95059949 58.67650973]
v=1.9574872796740765, a=2.037086775883457
loglike= -1330.2541994056523  grads= [-165.33560641 -158.87509077]
v=1.959830959011912, a=2.043346114203349
loglike= -1332.388437673114  grads= [-346.04811565 -331.90451958]
v=1.959741685156416, a=2.0472719384127465
loglike= -1333.8067315971512  grads= [-426.45260436 -408.21980208]
v=1.956654159963488, a=2.0483277002634046


  0%|          | 10/9999 [02:10<52:12:16, 18.81s/it]

loglike= -1332.9673173250553  grads= [-381.46306545 -364.39056788]
v=1.9605392342215027, a=2.039420755376049
loglike= -1331.431356102144  grads= [-279.37647114 -268.57063769]
v=1.9354337508693882, a=2.0406751585915424
loglike= -1330.8908847633552  grads= [233.61558299 221.56759352]
v=1.9248298462831324, a=2.043091608573854


  0%|          | 11/9999 [02:24<48:35:53, 17.52s/it]

loglike= -1333.5757888789078  grads= [411.83379325 387.99531731]
v=1.9536436472296177, a=2.045591522477966


  0%|          | 12/9999 [02:29<37:55:03, 13.67s/it]

loglike= -1331.1667801135195  grads= [-258.96555344 -247.32523711]
v=1.9432605987237273, a=2.0441150760800473


  0%|          | 13/9999 [02:34<30:22:30, 10.95s/it]

loglike= -1329.616703866486  grads= [-4.24755752 -4.03798752]
v=1.9013442284563529, a=2.077636663941281
loglike= -1330.8107983595528  grads= [230.26284939 210.7244964 ]
v=1.9820102554612617, a=2.0075830169875237
loglike= -1329.750119312537  grads= [-74.5977338  -73.64750159]
v=1.965144502247038, a=1.916143981718507


  0%|          | 14/9999 [02:49<34:08:11, 12.31s/it]

loglike= -1438.0003972814116  grads= [1999.85397352 2050.99516469]
v=1.939690141650034, a=2.0476207714341177
loglike= -1329.6163084230875  grads= [1.00590477 0.95288326]
v=1.9467082368026247, a=2.0404926214905723
loglike= -1329.6166945818763  grads= [-4.19277223 -4.00006555]
v=1.9500346400147446, a=2.0367545041874675
loglike= -1329.6163125115725  grads= [1.08441869 1.03824688]
v=1.9533923994151605, a=2.0330464080053323
loglike= -1329.6169493499397  grads= [5.32012284 5.11168239]
v=1.9569039912560398, a=2.0294861171653684
loglike= -1329.6165753486493  grads= [3.51092366 3.38535971]
v=1.9605171021728802, a=2.0260237146939213
loglike= -1329.6163975852226  grads= [-2.18253234 -2.11196539]
v=1.9640671047206653, a=2.022500244311015
loglike= -1329.6168814092318  grads= [-5.01487981 -4.86999223]
v=1.9674721009927902, a=2.0188359571065515
loglike= -1329.6163495769133  grads= [-1.64701564 -1.60511176]
v=1.9708294734706475, a=2.015125257767296
loglike= -1329.6166510594385  grads= [3.91429885 3.82

  0%|          | 14/9999 [1:30:52<1080:14:42, 389.47s/it]


KeyboardInterrupt: 

In [None]:

# Print and plot results
print("Mean: ", np.mean(samples, axis=0))
per = np.percentile(samples, [16, 50, 84], axis=0)
print("Parameters = {} (+{} / -{})".format(per[1], per[2] - per[1], per[1] - per[0]))

plt.figure()
for i in range(len(x0)):
    plt.hist(samples[:, i], bins=30, alpha=0.5, label=f'Param {i+1}')
plt.legend()
plt.show()
