Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre_autograd make_fx broken with simple F.linear with symbolic shape #100055

Closed
haijieg opened this issue Apr 26, 2023 · 21 comments
Closed

pre_autograd make_fx broken with simple F.linear with symbolic shape #100055

haijieg opened this issue Apr 26, 2023 · 21 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@haijieg
Copy link

haijieg commented Apr 26, 2023

🐛 Describe the bug

I'm trying to get a pre_autograd aten_graph from dynamo with dynamic shape. The following code sort of works on torch==2.1.0.dev20230425 but is broken with '2.1.0a0+git1c11065'

import torch
import torch.nn as nn
import torch._dynamo as dynamo
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import detect_fake_mode

def compiler(gm, example_inputs):
    fake_mode = detect_fake_mode(example_inputs)
    fake_inputs = [fake_mode.from_tensor(i) if isinstance(i, torch.Tensor) else i
                   for i in example_inputs]
    with fake_mode, enable_python_dispatcher():
        fx_graph = make_fx(gm, pre_autograd=True)(*fake_inputs)
        print(fx_graph.graph)
    return gm.forward


@dynamo.optimize(compiler, dynamic=True)
def f(x, w, b):
    z = torch.nn.functional.linear(x, w, b)
    return z


w = torch.randn(20, 10)
b = torch.randn(20)
f(torch.randn(1, 2, 10), w, b)
f(torch.randn(1, 3, 10), w, b)

Output from torch==2.1.0.dev20230425

graph():
    %arg0_1 : [#users=2] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %arg2_1 : [#users=1] = placeholder[target=arg2_1]
    %arg3_1 : [#users=1] = placeholder[target=arg3_1]
    %t : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg2_1,), kwargs = {})
    %view : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%arg1_1, [%arg0_1, 10]), kwargs = {})
    %addmm : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg3_1, %view, %t), kwargs = {})
    %view_1 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%addmm, [1, %arg0_1, 20]), kwargs = {})

Output from '2.1.0a0+git1c11065'

torch._dynamo.exc.BackendCompilerFailed: backend='compiler' raised:
RuntimeError: s0 is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7f25897eed7
0>

Expected Output:
FX graph that contains "call_function[target=torch.ops.aten.linear.default]".

With torch==2.1.0.dev20230425, the problem is that aten.linear.default did not go through python dispatch, so we have a graph with "transpose + view + admm".

With '2.1.0a0+git1c11065', we have aten.linear.default properly go through python dispatch, and transpose + view + admm is going through inside_mode of ProxyTorchDispatchMode. However, within Linear.cpp, we have the following code

