Skip to content

Commit

Permalink
Fix cond branches take no arguments (pytorch#109308)
Browse files Browse the repository at this point in the history
For code like this:
```python
import torch
from functorch.experimental import control_flow
def exportdb_example2(x):
    def true_fn():
        return torch.sin(x)

    def false_fn():
        return torch.cos(x)

    return control_flow.cond(x.sum() > 0, true_fn, false_fn, [])
ep = torch._export.export(exportdb_example2, (torch.randn(4, 5),))
```
before the pr, when the branches take an empty/list of tuple as inputs, we'll have error like following:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test_cond.py", line 11, in <module>
    ep = torch._export.export(exportdb_example2, (torch.randn(4, 5),))
  File "/home/yidi/local/pytorch/torch/_export/__init__.py", line 340, in export
    gm_torch_level, _ = torch._dynamo.export(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1207, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/test_cond.py", line 3, in exportdb_example2
    def exportdb_example2(x):
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1173, in result_capturing_wrapper
    graph_captured_result = torch.func.functional_call(
  File "/home/yidi/local/pytorch/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/home/yidi/local/pytorch/torch/nn/utils/stateless.py", line 264, in _functional_call
    return module(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 725, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 305, in __call__
    raise e
  File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 292, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.2", line 10, in forward
  File "/home/yidi/local/pytorch/torch/_ops.py", line 301, in __call__
    return wrapper()
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_ops.py", line 297, in wrapper
    return self.dispatch(
  File "/home/yidi/local/pytorch/torch/_ops.py", line 280, in dispatch
    return kernel(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/utils.py", line 52, in inner
    return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/utils.py", line 25, in autograd_not_implemented_inner
    result = operator(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_ops.py", line 301, in __call__
    return wrapper()
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_ops.py", line 297, in wrapper
    return self.dispatch(
  File "/home/yidi/local/pytorch/torch/_ops.py", line 255, in dispatch
    return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 310, in cond_fake_tensor_mode
    flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
  File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 725, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 305, in __call__
    raise e
  File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 292, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
```

Thanks for @williamwen42 spotting this error! We fix it by addressing the case when add_after is -1.

Test Plan:
See newly added tests.

Pull Request resolved: pytorch#109308
Approved by: https://github.com/williamwen42
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Sep 15, 2023
1 parent 1aba61e commit f3d1401
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 10 deletions.
92 changes: 92 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,98 @@ def false_fn(x):
)
self.assertEqual(num_placeholders, 5)

def _check_simple_cond_graph(
self, fn, args, exp_graph, exp_true_graph, exp_false_graph
):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
self.assertEqual(out, fn(*args))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(backend.graphs), 1)

# Dynamic shapes produce a slightly different graph.
if check_dynamic_shape_capture():
return

gm = backend.graphs[0]
graph = gm.code.strip()
true_graph = gm.cond_true_0.code.strip()
false_graph = gm.cond_false_0.code.strip()
self.assertExpectedInline(graph, exp_graph)
self.assertExpectedInline(true_graph, exp_true_graph)
self.assertExpectedInline(false_graph, exp_false_graph)

def test_cond_branches_no_arguments(self):
def fn(x):
def true_fn():
return torch.sin(x)

def false_fn():
return torch.cos(x)

return control_flow.cond(x.sum() > 0, true_fn, false_fn, tuple())

exp_graph = """\
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
sum_1 = l_x_.sum()
gt = sum_1 > 0; sum_1 = None
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_, l_x_]); \
gt = cond_true_0 = cond_false_0 = l_x_ = None
return (cond,)
""".strip()
exp_true_graph = """\
def forward(self, l_x_, l_x__false_branch):
sin = torch.sin(l_x_); l_x_ = None
return sin
""".strip()
exp_false_graph = """\
def forward(self, l_x__true_branch, l_x_):
cos = torch.cos(l_x_); l_x_ = None
return cos
""".strip()
self._check_simple_cond_graph(
fn, (torch.randn(4, 5),), exp_graph, exp_true_graph, exp_false_graph
)

def test_cond_branches_no_arguments_no_closure(self):
def fn(x):
def true_fn():
return torch.ones(3, 4)

def false_fn():
return torch.ones(3, 4).sin()

return control_flow.cond(x.sum() > 0, true_fn, false_fn, tuple())

exp_graph = """\
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
sum_1 = l_x_.sum(); l_x_ = None
gt = sum_1 > 0; sum_1 = None
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, []); gt = cond_true_0 = cond_false_0 = None
return (cond,)
""".strip()
exp_true_graph = """\
def forward(self):
ones = torch.ones(3, 4)
return ones
""".strip()
exp_false_graph = """\
def forward(self):
ones = torch.ones(3, 4)
sin = ones.sin(); ones = None
return sin
""".strip()
self._check_simple_cond_graph(
fn, (torch.randn(4, 5),), exp_graph, exp_true_graph, exp_false_graph
)

def test_cond_side_effect_in_one_branches(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
Expand Down
30 changes: 20 additions & 10 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,16 +452,26 @@ def speculate_branch(branch):
# false_fn(x, a_true, b_true, c_true, a, b, d)
# https://github.com/pytorch/pytorch/issues/103530
def fixup_branch_inps(graph, add_after, new_args, suffix) -> None:
inp_count = 0
for node in graph.nodes:
if node.op == "placeholder":
if inp_count == add_after:
with graph.inserting_after(node):
for inp_node in new_args:
new_node_name = inp_node.node.name + suffix
graph.placeholder(new_node_name)
break
inp_count += 1
original_phs = [node for node in graph.nodes if node.op == "placeholder"]
assert add_after < len(
original_phs
), f"Invalid index for inserting lifted arguments {add_after}."

# When operands is empty, add_after can be -1 for false graph. In that case, we need to insert new
# nodes before the first node in the graph since placeholders precede normal nodes.
def _add_phs():
for inp_node in new_args:
new_node_name = inp_node.node.name + suffix
graph.placeholder(new_node_name)

if add_after == -1:
first_node = next(iter(graph.nodes))
with graph.inserting_before(first_node):
_add_phs()
else:
insertion_node = original_phs[add_after]
with graph.inserting_after(insertion_node):
_add_phs()

fixup_branch_inps(
true_graph,
Expand Down

0 comments on commit f3d1401

Please sign in to comment.