-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Adds opinfo-based autograd tests and (un)supported dtype tests #43451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 932c71b (more details on the Dr. CI page):
ci.pytorch.org: 2 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 26 times. |
low, high, requires_grad: bool = False) -> torch.Tensor: | ||
"""Returns a tensor of the specified size on the given device and dtype. | ||
The tensors values are between -9 and 9, inclusive, unless low (high) | ||
is not None in which case the values are between max(-9, low) and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the reason why you make this choice of using the max instead of just using the provided low
value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Low could be -inf and we only want to generate values in a finite range, so we need some stopping point. if we only represent finite domains then it would make sense to actually use low, however that may run into other issues as some of our tests will fail when float values become too large due to cross-platform discrepancies.
test/test_ops.py
Outdated
output.backward(t) | ||
inplace_output.backward(t) | ||
|
||
self.assertEqual(sample.input.grad, inplace_input.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to gradcheck and gradgradcheck (as appropriate)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test for same dtypes has been moved to gradcheck
Codecov Report
@@ Coverage Diff @@
## master #43451 +/- ##
==========================================
+ Coverage 69.31% 69.34% +0.03%
==========================================
Files 378 378
Lines 46745 46801 +56
==========================================
+ Hits 32403 32456 +53
- Misses 14342 14345 +3
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This breaks because it uses self.skipTest with unsupported signatures (only takes one string arg). |
I'll send a fix. |
Yes, you're correct. It's meant to be a single string. Looks like that code path just didn't get tested. |
nondet_tol: float = 0.0, | ||
check_undefined_grad: bool = True | ||
check_undefined_grad: bool = True, | ||
check_grad_dtypes: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry this new arg was not documented here. Is that on purpose on just an oversight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't recall offhand what we decided. This is probably an oversight? We may have thought to not document it at the time so we'd have more time to review the UX for this function later, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Do you think it is stable enough now that we should document it? Or it is still fairly internal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your call. I'm not planning any more changes to this function.
float_dtype = torch.float if dtype is torch.cfloat else torch.double | ||
real = torch.rand(size, device=device, dtype=float_dtype) * span - (span / 2) | ||
imag = torch.rand(size, device=device, dtype=float_dtype) * span - (span / 2) | ||
c = torch.complex(real, imag) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the above code can be rewritten as c = torch.rand(size, device=device, dtype=dtype) * span - span/2 * (1+1j)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PRs welcome ;)
Welcome (kinda?) back, @anjali411!
This PR adds a new test suite, test_ops.py, designed for generic tests across all operators with OpInfos. It currently has two kinds of tests:
This is a significant expansion and simplification of the current autogenerated autograd tests, which spend considerable processing their inputs. As an alternative, this PR extends OpInfos with "SampleInputs" that are much easier to use. These sample inputs are analogous to the existing tuples in
method_tests()
.Future PRs will extend OpInfo-based testing to other uses of
method_tests()
, like test_jit.py, to ensure that new operator tests can be implemented entirely using an OpInfo.