Skip to content

Commit

Permalink
Support list output for HigherOrderOperators (#101986)
Browse files Browse the repository at this point in the history
Fixes the issue in #100278: support list output for HigherOrderOperator.

Pull Request resolved: #101986
Approved by: https://github.com/zou3519
  • Loading branch information
ydwu4 authored and pytorchmergebot committed May 23, 2023
1 parent e7a6818 commit 7e58891
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 27 deletions.
51 changes: 47 additions & 4 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,24 +417,67 @@ def f(x):
{"Invoking an nn.Module inside HigherOrderOperator": 1},
)

def test_fallback_on_non_single_tensor_output(self):
def test_flat_list_output(self):
def f(x):
return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)

x = torch.randn(3)
self._test_wrap_simple(f, (x,), 2, expected_opcount=3)

def test_fallback_on_python_primitives_output(self):
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [1, torch.sin(x), 2.0], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [1, torch.sin(x), 2.0])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_nested_tuple_output(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: (x.sin(), x.cos()), x)
return wrap(lambda x: ((x.sin(), x.cos()),), x)

x = torch.randn(2, 3)
result = f(x)

self.assertEqual(result, (x.sin(), x.cos()))
self.assertEqual(result, ((x.sin(), x.cos()),))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_output_with_dict(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [{"a": -x}], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [{"a": -x}])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator with body with non single Tensor output": 1},
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_access_module_attr(self):
Expand Down
58 changes: 35 additions & 23 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,33 +773,42 @@ def speculate_subgraph(
output = f.call_function(tx, args, {})
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support non single Tensor output
# TODO: support pytree output
# We check always_restore because we dont use the output or side effects of always_restore code,
# like bwd.
if not isinstance(output, TensorVariable) and not always_restore:
unimplemented(
"HigherOrderOperator with body with non single Tensor output"
)

if always_restore:
# Nothing left to do here
return output, tx.output.graph, tracer.lifted_freevars
else:
if not isinstance(
output, (TensorVariable, ListVariable, TupleVariable)
):
unimplemented("HigherOrderOperator with body with pytree output")

if isinstance(output, (ListVariable, TupleVariable)):
if any(
not isinstance(var, TensorVariable)
for var in output.unpack_var_sequence(tx)
):
unimplemented(
"HigherOrderOperator body's output must consist of tensors only"
)

tx.output.guards.update(output.guards)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output.as_proxy(),))),
{},
)
graph = tx.output.graph
lifted_freevars = tracer.lifted_freevars

tx.output.guards.update(output.guards)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output.as_proxy(),))),
{},
)
graph = tx.output.graph
lifted_freevars = tracer.lifted_freevars

return (
output,
graph,
lifted_freevars,
)
return (
output,
graph,
lifted_freevars,
)

except torch._dynamo.exc.Unsupported as ex:
tx.output.graph = graph_checkpoint
Expand Down Expand Up @@ -1205,8 +1214,11 @@ def speculate_branch(branch):
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars),
)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r
example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
elif self.value.__name__ in (
"trampoline_autograd_fwd",
"trampoline_autograd_bwd",
Expand Down

0 comments on commit 7e58891

Please sign in to comment.