Skip to content

Commit

Permalink
Make unimplemented cases graph break
Browse files Browse the repository at this point in the history
  • Loading branch information
ydwu4 committed May 22, 2023
1 parent e1245f9 commit 19ba8b0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
59 changes: 54 additions & 5 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,23 @@ def f(x):
x = torch.randn(3)
self._test_wrap_simple(f, (x,), 2, expected_opcount=3)

def test_fallback_on_python_primitives(self):
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [1, torch.sin(x), 2.0], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [1, torch.sin(x), 2.0])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_pytree_output(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
Expand All @@ -432,18 +449,50 @@ def test_fallback_on_pytree_output(self):

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [(x.sin(), x.cos()), {"a": -x}], x)
return wrap(lambda x: [(x.sin(), x.cos())], x)

x = torch.randn(2, 3)
result = f(x)

self.assertEqual(result, [(x.sin(), x.cos()), {"a": -x}])
self.assertEqual(result, [(x.sin(), x.cos())])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{
"torch.* op returned non-Tensor dict call_function <built-in function getitem>": 1
},
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_output_with_dict(self):
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [{"a": -x}], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [{"a": -x}])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_fallback_on_output_with_dict(self):
counters.clear()
cnt = CompileCounter()

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: [{"a": -x}], x)

x = torch.randn(3)
result = f(x)
self.assertEqual(result, [{"a": -x}])
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)

def test_access_module_attr(self):
Expand Down
15 changes: 13 additions & 2 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,15 @@ def speculate_subgraph(
):
unimplemented("HigherOrderOperator with body with pytree output")

if isinstance(output, (ListVariable, TupleVariable)):
if any(
not isinstance(var, TensorVariable)
for var in output.unpack_var_sequence(tx)
):
unimplemented(
"HigherOrderOperator body's output must consist of tensors only"
)

if always_restore:
# Nothing left to do here
return output, tx.output.graph, tracer.lifted_freevars
Expand Down Expand Up @@ -1206,8 +1215,10 @@ def speculate_branch(branch):
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars),
)
example_value = pytree.tree_map(
lambda a: a.node.meta["example_value"], body_r.as_proxy()
example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
elif self.value.__name__ in (
"trampoline_autograd_fwd",
Expand Down

0 comments on commit 19ba8b0

Please sign in to comment.