# Numpyro

In [1]:
import numpyro
import numpyro.distributions as dist

In [2]:
import numpy as np
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

In [3]:
def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [4]:
from jax import random
from numpyro.infer import MCMC, NUTS

In [5]:
nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:00<00:00, 1585.22it/s, 31 steps of size 6.07e-02. acc. prob=0.75]


In [6]:
mcmc.print_summary() 


                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.48      2.88      5.14     -0.66      8.88    116.32      1.05
       tau      2.59      2.60      1.83      0.22      6.07     12.80      1.17
  theta[0]      5.65      4.24      5.74     -1.52     12.17    182.43      1.00
  theta[1]      4.82      3.96      5.37     -1.61     11.47    246.74      1.02
  theta[2]      4.26      4.16      5.30     -2.94      9.97    199.61      1.05
  theta[3]      4.68      4.02      5.36     -1.36     11.70    212.51      1.02
  theta[4]      3.89      3.95      4.85     -3.23      9.25    115.18      1.06
  theta[5]      4.28      4.14      5.30     -2.46     10.35    179.70      1.03
  theta[6]      5.87      4.12      5.70     -0.27     12.99    157.47      1.00
  theta[7]      4.77      4.14      5.31     -3.17     10.03    195.87      1.01

Number of divergences: 105


In [7]:
pe = mcmc.get_extra_fields()['potential_energy']

In [8]:
print('Expected log joint density: {:.2f}'.format(np.mean(-pe))) 

Expected log joint density: -49.01
