Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid scatter for single-device case in DDP #46304

Closed
wants to merge 11 commits into from

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Oct 14, 2020

Stack from ghstack:

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of to
which transfers the input tensors to the correct device.

The implementation of _recursive_to is modeled after scatter in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: D24296377

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Oct 14, 2020

💊 CI failures summary and remediations

As of commit 5b9773d (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 4 times.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 14, 2020
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 14, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114251484

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)
@dr-ci
Copy link

dr-ci bot commented Oct 14, 2020

💊 CI failures summary and remediations

As of commit 0967685 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 1 ongoing upstream failure:

These were probably caused by upstream breakages that are not fixed yet:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 35 times.

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 14, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114298122

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)
return _self.lin(x.t)
else:
self.assertTrue(len(x), expected_len)
self.assertTrue(x[0].device == x[1].device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we pass in the expected device to the constructor of ToyModel and validate it is correct here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this, but the test is basically expected to validate that the device is the current rank of the process, as we are testing single GPU per process. The next line asserts that the input is on the expected device.

Comment on lines 3827 to 3828
inp = [torch.randn(10, 10) for _ in range(expected_len)]
model(inp, list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests for dict and namedtuple as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added these in the latest version of the diff.

@codecov
Copy link

codecov bot commented Oct 14, 2020

Codecov Report

Merging #46304 into gh/rohan-varma/185/base will decrease coverage by 0.06%.
The diff coverage is 13.84%.

Impacted file tree graph

@@                     Coverage Diff                     @@
##           gh/rohan-varma/185/base   #46304      +/-   ##
===========================================================
- Coverage                    68.33%   68.27%   -0.07%     
===========================================================
  Files                          410      410              
  Lines                        53795    53856      +61     
===========================================================
+ Hits                         36760    36768       +8     
- Misses                       17035    17088      +53     
Impacted Files Coverage Δ
torch/nn/parallel/distributed.py 39.52% <10.00%> (-2.97%) ⬇️
.../testing/_internal/distributed/distributed_test.py 29.50% <15.15%> (-0.23%) ⬇️
torch/nn/parallel/scatter_gather.py 12.76% <50.00%> (ø)
torch/testing/_internal/expecttest.py 78.57% <0.00%> (+1.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f2e5ae4...a608196. Read the comment docs.

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 14, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114332450

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 14, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114338896

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)
if len(self.device_ids) == 1:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, any reason we still need to return inputs and kwargs as a list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly doing that to keep parity with the current version, but we can probably remove it in this code path.

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 16, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114504496

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D24296377/)!
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 21, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114861410

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D24296377/)!
In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Oct 22, 2020
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114896677

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D24296377/)!
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 7245d2c.

@ngimel
Copy link
Collaborator

ngimel commented Dec 25, 2020

@rohan-varma what's the motivation behind this PR? Does scatter for a single device incur some performance penalty? As #49819 says, previously it was possible to overlap h2d transfers with computation, now since the transfers happen on the default stream (.to() uses the default stream, as opposed to side stream in Scatter) this overlap is not possible.
Of course, current implementation can be fixed to use side stream, but I'm wondering if it's worth code complexity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants