Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert output_device at data_parallel from torch.device to index #10189

Closed
wants to merge 1 commit into from

Conversation

weiyangfb
Copy link
Contributor

@weiyangfb
Copy link
Contributor Author

@pytorchbot retest this please

@vishwakftw
Copy link
Contributor

I believe there are some instances of the same case in nn/parallel/distributed.py and nn/parallel/distributed_c10d.py. Could those be changed too?

@weiyangfb
Copy link
Contributor Author

@vishwakftw I see, I will change them as well

test/test_nn.py Outdated
# test output_device
l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda())
out = dp.data_parallel(l, i, (0, 1), torch.device('cuda'))

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb weiyangfb added the ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes label Aug 14, 2018
Copy link
Contributor

@li-roy li-roy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

@li-roy
Copy link
Contributor

li-roy commented Aug 16, 2018

actually, can we add a test for the other two code paths as well?

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The better fix is to make scatter, gather, parallel_apply, etc. to accept device objects (vs. converting to idx in DP). You can also make device_ids to support device objects this way.

@fmassa
Copy link
Member

fmassa commented Aug 28, 2018

ping @weiyangfb on @ssnl suggestion.

@ssnl
Copy link
Collaborator

ssnl commented Aug 28, 2018

you can probably use/adapt torch.cuda._get_device_index now to do that after #10833 .

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@weiyangfb
Copy link
Contributor Author

is this good to go? @ssnl

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@weiyangfb
Copy link
Contributor Author

is this good to go? @ssnl @teng-li @ailzhang

@@ -36,7 +37,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)

devices = list(map(lambda x: _get_device_index(x, True), devices))

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

test/test_nn.py Outdated
@@ -2922,6 +2922,18 @@ def test_data_parallel_small_back(self):
out = dp.data_parallel(l, i, (0, 1))
self.assertEqual(out, l(i))

# test output_device
l = nn.Linear(10, 5).float().cuda()

This comment was marked as off-topic.

device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
module (Module): module to be parallelized
device_ids (list of int or Device): CUDA devices (default: all devices)

This comment was marked as off-topic.

@@ -1,10 +1,12 @@
import torch.cuda.comm as comm
from torch.cuda._utils import _get_device_index


def replicate(network, devices, detach=False):
from ._functions import Broadcast

devices = tuple(devices)

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@weiyangfb
Copy link
Contributor Author

@ssnl would you like to take quick pass on this? The updates are separated test function and doc fixes. Thanks!

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, but there is one remaining nit to be addressed

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@weiyangfb
Copy link
Contributor Author

@ssnl I see, fixed more places with Device -> torch.device

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

…d APIs

1. convert torch.device to device.index in APIs
2. docs fixes
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.device and torch.nn.parallel.data_parallel compatibility
9 participants