Skip to content
This repository has been archived by the owner on Jan 6, 2023. It is now read-only.

Elastic agent doesn't detect worker failures in NCCL #134

Closed
ruipeterpan opened this issue Nov 16, 2020 · 4 comments
Closed

Elastic agent doesn't detect worker failures in NCCL #134

ruipeterpan opened this issue Nov 16, 2020 · 4 comments

Comments

@ruipeterpan
Copy link

ruipeterpan commented Nov 16, 2020

Context

I have been using torchelastic for a while to launch fault-tolerant jobs on CPUs using the gloo backend. I was switching to GPUs so that I can use broadcast and reduce. I firstly made the necessary modifications to move everything onto GPUs. Then, I changed the backend for group initialization from gloo to nccl hoping things will work as before. However, for nccl, when some workers gets killed, the remaining workers stay in the previous rendezvous and hang, whereas the elastic agent should be able to detect a worker failure and halts all workers.

Current Behavior

When using the nccl backend, when a worker is killed, the remaining workers hang instead of throwing a RuntimeError during all_reduce() like when using the gloo backend.

The workers that are killed outputs this (which is expected):

...
[INFO] 2020-11-16 05:37:30,257 api: [default] All workers successfully finished.
>

However, for the remaining workers, the elastic agent doesn't declare the process group as failed. Here is the log obtained by using export NCCL_DEBUG=INFO:

multigpu:141:158 [0] include/socket.h:416 NCCL WARN Net : Connection closed by remote peer
multigpu:141:158 [0] NCCL INFO transport/net_socket.cc:405 -> 2
multigpu:141:158 [0] NCCL INFO include/net.h:28 -> 2
multigpu:141:158 [0] NCCL INFO transport/net.cc:357 -> 2
multigpu:141:158 [0] NCCL INFO proxy.cc:198 -> 2 [Proxy Thread]

Expected Behavior

Just like gloo, after some workers are killed, the remaining workers/gloo should be able to detect a missing member during all_reduce(), and throw a RuntimeError so that the local_elastic_agent can mark the worker group as failed, halt the training, and wait for a new worker to join the next rendezvous.

The workers that are killed should output this:

...
[INFO] 2020-11-16 05:13:25,931 api: [default] All workers successfully finished.
>

The surviving workers should output this:

...
Traceback (most recent call last):
  File "worker.py", line 250, in <module>
    parse_args()
  File "worker.py", line 246, in parse_args
    init_processes(0, args)
  File "worker.py", line 219, in init_processes
    train(args)
  File "worker.py", line 130, in train
    update_gradients(model)
  File "worker.py", line 55, in update_gradients
    dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/distributed_c10d.py", line 948, in all_reduce
    work.wait()
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:575] Connection closed by peer [10.0.1.26]:5511
[ERROR] 2020-11-16 05:23:48,975 local_elastic_agent: [default] Worker group failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torchelastic-0.2.0-py3.8.egg/torchelastic/agent/server/local_elastic_agent.py", line 190, in _monitor_workers
    if self._process_context.join(timeout=-1):
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.8/dist-packages/torchelastic-0.2.0-py3.8.egg/torchelastic/agent/server/local_elastic_agent.py", line 79, in _wrap
    ret = fn(*args)
  File "/usr/local/lib/python3.8/dist-packages/torchelastic-0.2.0-py3.8.egg/torchelastic/distributed/launch.py", line 392, in wrapper_fn
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/usr/bin/python3.8', '-u', 'worker.py', '-pindex', '1', '-jobid', '53706', '-num_iters', '938']' returned non-zero exit status 1.

[INFO] 2020-11-16 05:23:48,975 api: [default] Worker group FAILED. 3/3 attempts left; will restart worker group
[INFO] 2020-11-16 05:23:48,975 api: [default] Stopping worker group
[INFO] 2020-11-16 05:23:48,976 api: [default] Rendezvous'ing worker group
INFO 2020-11-16 05:23:48,976 Attempting to join next rendezvous
INFO 2020-11-16 05:23:48,980 Observed existing rendezvous state: {'status': 'final', 'version': '41', 'participants': [0, 1], 'keep_alives': ['/torchelastic/p2p/run_53706/rdzv/v_41/rank_1', '/torchelastic/p2p/run_53706/rdzv/v_41/rank_0'], 'num_workers_waiting': 0}
INFO 2020-11-16 05:23:49,059 Added self to waiting list. Rendezvous full state: {"status": "final", "version": "41", "participants": [0, 1], "keep_alives": ["/torchelastic/p2p/run_53706/rdzv/v_41/rank_1", "/torchelastic/p2p/run_53706/rdzv/v_41/rank_0"], "num_workers_waiting": 1}
INFO 2020-11-16 05:23:49,065 Keep-alive key /torchelastic/p2p/run_53706/rdzv/v_41/rank_0 is not renewed.
INFO 2020-11-16 05:23:49,066 Rendevous version 41 is incomplete.
INFO 2020-11-16 05:23:49,066 Attempting to destroy it.
INFO 2020-11-16 05:23:49,072 Destroyed rendezvous version 41 successfully.
INFO 2020-11-16 05:23:49,073 Previously existing rendezvous state changed. Will re-try joining.
INFO 2020-11-16 05:23:49,073 Attempting to join next rendezvous
INFO 2020-11-16 05:23:49,089 New rendezvous state created: {'status': 'joinable', 'version': '42', 'participants': []}
INFO 2020-11-16 05:23:49,163 Joined rendezvous version 42 as rank 0. Full state: {'status': 'joinable', 'version': '42', 'participants': [0]}
INFO 2020-11-16 05:23:49,163 Waiting for remaining peers.

More details

  • I'm using dist.init_process_group(backend='gloo', init_method='env://') to initialize the process group.
  • I'm using the torchelastic launcher to launch the workers:
python3.8 -m torchelastic.distributed.launch --nnodes=2 --nproc_per_node=1 --rdzv_id=53706 --rdzv_backend=etcd --rdzv_endpoint=10.0.1.26:2379 worker.py
  • OS: Linux 5.3.0-1032-azure x86_64; Ubuntu 18.04.4
  • CUDA and NCCL version: CUDA11.0 (11.0-devel-ubuntu18.04), NCCL2.7.8-1
  • Framework (TF, PyTorch, MXNet): PyTorch3.8 (1.7.0+cu110)
  • torchelastic release: 0.2.0 (45dc33f)
  • Please let me know if I need to provide more information!
@ruipeterpan ruipeterpan changed the title Question about behavior difference on gloo and nccl Elastic agent doesn't detect worker failures in NCCL Nov 16, 2020
@tchaton
Copy link

tchaton commented Nov 16, 2020

Hey there,

I think I have the same trouble.
Any updates ?

Best regards,
Thomas Chaton.

@ruipeterpan
Copy link
Author

@tchaton Unfortunately I haven't been able to resolve this issue :(

@kiukchung
Copy link
Contributor

Thanks for the question. Have you tried setting NCCL_BLOCKING_WAIT (or if you are using pytorch nightly - NCCL_ASYNC_ERROR_HANDLING env var on your trainers?
https://pytorch.org/docs/stable/distributed.html

@ruipeterpan
Copy link
Author

Hey @kiukchung thanks for the pointer! Setting the environment variable export NCCL_BLOCKING_WAIT=1 makes the previously-hanging worker throw the following error, which is subsequently caught by the elastic agent.

Traceback (most recent call last):
  File "worker.py", line 251, in <module>
    parse_args()
  File "worker.py", line 247, in parse_args
    init_processes(0, args)
  File "worker.py", line 220, in init_processes
    train(args)
  File "worker.py", line 130, in train
    update_gradients(model)
  File "worker.py", line 55, in update_gradients
    dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/distributed_c10d.py", line 948, in all_reduce
    work.wait()
RuntimeError: NCCL error: unhandled system error, NCCL version 2.7.8
[ERROR] 2020-11-16 21:08:51,160 local_elastic_agent: [default] Worker group failed

Thanks again for the quick help! Closing this issue.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants