Skip to content

Conversation

ailzhang
Copy link
Contributor

Currently Pytorch supports using 1 IB card. By default the first one in device list is used. Or you can manually specify the device name by setting .name="mlx5_1".

@ailzhang ailzhang force-pushed the master branch 4 times, most recently from b031f08 to bb09176 Compare September 29, 2017 20:08
@soumith
Copy link
Member

soumith commented Sep 29, 2017

@pytorchbot add to whitelist

@ezyang
Copy link
Contributor

ezyang commented Oct 20, 2017

@ailzhang What's the status on this patch? :)

@xqding
Copy link

xqding commented Oct 25, 2017

I tried to compile from source with this patch and got the following complains:
pytorch/torch/lib/THD/base/data_channels/DataChannelGloo.cpp:7:43: fatal error: gloo/transport/ibverbs/device.h: No such file or directory #include "gloo/transport/ibverbs/device.h"

However, I do have the file gloo/transport/ibverbs/device.h.
Any thoughts on it?

@ailzhang
Copy link
Contributor Author

This error was caused by missing header file in your temp build folder, which means "pytorch/torch/lib/tmp_install/include/gloo/transport/ibverbs/device.h" is missing. Could you please check that?

@xqding
Copy link

xqding commented Oct 25, 2017

Just figured it out. I have to turn on the WITH_IBVERBS=1 in the file torch/lib/build_libs.sh.
Now it compiles fine. I will see if it works soon. Thanks.

@xqding
Copy link

xqding commented Oct 25, 2017

I tried a similar script as this one https://github.com/pytorch/examples/blob/master/imagenet/main.py.
Here is the error messages:

Traceback (most recent call last):
File "./script/train_dist_gloo.py", line 99, in <module>
outputs = net(inputs)
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 263, in __call__
result = self.forward(*input, **kwargs)
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/nn/parallel/distributed.py", line 156, in forward
self._sync_params()
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/nn/parallel/distributed.py", line 187, in _sync_params
dist.broadcast(flat_buffers, 0)
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/distributed/__init__.py", line 198, in broadcast
return torch._C._dist_broadcast(tensor, src, group)
RuntimeError: [/home/yuyou/downloads/mypytorch2/pytorch/torch/lib/gloo/gloo/transport/ibverbs/buffer.cc:108] Read timeout LID: 34 QPN: 7435 PSN: 13656770
terminate called after throwing an instance of 'gloo::EnforceNotMet'
what():  [enforce fail at /home/yuyou/downloads/mypytorch2/pytorch/torch/lib/gloo/gloo/cuda.cu:249] error == cudaSuccess. 29 vs 0. Error at: /home/yuyou/downloads/mypyt\
orch2/pytorch/torch/lib/gloo/gloo/cuda.cu:249: driver shutting down

@xqding
Copy link

xqding commented Oct 26, 2017

If I use the torch.distributed to average the gradient instead of torch.nn.parallel.DistributedDataParallel, the code works fine and has good scaling with the number of GPUs.

@ailzhang
Copy link
Contributor Author

Hi @xqding , could you share your script? I'm trying to debug this issue. Thanks!

@xqding
Copy link

xqding commented Oct 28, 2017

@ailzhang My script is essentially the same as https://github.com/pytorch/examples/blob/master/imagenet/main.py in distributed mode, except that I use a different dataset instead of imagenet dataset.

@xqding
Copy link

xqding commented Oct 28, 2017

Do you need the whole script to debug this? If so, I can try to reproduce the issue using the MNIST dataset.

@ailzhang
Copy link
Contributor Author

Hi @xqding, if would be good to share a code snippet how you averaged the gradient using nn.distributed instead of DistributedDataParallel, thanks!

@xqding
Copy link

xqding commented Oct 28, 2017

train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_loader = DataLoader(train_data,
                          batch_size = 32,
                          sampler = train_sampler,
                          num_workers= 2)

for epoch in range(num_epoches):  # loop over the dataset multiple times
    running_loss = 0.0
    print("Epoch: {}".format(epoch))
    train_sampler.set_epoch(epoch)
    for i, data in enumerate(train_loader, 0):
        # gettheinputs
        inputs = data['image']
        labels = data['category_id']
        labels = np.array([category_ids.index(l) for l in labels])
        print("i: {}".format(i))

        print("labels", labels)
        # wrap them in Variable
        inputs = inputs.cuda(async=True)
        labels = torch.from_numpy(labels).cuda(async=True)
	inputs, labels= Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss=criterion(outputs,labels)
	loss.backward()

        size = float(dist.get_world_size())
        for param in net.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size
	optimizer.step()

