-
Notifications
You must be signed in to change notification settings - Fork 25k
Fix distributions which don't properly honor validate_args=False #53600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
A number of derived distributions use base distributions in their implementation. We add what we hope is a comprehensive test whether all distributions actually honor skipping validation of arguments in log_prob and then fix the bugs we found. These bugs are particularly cumbersome in PyTorch 1.8 and master when validate_args is turned on by default In addition one might argue that validate_args is not performing as well as it should when the default is not to validate but the validation is turned on in instantiation. Arguably, there is another set of bugs or at least inconsistencies when validation of inputs does not prevent invalid indices in sample validation (when with validation an IndexError is raised in the test). We would encourage the implementors to be more ambitious when validation is turned on and amend sample validation to throw a ValueError for consistency.
💊 CI failures summary and remediationsAs of commit f455848 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Codecov Report
@@ Coverage Diff @@
## master #53600 +/- ##
==========================================
- Coverage 77.65% 77.64% -0.01%
==========================================
Files 1869 1869
Lines 182294 182298 +4
==========================================
- Hits 141562 141551 -11
- Misses 40732 40747 +15 |
@neerajprad can I interest you in this? |
d_nonval = Dist(validate_args=False, **param) | ||
d_val = Dist(validate_args=True, **param) | ||
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]): | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Samples from distributions follow specific batch and event shape semantics based on the shape of their parameters. To isolate incorrect shape related errors from incorrect support errors, how about we do the following:
for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
# get sample of correct shape
val = torch.full(d_val.batch_shape + d_val.event_shape, v)
# samples with incorrect shape must throw ValueError only
try:
log_prob = d_val.log_prob(v)
except ValueError:
pass
# check samples with incorrect support
try:
log_prob = d_val.log_prob(val)
except ValueError as e:
if e.args and 'must be within the support' in e.args[0]:
try:
log_prob = d_nonval.log_prob(val)
except RuntimeError:
pass
On my system, this surfaces a couple of other distributions like LogisticNormal and RelaxedOneHotCategorical that we need to fix additionally. Also, you are right that we shouldn't have to throw IndexError
when validate_args is True. I see that there are certain distributions which aren't calling validate_sample
when validate_args
is set which causes this IndexError (you can find these offenders if you call d_val.log_prob(v)
as you were doing earlier and they throw anything other than ValueError). If you want I can push these changes to your PR, let me know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the snippet above to surface both the issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whatever is most convenient to you, my usecase was to be able to disable validation without changing the default just for Beta but I though it might be nice to do this for the other distributions as well.
Your improvement would improve user experience in general, the users will be grateful (and I never know if the indexing stuff wouldn't trigger device_asserts on CUDA).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed some additional fixes. Note that you should be able to get back the earlier behavior (no validation for all distributions) by using Distribution.set_default_validate_args(False)
at the top of your file. Thanks for making these fixes along the way!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thank you! I appreciate that that is a workaround, but I don't want people to copy+paste that kind of not-so-good-practice from my code. :)
Thank you for completing my fixes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You write some of the highest visibility code for PyTorch, so thanks for helping us establish good hygiene practices!
@t-vi: Thanks for surfacing this, and making these fixes! There are a few additional fixes that can go into this PR, and I'll be happy to push these if you'd like. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@neerajprad merged this pull request in 76b58dd. |
A number of derived distributions use base distributions in their
implementation.
We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.
Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.