Skip to content

Conversation

teng-li
Copy link
Contributor

@teng-li teng-li commented Jan 29, 2018

After removing the hacky clear NCCL communicator cache and adding all the tests. I think we are in a good shape to release the NCCL backend from experimental.

contain correctly-sized tensors on each GPU to be used for output of
the collective.
e.g. output_tensor_lists[i] contrains the all_gather

This comment was marked as off-topic.

@soumith soumith merged commit 5c65466 into pytorch:master Jan 30, 2018
@Yi-Li
Copy link

Yi-Li commented Feb 26, 2018

Hi
I downloaded the most recent pytorch and tried to install it from the source so that I can make use of the new nccl APIs. I could run the torch/lib/nccl/test/single successfully.
I specified WITH_SYSTEM_NCCL=0 (along with WITH_NCCL=1 WITH_DISTRIBUTED=1, WITH_CUDA=1) when I invoke “python setup.py build develop". It seems the version of generated libnccl.so is 1.3.5, not version 2+, so THD was compiled without nccl support.
I also downloaded NCCL2 from Nvidia website, tried WITH_SYSTEM_NCCL=1, and specified NCCL_INCLUDE_DIR, NCCL_LIB_DIR, NCCL_ROOT_DIR. But THD was compiled without nccl support either. I have been stuck here for a few days. Could you help with this? Thanks a lot!

Best,
Lissa

@apaszke
Copy link
Contributor

apaszke commented Feb 26, 2018

The NCCL library provided in the repo is version 1. Version 2 is closed source and you have to download it from NVIDIA and use WITH_SYSTEM_NCCL=1

@Yi-Li
Copy link

Yi-Li commented Feb 26, 2018 via email

@Yi-Li
Copy link

Yi-Li commented Feb 27, 2018 via email

@Yi-Li
Copy link

Yi-Li commented Feb 27, 2018

Hi Adam,

Yes, I downloaded NCCL2 from Nvidia website, tried WITH_SYSTEM_NCCL=1, and specified NCCL_INCLUDE_DIR, NCCL_LIB_DIR, NCCL_ROOT_DIR to install pytorch. The installed pytorch version is 0.4.0a0+7703670. When I run the following simple test example (toy.py), an error message was thrown:

before init
after init
begin rank 1
Traceback (most recent call last):
File "toy.py", line 32, in
init_processes(args.rank, size, run, 'nccl')
File "toy.py", line 23, in init_processes
fn(rank, size)
File "toy.py", line 11, in run
dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
File "/home/liy/programs/pytorch/torch/distributed/init.py", line 326, in all_reduce
return torch._C._dist_all_reduce(tensor, op, group)
RuntimeError: NCCL error in: /home/liy/programs/pytorch/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:324, unhandled system error

Am I using something wrongly?

==========================
cat toy.py:
import torch
import torch.distributed as dist
import argparse

def run(rank, size):
""" Simple point-to-point communication. """
print('begin rank', rank)
group = dist.new_group([0, 1])
tensor = torch.ones(1).cuda()
dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
print('Rank ', rank, ' has data ', tensor[0])

def init_processes(rank, size, fn, backend):
""" Initialize the distributed environment. """
print('before init')
init_method="tcp://10.6.48.150:13530"
dist.init_process_group(backend,rank=rank,world_size=size,init_method=init_method)
print('after init')
fn(rank, size)

if name == "main":
size = 2
parser = argparse.ArgumentParser()
parser.add_argument('--rank', default=-1, type=int,
help='rank')
args = parser.parse_args()
init_processes(args.rank, size, run, 'nccl')

@apaszke
Copy link
Contributor

apaszke commented Feb 27, 2018

Hmm I don't know, it looks good at a first glance, and the error is coming somewhere from the inside of NCCL where we can't easily tell what's wrong 😕

@Yi-Li
Copy link

Yi-Li commented Feb 27, 2018 via email

@apaszke
Copy link
Contributor

apaszke commented Feb 27, 2018

Sorry, I don't know that

@teng-li
Copy link
Contributor Author

teng-li commented Feb 27, 2018

@Yi-Li this could possibly mean that NCCL is picking up the wrong interface. What does your ifconfig show?

@teng-li
Copy link
Contributor Author

teng-li commented Feb 27, 2018

@Yi-Li I would first try to set NCCL_SOCKET_IFNAME to the interface you would like NCCL to communicate.

Also set NCCL_DEBUG=INFO and run your program will give more info on why your program is failing on NCCL

@teng-li teng-li deleted the ncc2_release branch February 27, 2018 19:58
@Yi-Li
Copy link

Yi-Li commented Feb 27, 2018 via email

