Skip to content

Commit

Permalink
Update on "[HigherOrderOp] stop erroring out on non-Tensor returns"
Browse files Browse the repository at this point in the history
If map or autograd.Function have an input that returns a non-Tensor,
then the code just errors out. Instead of erroring out we should graph
break by raising Unsupported so users aren't confused. The better thing
to do is actually support non-Tensor returns but that requires more
work.

Test Plan:
- new tests

cc voznesenskym penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
  • Loading branch information
zou3519 committed Aug 18, 2023
2 parents 6d27e5a + 6e55dfe commit 39b4100
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion test/dynamo/test_autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def f(x):

self.assertEqual(result, Foo.apply(x))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(list(torch._dynamo.utils.counters['graph_break'].values()), [1])
self.assertEqual(
list(torch._dynamo.utils.counters["graph_break"].values()), [1]
)

@unittest.expectedFailure
def test_function_with_bound_free_variable(self):
Expand Down
5 changes: 3 additions & 2 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,11 @@ def fixup_branch_inps(graph, add_after, new_args, suffix) -> None:

def non_single_tensor_return_unsupported(api, ret):
from . import TensorVariable

if not isinstance(ret, TensorVariable):
raise Unsupported(
f"{api} over function that returns something "
f"other than one Tensor")
f"{api} over function that returns something " f"other than one Tensor"
)


class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
Expand Down

0 comments on commit 39b4100

Please sign in to comment.