Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init connect timeout when use torch.distributed.run #79388

Open
tingweiwu opened this issue Jun 13, 2022 · 9 comments
Open

Init connect timeout when use torch.distributed.run #79388

tingweiwu opened this issue Jun 13, 2022 · 9 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue

Comments

@tingweiwu
Copy link

tingweiwu commented Jun 13, 2022

🐛 Describe the bug

TRAINING_SCRIPT.py

def main():
    dist.init_process_group("nccl", init_method='env://')  
    .......
if __name__ == "__main__":
    main()

when I run this on both node0 and node1

export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2   
--rdzv_id=ID1--rdzv_backend=c10d  --rdzv_endpoint='IP1:2222' TRAINING_SCRIPT.py

I get the error from both node0 and node1

INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : launch_mnist.py
  min_nodes        : 2
  max_nodes        : 2
  nproc_per_node   : 1
  run_id           : ID1
  rdzv_backend     : c10d
  rdzv_endpoint    : IP1:2222
  rdzv_configs     : {'timeout': 900}
  max_restarts     : 0
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

[E socket.cpp:793] [c10d] The client socket has timed out after 60s while trying to connect to (IP1, 2222).
ERROR:torch.distributed.elastic.multiprocessing.errors.error_handler:{
  "message": {
    "message": "RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details.",
    "extraInfo": { .......}
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 156, in _create_tcp_store
    host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
TimeoutError: The client socket has timed out after 60s while trying to connect to (IP1, 2222).

The above exception was the direct cause of the following exception:

but when I change the run
on node0 (use localhost instead of IP1)

export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2   
--rdzv_id=ID1--rdzv_backend=c10d  --rdzv_endpoint='localhost:2222' TRAINING_SCRIPT.py

on node1

export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2   
--rdzv_id=ID1--rdzv_backend=c10d  --rdzv_endpoint='IP1:2222' TRAINING_SCRIPT.py

it go well.

the output of node0

INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : launch_mnist_v12.py
  min_nodes        : 2
  max_nodes        : 2
  nproc_per_node   : 1
  run_id           : ID1
  rdzv_backend     : c10d
  rdzv_endpoint    : localhost:2222
  rdzv_configs     : {'timeout': 900}
  max_restarts     : 0
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_6o014_3m/m638480e883e4cd58af52617214cfe50__u799hzl
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=IP1
  master_port=54902
  group_rank=0
  group_world_size=2
  local_ranks=[0]
  role_ranks=[0]
  global_ranks=[0]
  role_world_sizes=[2]
  global_world_sizes=[2]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_6o014_3m/m638480e883e4cd58af52617214cfe50__u799hzl/attempt_0/0/error.json
env MASTER_ADDR=IP1
env MASTER_PORT=54902
env WORLD_SIZE=2
env RANK=0
env LOCAL_RANK=0
| distributed init (rank 0): env://,(backend nccl):, local rank:0, world size:2
NCCL version 2.10.3+cuda10.2
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.317649

the output of node1

INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : launch_mnist_v12.py
  min_nodes        : 2
  max_nodes        : 2
  nproc_per_node   : 1
  run_id           : ID1
  rdzv_backend     : c10d
  rdzv_endpoint    : IP1:2222
  rdzv_configs     : {'timeout': 900}
  max_restarts     : 0
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_n4bpjqqf/m638480e883e4cd58af52617214cfe50_gz_c6jhz
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=IP1
  master_port=54902
  group_rank=1
  group_world_size=2
  local_ranks=[0]
  role_ranks=[1]
  global_ranks=[1]
  role_world_sizes=[2]
  global_world_sizes=[2]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_n4bpjqqf/m638480e883e4cd58af52617214cfe50_gz_c6jhz/attempt_0/0/error.json
env MASTER_ADDR=IP1
env MASTER_PORT=54902
env WORLD_SIZE=2
env RANK=1
env LOCAL_RANK=0
| distributed init (rank 1): env://,(backend nccl):, local rank:0, world size:2

another strange thing is that when I use deprecated module torch.distributed.launch, it goes well when I run

on both node 0 and node1

python -m torch.distributed.launch --master_addr="IP1" --master_port=2222 --nproc_per_node=1 --nnodes=2 TRAINING_SCRIPT.py

as mentioned in #76367

Versions

Collecting environment information...
PyTorch version: 1.11.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.26

Python version: 3.7.5 (default, Apr 26 2022, 08:54:01) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-3.10.0-514.44.5.10.h193.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 450.102.04
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.11.0
[pip3] torchvision==0.12.0
[conda] Could not collect

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 13, 2022
@fduwjj
Copy link
Contributor

fduwjj commented Jun 14, 2022

Hi, @kiukchung do you mind taking a look? I was able to only repo the single-host case but I cannot run it successfully on AWS for the multi-host case. Maybe I am missing something here and I can also learn how to set it up correctly.

@H-Huang H-Huang added the oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue label Jun 14, 2022
@kiukchung
Copy link
Collaborator

kiukchung commented Jun 14, 2022

this is most likely due to the internal method _matches_machine_hostname("IP1") not returning True on node0. localhost references the loopback device (which the _matches_machine_hostname("localhost") has special handling logic for).

torchelastic will call _matches_matchine_hostname() on the "host" part of the rdzv_endpoint (in this case IP1) on each node to determine whether the node should be the "master" node. Depending on the networking setup this may return "False" on node 0, which seems to be the case.

To get around this issue, you can use the fully qualified domain name of node0 as the host part of rdzv_endpoint or use the hostname that is returned when you run $hostname on node0, of course that hostname needs to have a route from node1 for node1 to be able to actually make a TCP connection to node0. This is easy to verify by running something like traceroute $node0_hostname from node1.

@tingweiwu
Copy link
Author

tingweiwu commented Jun 20, 2022

@kiukchung hi~I mistakenly replaced the domain server with the IP1.
allow me to re-describe this case, I run distributed experiment on k8s cluster.

1、Firstly, Create worker-0 Service with worker-0 pod backend, and worker-1 pod. I replace the startcmd with top -b for Starting the test manually

apiVersion: v1
kind: Service
metadata:
  name: worker-0
  namespace: system
spec:
  clusterIP: None
  ports:
  - port: 2222
    protocol: TCP
    targetPort: 2222
  selector:
    name: worker-0
  sessionAffinity: None
  type: ClusterIP
status:
  loadBalancer: {}

---

apiVersion: v1
items:
- apiVersion: v1
  kind: Pod
  metadata:
    labels:
      name: worker-0
    name: worker-0
    namespace: system
  spec:
    containers:
    - args:
      - top -b
      #- export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2 --rdzv_id=id --rdzv_backend=c10d  --rdzv_endpoint='worker-0.system.svc.cluster.local:2222' TRAINING_SCRIPT.py
      command:
      - /bin/bash
      - -c
    image: pytorch:1.10.2-cuda11.0-py3.7.5  
    .....	
	
- apiVersion: v1
  kind: Pod
  metadata:
    labels:
      name: worker-1
    name: worker-1
    namespace: system
  spec:
    containers:
    - args:
      - top -b
      #- export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2 --rdzv_id=id --rdzv_backend=c10d  --rdzv_endpoint='worker-0.system.svc.cluster.local:2222' TRAINING_SCRIPT.py
      command:
      - /bin/bash
      - -c
    image: pytorch:1.10.2-cuda11.0-py3.7.5  
    .....

2、 I run this inside worker-0 and worker-1

# kubectl exec -it -n system worker-0 bash
@worker-0:~/$ export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2 --rdzv_id=id --rdzv_backend=c10d  --rdzv_endpoint='worker-0.system.svc.cluster.local:2222' TRAINING_SCRIPT.py

# kubectl exec -it -n system worker-1 bash
@worker-1:~/$ export LOGLEVEL=INFO && python -m torch.distributed.run  --nproc_per_node=1 --nnodes=2 --rdzv_id=id --rdzv_backend=c10d  --rdzv_endpoint='worker-0.system.svc.cluster.local:2222' TRAINING_SCRIPT.py

it will be failed because _matches_machine_hostname("worker-0.system.svc.cluster.local") return False
socket.gethostname() on node0 is worker-0 remove the k8s tail system.svc.cluster.local

3、Execute the ping command inside worker0 and worker1

# kubectl exec -it -n system worker-0 bash
@worker-0:~/$ ping worker-0.system.svc.cluster.local
PING worker-0.system.svc.cluster.local (10.140.0.110) 56(84) bytes of data.
64 bytes from worker-0 (10.140.0.110): icmp_seq=1 ttl=64 time=0.016 ms
64 bytes from worker-0 (10.140.0.110): icmp_seq=2 ttl=64 time=0.031 ms

# kubectl exec -it -n system worker-1 bash
@worker-1:~/$ ping worker-0.system.svc.cluster.local
PING worker-0.system.svc.cluster.local (10.140.0.110) 56(84) bytes of data.
64 bytes from 10-140-0-110.worker-0.system.svc.cluster.local (10.140.0.110): icmp_seq=1 ttl=64 time=1.49 ms
64 bytes from 10-140-0-110.worker-0.system.svc.cluster.local (10.140.0.110): icmp_seq=2 ttl=64 time=0.390 ms

4、 Replace domain name(worker-0.system.svc.cluster.local) with pod ip(10.140.0.110) to run inside worker-0 and worker-1
it run successfully.
obviously _matches_machine_hostname("10.140.0.110") return True from this branch

5、As it is impossible for me to know the ip of the pod when it is not up, I would like to use localhost on worker-0 and worker-0.system.svc.cluster.local on worker-1,whether there will be unforeseen problems?

Considering our discussion in this #76367

torch.distributed.launch by design assumes that the same parameters will be passed on all nodes.

Although the backend used by launch is static not c10d, but I am still a little worried

@rohan-varma
Copy link
Member

@d4l3k could you help take a look at this one, and feel free to re-add oncall: distributed tag if distributed help is needed.

@rohan-varma rohan-varma removed the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 23, 2022
@d4l3k
Copy link
Collaborator

d4l3k commented Jun 23, 2022

Yeah, that's exactly the problem -- we don't do any DNS resolution when identifying if the local host matches. We really should fall back to DNS resolution or correctly handle the search paths in /etc/resolve.conf. Sorry you ran into this issue - we should probably fix it in trunk though that wouldn't be released until PyTorch 1.13 I believe

Your solution of using localhost on rank0 is exactly what we do in TorchX's Kubernetes scheduler to work around this problem

https://github.com/pytorch/torchx/blob/main/torchx/schedulers/kubernetes_scheduler.py#L378-L383

If you haven't already looked at TorchX, it's our recommended solution for running PT on Kubernetes since we do test it works E2E on an actual K8s cluster

@kiukchung
Copy link
Collaborator

fwiw if you are using Kubernetes I'd also encourage you to use TorchX to launch DDP jobs onto the k8s cluster. See: https://pytorch.org/torchx/latest/schedulers/kubernetes.html

TorchX is a PyTorch job launcher that we've worked on to help users launch training jobs onto various different types of schedulers. We've figured out these sort of kinks on the schedulers that we support.

@d4l3k
Copy link
Collaborator

d4l3k commented Jun 23, 2022

It would be good to fix this in elastic as well -- could probably make a call to https://docs.python.org/3/library/socket.html#socket.gethostbyaddr and check against the current IP list. PRs welcome :)

@jbschlosser jbschlosser added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 24, 2022
@kiukchung
Copy link
Collaborator

kiukchung commented Jun 29, 2022

Have you tried passing --rdzv_endpoint='worker-0:2222? @d4l3k, @rohan-varma we could probably close this issue since t seems like there is a workaround, and the recommended way to launch PTD jobs on k8s is by using TorchX.

@Kait0
Copy link

Kait0 commented Aug 26, 2022

I ran into the same error when running multi-node jobs on a SLURM cluster.

Traceback (most recent call last):
File "/home/.conda/envs/TIL/lib/python3.7/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 156, in _create_tcp_store
host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
TimeoutError: The client socket has timed out after 60s while trying to connect to (slurm-bm-13, 29500).

The same issue appear when I use torchrun:
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="${MASTER_ADDR}

torchrun --nnodes=3 --nproc_per_node=8 --max_restarts=1 --rdzv_id=9876543210 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR} train.py

or torchx run:
torchx run --scheduler slurm --scheduler_args partition=gpu-2080ti,time=3-00:00,job_dir=/path/to/slurm_logs dist.ddp --env OMP_NUM_THREADS=64,OPENBLAS_NUM_THREADS=1 --j 2x8 --gpu 8 --cpu 64 --memMB 256000 --max_retries 1 --script train.py

I did not really understand the workaround that was proposed.

cz-37 added a commit to cz-37/pytorch that referenced this issue Dec 5, 2022
…th the current IP list

Summary:
Pull Request: pytorch#79388

Fix torch.distributed.run init connect timeout by comparing `host` with the current IP list.

Test Plan: unit tests

Differential Revision: D41373962

fbshipit-source-id: bc5eb3add8c5fd27709711742837f6c39f125187
cz-37 added a commit to cz-37/pytorch that referenced this issue Dec 19, 2022
…th the current IP list (pytorch#90221)

Summary:
Pull Request resolved: pytorch#90221

Pull Request: pytorch#79388

Fix torch.distributed.run init connect timeout by comparing `host` with the current IP list.

Test Plan:
```
> buck2 test mode/dev-nosan //caffe2/test/distributed/elastic/rendezvous:utils_test -- --exact 'caffe2/test/distributed/elastic/rendezvous:utils_test - test_matches_machine_hostname_returns_true_if_ip_address_match_between_hosts (utils_test.UtilsTest)'

Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. 0 builds failed
```
Unit tests

Reviewed By: d4l3k

Differential Revision: D41373962

fbshipit-source-id: 00e6c102ed920d8b91a9978b5fe7d3ed37b62584
cz-37 added a commit to cz-37/pytorch that referenced this issue Dec 19, 2022
…th the current IP list (pytorch#90221)

Summary:
Pull Request resolved: pytorch#90221

Pull Request: pytorch#79388

Fix torch.distributed.run init connect timeout by comparing `host` with the current IP list.

Test Plan:
```
> buck2 test mode/dev-nosan //caffe2/test/distributed/elastic/rendezvous:utils_test -- --exact 'caffe2/test/distributed/elastic/rendezvous:utils_test - test_matches_machine_hostname_returns_true_if_ip_address_match_between_hosts (utils_test.UtilsTest)'

Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. 0 builds failed
```
Unit tests

Reviewed By: d4l3k

Differential Revision: D41373962

fbshipit-source-id: 7f138e2ef74b057f70271d32b605710bc5d287f6
pytorchmergebot pushed a commit that referenced this issue Dec 19, 2022
…th the current IP list (#90221)

Summary:
Pull Request: #79388

Fix torch.distributed.run init connect timeout by comparing `host` with the current IP list.

Test Plan: unit tests

Differential Revision: D41373962

Pull Request resolved: #90221
Approved by: https://github.com/d4l3k
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue
Projects
None yet
Development

No branches or pull requests

9 participants