Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change torch._dynamo.export(aten_graph=...) to allow pre_autograd tracing #98031

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,24 @@ def f(x):
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))

def test_pre_autograd_simple(self):
def f(x):
y = torch.ones_like(x)
return torch.matmul(x, y)

gm, _ = torch._dynamo.export(
f, torch.randn(5, 5), aten_graph=True, pre_autograd=True, tracing_mode="fake"
)

inp = torch.randn(6, 6)
self.assertEqual(gm(inp), f(inp))
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
ones_like_default = torch.ops.aten.ones_like.default(arg0, pin_memory = False)
matmul_default = torch.ops.aten.matmul.default(arg0, ones_like_default); arg0 = ones_like_default = None
return pytree.tree_unflatten([matmul_default], self._out_spec)""")

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_export_cond_in_aten_symbolic(self):
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def export(
f: Callable[..., Any],
*args,
aten_graph: bool = False,
pre_autograd: bool = False,
decomposition_table: Optional[
Dict[torch._ops.OpOverload, Callable[..., Any]]
] = None,
Expand All @@ -623,6 +624,12 @@ def export(
aten_graph (bool): If True, exports a graph with ATen operators.
If False, exports a graph with Python operators. Default is False.

pre_autograd (bool): If True, exports a graph with ATen operators,
but before autograd has run. This can be useful if you want to apply further tranformations
on a graph before running it through autograd.
This flag is only valid if aten_graph=True is set.
Default is False.

decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
Required if aten_graph or tracing_mode is specified. Default is None.

Expand Down Expand Up @@ -650,6 +657,8 @@ def export(
assert (
aten_graph
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
if pre_autograd:
assert aten_graph, 'pre_autograd=True can only be used when aten_graph=True'
f = innermost_fn(f)

graph = None
Expand Down Expand Up @@ -801,6 +810,7 @@ def graph_with_interpreter(*args):
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode="real",
pre_autograd=pre_autograd,
_allow_non_fake_inputs=True,
)(*example_fake_inputs)

Expand Down
3 changes: 2 additions & 1 deletion torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def can_handle_tensor(x):
else:
constant = None

track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
with inside_mode(proxy_mode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems surprising to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still waiting on explainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sigh, wrote the comment when I initially made this update and forgot to submit it.

I'd be happy to make this a separate PR. The issue was that:

(1) We detach tensors as part of creating proxies.

(2) In pre_autograd tracing, we currently push TorchProxyDispatchMode onto both the autograd mode stack, and the original python key mode stack

(3) The detach() calls get intercepted by the proxy mode on the original python key mode stack.

So we need to be careful that any aten ops that we call inside of TorchProxyDispatchMode happen in a with inside_mode() now.

I think we agreed 2-3 weeks ago that we could avoid re-entrant issues if we solved the "fallthrough keys can't be intercepted by the python dispatcher" issue, but we agreed this is difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok put this comment in the code?

track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
return out


Expand Down