New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NCCL] use cudaEventQuery instead of cudaStreamAddCallback to catch NCCL errors #43232
Conversation
…CL errors This method avoids the expensive serialization from adding a callback and instead polls the CUDA event to check for completion. It then performs the same error handling to throw an exception from the workCleanupThread. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
…CL errors This method avoids the expensive serialization from adding a callback and instead polls the CUDA event to check for completion. It then performs the same error handling to throw an exception from the workCleanupThread. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) ghstack-source-id: 1101884 Pull Request resolved: #43232
💊 CI failures summary and remediationsAs of commit 1d0a139 (more details on the Dr. CI page):
❄️ 1 failure tentatively classified as flakybut reruns have not yet been triggered to confirm: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️
|
… to catch NCCL errors" This method avoids the expensive serialization from adding a callback and instead polls the CUDA event to check for completion. It then performs the same error handling to throw an exception from the workCleanupThread. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… Loop Pull Request resolved: #43232 **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. ghstack-source-id: 110283587 Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… Loop Pull Request resolved: #43232 **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. ghstack-source-id: 110300867 Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… Loop Pull Request resolved: #43232 **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. ghstack-source-id: 110382519 Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… Loop Pull Request resolved: #43232 **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. ghstack-source-id: 110813858 Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… Loop Pull Request resolved: #43232 **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. ghstack-source-id: 111301605 Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
… to catch NCCL errors" **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/) [ghstack-poisoned]
… Loop Pull Request resolved: #43232 **This Commit:** Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. ghstack-source-id: 111311020 Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, have a few minor comments inline.
Can we also have another PR on top of this to remove the busy waiting in wait()
and replace it with cudaEventSynchronize
?
} | ||
return false; | ||
} | ||
|
||
void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We no longer need this function?
work->handleNCCLGuard(); | ||
// Handle Exceptions on failed GPU operations and remove completed | ||
// workNCCL objects from work vector. | ||
if (work->isCompletedAndThrowException()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add a LOG(ERROR) before we throw an exception mentioning the following: "Some NCCL operations have timed out/failed and due to the async nature of CUDA kernels subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we're taking the entire process down."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created #44988 with this change
Stack from ghstack:
This Commit:
Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and
handleNCCLGuard
so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 PR's in this stack work perfectly well without this optimization.This Stack:
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.
Differential Revision: D22929042