Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support list output for HigherOrderOperators #101986

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 47 additions & 4 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,24 +417,67 @@ def f(x):
{"Invoking an nn.Module inside HigherOrderOperator": 1},
)

def test_fallback_on_non_single_tensor_output(self):
def test_flat_list_output(self):
def f(x):
return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)

x = torch.randn(3)
self._test_wrap_simple(f, (x,), 2, expected_opcount=3)

def test_fallback_on_python_primitives_output(self):
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [1, torch.sin(x), 2.0], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [1, torch.sin(x), 2.0])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_nested_tuple_output(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: (x.sin(), x.cos()), x)
return wrap(lambda x: ((x.sin(), x.cos()),), x)

x = torch.randn(2, 3)
result = f(x)

self.assertEqual(result, (x.sin(), x.cos()))
self.assertEqual(result, ((x.sin(), x.cos()),))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_output_with_dict(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [{"a": -x}], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [{"a": -x}])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator with body with non single Tensor output": 1},
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_access_module_attr(self):
Expand Down
58 changes: 35 additions & 23 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,33 +773,42 @@ def speculate_subgraph(
output = f.call_function(tx, args, {})
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support non single Tensor output
# TODO: support pytree output
# We check always_restore because we dont use the output or side effects of always_restore code,
# like bwd.
if not isinstance(output, TensorVariable) and not always_restore:
unimplemented(
"HigherOrderOperator with body with non single Tensor output"
)

if always_restore:
# Nothing left to do here
return output, tx.output.graph, tracer.lifted_freevars
else:
if not isinstance(
output, (TensorVariable, ListVariable, TupleVariable)
):
unimplemented("HigherOrderOperator with body with pytree output")

if isinstance(output, (ListVariable, TupleVariable)):
if any(
not isinstance(var, TensorVariable)
for var in output.unpack_var_sequence(tx)
):
unimplemented(
"HigherOrderOperator body's output must consist of tensors only"
)

tx.output.guards.update(output.guards)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output.as_proxy(),))),
{},
)
graph = tx.output.graph
lifted_freevars = tracer.lifted_freevars

tx.output.guards.update(output.guards)
tx.output.create_node(
"output",
"output",
(tracer.create_arg((output.as_proxy(),))),
{},
)
graph = tx.output.graph
lifted_freevars = tracer.lifted_freevars

return (
output,
graph,
lifted_freevars,
)
return (
output,
graph,
lifted_freevars,
)

except torch._dynamo.exc.Unsupported as ex:
tx.output.graph = graph_checkpoint
Expand Down Expand Up @@ -1205,8 +1214,11 @@ def speculate_branch(branch):
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars),
)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r
example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
elif self.value.__name__ in (
"trampoline_autograd_fwd",
"trampoline_autograd_bwd",
Expand Down