Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow DataParallel to wrap CPU modules #17065

Closed
mrshenli opened this issue Feb 13, 2019 · 9 comments
Closed

Allow DataParallel to wrap CPU modules #17065

mrshenli opened this issue Feb 13, 2019 · 9 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@mrshenli
Copy link
Contributor

馃殌 Feature

Creating a model on CPU and then wrapping the model with DataParallel should automatically replicate the model on destination GPUs. Are there any reason to enforce that DataParallel's input model must be on GPU?

Motivation

model = nn.Linear(2, 2)
net = nn.DataParallel(model, device_ids=[0,1])
input_var = torch.randn(10, 2)
net(input_var)

The code above throws TypeError: Broadcast function not implemented for CPU tensors. It avoid the error, users need to explicitly call model.cuda().

  1. It is confusing whether it is the input tensor or the model tensor that should be placed on GPU.
  2. When calling nn.DataParallel(model, device_ids=[0,1]), we already have enough info on where the model should be replicated. It can be automatically handles regardless of whether the model is stored on CPU or GPU.

Pitch

Support the above code snippet.

@mrshenli mrshenli self-assigned this Feb 13, 2019
@mrshenli mrshenli added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Feb 13, 2019
@mrshenli
Copy link
Contributor Author

CC @douwekiela @pietern @soumith

@ssnl
Copy link
Collaborator

ssnl commented Feb 13, 2019

Creating a model on CPU and then wrapping the model with DataParallel should automatically replicate the model on destination GPUs. Are there any reason to enforce that DataParallel's input model must be on GPU?

This is not true. The model is broadcast at the beginning of each forward, not when constructing the DataParallel wrapper. The disadvantage of having the model on CPU, of course, is that the gradients are reduced to CPU at each iteration, which is slow and undesirable. IMO, automatically convert to one GPU upon construction is also not desirable because:

  1. Users may save a pointer to the wrapped module and reasonably expect it to still be on the original device.
  2. It can initialize a CUDA context, which isn't obvious to users that it should.

@mrshenli
Copy link
Contributor Author

@ssnl

How about explicitly throwing an error when constructing DataParallel if the wrapped model is on CPU?

@ssnl
Copy link
Collaborator

ssnl commented Feb 13, 2019

How about explicitly throwing an error when constructing DataParallel if the wrapped model is on CPU?

This SGTM :)

@douwekiela
Copy link

Okay, I think we should update the documentation for this then? Also, what is the best way to then move input_var to GPU? With the .to() semantics, we would have to specify one of the device_ids manually?

@mrshenli
Copy link
Contributor Author

mrshenli commented Feb 13, 2019

@douwekiela yes, I will update the docs in the fix for this issue.

Also, what is the best way to then move input_var to GPU? With the .to() semantics, we would have to specify one of the device_ids manually?

You don't have to move input_var to GPU. If you prefer to store it on GPU, it does not need to be one in the device_ids list I think. To move input_var, the following will all work:

  1. input_var.cuda()
  2. input_var.cuda(1)
  3. input_var.to('cuda:1')
  4. input_var.to(torch.device('cuda:1'))

@douwekiela
Copy link

Right. So I guess .cuda() assigns the tensor to the correct GPU for DataParallel automatically (@ssnl can you confirm?)? If so, that should be part of the code snippet imo ;)

@mrshenli
Copy link
Contributor Author

.cuda() moves it to the default GPU [link].

@ssnl
Copy link
Collaborator

ssnl commented Feb 14, 2019

@douwekiela The recommended way is to define one device object and use it throughout your program. https://pytorch.org/blog/pytorch-0_4_0-migration-guide/ is a helpful reading.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants