-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpInfo tests fail gradcheck for ops with TensorList inputs. #51996
Comments
Let's focus exclusively on the autograd side of this issue for the moment, as how we address it may affect the fix for the jit tests. @albanD I think this is a bug in {grad, gradgrad}check where they don't handle tuple inputs properly? Before @anjali411 implemented its OpInfo, for example, torch.stack had a custom autograd test (see 8fb5f16) but I don't think it used gradcheck. When given a tuple of tensors as input gradcheck will try to splat them when calling the function, but functions like torch.stack don't support a variable number of tensor arguments. One way to fix this might be to add a "grad_input_function" to gradcheck. This idea has come up elsewhere (#50837 cc @IvanYashchuk), too. @albanD, seems like you're in favor of that solution on the other issue, so let's start by trying it out as a fix for this, too? |
The proposal there was not to modify gradcheck but opinfo right? And have a pre-processing function before calling the user function.
This is explicitly not supported. The doc says: "func (function) – a Python function that takes Tensor inputs and returns a Tensor or a tuple of Tensors": tuple of Tensors is not supported as input to the function. |
You are correct.
Yes, I think this is conceptually what the input function would do. |
I don't think we need to make any updates in |
We could continue to wrap ops in lambdas, that's true, but that's also what we want to avoid. Wrapping ops in lambdas is going to bite us later when we go to compare with NumPy references, for example. Basically we want to only wrap the op in a lambda in some contexts (like gradcheck). |
Offline discussion w/@albanD: let's continue to investigate wrapping the op in a lambda before gradchecking. Maybe subsume the current output_func post-process, which currently gets wrapped with the op, too. |
Couldn't gradcheck flatten the list of tensors, do things with the tensors, and then reconstruct the list of tensors when it comes time to call the function? |
The question is how do you differentiate between a single argument that is a tuple of tensor vs a tuple containing multiple Tensors that should be passed as separate arguments to the function? |
I am not suggesting that we would do this (I'm just curious), but to go down this rabbit hole: we'd have to modify the signature of gradcheck so that instead of taking inputs as So if a user wanted to pass in multiple tensors, they would do This would make the handling of functions with TensorLists nicer, especially if we think that we'll add more TensorList operations to the API |
I think @zou3519's suggestion is the cleanest approach from an API perspective. Otherwise, wrapping the function in a lambda that converts from the flattened list of tensors back into the correct arguments to the op is not so straightforward. What if we had an op that takes 2 TensorLists. We'd have to keep track of counts to know go back from the reconstruct the input arguments from the flattened list of tensors. |
It is cleaner but is very BC-breaking :/ And yes if it takes multiple Tensor list, you will need to do slicing. Which can still be done automatically, but is indeed a bit more complex. |
We could have our internal gradcheck do this first (that's why we have an internal gradcheck :D) |
internal gradcheck??? which one is that? I thought there is just the one (the one in |
It is here: pytorch/torch/testing/_internal/common_utils.py Lines 1941 to 1954 in c8b3686
Here's the original issue for context: #49409 |
The failure cause of the following tests seems to be similar to that of
|
That looks like the jit's failure to understand complex literals. I have a PR disabling those tests for dtypes other than float32, btw. We get too much noise and very little signal from them for other dtypes. |
btw complex literal support is coming very soon, #52881 |
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
…orList inputs" This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
This PR adds a workaround to OpInfo tests for ops that take TensoList inputs which were failing gradcheck (see #51996). It also updates the *stack ops to pass in a list of Tensors as input to SampleInput to demonstrate that the tests now pass without the need for defining a lambda in op. Differential Revision: [D26860788](https://our.internmc.facebook.com/intern/diff/D26860788) [ghstack-poisoned]
For ops like
torch.stack
andtorch.linalg.multi_dot
that take aTensor[]
as input as opposed to a single Tensor, OpInfo testing does not work properly.If the
sample_inputs_func
returns a tuple of tensors per SampleInput, it will be passed as individual positional parameters to the op. If instead each SampleInput is returned with a list, then gradcheck will fail with the following error:The current workaround is to create a lambda as follows:
However, this is changing the actual op which can cause problems in the future, for instance, when comparing the results of the op against NumPy.
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @jianyuh @mruberry @heitorschueroff @walterddr @IvanYashchuk @VitalyFedyunin
The text was updated successfully, but these errors were encountered: