# Compare different sampling algorithms

In [8]:
def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)


In [9]:
nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))


sample: 100%|██████████| 1500/1500 [00:05<00:00, 297.08it/s, 7 steps of size 2.77e-01. acc. prob=0.76]  


In [10]:
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.56      2.94      4.72     -0.48      9.02    209.48      1.00
       tau      4.59      3.32      3.63      0.95      8.78    100.77      1.00
  theta[0]      6.99      6.16      6.09     -3.43     16.85    340.16      1.00
  theta[1]      5.09      4.86      5.17     -2.64     12.71    385.50      1.00
  theta[2]      4.12      5.31      4.42     -4.35     13.10    396.04      1.00
  theta[3]      4.98      5.00      5.14     -3.13     12.09    520.83      1.00
  theta[4]      3.38      4.68      3.65     -3.33     11.41    344.31      1.00
  theta[5]      4.15      4.93      4.55     -3.46     12.20    307.29      1.00
  theta[6]      7.13      5.27      6.57     -1.15     15.38    241.43      1.00
  theta[7]      5.36      5.89      5.05     -3.64     14.45    519.28      1.00

Number of divergences: 9
