Skip to content

Commit

Permalink
Elaborate methematical details of pyro.infer.predictive.WeighedPredic…
Browse files Browse the repository at this point in the history
…tive.
  • Loading branch information
Ben Zickel committed Mar 25, 2024
1 parent 026cae2 commit 260e528
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,25 @@ class WeighedPredictive(Predictive):
on the same initialization interface as :class:`Predictive`.
The methods `.forward` and `.call` can be called with an additional keyword argument
`model_guide` which is the model used to create and optimize the guide (if not
provided `model_guide` defaults to `self.model`), and they return both samples and log_weights.
``model_guide`` which is the model used to create and optimize the guide (if not
provided ``model_guide`` defaults to ``self.model``), and they return both samples and log_weights.
The weights are calculated as the per sample gap between the model_guide log-probability
and the guide log-probability (a guide must always be provided).
A typical use case would be based on a ``model`` :math:`p(x,z)=p(x|z)p(z)` and ``guide`` :math:`q(z)`
that has already been fitted to the model given observations :math:`p(X_{obs},z)`, both of which
are provided at itialization of :class:`WeighedPredictive` (same as you would do with :class:`Predictive`).
When calling an instance of :class:`WeighedPredictive` we provide the model given observations :math:`p(X_{obs},z)`
as the keyword argument ``model_guide``.
The resulting output would be the usual samples :math:`p(x|z)q(z)` returned by :class:`Predictive`,
along with per sample weights :math:`p(X_{obs},z)/q(z)`. The samples and weights can be fed into
:any:`weighed_quantile` in order to obtain the true quantiles of the resulting distribution.
Note that the ``model`` can be more elaborate with sample sites :math:`y` that are not observed
and are not part of the guide, if the samples sites :math:`y` are sampled after the observations
and the latent variables sampled by the guide, such that :math:`p(x,y,z)=p(y|x,z)p(x|z)p(z)` where
each element in the product represents a set of ``pyro.sample`` statements.
"""

def call(self, *args, **kwargs):
Expand Down

0 comments on commit 260e528

Please sign in to comment.