In [None]:
# dependencies

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

import numpy as np
import matplotlib.pyplot as plt

import jams.sampling

In [None]:
# define model

def eval_logcomp(x, mu, sig):
    return -(len(x) * np.log(sig) + np.linalg.norm(x - mu, 2) ** 2 / sig) / 2
def eval_logp(x):
    return np.logaddexp(eval_logcomp(x, mu, sig1), eval_logcomp(x, -mu, sig2))
def eval_d_logp(x):
    p = np.exp(eval_logp(x))
    dp = -((x - mu) / sig1 * np.exp(eval_logcomp(x, mu, sig1)) + (x + mu) / sig2 * np.exp(eval_logcomp(x, -mu, sig2)))
    return dp/p
    
mu = 2
sig1 = 1/4
sig2 = 1/9
d = 4
rng = np.random.default_rng(0)

In [None]:
# generate starting points

starting_points = rng.standard_normal(size=(32, d))

In [None]:
# generate 1e6 samples from target distribution

sampler = jams.sampling.sample_posterior(eval_logp, eval_d_logp, starting_points)
samples = [next(sampler) for _ in range(int(1e6))]
x = np.array([s[0] for s in samples])
i = np.array([s[1] for s in samples])

In [None]:
# inspect distribution of first coordinate and compare to known true density 

def eval_logp1(x, d):
    return np.logaddexp(-(np.log(2 * np.pi * sig1) + (x - mu) ** 2 / sig1) / 2, -(np.log(2 * np.pi * sig2) + (x + mu) ** 2 / sig2) / 2) - np.log(2)

x1 = np.linspace(-4, 4, 2 ** 8 + 1)
plt.figure(figsize=(9, 3))
plt.hist(x[int(1e3):, 0], 128, (-4, 4), True, alpha=.5)
plt.plot(x1, np.exp(eval_logp1(x1, d)), color='black')
plt.xlabel(r'$x_{1}$')
plt.ylabel('dens')

In [None]:
# inspect trace plot for the last 1000 observations

k = np.arange(len(samples) - int(1e3), len(samples))
plt.figure(figsize=(9, 3))
plt.plot(k, x[k,0])
plt.step(k, np.where(i[k] == 1, -mu, mu), color='black', alpha=.25)
plt.xlabel('iter')
plt.ylabel(r'$x_{1}$')