Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HigherOrderOp] stop erroring out on non-Tensor returns #107461

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/dynamo/test_autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils


class CustomFunc1(torch.autograd.Function):
Expand Down Expand Up @@ -321,6 +322,32 @@ def test_function_context_mark_and_save(self):
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)

def test_multi_output(self):
torch._dynamo.utils.counters.clear()
cnt = torch._dynamo.testing.CompileCounter()

class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone(), x.clone()

@staticmethod
def backward(ctx, grad1, grad2):
return grad1 + grad2

@torch.compile(backend=cnt)
def f(x):
return Foo.apply(x)

x = torch.randn(3, requires_grad=True)
result = f(x)

self.assertEqual(result, Foo.apply(x))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
list(torch._dynamo.utils.counters["graph_break"].values()), [1]
)

@unittest.expectedFailure
def test_function_with_bound_free_variable(self):
class LowerBound(torch.autograd.Function):
Expand Down
12 changes: 12 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,18 @@ def inner2(x, y):
name_set.add(name)
self.assertEqual(name_set, {"", "map_body_1.map_body_0", "map_body_1"})

def test_map_multi_return(self):
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return control_flow.map(lambda x: (x.sin(), x.sin()), x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, (x.sin(), x.sin()))
self.assertEqual(cnt.frame_count, 0)

def test_cond_subgraph_name_is_valid(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
Expand Down
11 changes: 11 additions & 0 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,15 @@ def fixup_branch_inps(graph, add_after, new_args, suffix) -> None:
)


def non_single_tensor_return_unsupported(api, ret):
from . import TensorVariable

if not isinstance(ret, TensorVariable):
raise Unsupported(
f"{api} over function that returns something " f"other than one Tensor"
)


class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
Expand Down Expand Up @@ -546,6 +555,7 @@ def call_function(
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars.keys()),
)
non_single_tensor_return_unsupported("torch.ops.higher_order.map", body_r)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r.new_empty(
[get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape]
Expand Down Expand Up @@ -979,6 +989,7 @@ def call_function(
*(arg.as_proxy() for arg in args),
*(arg for arg in body_lifted_freevars.keys()),
)
non_single_tensor_return_unsupported("autograd.Function forward", body_r)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r

Expand Down