Skip to content

Commit

Permalink
change torch._dynamo.export(aten_graph=...) to allow pre_autograd tra…
Browse files Browse the repository at this point in the history
…cing (#98031)

pre_autograd tracing is still early, but it should work for basic cases. This PR changes the API a bit for export to expose pre_autograd tracing. Name bikeshedding is welcome, but it looks like:
```
torch._dynamo.export(..., aten_graph="aten_pre_autograd")
```

Pull Request resolved: #98031
Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Apr 25, 2023
1 parent 62fad31 commit 15e1bee
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
25 changes: 25 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,31 @@ def f(x):
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))

def test_pre_autograd_simple(self):
def f(x):
y = torch.ones_like(x)
return torch.matmul(x, y)

gm, _ = torch._dynamo.export(
f,
torch.randn(5, 5),
aten_graph=True,
pre_autograd=True,
tracing_mode="fake",
)

inp = torch.randn(6, 6)
self.assertEqual(gm(inp), f(inp))
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
ones_like_default = torch.ops.aten.ones_like.default(arg0, pin_memory = False)
matmul_default = torch.ops.aten.matmul.default(arg0, ones_like_default); arg0 = ones_like_default = None
return pytree.tree_unflatten([matmul_default], self._out_spec)""",
)

@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_export_cond_in_aten_symbolic(self):
class ConditionOp(torch.nn.Module):
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def export(
f: Callable[..., Any],
*args,
aten_graph: bool = False,
pre_autograd: bool = False,
decomposition_table: Optional[
Dict[torch._ops.OpOverload, Callable[..., Any]]
] = None,
Expand All @@ -695,6 +696,12 @@ def export(
aten_graph (bool): If True, exports a graph with ATen operators.
If False, exports a graph with Python operators. Default is False.
pre_autograd (bool): If True, exports a graph with ATen operators,
but before autograd has run. This can be useful if you want to apply further tranformations
on a graph before running it through autograd.
This flag is only valid if aten_graph=True is set.
Default is False.
decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
Required if aten_graph or tracing_mode is specified. Default is None.
Expand Down Expand Up @@ -722,6 +729,8 @@ def export(
assert (
aten_graph
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
if pre_autograd:
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
f = innermost_fn(f)

graph = None
Expand Down Expand Up @@ -878,6 +887,7 @@ def graph_with_interpreter(*args):
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
pre_autograd=pre_autograd,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:
# Wrap the internal error to the user-facing error
Expand Down
7 changes: 6 additions & 1 deletion torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,12 @@ def can_handle_tensor(x):
else:
constant = None

track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
# See Note [Per-Dispatch-Key Modes Must Be Reentrant]
# If our mode is on multiple mode stacks (e.g. the Autograd and Python mode stacks)
# then we only want it to trace out proxies the first time that we hit an op.
# In particular, track_tensor_tree can call detach().
with inside_mode(proxy_mode):
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
return out


Expand Down

0 comments on commit 15e1bee

Please sign in to comment.