When epoch = 0, the code works fine. When it starts the epoch = 1, it crashes when it reaches the dist.all_reduce command with the following error message:

Traceback (most recent call last):
  File "./script/train_dist_new.py", line 126, in <module>
    dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
  File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/distributed/__init__.py", line 216, in all_reduce
    return torch._C._dist_all_reduce(tensor, op, group)
RuntimeError: [/home/yuyou/downloads/mypytorch2/pytorch/torch/lib/gloo/gloo/transport/ibverbs/buffer.cc:108] Read timeout LID: 45 QPN: 29770 PSN: 11590737



@xqding
Copy link

xqding commented Oct 28, 2017

Hi @ailzhang, let me know if you need any other information.

@ailzhang
Copy link
Contributor Author

ailzhang commented Oct 29, 2017

Hi @xqding , I can randomly reproduce the problem on my machine. I wonder if your workaround below solves the problem permanently. If so, could you share that part so that it may help me locate which part of DistributedDataParallel caused the timeout. Thanks!

If I use the torch.distributed to average the gradient instead of torch.nn.parallel.DistributedDataParallel, the code works fine and has good scaling with the number of GPUs.

@xqding
Copy link

xqding commented Oct 31, 2017

Here is a summary of what I have tried:

  1. Using DistributedDataParrallel gives me the following error:
Traceback (most recent call last):
File "./script/train_dist_gloo.py", line 99, in <module>
outputs = net(inputs)
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 263, in __call__
result = self.forward(*input, **kwargs)
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/nn/parallel/distributed.py", line 156, in forward
self._sync_params()
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/nn/parallel/distributed.py", line 187, in _sync_params
dist.broadcast(flat_buffers, 0)
File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/distributed/__init__.py", line 198, in broadcast
return torch._C._dist_broadcast(tensor, src, group)
  1. In stead of using DistributedDataParrallel, I tried to train independent model on each node and average the gradient using the following code:
for epoch in range(num_epoches):  # loop over the dataset multiple times
    running_loss = 0.0
    print("Epoch: {}".format(epoch))
    train_sampler.set_epoch(epoch)
    for i, data in enumerate(train_loader, 0):
        # gettheinputs
        inputs = data['image']
        labels = data['category_id']
        labels = np.array([category_ids.index(l) for l in labels])
        print("i: {}".format(i))

        print("labels", labels)
        # wrap them in Variable
        inputs = inputs.cuda(async=True)
        labels = torch.from_numpy(labels).cuda(async=True)
	inputs, labels= Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss=criterion(outputs,labels)
	loss.backward()

        size = float(dist.get_world_size())
        for param in net.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size
	optimizer.step()

It works fine for the first epoch. It crashes once it starts the second epoch with the follwing error:

Traceback (most recent call last):
  File "./script/train_dist_new.py", line 126, in <module>
    dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
  File "/home/yuyou/apps/anaconda3/lib/python3.5/site-packages/torch/distributed/__init__.py", line 216, in all_reduce
    return torch._C._dist_all_reduce(tensor, op, group)
RuntimeError: [/home/yuyou/downloads/mypytorch2/pytorch/torch/lib/gloo/gloo/transport/ibverbs/buffer.cc:108] Read timeout LID: 45 QPN: 29770 PSN: 11590737
  1. My workaround now is to combine multiple epoch of data into one epoch when I define the Dataset. Basically, looping over the dataloader once is same as looping over my train data multiple times. It works pretty well so far.

@zjoe zjoe mentioned this pull request Nov 20, 2017
@OnezeroW
Copy link

@xqding I found your code snippet quite like this.
I know that https://github.com/pytorch/examples/blob/master/imagenet/main.py use Gloo as the default backend, just as the following code:
parser.add_argument('--dist-backend', default='gloo', type=str, help='distributed backend').

However, Gloo use TCP by default. I wonder how to use Gloo IBVERBS. Thanks.

@401qingkong
Copy link

@xqding I found your code snippet quite like this.
I know that https://github.com/pytorch/examples/blob/master/imagenet/main.py use Gloo as the default backend, just as the following code:
parser.add_argument('--dist-backend', default='gloo', type=str, help='distributed backend').

However, Gloo use TCP by default. I wonder how to use Gloo IBVERBS. Thanks.

Have you solved this problem? how to use Gloo IBVERBS

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants