In [42]:
import numpyro
import jax.numpy as jnp

In [43]:
import numpy as np

In [44]:
J = 8

y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])

sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

In [45]:
import numpyro.distributions as dist

In [46]:
# Eight Schools example

def eight_schools(J, sigma, y=None):
   mu = numpyro.sample('mu', dist.Normal(0, 5))
   tau = numpyro.sample('tau', dist.HalfCauchy(5))
   with numpyro.plate('J', J):
       theta = numpyro.sample('theta', dist.Normal(mu, tau))
       numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [47]:
from jax import random
from numpyro.infer import MCMC, NUTS

nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

sample: 100%|██████████| 1500/1500 [00:06<00:00, 238.02it/s, 31 steps of size 4.85e-02. acc. prob=0.99] 


In [48]:
mcmc.print_summary(exclude_deterministic=False)


                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.30      3.44      4.23     -1.68      8.96    158.83      1.01
       tau      4.05      3.45      2.97      0.29      8.29     70.21      1.02
  theta[0]      6.41      5.66      5.74     -2.99     14.52    213.16      1.01
  theta[1]      5.16      4.92      4.98     -2.10     13.42    250.40      1.01
  theta[2]      3.70      5.47      3.73     -5.71     11.76    324.69      1.00
  theta[3]      4.83      5.32      4.52     -2.80     13.70    331.97      1.00
  theta[4]      3.49      4.95      3.68     -4.69     11.44    211.57      1.00
  theta[5]      3.96      5.10      4.24     -3.66     12.65    298.62      1.01
  theta[6]      6.75      5.52      6.05     -1.46     15.60    155.36      1.01
  theta[7]      5.07      5.31      4.80     -3.16     13.83    232.52      1.02

Number of divergences: 1


In [49]:
def new_school():
    mu = numpyro.sample('mu', dist.Normal(10_000, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    return numpyro.sample('obs', dist.Normal(mu, tau))

In [50]:
from numpyro.infer import Predictive
predictive = Predictive(new_school, posterior_samples=mcmc.get_samples())
samples_predictive = predictive(rng_key=random.PRNGKey(0))
samples_predictive['obs'].mean()

DeviceArray(4.1635695, dtype=float32)

In [51]:
from numpyro.distributions import constraints

In [56]:
# Now try VI
def school_guide(J, sigma, y=None):
    mu_loc = numpyro.param("mu_loc", 2.0)
    mu_scale = numpyro.param("mu_scale", 1.0, constraint=constraints.positive)

    log_tau_loc = numpyro.param("log_tau_loc", 0.0)
    log_tau_scale = numpyro.param("log_tau_scale", 1.0, constraint=constraints.positive)

    mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
    # log_tau = numpyro.sample("log_tau", dist.Normal(log_tau_loc, log_tau_scale))
    tau = numpyro.sample("tau", dist.LogNormal(log_tau_loc, log_tau_scale))
    # tau = numpyro.deterministic("tau", jnp.exp(log_tau))

    with numpyro.plate("J", J):
        numpyro.sample("theta", dist.Normal(mu, tau))

In [57]:
from numpyro.infer import SVI, Trace_ELBO

optimizer = numpyro.optim.Adam(0.001)
svi = SVI(eight_schools, school_guide, optimizer, Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 5000, J=J, sigma=sigma, y=y)
svi_results.params

100%|██████████| 5000/5000 [00:09<00:00, 548.42it/s, init loss: 33.9446, avg. loss [4751-5000]: 31.7353] 


{'log_tau_loc': DeviceArray(0.25769845, dtype=float32),
 'log_tau_scale': DeviceArray(0.88301367, dtype=float32),
 'mu_loc': DeviceArray(4.082992, dtype=float32),
 'mu_scale': DeviceArray(3.126091, dtype=float32)}

In [3]:
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

In [4]:
N, D = 3000, 3
def logistic_regression(data, labels):
    coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
    intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
    logits = jnp.sum(coefs * data + intercept, axis=-1)
    return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

data = random.normal(random.PRNGKey(0), (N, D))
true_coefs = jnp.arange(1., D + 1.)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

In [13]:
num_warmup, num_samples = 1000, 1000
mcmc = MCMC(NUTS(model=logistic_regression), num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.PRNGKey(2), data, labels)
mcmc.print_summary()

sample: 100%|██████████| 2000/2000 [00:11<00:00, 176.32it/s, 3 steps of size 5.88e-01. acc. prob=0.91] 



                 mean       std    median      5.0%     95.0%     n_eff     r_hat
   coefs[0]      0.97      0.07      0.97      0.85      1.09    774.42      1.00
   coefs[1]      2.05      0.09      2.06      1.90      2.21    685.42      1.00
   coefs[2]      3.19      0.13      3.19      2.98      3.40    641.99      1.00
  intercept     -0.03      0.02     -0.03     -0.06      0.00    946.92      1.00

Number of divergences: 0


In [16]:
model = handlers.substitute(handlers.seed(logistic_regression, random.PRNGKey(2)), list(mcmc.get_samples().values())[0])
model_trace = handlers.trace(model).get_trace(data, labels)
model_trace

AttributeError: 'DeviceArray' object has no attribute 'get'

In [31]:
def log_likelihood(rng_key, params, model, *args, **kwargs):
    model = handlers.substitute(
        # handlers.seed(
            model,
            # rng_key),
        params)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    obs_node = model_trace['obs']
    # print(obs_node)
    # print(obs_node['fn'])
    # raise ValueError("stop")
    return obs_node['fn'].log_prob(obs_node['value'])

In [34]:
def log_predictive_density(rng_key, params, model, *args, **kwargs):
    n = list(params.values())[0].shape[0]
    log_lk_fn = vmap(lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs))
    log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
    print(log_lk_vals[0], log_lk_vals.shape, n)
    return jnp.sum(logsumexp(log_lk_vals, 0) - jnp.log(n))

In [35]:
print(log_predictive_density(random.PRNGKey(2), mcmc.get_samples(), logistic_regression, data, labels))

[-0.11249876 -0.08116782 -0.17313612 ... -0.4367317  -0.31474683
 -0.26940894] (1000, 3000) 1000
-874.9193
