Skip to content

Commit

Permalink
Clarify how to make predictive in SVI (#1549)
Browse files Browse the repository at this point in the history
* clarify how to make predictive in svi

* wrong usage of Predictive with guide

* fix svi example
  • Loading branch information
fehiepsi committed Mar 7, 2023
1 parent 14fd004 commit 9423ebc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 2 additions & 2 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,8 +1647,8 @@ class AutoLaplaceApproximation(AutoContinuous):
Laplace approximation (quadratic approximation) approximates the posterior
:math:`\log p(z | x)` by a multivariate normal distribution in the
unconstrained space. Under the hood, it uses Delta distributions to
construct a MAP guide over the entire (unconstrained) latent space. Its
covariance is given by the inverse of the hessian of :math:`-\log p(x, z)`
construct a MAP (i.e. point estimate) guide over the entire (unconstrained) latent
space. Its covariance is given by the inverse of the hessian of :math:`-\log p(x, z)`
at the MAP point of `z`.
Usage::
Expand Down
10 changes: 8 additions & 2 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SVI(object):
>>> def model(data):
... f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
... with numpyro.plate("N", data.shape[0]):
... with numpyro.plate("N", data.shape[0] if data is not None else 10):
... numpyro.sample("obs", dist.Bernoulli(f), obs=data)
>>> def guide(data):
Expand All @@ -110,9 +110,15 @@ class SVI(object):
>>> svi_result = svi.run(random.PRNGKey(0), 2000, data)
>>> params = svi_result.params
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
>>> # use guide to make predictive
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data=None)
>>> # get posterior samples
>>> predictive = Predictive(guide, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data)
>>> posterior_samples = predictive(random.PRNGKey(1), data=None)
>>> # use posterior samples to make predictive
>>> predictive = Predictive(model, posterior_samples, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param model: Python callable with Pyro primitives for the model.
:param guide: Python callable with Pyro primitives for the guide
Expand Down

0 comments on commit 9423ebc

Please sign in to comment.