Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing pyro.infer.predictive.WeighedPredictive which reports weights along with predicted samples #3345

Merged
merged 10 commits into from
Mar 26, 2024

Conversation

BenZickel
Copy link
Contributor

The Problem

When sampling from the posterior predictive distribution we are often using a guide as an approximation for the posterior. As mentioned in #3340 it is often desirable to correct for the non-uniform per sample gap between the model log-probability and the guide log-probability. This gap is essentially the weight that should be assigned to each sample.

The current implementation of pyro.infer.predictive.Predictive does not support calculation of these weights.

The Proposed Solution

Add pyro.infer.predictive.WeighedPredictive which supports calculation of per sample weights.

The implementation relies on three objects:

  • Model which samples from priors and observations (same as in instantiation of pyro.infer.predictive.Predictive).
  • Guide which approximates the posterior given observations (same as in instantiation of pyro.infer.predictive.Predictive).
  • Model with observations constrained to be the actual observations. This model was used in creating the guide and is provided to instances of pyro.infer.predictive.WeighedPredictive when called as the keyword argument model_guide (as in the model that was used when creating the guide).

The model_guide is what enables calculation of the weights. If not provided we use the model provided at instantiation of pyro.infer.predictive.WeighedPredictive as the model_guide (in this case the model provided at instantiation is usually already with observations constrained to be the actual observations).

Design Considerations

  • Maintain backwards compatibility of pyro.infer.predictive.Predictive.
  • Reuse as much as possible from pyro.infer.predictive.Predictive when implementing pyro.infer.predictive.WeighedPredictive.

@fritzo
Copy link
Member

fritzo commented Mar 25, 2024

Hi @BenZickel I think it's a great idea to add a Predictive interface that can incorporate both a set of samples and some sort importance weights. And I appreciate your efforts to share interface and code with Predictive (which reduces our maintenance efforts!). Here are some general questions & comments before I do a thorough code review:

  1. I think it would also be good to add more mathematical details in the docstring of this class, to make it clear exactly what it's doing.
  2. Do I understand correctly that WeightedPredictive samples from the model p by drawing proposals from the guide q, then weighting each sample z by p(z)/q(z), returning a weighted set of samples?
  3. If so, it might be nice to add some sort of utility importance_resample: WeightedSamples -> Samples to convert from the weighted representation back to an unweighted representation (as done in SMCFilter).
  4. How does your WeightedPredictive relate to pyro.infer.Importance? Could they share any machinery? Should we link their docstrings?
  5. How does your WeightedPredictive relate to pyro.infer.ReweightedWakeSleep? IIUC, WeightedPredictive + importance_resample is like ReweightedWakeSleep, but where the former optimizes ELBO(guide,model) and the latter directly optimizes the model posterior density p(z|x)?

cc @martinjankowiak who may have a better understanding of the relationships between these inference algorithms.

@BenZickel
Copy link
Contributor Author

Thank you for your comments @fritzo. See below my feedback:

  1. I've added more mathematical details to the docstring.
  2. Yes, WeighedPredictive returns the exact same samples returned by Predictive, namely p(x|z)q(z), but accompanies each sample with its weight p(Xobs,z)/q(z) where p(Xobs,z)=p(Xobs|z)p(z).
  3. I agree that we need to have some way to convert weighed samples to unweighed samples, but I believe this should be added in another pull request as there are several considerations related to multi-dimensional event samples and interpolation methods (SMCFilter does not do interpolation). For now, we can calculate quantiles of weighed samples using weighed_quantile introduced in Add function for calculating quantiles of weighed samples. #3340.
  4. I've created as much shared machinery between WeighedPredictive and pyro.infer.Importance as I could in 026cae2. They do indeed share concepts and code.
  5. As I see it pyro.infer.ReweightedWakeSleep is a strategy to create a guide from your model and observations. As the created guide is usually not perfectly proportional to the model you would want to use WeighedPredictive in order to obtain the true distribution quantiles.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks great. Thanks for adding tests!

@@ -31,53 +34,58 @@ def _guess_max_plate_nesting(model, args, kwargs):
return max_plate_nesting


class _predictiveResults(NamedTuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup! This will also help us with #2550

@fritzo fritzo merged commit 0dc635f into pyro-ppl:dev Mar 26, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants