-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NCCL] use cudaEventQuery instead of cudaStreamAddCallback to catch NCCL errors #43232
Changes from all commits
94eb6ed
0dc34c9
22939e1
5dce7fe
76a38bf
6c25257
d61c6f2
86e2482
1d0a139
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -306,6 +306,19 @@ void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() { | |
} | ||
} | ||
|
||
bool ProcessGroupNCCL::WorkNCCL::isCompletedAndThrowException() { | ||
checkAndSetException(); | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
if (exception_ || finishedGPUExecutionInternal()) { | ||
completed_ = true; | ||
if (exception_) { | ||
std::rethrow_exception(exception_); | ||
} | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
completed_ = true; | ||
|
@@ -632,10 +645,9 @@ void ProcessGroupNCCL::workCleanupLoop() { | |
for (auto it = workList_.begin(); it != workList_.end(); | ||
/* no increment*/) { | ||
auto& work = *it; | ||
if (work->isCompleted()) { | ||
// Handle Exceptions on failed GPU operations and remove completed | ||
// workNCCL objects from work vector. | ||
work->handleNCCLGuard(); | ||
// Handle Exceptions on failed GPU operations and remove completed | ||
// workNCCL objects from work vector. | ||
if (work->isCompletedAndThrowException()) { | ||
osalpekar marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to add a LOG(ERROR) before we throw an exception mentioning the following: "Some NCCL operations have timed out/failed and due to the async nature of CUDA kernels subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we're taking the entire process down." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created #44988 with this change |
||
it = workList_.erase(it); | ||
} else { | ||
// Increment the iterator if the current WorkNCCL object is not | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We no longer need this function?