-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert output_device at data_parallel from torch.device to index #10189
Conversation
weiyangfb
commented
Aug 2, 2018
- fixes torch.device and torch.nn.parallel.data_parallel compatibility #9984
@pytorchbot retest this please |
I believe there are some instances of the same case in |
@vishwakftw I see, I will change them as well |
test/test_nn.py
Outdated
# test output_device | ||
l = nn.Linear(10, 5).float().cuda() | ||
i = Variable(torch.randn(20, 10).float().cuda()) | ||
out = dp.data_parallel(l, i, (0, 1), torch.device('cuda')) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
6606483
to
d7c16b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me.
actually, can we add a test for the other two code paths as well? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The better fix is to make scatter
, gather
, parallel_apply
, etc. to accept device objects (vs. converting to idx in DP). You can also make device_ids
to support device objects this way.
ping @weiyangfb on @ssnl suggestion. |
you can probably use/adapt |
d7c16b6
to
c70b9d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
is this good to go? @ssnl |
c60a6a6
to
9e500da
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@@ -36,7 +37,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): | |||
assert len(modules) == len(devices) | |||
else: | |||
devices = [None] * len(modules) | |||
|
|||
devices = list(map(lambda x: _get_device_index(x, True), devices)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
test/test_nn.py
Outdated
@@ -2922,6 +2922,18 @@ def test_data_parallel_small_back(self): | |||
out = dp.data_parallel(l, i, (0, 1)) | |||
self.assertEqual(out, l(i)) | |||
|
|||
# test output_device | |||
l = nn.Linear(10, 5).float().cuda() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/data_parallel.py
Outdated
device_ids: CUDA devices (default: all devices) | ||
output_device: device location of output (default: device_ids[0]) | ||
module (Module): module to be parallelized | ||
device_ids (list of int or Device): CUDA devices (default: all devices) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/replicate.py
Outdated
@@ -1,10 +1,12 @@ | |||
import torch.cuda.comm as comm | |||
from torch.cuda._utils import _get_device_index | |||
|
|||
|
|||
def replicate(network, devices, detach=False): | |||
from ._functions import Broadcast | |||
|
|||
devices = tuple(devices) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
25b90d1
to
4ed0d8f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@ssnl would you like to take quick pass on this? The updates are separated test function and doc fixes. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, but there is one remaining nit to be addressed
4ed0d8f
to
3a4e211
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@ssnl I see, fixed more places with |
3a4e211
to
f360f36
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…d APIs 1. convert torch.device to device.index in APIs 2. docs fixes
f360f36
to
d5721ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.