static inline Tensor _flatten_3d_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
    const auto input_sizes = input.sym_sizes();
    const auto result = at::addmm(bias, input.view_symint({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t());    return result.view_symint({input_sizes[0], input_sizes[1], result.sym_size(1)});
}

So the input to aten::view is a new symbol s0 = 1 * s0 but it is somehow not tracked by ProxyTorchDispatchMode.symnode_tracker.

Versions

'2.1.0a0+git1c11065'

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@anijain2305
Copy link
Contributor

cc @bdhirsh

@anijain2305 anijain2305 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2023
@bdhirsh
Copy link
Contributor

bdhirsh commented Apr 28, 2023

I'm not able to repro this on master. @haijieg can you try running again on a nightly?

FWIW: the pre_autograd flag is very experimental: in particular, it's known to not properly perserve autograd API's like checkpointing or no_grad() in the graph. We also might (in the near future) rename it to "pre_dispatch".

@haijieg
Copy link
Author

haijieg commented Apr 28, 2023

I'm not able to repro this on master. @haijieg can you try running again on a nightly?

FWIW: the pre_autograd flag is very experimental: in particular, it's known to not properly perserve autograd API's like checkpointing or no_grad() in the graph. We also might (in the near future) rename it to "pre_dispatch".

hi @bdhirsh I can still repro this on nightly 0428 as well as master built from last commit: 9e1f46d.

@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2023

Yeah, I can repro this on a branch based off of 23da1fd

@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2023

@haijieg your diagnosis is very good. What is supposed to happen is ProxyTensorMode is supposed to have a crack at all the returned outputs when addmm returns, which would cause the fresh s0 * 1 to get tracked as being a size of addmm. But clearly this isn't happening. Maybe we are not properly interposing modes for python dispatcher registrations? @bdhirsh maybe this rings a bell?

@bdhirsh
Copy link
Contributor

bdhirsh commented May 2, 2023

I can repro now, taking a look - not sure what I was doing wrong the first time.

What is supposed to happen is ProxyTensorMode is supposed to have a crack at all the returned outputs when addmm returns, which would cause the fresh s0 * 1 to get tracked as being a size of addmm. But clearly this isn't happening.

hmm - that linear decomposition runs below autograd. So I think I'd actually expect that any symint compute that happens inside of the linear call doesn't need to worry about proxies, and only the inputs / outputs to the linear() call need to worry about proxies. I'll keep looking.

@haijieg
Copy link
Author

haijieg commented May 5, 2023

@ezyang @bdhirsh have you had a chance to root cause the issue? Is there anything I can help with?

@bdhirsh
Copy link
Contributor

bdhirsh commented May 5, 2023

Some progress so far, but I'm still looking around. The fix seems pretty involved unfortunately :(

(1) The original reason for the not tracked with proxy error is because we shouldn't have made those proxies to begin with. This code in proxy_tensor.py:

    if proxy_mode.is_inside_mode:
        return func(*args, **kwargs)

is meant to make proxy tensor "re-entrant". When we trace pre-autograd, we end up running through ProxyTensorMode twice - once for the autograd key, and again at the bottom of the dispatcher. If we hit the mode twice, we want to create proxies the first time (at the autograd layer), but pass-through when we hit it the second time. The problem was that this code needs to be moved earlier: right now we try to grab proxies from the inputs before hitting this code, and those proxies might not exist (here).

When I do that I run into a new issue though:

RuntimeError: Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type FakeTensor

It's coming from here. The mode stack should be inactive in that code, so the call to torch.empty() should produce a plain tensor, that we wrap in a FakeTensor. The mode stack is active though, so we end up producing a fake tensor and wrapping it inside another fake tensor (which errors). Still digging into this one.

@haijieg
Copy link
Author

haijieg commented May 12, 2023

Hi @bdhirsh is there more progress on this ticket?

@bdhirsh
Copy link
Contributor

bdhirsh commented May 12, 2023

Not yet - quick update on my side is that most of my time this week was spent on #100587 and a few other bugs. That PR should land soon, and I'm gonna devote more time to this next week.

@haijieg
Copy link
Author

haijieg commented May 12, 2023

Thank you @bdhirsh #100587 looks pretty cool and we are very excited to be able to use it. Our project will benefit a lot from having a dynamic shape compatible pre_dispatch tracing for forward only graph, and a post autograd forward+backward joint tracing which you did in #100587.

@bdhirsh
Copy link
Contributor

bdhirsh commented May 18, 2023

One (basic) observation so far: I can make the error go away by using a fresh FakeTensorMode instead of re-using the ambient one:

fake_mode = FakeTensorMode()

That at least points to the fact that the issue has something to do with the current state stashed on the shared FakeTensorMode. (the tensor_memo, or the ShapeEnv, maybe?). The bug also doesn't repro if I remove dynamic=True.

@haijieg
Copy link
Author

haijieg commented May 18, 2023

It has to do with ShapeEnv, because the error Is raised from an untracked symint. A fresh FakeTensorMode won't have a shape_env. Removing dynamic=True also removes any symbolic shape tracking which makes the error go away.
Also we need to call f twice, and the error only triggers for the 2nd time because 1st time dynamo assumes static by default, and 2nd time dynamic shape was turned on.

@bdhirsh
Copy link
Contributor

bdhirsh commented May 18, 2023

Yep you're totally right - make_fx-only repro here:

w = torch.randn(20, 10)
b = torch.randn(20)
i1 = torch.randn(1, 2, 10)

s = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=s)
fake_inputs = [fake_mode.from_tensor(i) if isinstance(i, torch.Tensor) else i
               for i in [i1, w, b]]
with fake_mode, enable_python_dispatcher():
    g1 = make_fx(f, pre_autograd=True)(*fake_inputs)

(it looks like it's only "dynamic shapes related" because we hit the error with the mode stack when we go down this fast-path impl, which is only triggered for dynamic shapes (code))

@bdhirsh
Copy link
Contributor

bdhirsh commented May 18, 2023

Should be fixed by #101817 !

Also, quick heads up - the PR later in the stack renames pre_autograd to pre_dispatch. The motivation there was that people writing pre-autograd transforms are probably willing to write their transforms on a graph that hasn't had any dispatcher logic run yet (autograd, but also e.g. autocast). This way we won't have to worry about eventually adding more variants of pre_autograd for each functionality that the dispatcher has.

bdhirsh added a commit that referenced this issue May 18, 2023
…grad=True)"

Fixes #100055

I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86)

(1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly

(2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we:

(a) hit `linear()` with ProxyMode
(b) desugared into `addmm` with ProxyMode, with FakeTensor inputs

What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode".




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
Fixes #100055

I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86)

(1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly

(2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we:

(a) hit `linear()` with ProxyMode
(b) desugared into `addmm` with ProxyMode, with FakeTensor inputs

What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode".




[ghstack-poisoned]
@haijieg
Copy link
Author

haijieg commented May 18, 2023

Awesome! Thanks for the fix and heads up. I like the name change to pre_dispatch which is more intuitive.

@haijieg
Copy link
Author

haijieg commented May 30, 2023

Hi @bdhirsh, are we ready to merge #101817?

@bdhirsh
Copy link
Contributor

bdhirsh commented May 31, 2023

Sorry for the wait @haijieg. I'm actually re-writing that PR to simplify it: there have been several bugs because "TorchProxyDispatchMode is now re-entrant", and so instead of that PR, I'm going to make it so that we don't need that mode to be re-entrant in the first place. I'm going to put it on a dedicated dispatch key, and ensure that that dispatch key is always active when pre_dispatch tracing is running.

@haijieg
Copy link
Author

haijieg commented May 31, 2023

Thank you for the update @bdhirsh . Look forward to your new PR landing.

bdhirsh added a commit that referenced this issue Jun 5, 2023
…grad=True)"

Fixes #100055

I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86)

(1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly

(2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we:

(a) hit `linear()` with ProxyMode
(b) desugared into `addmm` with ProxyMode, with FakeTensor inputs

What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode".




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jun 5, 2023
Fixes #100055

I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86)

(1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly

(2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we:

(a) hit `linear()` with ProxyMode
(b) desugared into `addmm` with ProxyMode, with FakeTensor inputs

What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode".




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jun 5, 2023
…grad=True)"

Fixes #100055

I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86)

(1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly

(2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we:

(a) hit `linear()` with ProxyMode
(b) desugared into `addmm` with ProxyMode, with FakeTensor inputs

What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode".




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jun 5, 2023
Fixes #100055

I left more comments in the PR, but there were two issues that both stemmed from "ProxyTorchDispatchMode is supposed effectively be a no-op when it runs re-entrantly" (more details at this note: https://github.com/pytorch/pytorch/blob/a33ac4454058d25abe43532fbe86930e7f55bdda/torch/utils/_python_dispatch.py#L86)

(1) We were (incorrectly) attaching proxies to the input tensors of each op, when the mode was run re-entrantly

(2) We have some logic to return `NotImplemented` if our mode can't handle the inputs it was passed. The linked issue hit a problem where we:

(a) hit `linear()` with ProxyMode
(b) desugared into `addmm` with ProxyMode, with FakeTensor inputs

What we **want** to happen is for the re-entrant ProxyMode call to no-op and redispatch, so we hit FakeTensorMode, and execute `addmm` using FakeTensorMode. Instead, we hit the NotImplemented code path. I didn't dig too much deeper into what exactly it was about hitting the not-implemented path that caused the error, but hitting it went against the idea of "we should no-op when we re-enter ProxyTorchDispatchMode".




[ghstack-poisoned]
@haijieg
Copy link
Author

haijieg commented Jun 21, 2023

Hi @bdhirsh , would you kindly provide an update on this issue?

I notice a recent change 3318a83#diff-91cb5d56bb8b6759eb8540df226fe248c5f422c09edd0751555cf2dc626add33R1897 also broke pre_dispatch tracing.

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 21, 2023

@haijieg I landed before I went on PTO, but unfortunately the PR was reverted due to an internal failure.

I have a re-land PR here that you can follow (it should land soon) #103888

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants