Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable distribution with numerically integrated log-probability calculation (StableWithLogProb). #3369

Merged
merged 24 commits into from
May 28, 2024

Conversation

BenZickel
Copy link
Contributor

@BenZickel BenZickel commented May 20, 2024

This fixes #3280 by adding pyro.distributions.StableWithLogProb which is based on pyro.distributions.Stable with an additional log_prob method (I opted for not modifying the pyro.distributions.Stable distribution at this stage).

Code is based on combining #3280 (comment) by @mawright with the existing Stable distribution Pyro code base, with the following modifications:

  • Make the calculation stable at and near an alpha value of one, and values at and near zero.
  • Eliminate dependency on the torchquad package.
  • Cache integration points (not sure if torchquad does this but overall speed is 25% faster than the reference implementation based on torchquad).

Per iteration duration is about 5 times slower than with reparameterization but overall convergence is much faster, and includes cases which do not converge with reparameterization (like skew beta estimation).

The log-probability calculation is based on integration over a uniformly distributed random variable $u$ such that $P(x) = \int du P(x|u) P(u)$. The integral can be converted to a reparameterization where we first sample $u$ with probability density $P(u)$ or $g(u)$ when approximating the posterior distribution by a guide, and secondly sampling or observing $x$ with the distribution $P(x|u)$. Initial tests indicate this reparameterization works but is still slower than estimating the log-probability by integration.

A usage example with real life data has been added to the last section of the Stable distribution tutorial.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see this implemented!

I don't trust my review of the math, but I would trust some sort of density test. One option is to use goftests.density_goodness_of_fit, something like:

@pytest.mark.parametrize(...)
def test_density(stability, skew, loc, scale):
    d = StableWithLogProb(stability, skew, loc, scale)
    samples = d.sample(1000)
    probs = d.log_prob(samples).exp()
    gof = goftests.density_goodness_of_fit(samples, probs)
    assert gof > 1e-2

Another option is to check against a reference implementation, say something in scipy. WDYT?

(btw thanks for your patience!)

beta = self.skew.double()
value = value.double()

return _stable_log_prob(alpha, beta, value, self.coords) - self.scale.log()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll want to convert the result of _stable_log_prob() back to value.dtype, right? Something like:

logp = _stable_log_prob(alpha, beta, value, self.coords)
return logp.to(dtype=value.dtype) - self.scale.log()

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

I think it would actually be cleaner to implement Stable.log_prob() rather than a separate StableWithLogProb class (but thank you for drafting a non-invasive solution!). Do you see any blockers to simply merging Stable <-> StableWithLogProb in this PR? I think the only change will be the need to update your tutorial's summary:

 ## Summary
-- [Stable.log_prob()](http://docs.pyro.ai/en/stable/distributions.html#stable) is undefined.
+- [Stable.log_prob()](http://docs.pyro.ai/en/stable/distributions.html#stable) is very expensive.

and simply omit the reparam stuff from your new section. The single Stable solution is nice in that users will at least be able to use default SVI and HMC, and the older StableReparam machinery can become an approximate cost-saving tool.

EDIT I guess we'd need to revise some pytest.raises checks in the tests, which might be easiest by adding an internal distribution pyro.distributions.testing.fakes.StableWithoutLogProb.

BTW I am a big fan of the Levy Stable distribution, and am delighted to see Pyro improving its support for heavy-tailed inference.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge as-is, but I think this .log_prob() deserves to be in the main Stable distribution. LMK if you want me to merge now, or if you want to combine Stable <-> StableWithLogProbs.

@BenZickel
Copy link
Contributor Author

Thx @fritzo for the review!

I also think heavy-tailed inference is much needed and I really appreciate all the work done on this so far.

It might be better to combine StableWithLogProb and Stable but I'd do it in a separate pull request (if at all). The advantage of keeping them separate is that users will be made explicitly aware of both the high cost of the log-probability calculation and the possibility of reducing that cost at the expense of accuracy by reparameterization. If we do combine the two we also need to figure out if the behavior of MinimalReparam needs to be modified when handling the Stable distribution.

@BenZickel
Copy link
Contributor Author

One more option that comes to mind is to keep both Stable and StableWithLogProb and add the .log_prob method to Stable. This way a user can enforce no reparameterization by using StableWithLogProb instead of Stable.

@fritzo fritzo merged commit 0678b35 into pyro-ppl:dev May 28, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trouble with basic parameter estimation with the Stable distribution
2 participants