Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Asynchronous Error Handling for Distributed Training with NCCL #46874

Closed
osalpekar opened this issue Oct 26, 2020 · 1 comment
Closed
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@osalpekar
Copy link
Member

osalpekar commented Oct 26, 2020

Motivation

DistributedDataParallel (DDP) training on GPUs using the NCCL process group routinely hangs, which is an unpleasant experience for users of PyTorch Distributed. In various situations (desynchronizations, high GPU utilization, etc.), one of the ranks in the process group may be slower to call the collective than the remaining ranks. Often, the slow rank is blocked on a previous CUDA operation and due to high GPU utilization, it is not able to proceed to calling the collective. Meanwhile, all the remaining ranks block, waiting for the stuck rank, which causes the entire training to hang indefinitely. The training can often stay in this state for hours or until it is manually killed by the user.

Alternatives

Ideally we would like to provide a mechansim to detect and recover from hanging without any performance overhead. One
feature already exists to detect hangs: NCCL_BLOCKING_WAIT. Blocking wait blocks the main thread when the
wait function on the associated WorkNCCL object is called, and this wait function polls every fixed time interval whether or not the collective has timed out. If the collective has timed out, an exception is thrown from the main thread. However, due to
blocking the main thread, this approach may incur up to a 60% regression on training performance.

Pitch

An analysis of blocking wait functionality suggests error handling and timeout checking must happen asynchronously. The existing ncclCommWatchdogThread polls for NCCL errors at some fixed duration and aborts the associated NCCL communicators so that future NCCL functions do not operate on corrupted data. We can additionally make the ncclCommWatchdogThread check for timed out collectives and set an appropriate exception on the WorkNCCL objects associated with the collectives if necessary.

We cannot surface exceptions set on the WorkNCCL objects by blocking the main thread or using some other trigger (such as when the next collective is called) since this may incur a large performance overhead and may not work for all workloads. As a result, we introduce a new helper thread, the workCleanupThread. Every time a collective is called, we add its WorkNCCL
object to a list. The workCleanupThread then iterates through this list of ongoing collectives. We check whether collectives
have completed successfully using cudaEventQuery and remove those objects from the list. For WorkNCCL objects that
have an exception set (which may have been set due to errors or timeouts set by the watchdog), we rethrow the exception.
Since this exception is being thrown from a helper thread, the training process will crash.

Due to the asynchronous nature of detecting and surfacing errors, this feature has little to no performance overhead for DDP
training on even the most complex models.

Usage

To enable this feature, set the environment variable NCCL_ASYNC_ERROR_HANDLING to 1. The timeout after which stuck
collectives are aborted can be configured when initializing the process group:

import torch.distributed as dist

dist.init_process_group(
    …
    backend=“nccl”,
    timeout=timedelta(seconds=30)  # Set your desired timeout here. The default is 30 minutes.
)

Using this feature by itself (while using DDP for training with NCCL) allows users to abort stuck collectives and thereby save
compute time that would otherwise have been wasted due to the hanging. However, using this feature along with torchelastic
allows training to continue even after the hang. This feature will crash the training process after detecting a stuck collective, and torchelastic will see the SIGABRT from the training process and restart training from the last checkpoint. This provides a comprehensive method for detecting and recovering from hangs with little performance overhead.

Lastly, this feature is separate from NCCL_BLOCKING_WAIT, so only one of these two environment variables should be set during training.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 26, 2020
@osalpekar osalpekar added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 26, 2020
@osalpekar osalpekar self-assigned this Oct 26, 2020
@osalpekar
Copy link
Member Author

Closed by stack starting from PR #41050

osalpekar added a commit to osalpekar/elastic that referenced this issue Nov 9, 2020
Summary:
NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the `timeout` argument to `init_process_group`.

Differential Revision: D23610237

fbshipit-source-id: dbf4acbcaa470eccb0c2d675652df644969f8689
osalpekar added a commit to osalpekar/elastic that referenced this issue Nov 10, 2020
Summary:
Pull Request resolved: pytorch#133

NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the `timeout` argument to `init_process_group`.

Reviewed By: kiukchung

Differential Revision: D23610237

fbshipit-source-id: 0183ce0fe0f7c3e6d615c352183ae74fd0bee854
osalpekar added a commit to osalpekar/ClassyVision that referenced this issue Nov 11, 2020
Summary:
NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the timeout argument to init_process_group.

Differential Revision: D24840115

fbshipit-source-id: 7df4af88f8f620bc13a1563851d664a27abbff0b
osalpekar added a commit to osalpekar/elastic that referenced this issue Dec 9, 2020
Summary:
Pull Request resolved: pytorch#133

NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the `timeout` argument to `init_process_group`.

Reviewed By: kiukchung, jiayisuse

Differential Revision: D23610237

fbshipit-source-id: 156ff820268ff7de62fe332281c8e572b1bda2ad
facebook-github-bot pushed a commit to pytorch/elastic that referenced this issue Dec 9, 2020
Summary:
Pull Request resolved: #133

NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the `timeout` argument to `init_process_group`.

Reviewed By: kiukchung, jiayisuse

Differential Revision: D23610237

fbshipit-source-id: 7a2a496c0b781b68d76e138bd66ca0b7c04f17d0
osalpekar added a commit to osalpekar/ClassyVision that referenced this issue Jan 27, 2021
Summary:
Pull Request resolved: facebookresearch#650

NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the timeout argument to init_process_group.

Reviewed By: mannatsingh

Differential Revision: D24840115

fbshipit-source-id: 83b75a77b26704f1e2fa262c77f373f9c4cdd18d
facebook-github-bot pushed a commit to facebookresearch/ClassyVision that referenced this issue Jan 28, 2021
Summary:
Pull Request resolved: #650

NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the timeout argument to init_process_group.

Reviewed By: mannatsingh

Differential Revision: D24840115

fbshipit-source-id: 0a29471878aed7501a801cc69bd47b9718871396
fotstrt pushed a commit to eth-easl/elastic that referenced this issue Feb 17, 2022
Summary:
Pull Request resolved: pytorch#133

NCCL Async Error Handling is a new mechanism implemented in ProcessGroupNCCL to provide reliability for DDP training runs using NCCL. See here for a more detailed background and implementation details: pytorch/pytorch#46874.

At a high-level, this system was designed to ensure desynchronization, high GPU utilization, and NCCL errors don't cause indefinite hanging in distributed training runs. This system catches these errors without any perf impact and brings down the training process, and torchelastic can detect this and restart training from the previous checkpoint. The time after which stuck collectives are detected can be tuned using the `timeout` argument to `init_process_group`.

Reviewed By: kiukchung, jiayisuse

Differential Revision: D23610237

fbshipit-source-id: 7a2a496c0b781b68d76e138bd66ca0b7c04f17d0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants