Skip to content

Commit

Permalink
[WIP] use cudaEventQuery instead of cudaStreamAddCallback to catch NC…
Browse files Browse the repository at this point in the history
…CL errors

This method avoids the expensive serialization from adding a callback
and instead polls the CUDA event to check for completion. It then performs the
same error handling to throw an exception from the workCleanupThread.

Differential Revision: [D22929042](https://our.internmc.facebook.com/intern/diff/D22929042/)

ghstack-source-id: 110188426
Pull Request resolved: #43232
  • Loading branch information
osalpekar committed Aug 18, 2020
1 parent c51c724 commit 75679c3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
33 changes: 16 additions & 17 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,6 @@ ncclResult_t ncclAlltoallv(
}
#endif

void CUDART_CB
errorGuard(cudaStream_t /* unused */, cudaError_t /* unused */, void* data) {
ProcessGroupNCCL::WorkNCCL* work = (ProcessGroupNCCL::WorkNCCL*)data;
work->handleNCCLGuard();
}

} // namespace

const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
Expand Down Expand Up @@ -310,6 +304,19 @@ void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() {
}
}

bool ProcessGroupNCCL::WorkNCCL::isCompletedAndThrowException() {
checkAndSetException();
std::lock_guard<std::mutex> lock(mutex_);
if (exception_ || finishedGPUExecutionInternal()) {
completed_ = true;
if (exception_) {
std::rethrow_exception(exception_);
}
return true;
}
return false;
}

void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() {
std::lock_guard<std::mutex> lock(mutex_);
completed_ = true;
Expand All @@ -330,16 +337,6 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStreams() {

// Block the current stream on the NCCL stream
cudaEvents_[i].block(currentStream);
// Enqueue guard function as a callback on the current stream to throw
// user-stream exception upon error.
if (!blockingWait_) {
cudaStreamAddCallback(currentStream, errorGuard, this, 0);
}
// If we use the work to do barrier, we should block here
if (!barrierTensors_.empty()) {
at::cuda::CUDAGuard gpuGuard(devices_[i]);
AT_CUDA_CHECK(cudaDeviceSynchronize());
}
}
}

Expand Down Expand Up @@ -631,7 +628,9 @@ void ProcessGroupNCCL::workCleanupLoop() {
/* no increment*/) {
auto& work = *it;
if (work->isCompleted()) {
// Remove all Completed WorkNCCL Objects from the Vector
// Handle Exceptions on failed GPU operations and remove completed
// workNCCL objects from work vector.
work->handleNCCLGuard();
it = workVector_.erase(it);
} else {
// Increment the iterator if the current WorkNCCL object is not
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class ProcessGroupNCCL : public ProcessGroup {
// It actually returns a FutureNCCL object which is a sub class Future.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

// Checks for completion of the WorkNCCL object, and if complete, handles
// any caught errors or exceptions. Returns true if the work is completed.
bool isCompletedAndThrowException();

// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);

Expand Down

0 comments on commit 75679c3

Please sign in to comment.