Skip to content

Commit

Permalink
[NCCL] Modularize ncclCommWatchdog (#46051)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #46051

Creates a subroutine for aborting timed out collectives. This should help modularize the ncclCommWatchdog a bit, since it is growing too large.
ghstack-source-id: 114398496

Test Plan:
Successful Flow Run:
f225037915
f217609101

Reviewed By: jiayisuse

Differential Revision: D23607535

fbshipit-source-id: 0b1c9483bcd3a41847fc8c0bf6b22cdba01fb1e6
  • Loading branch information
osalpekar authored and facebook-github-bot committed Oct 16, 2020
1 parent be0c431 commit 2e2fe8c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
52 changes: 28 additions & 24 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,33 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
}
}

void ProcessGroupNCCL::abortTimedOutCollectives(std::unordered_set<std::string>& abortedCommIds) {
std::unique_lock<std::mutex> lock(workMetaListMutex_);
for (auto& work : workMetaList_) {
work.checkAndSetException();
// Aborting NCCL Communicators due to errors is already handled above.
if (work.exception()) {
continue;
}

// Check for Timeouts in the WorkNCCL Operations, and abort all
// communicators accordingly.
if (work.timedOut()) {
LOG(INFO)
<< "[Rank " << rank_
<< "] Watchdog caught collective operation timeout for work: "
<< work;
std::exception_ptr exception_ptr = std::make_exception_ptr(
std::runtime_error("NCCL Operation Timed Out"));
work.setException(exception_ptr);
for (const auto& ncclComm : work.ncclComms_) {
ncclComm->ncclCommAbort();
abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
}
}
}

void ProcessGroupNCCL::ncclCommWatchdog() {
try {
LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!";
Expand Down Expand Up @@ -556,30 +583,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
}

if (asyncErrorHandling_) {
std::unique_lock<std::mutex> lock(workMetaListMutex_);
for (auto& work : workMetaList_) {
work.checkAndSetException();
// Aborting NCCL Communicators due to errors is already handled above.
if (work.exception()) {
continue;
}

// Check for Timeouts in the WorkNCCL Operations, and abort all
// communicators accordingly.
if (work.timedOut()) {
LOG(INFO)
<< "[Rank " << rank_
<< "] Watchdog caught collective operation timeout for work: "
<< work;
std::exception_ptr exception_ptr = std::make_exception_ptr(
std::runtime_error("NCCL Operation Timed Out"));
work.setException(exception_ptr);
for (const auto& ncclComm : work.ncclComms_) {
ncclComm->ncclCommAbort();
abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
}
}
abortTimedOutCollectives(abortedCommIds);
}

if (blockingWait_) {
Expand Down
5 changes: 5 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ class ProcessGroupNCCL : public ProcessGroup {

void ncclCommWatchdogInternal();

// This function iterates through the list of WorkNCCL objects in the
// workList_ corresponding to incomplete collectives and then aborts NCCL
// communicators associated with timed out collectives.
void abortTimedOutCollectives(std::unordered_set<std::string>& abortedCommIds);

void workCleanupLoop();

protected:
Expand Down

0 comments on commit 2e2fe8c

Please sign in to comment.