-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[testing] Support input samples where self
is broadcasted.
#53014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[testing] Support input samples where self
is broadcasted.
#53014
Conversation
💊 CI failures summary and remediationsAs of commit 64883fa (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
test/test_ops.py
Outdated
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") | ||
|
||
def is_inplace(variant): | ||
return variant.__name__.endswith('_') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better way to check this?
Or should we trickle variant type and variant from the function that calls this helper?
Line 141 in c4c77e2
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) |
Line 161 in c4c77e2
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct way to check this is by testing if the variant is
the inplace variant acquired from the OpInfo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately that won't work as the op.get_inplace()
is wrapped by _get_safe_inplace
.
Lines 68 to 73 in c4c77e2
def _get_safe_inplace(self, inplace_variant): | |
@wraps(inplace_variant) | |
def _fn(t, *args, **kwargs): | |
return inplace_variant(t.clone(), *args, **kwargs) | |
return _fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you access it through the dunder wrapped attribute?
foo.__wrapped__ is op.get_inplace()
(if foo
may or may not be wrapped you'll need to check for the attr first, of course)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup it works with that. Thanks!!
def is_inplace(variant):
if hasattr(variant, "__wrapped__"):
return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace()
__slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad', 'broadcasts_self'] | ||
|
||
def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None): | ||
def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None, broadcasts_self=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach is reasonable but I still think we should filter sample_inputs() by whether the samples will be used for inplace computations or not. However, let me get a few other opinions on this. Maybe this is a better approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I neglected to update this, @kshitij12345. General consensus is that it'd be preferable to filter sample_inputs() by what's inplace for now. We may want to revise that decision later, however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC,
We will have to update the signature for every sample_inputs_*
with
sample_inputs_*(op_info, device, dtype, requires_grad, is_inplace_variant)
?
pytorch/torch/testing/_internal/common_methods_invocations.py
Lines 255 to 261 in afe339d
def sample_inputs(self, device, dtype, requires_grad=False): | |
"""Returns an iterable of SampleInputs. | |
These samples should be sufficient to test the function works correctly | |
with autograd, TorchScript, etc. | |
""" | |
return self.sample_inputs_func(self, device, dtype, requires_grad) |
And update sample_inputs
to
def sample_inputs(self, device, dtype, requires_grad=False, is_inplace_variant=False):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do that (and make is_inplace_variant kwarg-only). That would be the most discoverable thing.
But we don't have to. The operations for testing the inplace variant can request sample inputs by passing the for_inplace_variant (we can work on the name) kwarg like this:
try:
samples = op.sample_inputs(..., for_inplace_variant=True)
except TypeError as te:
samples = op.sample_inputs(...)
Then functions which support this option can implement it as:
def sample_inputs_foo(..., *, for_inplace_variant=False)
A nicer (but more disruptive) solution would be to make sample_inputs take **kwargs, and then the functions that want to use "for_inplace_variant" or similar options can query for it from the kwarg dict.
What would you think of an approach like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the second approach (though more disruptive). Will try that.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For test_variant_consistency_eager
,
Lines 204 to 242 in 0d81528
@_variant_ops(op_db) | |
def test_variant_consistency_eager(self, device, dtype, op): | |
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) | |
for sample in samples: | |
# Acquires variants (method variant, inplace variant, aliases) | |
method = op.get_method() | |
inplace = op.get_inplace() | |
# list of all inplace ops: inplace variant + alias inplace variants if exist | |
inplace_ops = [inplace, ] | |
aliases = [] | |
for a_op in op.aliases: | |
aliases.append(a_op.op) | |
aliases.append(a_op.method_variant) | |
aliases.append(a_op.inplace_variant) | |
inplace_ops.append(a_op.inplace_variant) | |
aliases = tuple(aliases) | |
inplace_ops = tuple(v for v in inplace_ops if v is not None) | |
variants = (v for v in (method, inplace) + aliases if v is not None) | |
# Computes function forward and backward values | |
sample.input.grad = None | |
expected_forward = op(sample.input, *sample.args, **sample.kwargs) | |
expected_grad = None | |
# TODO: backward consistency only supported for single tensor outputs | |
# TODO: backward consistency only checked on sample.input, not all | |
# tensor inputs | |
# TODO: update to handle checking grads of all tensor inputs as | |
# derived from each tensor output | |
if (op.supports_autograd and isinstance(expected_forward, torch.Tensor)): | |
expected_forward.sum().backward() | |
expected_grad = sample.input.grad | |
# Test eager consistency | |
for variant in variants: |
We will have to put call for sample_inputs
inside the variants loop. Since call to sample_input has to actually materialize all the tensors and perform forward multiple times on the same input sample (for different variant if that sample is valid for multiple variants), this will result in performance regression for this test.
I am feeling a bit sceptical about this now.
Let me know if that sounds acceptable. (or if i missed anything)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great point, and that function could probably use a small refactoring. I agree with you but think we can mitigate the damage:
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
for sample in samples:
# Acquires variants (method variant, inplace variant, aliases)
method = op.get_method()
inplace = op.get_inplace()
# list of all inplace ops: inplace variant + alias inplace variants if exist
inplace_ops = [inplace, ]
aliases = []
for a_op in op.aliases:
aliases.append(a_op.op)
aliases.append(a_op.method_variant)
aliases.append(a_op.inplace_variant)
inplace_ops.append(a_op.inplace_variant)
aliases = tuple(aliases)
inplace_ops = tuple(v for v in inplace_ops if v is not None)
variants = (v for v in (method, inplace) + aliases if v is not None)
What's weird about this is that the test is rebuilding its understand of the operator's variants on each sample. This is unnecessary since variants are sample-independent. We should lift that section out of the for loop and acquire the variants once upfront.
Now let's look at the rest of the function:
# Computes function forward and backward values
sample.input.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None
if (op.supports_autograd and isinstance(expected_forward, torch.Tensor)):
expected_forward.sum().backward()
expected_grad = sample.input.grad
for variant in variants:
sample.input.grad = None
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
self.assertEqual(expected_forward, variant_forward)
if expected_grad is not None and (variant not in inplace_ops or op.supports_inplace_autograd):
variant_forward.sum().backward()
self.assertEqual(expected_grad, sample.input.grad)
I think we can get rid of the for variant in variants
loop here by using a itertools.product(sample inputs, variants) after the sample input and variant acquisition at the start of the test. Then we can create a helper function that executes this test body. The helper function takes a variant, a sample input, and whether the operation is inplace or not (so it knows whether to perform the copy. Then the test works like this:
- the variants are identified
- sample inputs are acquired
- if the op has an inplace variant, inplace sample inputs are acquired
- an itertools product of sample inputs x non-inplace variants invokes the helper function
- an itertools product of inplace sample inputs x inplace variants invokes the helper function
Does that make sense? Looking forward to hearing your thoughts.
You are correct that this will still redundantly create sample inputs that are common to both the out-of-place and in-place variant. Which is unfortunate from a performance standpoint, and your idea of marking SampleInputs as safe-for-inplace would have handled this problem better. However I think this is an OK penalty (for now, anyway).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does make sense. I'll try to do that. Will ping here if I stumble into another block.
Thanks!
@mruberry, have fixed the merge conflicts. I think this is ready for another round. PTAL :) Note: Or should we add a try except (similar to the one mentioned above) temporarily and then remove once this signature becomes the norm? try:
samples = op.sample_inputs(..., **kwargs)
except TypeError as te:
samples = op.sample_inputs(...) Thanks! |
return variant.__wrapped__ is op.get_inplace() | ||
return variant is op.get_inplace() | ||
|
||
samples = op.sample_inputs(device, dtype, requires_grad=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your proposal to wrap this in a try/except so not every sample_input needs to be updated to use **kwargs when this lands is a good idea.
|
||
if len(inplace_ops) > 0: | ||
inplace_samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, | ||
for_inplace_variant=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also wrap this in try/except to minimize logical merge conflicts
yield SampleInput(make_arg((S,)), | ||
args=(torch.randn(S, S, device=device) > 0, make_arg(()))) | ||
yield SampleInput(make_arg((S,)), | ||
args=(torch.randn(S, S, device=device) > 0, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change instead of using the bernoulli op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bernoulli
produces scalar tensor however we want (S, S) tensor.
pytorch/torch/testing/_internal/common_methods_invocations.py
Lines 3908 to 3909 in 3d492b0
def bernoulli_scalar(): | |
return torch.tensor(0, dtype=torch.bool).bernoulli_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
('masked_fill', (M,), (torch.BoolTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),
I guess I mean where did this case go? We don't have to worry about it in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, replaced it with ( semantically both are doing the same thing),
SampleInput(make_arg((S,)),
args=(torch.randn(S, S, device=device) > 0, 10)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kshitij12345! Overall this looks good, as usual.
I made a few comments; I like your suggestion to reduce logical merge conflicts by using a try/except.
I think the OpInfos for masked_fill and masked_scatter can also be updated with this change, right?
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'), |
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'), |
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) | ||
except TypeError: | ||
samples = self.sample_inputs_func(self, device, dtype, requires_grad) | ||
return samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have added the try/except here. (we don't need to add try/except in test_ops)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry can you please see if this is ok? Will rebase accordingly then. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this seems fine
inplace_ops = tuple(v for v in inplace_ops if v is not None) | ||
variants = (v for v in (method, inplace) + aliases if v is not None) | ||
inplace_variants = tuple(v for v in inplace_ops if v is not None) | ||
variants = tuple(v for v in (method, inplace) + aliases if v is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! And thanks for including the fix for the variant generator. If you rebase this now I think we can get it landed, @kshitij12345.
I made one comment, but we don't need to worry about in this PR even if it is an issue.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
With this, we might want to update/add entry to mention this new change at https://github.com/pytorch/pytorch/wiki/Writing-tests-in-PyTorch-1.8 and #54261 (OpInfo porting tracker issue) |
Codecov Report
@@ Coverage Diff @@
## master #53014 +/- ##
==========================================
- Coverage 77.42% 77.23% -0.20%
==========================================
Files 1895 1895
Lines 187524 187556 +32
==========================================
- Hits 145194 144859 -335
- Misses 42330 42697 +367 |
Edit: nevermind, ROCm failures are in the base! Hmm... I was hoping to land this, but the ROCm build failure is worrying. @kshitij12345, would you take a look? |
Fixes #50747
Reference #50006