Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly convert namedtuples in DDP #44220

Closed
wants to merge 5 commits into from

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Sep 4, 2020

Stack from ghstack:

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple or typing.NamedTuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
scatter_gather.py to resolve the issue reported in
#44009

Differential Revision: D23536752

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)

[ghstack-poisoned]
Closes #44009
Currently if a dataloader returns objects created with a
`collections.namedtuple` or `typing.NamedTuple`, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Sep 4, 2020
Pull Request resolved: #44220

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009
ghstack-source-id: 111478085

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)
@dr-ci
Copy link

dr-ci bot commented Sep 4, 2020

💊 CI failures summary and remediations

As of commit 939378b (more details on the Dr. CI page):


Commit 939378b was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 15 times.

@rohan-varma
Copy link
Member Author

Not sure who the best reviewer for this change is since scatter_gather.py hasn't really been touched very often, so let me know if there are better reviewers than distributed folks.

Comment on lines 3134 to 3135
@require_backend({"nccl", "gloo"})
@require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add some unit test that directly test the scatter API with a variety of test cases? For example the case for namedtuple(a=1, b=1) you mentioned in the docs is a good test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test in test/test_cuda.py since it doesn't look like there were tests for this API anywhere else.

Closes #44009
Currently if a dataloader returns objects created with a
`collections.namedtuple` or `typing.NamedTuple`, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Sep 15, 2020
Pull Request resolved: #44220

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009
ghstack-source-id: 112050439

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)
Closes #44009
Currently if a dataloader returns objects created with a
`collections.namedtuple` or `typing.NamedTuple`, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Sep 16, 2020
Pull Request resolved: #44220

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009
ghstack-source-id: 112198252

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)
@codecov
Copy link

codecov bot commented Sep 16, 2020

Codecov Report

Merging #44220 into gh/rohan-varma/164/base will increase coverage by 0.65%.
The diff coverage is n/a.

Impacted file tree graph

@@                     Coverage Diff                     @@
##           gh/rohan-varma/164/base   #44220      +/-   ##
===========================================================
+ Coverage                    67.85%   68.50%   +0.65%     
===========================================================
  Files                          384      409      +25     
  Lines                        49919    52587    +2668     
===========================================================
+ Hits                         33873    36026    +2153     
- Misses                       16046    16561     +515     
Impacted Files Coverage Δ
torch/_lobpcg.py 69.50% <0.00%> (-16.84%) ⬇️
torch/distributed/rpc/__init__.py 36.92% <0.00%> (-9.89%) ⬇️
torch/types.py 77.41% <0.00%> (-9.54%) ⬇️
...istributed/rpc/process_group_agent_test_fixture.py 55.00% <0.00%> (-9.29%) ⬇️
torch/utils/mobile_optimizer.py 86.84% <0.00%> (-7.28%) ⬇️
torch/quantization/quantize_fx.py 91.42% <0.00%> (-6.69%) ⬇️
...g/_internal/distributed/rpc/dist_optimizer_test.py 20.66% <0.00%> (-6.19%) ⬇️
torch/quantization/fx/pattern_utils.py 87.23% <0.00%> (-1.86%) ⬇️
torch/onnx/utils.py 70.29% <0.00%> (-1.78%) ⬇️
torch/autograd/__init__.py 84.28% <0.00%> (-1.43%) ⬇️
... and 149 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1552a92...939378b. Read the comment docs.

@rohan-varma
Copy link
Member Author

@pritamdamania87 Added the tests for scatter, any more thoughts on this diff? Thanks!


scatter_out = scatter_gather.scatter(inp, target_gpus)
for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also assert that the _fields attribute is preserved correctly and also the appropriate values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adding the tests for valued actually uncovered a bug in the implementation that I just fixed.

Closes #44009
Currently if a dataloader returns objects created with a
`collections.namedtuple` or `typing.NamedTuple`, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 2, 2020
Pull Request resolved: #44220

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009
ghstack-source-id: 113423287

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in f8c1ca5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants