Skip to content

Commit

Permalink
graph break on None output when not restore
Browse files Browse the repository at this point in the history
  • Loading branch information
ydwu4 committed May 23, 2023
1 parent 6c1f486 commit 32d8904
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
6 changes: 3 additions & 3 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def f(x):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors or None only": 1},
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_nested_tuple_output(self):
Expand All @@ -458,7 +458,7 @@ def f(x):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors or None only": 1},
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_output_with_dict(self):
Expand All @@ -477,7 +477,7 @@ def f(x):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors or None only": 1},
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_access_module_attr(self):
Expand Down
62 changes: 28 additions & 34 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,45 +776,39 @@ def speculate_subgraph(
# TODO: support pytree output
# We check always_restore because we dont use the output or side effects of always_restore code,
# like bwd.
if (
not isinstance(output, (TensorVariable, ListVariable, TupleVariable))
and not always_restore
):
unimplemented("HigherOrderOperator with body with pytree output")

if isinstance(output, (ListVariable, TupleVariable)):

def _is_supported(var):
return isinstance(var, TensorVariable) or (
isinstance(var, ConstantVariable) and var.value is None
)

if any(
not _is_supported(var) for var in output.unpack_var_sequence(tx)
):
unimplemented(
"HigherOrderOperator body's output must consist of tensors or None only"
)

if always_restore:
# Nothing left to do here
return output, tx.output.graph, tracer.lifted_freevars
else:
if not isinstance(
output, (TensorVariable, ListVariable, TupleVariable)
):
unimplemented("HigherOrderOperator with body with pytree output")

tx.output.guards.update(output.guards)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output.as_proxy(),))),
{},
)
graph = tx.output.graph
lifted_freevars = tracer.lifted_freevars
if isinstance(output, (ListVariable, TupleVariable)):
if any(
not isinstance(var, TensorVariable)
for var in output.unpack_var_sequence(tx)
):
unimplemented(
"HigherOrderOperator body's output must consist of tensors only"
)

tx.output.guards.update(output.guards)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output.as_proxy(),))),
{},
)
graph = tx.output.graph
lifted_freevars = tracer.lifted_freevars

return (
output,
graph,
lifted_freevars,
)
return (
output,
graph,
lifted_freevars,
)

except torch._dynamo.exc.Unsupported as ex:
tx.output.graph = graph_checkpoint
Expand Down

0 comments on commit 32d8904

Please sign in to comment.