Skip to content

Init connect timeout when use torch.distributed.run #79388

@tingweiwu

Description

@tingweiwu

🐛 Describe the bug

TRAINING_SCRIPT.py

def main():
    dist.init_process_group("nccl", init_method='env://')  
    .......
if __name__ == "__main__":
    main()

when I run this on both node0 and node1

export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2   
--rdzv_id=ID1--rdzv_backend=c10d  --rdzv_endpoint='IP1:2222' TRAINING_SCRIPT.py

I get the error from both node0 and node1

INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : launch_mnist.py
  min_nodes        : 2
  max_nodes        : 2
  nproc_per_node   : 1
  run_id           : ID1
  rdzv_backend     : c10d
  rdzv_endpoint    : IP1:2222
  rdzv_configs     : {'timeout': 900}
  max_restarts     : 0
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

[E socket.cpp:793] [c10d] The client socket has timed out after 60s while trying to connect to (IP1, 2222).
ERROR:torch.distributed.elastic.multiprocessing.errors.error_handler:{
  "message": {
    "message": "RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details.",
    "extraInfo": { .......}
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 156, in _create_tcp_store
    host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
TimeoutError: The client socket has timed out after 60s while trying to connect to (IP1, 2222).

The above exception was the direct cause of the following exception:

but when I change the run
on node0 (use localhost instead of IP1)

export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2   
--rdzv_id=ID1--rdzv_backend=c10d  --rdzv_endpoint='localhost:2222' TRAINING_SCRIPT.py

on node1

export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2   
--rdzv_id=ID1--rdzv_backend=c10d  --rdzv_endpoint='IP1:2222' TRAINING_SCRIPT.py

it go well.

the output of node0

INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : launch_mnist_v12.py
  min_nodes        : 2
  max_nodes        : 2
  nproc_per_node   : 1
  run_id           : ID1
  rdzv_backend     : c10d
  rdzv_endpoint    : localhost:2222
  rdzv_configs     : {'timeout': 900}
  max_restarts     : 0
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_6o014_3m/m638480e883e4cd58af52617214cfe50__u799hzl
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=IP1
  master_port=54902
  group_rank=0
  group_world_size=2
  local_ranks=[0]
  role_ranks=[0]
  global_ranks=[0]
  role_world_sizes=[2]
  global_world_sizes=[2]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_6o014_3m/m638480e883e4cd58af52617214cfe50__u799hzl/attempt_0/0/error.json
env MASTER_ADDR=IP1
env MASTER_PORT=54902
env WORLD_SIZE=2
env RANK=0
env LOCAL_RANK=0
| distributed init (rank 0): env://,(backend nccl):, local rank:0, world size:2
NCCL version 2.10.3+cuda10.2
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.317649

the output of node1

INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : launch_mnist_v12.py
  min_nodes        : 2
  max_nodes        : 2
  nproc_per_node   : 1
  run_id           : ID1
  rdzv_backend     : c10d
  rdzv_endpoint    : IP1:2222
  rdzv_configs     : {'timeout': 900}
  max_restarts     : 0
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_n4bpjqqf/m638480e883e4cd58af52617214cfe50_gz_c6jhz
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=IP1
  master_port=54902
  group_rank=1
  group_world_size=2
  local_ranks=[0]
  role_ranks=[1]
  global_ranks=[1]
  role_world_sizes=[2]
  global_world_sizes=[2]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_n4bpjqqf/m638480e883e4cd58af52617214cfe50_gz_c6jhz/attempt_0/0/error.json
env MASTER_ADDR=IP1
env MASTER_PORT=54902
env WORLD_SIZE=2
env RANK=1
env LOCAL_RANK=0
| distributed init (rank 1): env://,(backend nccl):, local rank:0, world size:2

another strange thing is that when I use deprecated module torch.distributed.launch, it goes well when I run

on both node 0 and node1

python -m torch.distributed.launch --master_addr="IP1" --master_port=2222 --nproc_per_node=1 --nnodes=2 TRAINING_SCRIPT.py

as mentioned in #76367

Versions

Collecting environment information...
PyTorch version: 1.11.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.26

Python version: 3.7.5 (default, Apr 26 2022, 08:54:01) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-3.10.0-514.44.5.10.h193.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 450.102.04
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.11.0
[pip3] torchvision==0.12.0
[conda] Could not collect

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: r2pAdd this issue/PR to R2P (elastic) oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions