In [2]:
import pymc as pm
import numpy as np

y = np.array([28, 8, -3, 7, -1, 1, 18, 12],dtype=np.float32)
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18],dtype=np.float32)
J = len(y)



# Model

In [31]:
with pm.Model() as model:
  mu = pm.Normal('mu', 0., 10.)
  tau = pm.LogNormal('tau', 0.,1.)
  eta = pm.Normal('eta', shape=8)
  obs=pm.Normal('observed',mu+tau*eta,sigma,observed=y)


In [32]:
import bayeux as bx

bx_model = bx.Model.from_pymc(model)

In [33]:
bx_model.log_density(bx_model.test_point)

Array(-43.57248197, dtype=float64)

In [34]:
bx_model.test_point

{'mu': array(0.),
 'tau': array(1.64872122),
 'eta': array([0., 0., 0., 0., 0., 0., 0., 0.])}

In [39]:
import distrax
from jax import grad
import jax.numpy as jnp


def log_likelihood(test_point):
    log_prior_eta=distrax.Normal(0.,1.).log_prob(test_point['eta']).sum()
    log_prior_mu=distrax.Normal(0.,10.).log_prob(test_point['mu'])
    log_prior_tau=distrax.Transformed(distrax.Normal(loc=0., scale=1.),
                                    distrax.Lambda(lambda x:jnp.exp(x))).log_prob(test_point['tau'])
    #log_prior_theta=distrax.Transformed(distrax.Normal(loc=0., scale=1.),
    #                                distrax.ScalarAffine(test_point['mu'],test_point['tau'])).log_prob(test_point['theta']).sum()
    #test_point['theta']=test_point['mu']+test_point['tau']*test_point['eta']
    log_like=distrax.Independent(distrax.Normal(test_point['mu']+test_point['tau']*test_point['eta'],sigma)).log_prob(y).sum()
    return log_prior_eta+log_like+log_prior_mu+log_prior_tau

dlog_p = grad(log_likelihood)

print(log_likelihood(bx_model.test_point))

-43.572482770210605


# Inference

In [37]:
with model:
    i_data_pmc=pm.sample(2000)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, tau, eta]
  self.pid = os.fork()


Output()

  self.pid = os.fork()


Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 5 seconds.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.


In [38]:
import jax 

idata_bx = bx_model.mcmc.numpyro_nuts(seed=jax.random.key(0))

sample: 100%|██████████| 1500/1500 [00:08<00:00, 167.03it/s]


In [40]:
def transform_fn(test_point):
  return {'mu': test_point['mu'], 'tau': jnp.exp(test_point['tau']),
          'eta':test_point['eta']}

bx_jax = bx.Model(
    log_density=log_likelihood,
    test_point=bx_model.test_point,
    transform_fn=transform_fn)

idata_jax=bx_jax.mcmc.numpyro_nuts(seed=jax.random.key(0))

sample: 100%|██████████| 1500/1500 [00:08<00:00, 170.23it/s]


# Diagnostics

In [41]:
import arviz as az

az.summary(i_data_pmc)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
eta[0],0.127,0.995,-1.785,1.938,0.01,0.011,10706.0,6310.0,1.0
eta[1],0.026,0.991,-1.883,1.813,0.01,0.011,10673.0,6134.0,1.0
eta[2],-0.052,0.978,-1.847,1.85,0.009,0.011,11799.0,5829.0,1.0
eta[3],0.007,0.995,-1.84,1.938,0.01,0.012,10198.0,5886.0,1.0
eta[4],-0.109,0.984,-1.917,1.747,0.01,0.011,9632.0,6046.0,1.0
eta[5],-0.086,0.975,-1.846,1.771,0.01,0.01,9829.0,5947.0,1.0
eta[6],0.148,0.998,-1.725,1.998,0.01,0.011,9842.0,5513.0,1.0
eta[7],0.012,1.006,-1.86,1.95,0.01,0.012,10326.0,5411.0,1.0
mu,6.62,3.752,-0.455,13.761,0.04,0.031,8970.0,5744.0,1.0
tau,1.383,1.385,0.027,3.762,0.018,0.014,8400.0,5795.0,1.0


In [42]:
az.summary(idata_bx)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
eta[0],0.124,1.001,-1.703,2.032,0.008,0.012,14820.0,6565.0,1.0
eta[1],0.016,0.987,-1.866,1.871,0.008,0.012,16413.0,5418.0,1.0
eta[2],-0.052,1.006,-1.891,1.838,0.008,0.012,15802.0,6068.0,1.0
eta[3],0.003,0.99,-1.802,1.933,0.008,0.012,16815.0,5980.0,1.0
eta[4],-0.115,1.001,-1.977,1.737,0.008,0.013,14360.0,6108.0,1.0
eta[5],-0.057,1.014,-1.869,1.944,0.008,0.012,16357.0,6173.0,1.0
eta[6],0.173,0.977,-1.545,2.123,0.008,0.012,14891.0,6354.0,1.0
eta[7],0.021,0.994,-1.867,1.867,0.008,0.012,16846.0,6037.0,1.0
mu,6.59,3.759,-0.284,13.978,0.031,0.027,14393.0,5336.0,1.0
tau,1.42,1.458,0.052,3.885,0.017,0.014,13167.0,6131.0,1.0


In [43]:
az.summary(idata_jax)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
eta[0],0.129,0.993,-1.806,1.947,0.008,0.012,15272.0,6498.0,1.0
eta[1],0.019,0.982,-1.864,1.853,0.007,0.012,17784.0,6075.0,1.0
eta[2],-0.053,1.018,-1.984,1.799,0.008,0.013,16922.0,5867.0,1.0
eta[3],0.006,0.991,-1.917,1.851,0.007,0.012,18187.0,6369.0,1.0
eta[4],-0.12,1.006,-1.995,1.757,0.009,0.012,13187.0,6288.0,1.0
eta[5],-0.056,1.004,-1.986,1.808,0.008,0.013,15836.0,5733.0,1.0
eta[6],0.157,0.99,-1.658,2.062,0.008,0.012,15153.0,6052.0,1.0
eta[7],0.023,0.995,-1.755,1.961,0.008,0.012,17503.0,6423.0,1.0
mu,6.599,3.797,-0.331,14.11,0.031,0.027,14896.0,5803.0,1.0
tau,1.414,1.457,0.025,3.96,0.016,0.013,13030.0,6499.0,1.0
