-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix torch export with dict input nested in args #162618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix torch export with dict input nested in args #162618
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162618
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 45b8873 with merge base 2f53395 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix!
50de85a
to
8d65984
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Investigated together with @pyemma and @taotaohuang001 ## Problem when calling exported module with dict nested in the args tuple, it will make following complaits ``` Traceback (most recent call last): File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module> print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__ raise e File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl return inner() File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn return fn(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match raise ValueError( # noqa: B904 ValueError: Trying to flatten user inputs with exported input tree spec: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*, *])]), TreeSpec(dict, [], [])]) but actually got inputs with tree spec of: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*, *])]), TreeSpec(dict, [], [])]). Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing. ``` ## How to reproduce the issue ```python import torch # create a nn.Module with data_batch as input and output as output class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = torch.nn.Linear(10, 1) def forward(self, data_batch): h1 = self.linear(data_batch["a1"]) h2 = self.linear(data_batch["a2"]) return h1 + h2 # torch export this module model = MyModel() example_args_forward = ( { "a1": torch.randn(10), "a2": torch.randn(10), }, ) exported_model = torch.export.export(model, example_args_forward, strict=True) # save the exported model torch.export.save(exported_model, "exported_model.pt2") # load the exported model exported_model = torch.export.load("exported_model.pt2").module() # run the exported model print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) ``` ## Root Cause Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like TreeSpec(dict, ['a2', 'a1'], [*,*]) To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors. ## Solution Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints. Pull Request resolved: pytorch#162618 Approved by: https://github.com/angelayi
Investigated together with @pyemma and @taotaohuang001 ## Problem when calling exported module with dict nested in the args tuple, it will make following complaits ``` Traceback (most recent call last): File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module> print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__ raise e File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl return inner() File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn return fn(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match raise ValueError( # noqa: B904 ValueError: Trying to flatten user inputs with exported input tree spec: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*, *])]), TreeSpec(dict, [], [])]) but actually got inputs with tree spec of: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*, *])]), TreeSpec(dict, [], [])]). Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing. ``` ## How to reproduce the issue ```python import torch # create a nn.Module with data_batch as input and output as output class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = torch.nn.Linear(10, 1) def forward(self, data_batch): h1 = self.linear(data_batch["a1"]) h2 = self.linear(data_batch["a2"]) return h1 + h2 # torch export this module model = MyModel() example_args_forward = ( { "a1": torch.randn(10), "a2": torch.randn(10), }, ) exported_model = torch.export.export(model, example_args_forward, strict=True) # save the exported model torch.export.save(exported_model, "exported_model.pt2") # load the exported model exported_model = torch.export.load("exported_model.pt2").module() # run the exported model print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) ``` ## Root Cause Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like TreeSpec(dict, ['a2', 'a1'], [*,*]) To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors. ## Solution Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints. Pull Request resolved: pytorch#162618 Approved by: https://github.com/angelayi
Investigated together with @pyemma and @taotaohuang001 ## Problem when calling exported module with dict nested in the args tuple, it will make following complaits ``` Traceback (most recent call last): File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module> print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__ raise e File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl return inner() File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn return fn(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match raise ValueError( # noqa: B904 ValueError: Trying to flatten user inputs with exported input tree spec: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*, *])]), TreeSpec(dict, [], [])]) but actually got inputs with tree spec of: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*, *])]), TreeSpec(dict, [], [])]). Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing. ``` ## How to reproduce the issue ```python import torch # create a nn.Module with data_batch as input and output as output class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = torch.nn.Linear(10, 1) def forward(self, data_batch): h1 = self.linear(data_batch["a1"]) h2 = self.linear(data_batch["a2"]) return h1 + h2 # torch export this module model = MyModel() example_args_forward = ( { "a1": torch.randn(10), "a2": torch.randn(10), }, ) exported_model = torch.export.export(model, example_args_forward, strict=True) # save the exported model torch.export.save(exported_model, "exported_model.pt2") # load the exported model exported_model = torch.export.load("exported_model.pt2").module() # run the exported model print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) ``` ## Root Cause Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like TreeSpec(dict, ['a2', 'a1'], [*,*]) To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors. ## Solution Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints. Pull Request resolved: pytorch#162618 Approved by: https://github.com/angelayi
Investigated together with @pyemma and @taotaohuang001 ## Problem when calling exported module with dict nested in the args tuple, it will make following complaits ``` Traceback (most recent call last): File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module> print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__ raise e File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl return inner() File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn return fn(*args, **kwargs) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match raise ValueError( # noqa: B904 ValueError: Trying to flatten user inputs with exported input tree spec: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*, *])]), TreeSpec(dict, [], [])]) but actually got inputs with tree spec of: TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*, *])]), TreeSpec(dict, [], [])]). Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing. ``` ## How to reproduce the issue ```python import torch # create a nn.Module with data_batch as input and output as output class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = torch.nn.Linear(10, 1) def forward(self, data_batch): h1 = self.linear(data_batch["a1"]) h2 = self.linear(data_batch["a2"]) return h1 + h2 # torch export this module model = MyModel() example_args_forward = ( { "a1": torch.randn(10), "a2": torch.randn(10), }, ) exported_model = torch.export.export(model, example_args_forward, strict=True) # save the exported model torch.export.save(exported_model, "exported_model.pt2") # load the exported model exported_model = torch.export.load("exported_model.pt2").module() # run the exported model print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)})) ``` ## Root Cause Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like TreeSpec(dict, ['a2', 'a1'], [*,*]) To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors. ## Solution Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints. Pull Request resolved: pytorch#162618 Approved by: https://github.com/angelayi
Investigated together with @pyemma and @taotaohuang001
Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
How to reproduce the issue
Root Cause
Input spec is encoded as TreeSpec in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution hook to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like
TreeSpec(dict, ['a2', 'a1'], [,])
To workaround this, the input check reorders kwargs, that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.
Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.