Skip to content

Fix distributions which don't properly honor validate_args=False #53600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Mar 9, 2021

A number of derived distributions use base distributions in their
implementation.

We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.

Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.

A number of derived distributions use base distributions in their
implementation.

We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.

Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 9, 2021

💊 CI failures summary and remediations

As of commit f455848 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@codecov
Copy link

codecov bot commented Mar 9, 2021

Codecov Report

Merging #53600 (f455848) into master (e90e773) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #53600      +/-   ##
==========================================
- Coverage   77.65%   77.64%   -0.01%     
==========================================
  Files        1869     1869              
  Lines      182294   182298       +4     
==========================================
- Hits       141562   141551      -11     
- Misses      40732    40747      +15     

@t-vi t-vi requested a review from neerajprad March 9, 2021 20:01
@t-vi
Copy link
Collaborator Author

t-vi commented Mar 9, 2021

@neerajprad can I interest you in this?

d_nonval = Dist(validate_args=False, **param)
d_val = Dist(validate_args=True, **param)
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
try:
Copy link
Contributor

@neerajprad neerajprad Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Samples from distributions follow specific batch and event shape semantics based on the shape of their parameters. To isolate incorrect shape related errors from incorrect support errors, how about we do the following:

      for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
          # get sample of correct shape
          val = torch.full(d_val.batch_shape + d_val.event_shape, v)
          # samples with incorrect shape must throw ValueError only
          try:
              log_prob = d_val.log_prob(v)
          except ValueError:
              pass
          # check samples with incorrect support 
          try:
              log_prob = d_val.log_prob(val)
          except ValueError as e:
              if e.args and 'must be within the support' in e.args[0]:
                  try:
                      log_prob = d_nonval.log_prob(val)
                  except RuntimeError:
                      pass

On my system, this surfaces a couple of other distributions like LogisticNormal and RelaxedOneHotCategorical that we need to fix additionally. Also, you are right that we shouldn't have to throw IndexError when validate_args is True. I see that there are certain distributions which aren't calling validate_sample when validate_args is set which causes this IndexError (you can find these offenders if you call d_val.log_prob(v) as you were doing earlier and they throw anything other than ValueError). If you want I can push these changes to your PR, let me know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the snippet above to surface both the issues.

Copy link
Collaborator Author

@t-vi t-vi Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever is most convenient to you, my usecase was to be able to disable validation without changing the default just for Beta but I though it might be nice to do this for the other distributions as well.
Your improvement would improve user experience in general, the users will be grateful (and I never know if the indexing stuff wouldn't trigger device_asserts on CUDA).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed some additional fixes. Note that you should be able to get back the earlier behavior (no validation for all distributions) by using Distribution.set_default_validate_args(False) at the top of your file. Thanks for making these fixes along the way!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you! I appreciate that that is a workaround, but I don't want people to copy+paste that kind of not-so-good-practice from my code. :)
Thank you for completing my fixes!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You write some of the highest visibility code for PyTorch, so thanks for helping us establish good hygiene practices!

@neerajprad
Copy link
Contributor

cc. @fritzo, @fehiepsi.

@neerajprad
Copy link
Contributor

@t-vi: Thanks for surfacing this, and making these fixes! There are a few additional fixes that can go into this PR, and I'll be happy to push these if you'd like.

@neerajprad neerajprad requested review from fritzo and fehiepsi March 9, 2021 23:10
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@neerajprad merged this pull request in 76b58dd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants