Skip to content

Commit

Permalink
[NCCL] Additional Lock Optimizations for handleNCCLGuard from Cleanup…
Browse files Browse the repository at this point in the history
… Loop

Pull Request resolved: #43232

**This Commit:**
Here we introduce some experimental lock optimizations. This essentially checking for workNCCL completion and `handleNCCLGuard` so that the lock is acquired once instead of twice. The performance implications of this are still being measured. The first 5 diffs in this stack work perfectly well without this optimization.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

ghstack-source-id: 110300867

Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)
  • Loading branch information
osalpekar committed Aug 20, 2020
1 parent e0ffdff commit 7b570e7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
16 changes: 14 additions & 2 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() {
}
}

bool ProcessGroupNCCL::WorkNCCL::isCompletedAndThrowException() {
checkAndSetException();
std::lock_guard<std::mutex> lock(mutex_);
if (exception_ || finishedGPUExecutionInternal()) {
completed_ = true;
if (exception_) {
std::rethrow_exception(exception_);
}
return true;
}
return false;
}

void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() {
std::lock_guard<std::mutex> lock(mutex_);
completed_ = true;
Expand Down Expand Up @@ -615,10 +628,9 @@ void ProcessGroupNCCL::workCleanupLoop() {
for (auto it = workVector_.begin(); it != workVector_.end();
/* no increment*/) {
auto& work = *it;
if (work->isCompleted()) {
if (work->isCompletedAndThrowException()) {
// Handle Exceptions on failed GPU operations and remove completed
// workNCCL objects from work vector.
work->handleNCCLGuard();
it = workVector_.erase(it);
} else {
// Increment the iterator if the current WorkNCCL object is not
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class ProcessGroupNCCL : public ProcessGroup {
// It actually returns a FutureNCCL object which is a sub class Future.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

// Checks for completion of the WorkNCCL object, and if complete, handles
// any caught errors or exceptions. Returns true if the work is completed.
bool isCompletedAndThrowException();

// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);

Expand Down

0 comments on commit 7b570e7

Please sign in to comment.