-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pre_autograd make_fx
broken with simple F.linear with symbolic shape
#100055
Comments
cc @bdhirsh |
I'm not able to repro this on master. @haijieg can you try running again on a nightly? FWIW: the |
hi @bdhirsh I can still repro this on nightly 0428 as well as master built from last commit: 9e1f46d. |
Yeah, I can repro this on a branch based off of 23da1fd |
@haijieg your diagnosis is very good. What is supposed to happen is ProxyTensorMode is supposed to have a crack at all the returned outputs when |
I can repro now, taking a look - not sure what I was doing wrong the first time.
hmm - that linear decomposition runs below autograd. So I think I'd actually expect that any symint compute that happens inside of the linear call doesn't need to worry about proxies, and only the inputs / outputs to the linear() call need to worry about proxies. I'll keep looking. |
Some progress so far, but I'm still looking around. The fix seems pretty involved unfortunately :( (1) The original reason for the
is meant to make proxy tensor "re-entrant". When we trace pre-autograd, we end up running through ProxyTensorMode twice - once for the autograd key, and again at the bottom of the dispatcher. If we hit the mode twice, we want to create proxies the first time (at the autograd layer), but pass-through when we hit it the second time. The problem was that this code needs to be moved earlier: right now we try to grab proxies from the inputs before hitting this code, and those proxies might not exist (here). When I do that I run into a new issue though:
It's coming from here. The mode stack should be inactive in that code, so the call to |
Hi @bdhirsh is there more progress on this ticket? |
Not yet - quick update on my side is that most of my time this week was spent on #100587 and a few other bugs. That PR should land soon, and I'm gonna devote more time to this next week. |
One (basic) observation so far: I can make the error go away by using a fresh FakeTensorMode instead of re-using the ambient one:
That at least points to the fact that the issue has something to do with the current state stashed on the shared FakeTensorMode. (the tensor_memo, or the ShapeEnv, maybe?). The bug also doesn't repro if I remove |
It has to do with ShapeEnv, because the error Is raised from an untracked symint. A fresh FakeTensorMode won't have a shape_env. Removing |
Yep you're totally right - make_fx-only repro here:
(it looks like it's only "dynamic shapes related" because we hit the error with the mode stack when we go down this fast-path impl, which is only triggered for dynamic shapes (code)) |
Should be fixed by #101817 ! Also, quick heads up - the PR later in the stack renames |
…grad=True)" Fixes #100055 I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86) (1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly (2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we: (a) hit `linear()` with ProxyMode (b) desugared into `addmm` with ProxyMode, with FakeTensor inputs What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode". [ghstack-poisoned]
Fixes #100055 I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86) (1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly (2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we: (a) hit `linear()` with ProxyMode (b) desugared into `addmm` with ProxyMode, with FakeTensor inputs What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode". [ghstack-poisoned]
Awesome! Thanks for the fix and heads up. I like the name change to |
Sorry for the wait @haijieg. I'm actually re-writing that PR to simplify it: there have been several bugs because "TorchProxyDispatchMode is now re-entrant", and so instead of that PR, I'm going to make it so that we don't need that mode to be re-entrant in the first place. I'm going to put it on a dedicated dispatch key, and ensure that that dispatch key is always active when pre_dispatch tracing is running. |
Thank you for the update @bdhirsh . Look forward to your new PR landing. |
…grad=True)" Fixes #100055 I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86) (1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly (2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we: (a) hit `linear()` with ProxyMode (b) desugared into `addmm` with ProxyMode, with FakeTensor inputs What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode". [ghstack-poisoned]
Fixes #100055 I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86) (1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly (2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we: (a) hit `linear()` with ProxyMode (b) desugared into `addmm` with ProxyMode, with FakeTensor inputs What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode". [ghstack-poisoned]
…grad=True)" Fixes #100055 I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86) (1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly (2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we: (a) hit `linear()` with ProxyMode (b) desugared into `addmm` with ProxyMode, with FakeTensor inputs What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode". [ghstack-poisoned]
Fixes #100055 I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86) (1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly (2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we: (a) hit `linear()` with ProxyMode (b) desugared into `addmm` with ProxyMode, with FakeTensor inputs What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode". [ghstack-poisoned]
Hi @bdhirsh , would you kindly provide an update on this issue? I notice a recent change 3318a83#diff-91cb5d56bb8b6759eb8540df226fe248c5f422c09edd0751555cf2dc626add33R1897 also broke |
🐛 Describe the bug
I'm trying to get a
pre_autograd
aten_graph from dynamo with dynamic shape. The following code sort of works ontorch==2.1.0.dev20230425
but is broken with '2.1.0a0+git1c11065'Output from
torch==2.1.0.dev20230425
Output from '2.1.0a0+git1c11065'
Expected Output:
FX graph that contains "call_function[target=torch.ops.aten.linear.default]".
With
torch==2.1.0.dev20230425
, the problem is thataten.linear.default
did not go through python dispatch, so we have a graph with "transpose + view + admm".With '2.1.0a0+git1c11065', we have
aten.linear.default
properly go through python dispatch, andtranspose + view + admm
is going throughinside_mode
ofProxyTorchDispatchMode
. However, within Linear.cpp, we have the following codeSo the input to
aten::view
is a new symbols0 = 1 * s0
but it is somehow not tracked byProxyTorchDispatchMode.symnode_tracker
.Versions
'2.1.0a0+git1c11065'
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh
The text was updated successfully, but these errors were encountered: