Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] No graph break on calling dict & collections.OrderedDict() #95250

Closed
wants to merge 5 commits into from

Conversation

yanboliang
Copy link
Contributor

@yanboliang yanboliang commented Feb 21, 2023

It's common to call dict() or collections.OrderedDict() inside of forward function, so we should not graph break.

This pattern has been used in many places including:

  • The use case in torchvision.
  • It causes ~100 model failures(nopython=True) in the 14k github models.
  • Also it hits several Meta internal use cases.

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95250

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bf1d037:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yanboliang yanboliang added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 23, 2023
@yanboliang
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 23, 2023
…#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch/pytorch#95250
Approved by: https://github.com/jansel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 23, 2023
…#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch/pytorch#95250
Approved by: https://github.com/jansel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 23, 2023
…#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch/pytorch#95250
Approved by: https://github.com/jansel
@yanboliang yanboliang deleted the dict branch February 23, 2023 16:57
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 25, 2023
…#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch/pytorch#95250
Approved by: https://github.com/jansel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 25, 2023
…#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch/pytorch#95250
Approved by: https://github.com/jansel
pytorchmergebot pushed a commit that referenced this pull request Mar 1, 2023
Fixes OrderedDict reconstruction issue found in #95250 with an attempt to fix it here #95725

Pull Request resolved: #95800
Approved by: https://github.com/yanboliang, https://github.com/clee2000
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 2, 2023
Fixes OrderedDict reconstruction issue found in pytorch/pytorch#95250 with an attempt to fix it here pytorch/pytorch#95725

Pull Request resolved: pytorch/pytorch#95800
Approved by: https://github.com/yanboliang, https://github.com/clee2000
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Fixes OrderedDict reconstruction issue found in pytorch/pytorch#95250 with an attempt to fix it here pytorch/pytorch#95725

Pull Request resolved: pytorch/pytorch#95800
Approved by: https://github.com/yanboliang, https://github.com/clee2000
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
…#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch/pytorch#95250
Approved by: https://github.com/jansel
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
Fixes OrderedDict reconstruction issue found in pytorch/pytorch#95250 with an attempt to fix it here pytorch/pytorch#95725

Pull Request resolved: pytorch/pytorch#95800
Approved by: https://github.com/yanboliang, https://github.com/clee2000
@RuibaiXu
Copy link

RuibaiXu commented May 2, 2023

I found that in the test_tristandeleu_pytorch_meta.py of the 14k github models projects, there will still be cases calling OrderedDict() with iterators args leading to graph break.
For example, in line 358 in test_tristandeleu_pytorch_meta.py
params = OrderedDict(self.named_parameters())
the arg self.named_parameters() will be ListIteratorVariable().

I want to know if the case is likely to be supported. And if so, what are the general ideas?

@RuibaiXu
Copy link

RuibaiXu commented May 2, 2023

I found that in the test_tristandeleu_pytorch_meta.py of the 14k github models projects, there will still be cases calling OrderedDict() with iterators args leading to graph break. For example, in line 358 in test_tristandeleu_pytorch_meta.py params = OrderedDict(self.named_parameters()) the arg self.named_parameters() will be ListIteratorVariable().

I want to know if the case is likely to be supported. And if so, what are the general ideas?

I found it has been fixed by #96122, thank you!

pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
…ytorch#95250)

It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.

This pattern has been used in many places including:
* The use case in [torchvision](
https://github.com/pytorch/vision/blob/928b05cad36eadb13e169f03028767c8bcd1f21d/torchvision/models/_utils.py#L66-L73).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.

Pull Request resolved: pytorch#95250
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants