Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch in multi-cpu cluster #2733

Closed
msabvid opened this issue Sep 14, 2017 · 5 comments
Closed

Pytorch in multi-cpu cluster #2733

msabvid opened this issue Sep 14, 2017 · 5 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@msabvid
Copy link

msabvid commented Sep 14, 2017

Hi,
I would like to use the distributed module to train a convolution net in a CPU cluster. Investigating your code, the function torch.cuda.device_count() is called in several places, and is used to populate the device_ids list. Since I don't have any GPU devices in my cluster, the method device_count() will always return 0 and any subsequent attempt to access device_ids[0] will result in an index exception.
Taking a naive path and changing device_count so that it always returns the number of nodes I intend to use then I get a different error:

if not all(input.is_cuda for input in inputs):
raise TypeError('Broadcast function not implemented for CPU tensors')

So I would like to ask you whether you have any plans to implement the distributed module to train networks in a multi-cpu cluster.

Many thanks

@apaszke
Copy link
Contributor

apaszke commented Sep 19, 2017

Just to confirm - when you say multi-cpu cluster you don't mean having 2 CPUs within a single computer (NUMA), but an actual cluster with multiple machines, right?

The problem is that DistributedDataParallel doesn't support non-CUDA networks right now. We should implement it at some point.

@msabvid
Copy link
Author

msabvid commented Sep 20, 2017

Yes I meant a cluster with multiple machines.

No problem, I'll keep an eye for new updates.

Thanks

@rabeehk
Copy link

rabeehk commented Jan 14, 2019

Hi
I am trying to run distributeddataparallelCPU module for a simple model,
Here is the file(test.py) I wrote:

import torch.distributed as dist
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=0, type=int)
args = parser.parse_args()
dist.init_process_group(backend="mpi", init_method="env://",
                       world_size = int(os.environ["WORLD_SIZE"]), rank=args.local_rank)
## Model.
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

model = Model() 
net = torch.nn.parallel.DistributedDataParallelCPU(model)

and I call the file as below:

python -m torch.distributed.launch  --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=6006  test.py 

I got the error that
RuntimeError: the MPI backend is not available; try to recompile the THD package with MPI support at /opt/conda/conda-bld/pytorch_1532581333611/work/torch/lib/THD/process_group/General.cpp:17

Thanks for your assistance

@zou3519 zou3519 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jan 22, 2019
@sth1997
Copy link

sth1997 commented May 23, 2019

Excuse me, does DistributedDataParallel support non-CUDA networks right now?

@mrshenli mrshenli added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 24, 2019
@mrshenli
Copy link
Contributor

@sth1997 @marcsv87 DDP on CPU devices should be available now (see doc). I am closing this one, but feel free to reopen this issue if it fails to work for you.

@rabeeh that seems to be a different issue than the original published one. If you still need assistance on that, could you please post create a new issue for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants