Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not invert Gamma rate and fix JAX implementation of Gamma and Pareto #460

Merged
merged 1 commit into from
Sep 29, 2023

Conversation

tvwenger
Copy link
Contributor

@tvwenger tvwenger commented Sep 27, 2023

Motivation for these changes

The current implementation of GammaRV inverts the passed rate (beta) parameter upon each call. This breaks functionality in pymc as described in this bug report. This PR changes the parameterization so that we only invert beta when we draw samples. Along with this pytensor PR, the bug is fixed. This PR changes the parameterization to store scale (1/beta) and adds a helper function to convert rate to scale and warn the use that rate is deprecated.

Implementation details

No changes to API.

Checklist

Major / Breaking Changes

  • Store beta directly rather than its inverse. Store scale directly. There are no changes to the API.

New features

  • N/A

Bugfixes

Documentation

  • N/A

Maintenance

  • N/A

@ricardoV94 ricardoV94 changed the title Reparameterize GammaRV so beta is not inverted at each call Do not invert rate in outer graph of GammaRV Sep 27, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Sep 27, 2023

We might need to check if the Numba/ JAX implementations are still correct:

@numba_funcify.register(aer.GammaRV)

It seems like JAX was probably doing something wrong? It was expecting the rate while what was being provided was the scale?

sample = jax_op(sampling_key, shape, size, dtype) / rate

Would be good to have a careful look and see if it was indeed broken, and if so, why didn't this test fail?

(
aer.gamma,
[
set_test_value(
at.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
at.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"gamma",
lambda a, b: (a, 0.0, b),
),

@tvwenger
Copy link
Contributor Author

@ricardoV94 I'm not sure about the numba implementation.

Regarding JAX, I think we should be multiplying by rate, no?

Regarding the tests, it looks like it's only testing with beta = 1? That's the special case when rate == shape.

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 27, 2023

Regarding JAX, I think we should be multiplying by rate, no?

I would expect that you would multiply by scale or divide by rate if those parameters mean what they usually mean. So my hunch is that the implementation is now correct and was wrong before.

Would be good to check. And you're right we were testing the case where they're equivalent, I misread the parametrization and I thought we were testing both 1 and 2

@tvwenger
Copy link
Contributor Author

I would expect that you would multiply by scale or divide by rate if those parameters mean what they usually mean. So my hunch is that the implementation is now correct and was wrong before.

Yes, I think you're right.

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 27, 2023

I think numba is just calling np.random.gamma under the hood with positional arguments, so it should start failing now with these changes (let's make sure we are not testing shape=rate=1 again).

Hmm maybe the easiest solution is still to define the Op in terms of scale internally and have a helper "gamma" function (not __call__) that does the one time inversion for backwards compatibility.

In that function we could start deprecating rate parametrization (with a warning) in favor of scale kwarg. That way later we can silently remove the helper altogether. Something like

_gamma = Gamma()

def gamma(shape, rate=None, *, scale=None, ...):
  # TODO: Remove helper when rate is deprecated 
  if rate is not None and scale is not None:
     raise ValueError("Can't specify both rate and scale")
  elif rate is None and scale is None:
    raise ValueError("Missing scale argument")
  elif rate is not None:
    warnings.warn("Gamma rate argument is deprecated and will stop working, use scale instead", FutureWarning)
    scale = 1.0/rate

  return _gamma(shape, scale, ...)

This way we don't need a PR in PyMC (or just one that adds a test for the old bug)

@tvwenger
Copy link
Contributor Author

@ricardoV94 Thanks for the suggestions. I've updated the PR with the proposed changes, but I wasn't able to run the tests.

@ricardoV94
Copy link
Member

Looks like pre-commit is failing. Can you run and push again?

Also you should rebase from main, as we merged a fix that was causing an unrelated CI to fail.

@ricardoV94
Copy link
Member

I rebased and pushed some changes. Also fixed the jax implementation of Pareto

@ricardoV94 ricardoV94 added bug Something isn't working jax random variables labels Sep 29, 2023
@ricardoV94 ricardoV94 changed the title Do not invert rate in outer graph of GammaRV Do not invert Gamma rate and fix JAX implementation of Gamma and Pareto Sep 29, 2023
Also fix wrong JAX implementation of Gamma and Pareto RVs
@codecov-commenter
Copy link

Codecov Report

Merging #460 (a18ccda) into main (a708732) will increase coverage by 0.00%.
Report is 2 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #460   +/-   ##
=======================================
  Coverage   80.64%   80.64%           
=======================================
  Files         160      160           
  Lines       46016    46022    +6     
  Branches    11263    11265    +2     
=======================================
+ Hits        37108    37114    +6     
  Misses       6671     6671           
  Partials     2237     2237           
Files Coverage Δ
pytensor/link/jax/dispatch/random.py 96.05% <100.00%> (-0.02%) ⬇️
pytensor/tensor/random/basic.py 99.05% <100.00%> (+0.01%) ⬆️
pytensor/tensor/random/rewriting/jax.py 100.00% <100.00%> (ø)
pytensor/tensor/rewriting/linalg.py 77.31% <100.00%> (ø)

@ricardoV94 ricardoV94 merged commit 3169197 into pymc-devs:main Sep 29, 2023
53 checks passed
@tvwenger
Copy link
Contributor Author

Thanks for your help @ricardoV94!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants