Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix cond branches take no arguments (pytorch#109308)
For code like this: ```python import torch from functorch.experimental import control_flow def exportdb_example2(x): def true_fn(): return torch.sin(x) def false_fn(): return torch.cos(x) return control_flow.cond(x.sum() > 0, true_fn, false_fn, []) ep = torch._export.export(exportdb_example2, (torch.randn(4, 5),)) ``` before the pr, when the branches take an empty/list of tuple as inputs, we'll have error like following: ```python Traceback (most recent call last): File "/home/yidi/local/pytorch/test_cond.py", line 11, in <module> ep = torch._export.export(exportdb_example2, (torch.randn(4, 5),)) File "/home/yidi/local/pytorch/torch/_export/__init__.py", line 340, in export gm_torch_level, _ = torch._dynamo.export( File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1207, in inner result_traced = opt_f(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/test_cond.py", line 3, in exportdb_example2 def exportdb_example2(x): File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1173, in result_capturing_wrapper graph_captured_result = torch.func.functional_call( File "/home/yidi/local/pytorch/torch/_functorch/functional_call.py", line 143, in functional_call return nn.utils.stateless._functional_call( File "/home/yidi/local/pytorch/torch/nn/utils/stateless.py", line 264, in _functional_call return module(*args, **kwargs) File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 725, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 305, in __call__ raise e File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 292, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, **kwargs) File "<eval_with_key>.2", line 10, in forward File "/home/yidi/local/pytorch/torch/_ops.py", line 301, in __call__ return wrapper() File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_ops.py", line 297, in wrapper return self.dispatch( File "/home/yidi/local/pytorch/torch/_ops.py", line 280, in dispatch return kernel(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_higher_order_ops/utils.py", line 52, in inner return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) File "/home/yidi/local/pytorch/torch/_higher_order_ops/utils.py", line 25, in autograd_not_implemented_inner result = operator(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_ops.py", line 301, in __call__ return wrapper() File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_ops.py", line 297, in wrapper return self.dispatch( File "/home/yidi/local/pytorch/torch/_ops.py", line 255, in dispatch return self.python_key_mode_table[type(curr_mode)](*args, **kwargs) File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 310, in cond_fake_tensor_mode flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands)) File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 725, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 305, in __call__ raise e File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 292, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given ``` Thanks for @williamwen42 spotting this error! We fix it by addressing the case when add_after is -1. Test Plan: See newly added tests. Pull Request resolved: pytorch#109308 Approved by: https://github.com/williamwen42
- Loading branch information