@teng-li
Copy link
Contributor Author

teng-li commented Feb 27, 2018

yeah, I have seen similar issues with a mismatching glibc version.

@Yi-Li
Copy link

Yi-Li commented Mar 29, 2018 via email

@apaszke
Copy link
Contributor

apaszke commented Mar 29, 2018

Exactly. That address should be the IP and port that you gave the first process to listen at.

@cyang49
Copy link

cyang49 commented Mar 30, 2018

I'm having the same issue as @Yi-Li with the imagenet example. I tried the toy.py she posted and found I couldn't reproduce it. I realized that the difference is that I set CUDA_VISIBLE_DEVICES when I run imagenet. After adding this when I run toy.py, I can reproduce the same error.

The modified toy.py I used:

import torch
import torch.distributed as dist
import argparse

def run(rank, size):
    """ Simple point-to-point communication. """
    print('begin rank', rank)
    group = dist.new_group([0, 1])
    tensor = torch.ones(1).cuda()
    dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor[0])

def init_processes(rank, size, fn, backend):
    """ Initialize the distributed environment. """
    print('before init')
    init_method="file://./sync"
    #dist.init_process_group(backend,rank=rank,world_size=size,init_method=init_method)
    dist.init_process_group(backend,world_size=size,init_method=init_method)
    print('after init')
    fn(rank, size)


if __name__ == "__main__":
    size = 2
    parser = argparse.ArgumentParser()
    parser.add_argument('--rank', default=-1, type=int,
                        help='rank')
    args = parser.parse_args()
    init_processes(args.rank, size, run, 'nccl')

The commands I used on different terminal of the same machine to trigger the error:

# terminal 0
CUDA_VISIBLE_DEVICES=0 python toy.py --rank 0
# terminal 1
CUDA_VISIBLE_DEVICES=1 python toy.py --rank 1

I'm not exactly sure this is the same problem with @Yi-Li but could anyone help with using NCCL on different CUDA devices?

@apaszke
Copy link
Contributor

apaszke commented Mar 30, 2018

CUDA_VISIBLE_DEVICES doesn't mix well with NCCL. Remove it and then use torch.cuda.set_device(X) at the top of your script

@teng-li
Copy link
Contributor Author

teng-li commented Mar 30, 2018

@cyang49 Yeah, CUDA_VISIBLE_DEVICES is incompatible with CUDA IPCs, since each process only sees its own GPU, hence cannot even see others to use GPU Direct P2P. So please make sure that all processes that will use NCCL can see all the devices. Like Adam said, you can control which GPU the process is operating on by either using torch.cuda.set_device() or with torch.cuda.device()

@cyang49
Copy link

cyang49 commented Mar 30, 2018

@apaszke @teng-li Using torch.cuda.device() worked. However I put the call to fn(rank, size) in a loop and it has some error in rank 1 (the other one hangs) after a while:

1258th operation
begin rank 1
Rank  1  has data  
 2
[torch.cuda.FloatTensor of size () (GPU 0)]

1259th operation
begin rank 1
Traceback (most recent call last):
  File "toy.py", line 33, in <module>
    init_processes(args.rank, size, run, 'nccl')
  File "toy.py", line 24, in init_processes
    fn(rank, size)
  File "toy.py", line 10, in run
    dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
  File "/home/ccyang/anaconda3/lib/python3.6/site-packages/torch/distributed/__init__.py", line 326, in all_reduce
    return torch._C._dist_all_reduce(tensor, op, group)
RuntimeError: NCCL error in: /home/ccyang/gpfs/pytorch/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:324, internal error

The above error is reproducible on my system and happens at the same 1259th operation on rank 1.

And I also got this other kind of error:

928th operation
begin rank 1
mlx5: c460login01: got completion with error:
00000000 00000000 00000000 00000000
00000000 00000000 00000000 00000000
00000001 00000000 00000000 00000000
00000000 9d00c311 00012ecb 0003b3e2
Traceback (most recent call last):
  File "toy.py", line 33, in <module>
    init_processes(args.rank, size, run, 'nccl')
  File "toy.py", line 24, in init_processes
    fn(rank, size)
  File "toy.py", line 10, in run
    dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
  File "/home/ccyang/anaconda3/lib/python3.6/site-packages/torch/distributed/__init__.py", line 326, in all_reduce
    return torch._C._dist_all_reduce(tensor, op, group)
RuntimeError: NCCL error in: /home/ccyang/gpfs/pytorch/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:324, unhandled system error

The code is here

import torch
import torch.distributed as dist
import argparse

def run(rank, size):
    """ Simple point-to-point communication. """
    print('begin rank', rank)
    group = dist.new_group([0, 1])
    tensor = torch.FloatTensor(torch.ones(1)).cuda()
    dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor[0])

def init_processes(rank, size, fn, backend):
    """ Initialize the distributed environment. """
    print('before init')
    torch.cuda.device(rank)
    init_method="tcp://127.0.0.1:16543"
    dist.init_process_group(backend,rank=rank,world_size=size,init_method=init_method)
    print('after init')
    for i in range(100000):
      print('{}th operation'.format(i))
      fn(rank, size)

if __name__ == "__main__":
    size = 2
    parser = argparse.ArgumentParser()
    parser.add_argument('--rank', default=-1, type=int,
                        help='rank')
    args = parser.parse_args()
    init_processes(args.rank, size, run, 'nccl')

@teng-li
Copy link
Contributor Author

teng-li commented Mar 30, 2018

This looks like NCCL issue to me, could you get the NCCL logs
by setting NCCL_DEBUG=INFO, like running: NCCL_DEBUG=INFO python YOUR_programs

@cyang49
Copy link

cyang49 commented Mar 30, 2018

Adding NCCL_DEBUG=INFO gives perhaps some hint to what happened. Maybe it's memory leak? I observed nvidia-smi and see that the device 0 memory ran out but device 1 memory usage stays constant. I'm not sure if the tensor is being correctly allocated on the corresponding devices?

c460c041:101211:106353 [0] INFO CUDA Dev 0, IB Ports : mlx5_2/1(SOC) mlx5_0/1(SOC)
c460c041:101211:106353 [0] init.cu:218 WARN Cuda failure 'out of memory'
c460c041:101211:106353 [0] transport/p2p.cu:404 WARN rank 1 failed to get CUDA IPC handle to device 0 : 11 invalid argument
c460c041:101211:106353 [0] INFO init.cu:191 -> 3
c460c041:101211:106353 [0] INFO init.cu:266 -> 3
c460c041:101211:106353 [0] INFO init.cu:460 -> 3
c460c041:101211:106353 [0] INFO init.cu:517 -> 3
c460c041:101211:106353 [0] INFO misc/group.cu:70 -> 3 [Async thread]
Traceback (most recent call last):
  File "toy.py", line 33, in <module>
    init_processes(args.rank, size, run, 'nccl')
  File "toy.py", line 24, in init_processes
    fn(rank, size)
  File "toy.py", line 10, in run
    dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
  File "/home/ccyang/anaconda3/lib/python3.6/site-packages/torch/distributed/__init__.py", line 326, in all_reduce
    return torch._C._dist_all_reduce(tensor, op, group)
RuntimeError: NCCL error in: /home/ccyang/gpfs/pytorch/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:324, internal error

@Yi-Li
Copy link

Yi-Li commented Mar 30, 2018 via email

@cyang49
Copy link

cyang49 commented Mar 30, 2018

@Yi-Li

def init_processes(rank, size, fn, backend):
    """ Initialize the distributed environment. """
    print('before init')
    torch.cuda.set_device(rank)
    init_method="tcp://127.0.0.1:16543"
    dist.init_process_group(backend,rank=rank,world_size=size,init_method=init_method)
    print('after init')
    for i in range(100000):
      print('{}th operation'.format(i))
      fn(rank, size)

Edit:
I think the torch.cuda.device() didn't really take effect. I used torch.cuda.set_device() instead and nvidia-smi shows both devices are being used now. The same out of memory error still happens.

@Yi-Li
Copy link

Yi-Li commented Mar 30, 2018 via email

@zsk423200
Copy link

@cyang49 I'm having the same issue as you, 2 docker in one node, and use nccl
1st docker : export CUDA_VISIBLE_DEVICES=0,1
2nd docker: export CUDA_VISIBLE_DEVICES=2,3

get error:
DataChannelNccl.cpp:324, unhandled cuda error

use NCCL_DEBUG=INFO, info:
host1:1567:1583 [0] transport/p2p.cu:515 WARN failed to open CUDA IPC handle : 11 invalid argument
host1:1567:1583 [0] INFO init.cu:485 -> 1
host1:1567:1583 [0] INFO init.cu:542 -> 1
host1:1567:1583 [0] INFO misc/group.cu:70 -> 1 [Async thread]

if i set the same CUDA_VISIBLE_DEVICES, it works fine. and gloo is also ok.

have you settled the problem

@stdacore
Copy link

I have the same problem, but i find out the main reason is that the sample code does distributed.new_group every iteration. It seems to allocate new memory when calling. So i just do it one time when initialization, then the problem solved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants