-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not invert Gamma rate and fix JAX implementation of Gamma and Pareto #460
Conversation
We might need to check if the Numba/ JAX implementations are still correct:
It seems like JAX was probably doing something wrong? It was expecting the pytensor/pytensor/link/jax/dispatch/random.py Line 233 in a708732
Would be good to have a careful look and see if it was indeed broken, and if so, why didn't this test fail? pytensor/tests/link/jax/test_random.py Lines 142 to 157 in 53b00ea
|
@ricardoV94 I'm not sure about the Regarding Regarding the tests, it looks like it's only testing with |
I would expect that you would multiply by scale or divide by rate if those parameters mean what they usually mean. So my hunch is that the implementation is now correct and was wrong before. Would be good to check. And you're right we were testing the case where they're equivalent, I misread the parametrization and I thought we were testing both 1 and 2 |
Yes, I think you're right. |
I think numba is just calling np.random.gamma under the hood with positional arguments, so it should start failing now with these changes (let's make sure we are not testing shape=rate=1 again). Hmm maybe the easiest solution is still to define the Op in terms of scale internally and have a helper "gamma" function (not In that function we could start deprecating rate parametrization (with a warning) in favor of scale kwarg. That way later we can silently remove the helper altogether. Something like _gamma = Gamma()
def gamma(shape, rate=None, *, scale=None, ...):
# TODO: Remove helper when rate is deprecated
if rate is not None and scale is not None:
raise ValueError("Can't specify both rate and scale")
elif rate is None and scale is None:
raise ValueError("Missing scale argument")
elif rate is not None:
warnings.warn("Gamma rate argument is deprecated and will stop working, use scale instead", FutureWarning)
scale = 1.0/rate
return _gamma(shape, scale, ...) This way we don't need a PR in PyMC (or just one that adds a test for the old bug) |
@ricardoV94 Thanks for the suggestions. I've updated the PR with the proposed changes, but I wasn't able to run the tests. |
Looks like pre-commit is failing. Can you run and push again? Also you should rebase from main, as we merged a fix that was causing an unrelated CI to fail. |
I rebased and pushed some changes. Also fixed the jax implementation of Pareto |
b33078a
to
600229b
Compare
Also fix wrong JAX implementation of Gamma and Pareto RVs
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #460 +/- ##
=======================================
Coverage 80.64% 80.64%
=======================================
Files 160 160
Lines 46016 46022 +6
Branches 11263 11265 +2
=======================================
+ Hits 37108 37114 +6
Misses 6671 6671
Partials 2237 2237
|
Thanks for your help @ricardoV94! |
Motivation for these changes
The current implementation of
GammaRV
inverts the passedrate
(beta
) parameter upon each call. This breaks functionality inpymc
as described in this bug report.This PR changes the parameterization so that we only invertThis PR changes the parameterization to storebeta
when we draw samples. Along with this pytensor PR, the bug is fixed.scale
(1/beta
) and adds a helper function to convertrate
toscale
and warn the use thatrate
is deprecated.Implementation details
No changes to API.
Checklist
Major / Breaking Changes
StoreStorebeta
directly rather than its inverse.scale
directly. There are no changes to the API.New features
Bugfixes
Documentation
Maintenance