-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AOT Autograd - Contiguous tensors #537
Comments
@anijain2305 Can you add an example failure case? |
When I got this error, I had TorchDynamo + AOT Autograd setup. I was unable to extract a separate subgraph, just with AOT Autograd, that could expose this issue. But, let me think. As I understand the issue better now, I might be able to come up with an example. |
This feels like it is probably because the tracer isn't replicating strides correctly (which is an easy mistake to make). If so, we should be able to make progress on this. EDIT: OK well it's not /completely/ busted, see functorch/functorch/_src/python_key.py Line 71 in 8acf2d1
|
@ezyang The problem is sadly ... more fundamental. The problem is that given In some sense this is kind of a stupid error. This only happens at the boundaries to the backwards pass, so at worst, this can be fixed with a single extra contiguous call there. But... that somewhat pessimizes the performance for small cases, which is why I've resisted doing it :( |
Maybe we can just take a stride argument to vjp lol |
Maybe... I am wondering whether we're going to need to relax the current restrictions with In which case we're going to need to change our tracing/caching strategy, more or less. Currently we trace the forwards + backwards graph upon hitting the forwards pass (and cache upon the forwards pass's inputs). Instead, we might need to cache on the inputs to the backwards pass, which is a bit... unclear about what that looks like. |
Also, sorry, this doesn't even work. The problem is that, at the time we call |
AOT Autograd does not handle non-contiguous tensors correctly right now.
One way to handle this is to trace backwards forcing contiguous tensors for the outputs and then call contiguous on backwards input grad tensors. This is done in #536. However, this can result in high and unnecessary overhead.
Other option is to record the strides of out in the forward pass and then accordingly restride the input grads in the backward. We have to investigate if that can be done for all cases.
The text was updated successfully, but these errors were encountered: