Skip to content

Conversation

supercharleszhu
Copy link
Contributor

Investigated together with @pyemma and @taotaohuang001

Problem

when calling exported module with dict nested in the args tuple, it will make following complaits

Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
    print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
    flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
    raise ValueError(  # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of: 
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

How to reproduce the issue

import torch


# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
   def __init__(self):
       super(MyModel, self).__init__()
       self.linear = torch.nn.Linear(10, 1)


   def forward(self, data_batch):
       h1 = self.linear(data_batch["a1"])
       h2 = self.linear(data_batch["a2"])
       return h1 + h2




# torch export this module
model = MyModel()
example_args_forward = (
   {
       "a1": torch.randn(10),
       "a2": torch.randn(10),
   },
)
exported_model = torch.export.export(model, example_args_forward, strict=True)


# save the exported model
torch.export.save(exported_model, "exported_model.pt2")


# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()


# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))

Root Cause

Input spec is encoded as TreeSpec in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution hook to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like

TreeSpec(dict, ['a2', 'a1'], [,])

To workaround this, the input check reorders kwargs, that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.

Solution

Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.

Copy link

pytorch-bot bot commented Sep 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162618

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 45b8873 with merge base 2f53395 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@supercharleszhu supercharleszhu changed the title Fix export with nested args with dict input Fix torch export with nested args with dict input Sep 10, 2025
@supercharleszhu supercharleszhu changed the title Fix torch export with nested args with dict input Fix torch export with dict input nested in args Sep 10, 2025
Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

@supercharleszhu supercharleszhu force-pushed the fix-export-nested-dict-arg branch from 50de85a to 8d65984 Compare September 10, 2025 21:46
@supercharleszhu
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Investigated together with @pyemma and @taotaohuang001

## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
    print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
    flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
    raise ValueError(  # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

```

## How to reproduce the issue
```python
import torch

# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
   def __init__(self):
       super(MyModel, self).__init__()
       self.linear = torch.nn.Linear(10, 1)

   def forward(self, data_batch):
       h1 = self.linear(data_batch["a1"])
       h2 = self.linear(data_batch["a2"])
       return h1 + h2

# torch export this module
model = MyModel()
example_args_forward = (
   {
       "a1": torch.randn(10),
       "a2": torch.randn(10),
   },
)
exported_model = torch.export.export(model, example_args_forward, strict=True)

# save the exported model
torch.export.save(exported_model, "exported_model.pt2")

# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()

# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))

```

## Root Cause
Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like

TreeSpec(dict, ['a2', 'a1'], [*,*])

To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.

## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: pytorch#162618
Approved by: https://github.com/angelayi
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Investigated together with @pyemma and @taotaohuang001

## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
    print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
    flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
    raise ValueError(  # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

```

## How to reproduce the issue
```python
import torch

# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
   def __init__(self):
       super(MyModel, self).__init__()
       self.linear = torch.nn.Linear(10, 1)

   def forward(self, data_batch):
       h1 = self.linear(data_batch["a1"])
       h2 = self.linear(data_batch["a2"])
       return h1 + h2

# torch export this module
model = MyModel()
example_args_forward = (
   {
       "a1": torch.randn(10),
       "a2": torch.randn(10),
   },
)
exported_model = torch.export.export(model, example_args_forward, strict=True)

# save the exported model
torch.export.save(exported_model, "exported_model.pt2")

# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()

# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))

```

## Root Cause
Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like

TreeSpec(dict, ['a2', 'a1'], [*,*])

To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.

## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: pytorch#162618
Approved by: https://github.com/angelayi
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Investigated together with @pyemma and @taotaohuang001

## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
    print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
    flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
    raise ValueError(  # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

```

## How to reproduce the issue
```python
import torch

# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
   def __init__(self):
       super(MyModel, self).__init__()
       self.linear = torch.nn.Linear(10, 1)

   def forward(self, data_batch):
       h1 = self.linear(data_batch["a1"])
       h2 = self.linear(data_batch["a2"])
       return h1 + h2

# torch export this module
model = MyModel()
example_args_forward = (
   {
       "a1": torch.randn(10),
       "a2": torch.randn(10),
   },
)
exported_model = torch.export.export(model, example_args_forward, strict=True)

# save the exported model
torch.export.save(exported_model, "exported_model.pt2")

# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()

# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))

```

## Root Cause
Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like

TreeSpec(dict, ['a2', 'a1'], [*,*])

To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.

## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: pytorch#162618
Approved by: https://github.com/angelayi
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Investigated together with @pyemma and @taotaohuang001

## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
    print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
    flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
    raise ValueError(  # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
      *])]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
      *])]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

```

## How to reproduce the issue
```python
import torch

# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
   def __init__(self):
       super(MyModel, self).__init__()
       self.linear = torch.nn.Linear(10, 1)

   def forward(self, data_batch):
       h1 = self.linear(data_batch["a1"])
       h2 = self.linear(data_batch["a2"])
       return h1 + h2

# torch export this module
model = MyModel()
example_args_forward = (
   {
       "a1": torch.randn(10),
       "a2": torch.randn(10),
   },
)
exported_model = torch.export.export(model, example_args_forward, strict=True)

# save the exported model
torch.export.save(exported_model, "exported_model.pt2")

# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()

# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))

```

## Root Cause
Input spec is encoded as [TreeSpec](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/utils/_pytree.py#L1059) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L66) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like

TreeSpec(dict, ['a2', 'a1'], [*,*])

To workaround this, the input check reorders [kwargs](https://github.com/pytorch/pytorch/blob/582d278983b28a91ac0cedd035183f2495bb6887/torch/export/_unlift.py#L67), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.

## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: pytorch#162618
Approved by: https://github.com/angelayi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants