Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT Autograd - Contiguous tensors #537

Open
anijain2305 opened this issue Feb 25, 2022 · 7 comments
Open

AOT Autograd - Contiguous tensors #537

anijain2305 opened this issue Feb 25, 2022 · 7 comments

Comments

@anijain2305
Copy link
Contributor

anijain2305 commented Feb 25, 2022

AOT Autograd does not handle non-contiguous tensors correctly right now.

One way to handle this is to trace backwards forcing contiguous tensors for the outputs and then call contiguous on backwards input grad tensors. This is done in #536. However, this can result in high and unnecessary overhead.

Other option is to record the strides of out in the forward pass and then accordingly restride the input grads in the backward. We have to investigate if that can be done for all cases.

@Chillee
Copy link
Contributor

Chillee commented Feb 25, 2022

@anijain2305 Can you add an example failure case?

@anijain2305
Copy link
Contributor Author

anijain2305 commented Feb 25, 2022

When I got this error, I had TorchDynamo + AOT Autograd setup. I was unable to extract a separate subgraph, just with AOT Autograd, that could expose this issue.

But, let me think. As I understand the issue better now, I might be able to come up with an example.

@ezyang
Copy link
Contributor

ezyang commented Mar 1, 2022

This feels like it is probably because the tracer isn't replicating strides correctly (which is an easy mistake to make). If so, we should be able to make progress on this.

EDIT: OK well it's not /completely/ busted, see

strides=elem.stride(), storage_offset=elem.storage_offset(),

@Chillee
Copy link
Contributor

Chillee commented Mar 1, 2022

@ezyang The problem is sadly ... more fundamental.

The problem is that given f, we AOT trace out f_out, backwards_function = vjp(f, inputs). Then, even though we don't have the actual input to backwards_function, we know that its shape, dtype, and device must be identical to f_out. The problem is that the strides are not necessarily identical, and that's what causing issues here.

In some sense this is kind of a stupid error. This only happens at the boundaries to the backwards pass, so at worst, this can be fixed with a single extra contiguous call there. But... that somewhat pessimizes the performance for small cases, which is why I've resisted doing it :(

@ezyang
Copy link
Contributor

ezyang commented Mar 2, 2022

Maybe we can just take a stride argument to vjp lol

@Chillee
Copy link
Contributor

Chillee commented Mar 2, 2022

Maybe... I am wondering whether we're going to need to relax the current restrictions with __torch_dispatch__ though... Like, it seems reasonable to me that the inputs to vjp could be an arbitrary tensor type.

In which case we're going to need to change our tracing/caching strategy, more or less.

Currently we trace the forwards + backwards graph upon hitting the forwards pass (and cache upon the forwards pass's inputs). Instead, we might need to cache on the inputs to the backwards pass, which is a bit... unclear about what that looks like.

@Chillee
Copy link
Contributor

Chillee commented Mar 4, 2022

Maybe we can just take a stride argument to vjp lol

Also, sorry, this doesn't even work. The problem is that, at the time we call vjp, we don't know what the stride of the backwards input is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants