Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to create Gloo new group after initialized with NCCL #68726

Open
zhuzilin opened this issue Nov 22, 2021 · 9 comments
Open

Failed to create Gloo new group after initialized with NCCL #68726

zhuzilin opened this issue Nov 22, 2021 · 9 comments
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@zhuzilin
Copy link
Contributor

zhuzilin commented Nov 22, 2021

馃悰 Bug

In our project Tencent/PatrickStar, we need to create a NCCL comm group and a Gloo comm group in order to utilize both GPU reduce scatter and CPU comm operations. However, it seems that after initialized with NCCL, the Gloo group could not detect the master address and master port, but instead using localhost (127.0.0.1).

I'm afraid the reason is that the NCCL store and Gloo store are not compatible with each other so that the new Gloo group could not read the master addr saved by NCCL group.

The relevant error message is:

Traceback (most recent call last):
Traceback (most recent call last):
  File "test_new_group.py", line 4, in <module>
  File "test_new_group.py", line 4, in <module>
        cpu_comm = torch.distributed.new_group(backend="gloo")cpu_comm = torch.distributed.new_group(backend="gloo")

  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2843, in new_group
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2843, in new_group
        pg = _new_process_group_helper(pg = _new_process_group_helper(

  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 668, in _new_process_group_helper
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 668, in _new_process_group_helper
    pg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout)
    RuntimeErrorpg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout): 
[/opt/pytorch/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:799] connect [127.0.0.1]:1453: Connection refused
RuntimeError: [/opt/pytorch/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:799] connect [127.0.0.1]:4005: Connection refused
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 6027) of binary: /opt/conda/bin/python3

To Reproduce

# test_new_group.py
import torch

torch.distributed.init_process_group(backend="nccl")
cpu_comm = torch.distributed.new_group(backend="gloo")

Run :

python3 -m torch.distributed.launch --nproc_per_node=1 \
               --nnodes=2 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
              --node_rank=$NODE_RANK \
              test_new_group.py

Note that the code should work with only one machine or the second new group is created with backend "nccl".

Expected behavior

User should be allowed to create 2 comm group of different types.

Environment

I'm using the NGC container: nvcr.io/nvidia/pytorch:21.09-py3

  • PyTorch Version (e.g., 1.0): 1.10.0
  • OS (e.g., Linux): linux
  • Python version: 3.8
  • CUDA/cuDNN version: 11.4
  • GPU models and configuration: A100

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

@jbschlosser jbschlosser added module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue labels Nov 22, 2021
@rohan-varma
Copy link
Member

Hmm, I am not able to reproduce the issue using the latest version of PyTorch nightly.

I ran the following script:

https://github.com/rohan-varma/torch-script/blob/master/training_script.py

via

srun -p train --nodes=2 -t 5:00:00 --gpus-per-node=1 --cpus-per-task=8 ./test.sh

on a GPU cluster, where test.sh is https://github.com/rohan-varma/torch-script/blob/master/test.sh, and the output was -

[W socket.cpp:634] The server socket on [ip-10-200-91-114.ec2.internal]:29501 is not yet listening (generic error: 111 - Connection refused).
imported
imported
initialized pgs
done
initialized pgs
done
/fsx/users/rvarm1/conda/envs/pytorch_nightly/lib/python3.8/site-packages/torch/distributed/launch.py:178: FutureWarning: The module torch.distributed.launch is deprecated

A significant overhaul of the TCPStore has landed recently which improves the error logging, could you try on the latest/nightly PyTorch so we can get more error details? #68226

@rohan-varma
Copy link
Member

In addition, new_group will use the same default_store as the default process group, so it is quite strange that it would attempt to connect to the wrong address.

Can you print out $MASTER_ADDR and $MASTER_PORT in both of your worker scripts?

@zhuzilin
Copy link
Contributor Author

@rohan-varma Thank you for your reply! I've double checked the cluster and found that the real error is I could not establish gloo comm group. And the reason for that is the hostname of the ips are all set to the same and point to 127.0.0.1. Therefore gloo resolved the hostname to localhost and connected to itself... I've manually changed the hostname at it works now :)

Thank you for your help! And I wonder if there is a way to pass the ip directly to gloo?

@weiyx16
Copy link

weiyx16 commented Oct 2, 2022

I found if we set TORCH_DISTRIBUTED_DEBUG=INFO, PyTorch will setup a gloo-backend group. If the clusters don't support gloo communication, setting this environment variable will cause the error.

@masip85
Copy link

masip85 commented Jul 17, 2023

I found if we set TORCH_DISTRIBUTED_DEBUG=INFO, PyTorch will setup a gloo-backend group. If the clusters don't support gloo communication, setting this environment variable will cause the error.

This is happening to me too. If I want DEBUG info, can't I avoid that? Is this issue detected?

@kumpera
Copy link
Contributor

kumpera commented Jul 24, 2023

This is not currently possible

@trias702
Copy link

@zhuzilin How did you manage to change the hostname so it works? Did you ever find a way to pass an IP directly to gloo?

@chestnut-Q
Copy link

How did you manage to change the hostname so it works? Did you ever find a way to pass an IP directly to gloo?

@trias702 You can try manually setting the network interface as follows: First, use ifconfig to find the interface corresponding to your IP address, such as em0, eth0; then set the environment variable os.environ['GLOO_SOCKET_IFNAME'] = 'eth0'.

@jbohnslav
Copy link

This is an unfortunate bug. The only time you want to set TORCH_DISTRIBUTED_DEBUG=INFO is if you're having trouble with torch.distributed. That seems like the wrong time to set up an extra process group. In my case, the gloo backend wasn't compatible with my environment and crashed all my jobs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

9 participants