Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NCCL] use cudaEventQuery instead of cudaStreamAddCallback to catch NCCL errors #43232

Closed
wants to merge 9 commits into from
Closed
20 changes: 16 additions & 4 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,19 @@ void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() {
}
}

bool ProcessGroupNCCL::WorkNCCL::isCompletedAndThrowException() {
checkAndSetException();
std::lock_guard<std::mutex> lock(mutex_);
if (exception_ || finishedGPUExecutionInternal()) {
completed_ = true;
if (exception_) {
std::rethrow_exception(exception_);
}
return true;
}
return false;
}

void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need this function?

std::lock_guard<std::mutex> lock(mutex_);
completed_ = true;
Expand Down Expand Up @@ -632,10 +645,9 @@ void ProcessGroupNCCL::workCleanupLoop() {
for (auto it = workList_.begin(); it != workList_.end();
/* no increment*/) {
auto& work = *it;
if (work->isCompleted()) {
// Handle Exceptions on failed GPU operations and remove completed
// workNCCL objects from work vector.
work->handleNCCLGuard();
// Handle Exceptions on failed GPU operations and remove completed
// workNCCL objects from work vector.
if (work->isCompletedAndThrowException()) {
osalpekar marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add a LOG(ERROR) before we throw an exception mentioning the following: "Some NCCL operations have timed out/failed and due to the async nature of CUDA kernels subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we're taking the entire process down."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #44988 with this change

it = workList_.erase(it);
} else {
// Increment the iterator if the current WorkNCCL object is not
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class ProcessGroupNCCL : public ProcessGroup {
// It actually returns a FutureNCCL object which is a sub class Future.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

// Checks for completion of the WorkNCCL object, and if complete, handles
// any caught errors or exceptions. Returns true if the work is completed.
bool isCompletedAndThrowException();

// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);

Expand Down