Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for calculating quantiles of weighed samples. #3340

Merged
merged 6 commits into from
Mar 18, 2024

Conversation

BenZickel
Copy link
Contributor

The Problem

When calculating distribution quantiles based on samples sampled from a guide it is often desirable to correct for the non-uniform gap between the model log-probability and the guide log-probability (this gap is already computed by the pyro.infer.importance.Importance module).

Currently there is only support for the below:

  • Calculating quantiles from equally weighed samples by using pyro.ops.stats.quantile which interpolates quantiles over the sorted samples.
  • Sampling with non-uniform weights using pyro.distributions.empirical.Empirical which does not support interpolation.

The Proposed Solution

Add a function pyro.ops.stats.weighed_quantile which supports calculating interpolated quantiles with non-uniform weights. The implementation is similar to pyro.ops.stats.quantile but with different weights generated for each index of the input, according to the indices that sort that specific index.

@BenZickel
Copy link
Contributor Author

@fehiepsi, I've fixed the testing errors and now make format and make test pass successfully.
Sorry for not getting it right earlier.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BenZickel looks great! I wish we had designed Pyro for more native support of weighted sets of samples, but this is a step in the right direction.

Could you add a simple unit test of this function, similar to test_quantiles() in tests/ops/test_stats.py?

pyro/ops/stats.py Outdated Show resolved Hide resolved
@BenZickel BenZickel requested a review from fritzo March 17, 2024 22:50
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just one little comment on Tuple[float] vs Tuple[float, ...]

pyro/ops/stats.py Outdated Show resolved Hide resolved
@BenZickel BenZickel requested a review from fritzo March 17, 2024 23:19
@fritzo fritzo merged commit 0474cc9 into pyro-ppl:dev Mar 18, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants