Skip to content

Commit

Permalink
Add an explicit _shutdown method to ProcessGroupNCCL (#111392)
Browse files Browse the repository at this point in the history
Currently, the only way ProcessGroupNCCL shuts down its background threads and aborts all communicators is via the destructor.

However, given how python GC works and code holding references to the PG in multiple places, in practice calling `destroy_process_group` doesn't actually end up invoking the destructor.

As a result, in this PR I'm adding a explicit shutdown method to that users can call to cleanup all resources.
Pull Request resolved: #111392
Approved by: https://github.com/XilunWu, https://github.com/wanchaol, https://github.com/fduwjj
  • Loading branch information
pritamdamania87 authored and pytorchmergebot committed Oct 24, 2023
1 parent 6d78f34 commit 0ad91c2
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 11 deletions.
29 changes: 28 additions & 1 deletion test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def test_abort_pg(self):
dist.all_reduce(t)

def abortpg():
c10d.distributed_c10d._get_default_group()._get_backend(torch.device(device))._abort()
c10d.distributed_c10d._get_default_group()._get_backend(torch.device(device))._shutdown()

# Initialize DDP to ensure "destroy_process_group" will not call
# ProcessGroupNCCL destructor since DDP holds a reference to process group.
Expand All @@ -1170,6 +1170,33 @@ def abortpg():

thread.join()

@requires_nccl()
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_close_pg(self):
# Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
# abort the process group.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"

store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
device = self.rank_to_GPU[self.rank][0]

t = torch.rand(10, 10, device=device)
# First allreduce to initialize state.
pg.allreduce(t)

# Destroy pg and validate pg is still in working condition since we hold a
# reference above.
dist.destroy_process_group()
pg.allreduce([t])

# Now close pg and validate it no longer works.
pg._get_backend(torch.device(device))._shutdown()

# Try another collective.
with self.assertRaises(dist.DistBackendError):
pg.allreduce([t])

class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
):
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/NCCLUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ ncclComm_t NCCLComm::getNcclComm() {
auto commFailureMsg = commFailureReason_ != c10::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_)
: "";
TORCH_CHECK(
TORCH_CHECK_WITH(
DistBackendError,
false,
c10::str(
"NCCL communicator was aborted on rank ",
Expand Down
21 changes: 18 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,18 +878,33 @@ void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
abortCommsFromMap(inInitializationCommMap_, rank_, abortReason);
}

ProcessGroupNCCL::~ProcessGroupNCCL() {
void ProcessGroupNCCL::shutdown() {
// Don't join threads here since the purpose of this method is to abort all
// communicators and signal the threads to exit. Joining on the threads could
// potentially block and hence avoid it in this method.
terminateProcessGroup_.store(true);

std::string abortReason = c10::str("Process Group shutdown on rank ", rank_);
abort(abortReason);

workMetaListCV_.notify_one();
}

ProcessGroupNCCL::~ProcessGroupNCCL() {
terminateProcessGroup_.store(true);
workMetaListCV_.notify_one();

#ifdef ENABLE_NCCL_ERROR_CHECKING
ncclCommWatchdogThread_.join();
if (ncclCommWatchdogThread_.joinable()) {
ncclCommWatchdogThread_.join();
}
#endif

if (onCompletionHookThread_.joinable())
onCompletionHookThread_.join();

// Abort all NCCL Communicators on Process Group Destruction
// Abort communicators after all threads have exited to avoid having the
// threads dying due to aborted communicator and raising a SIGABRT
std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
abort(abortReason);
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// instead of relying on ProcessGroupNCCL destructor.
void abort(c10::optional<std::string> abortReason = c10::nullopt);

void shutdown();

protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(
Expand Down
8 changes: 3 additions & 5 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2239,12 +2239,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::arg("timeout") = ::c10d::kProcessGroupNCCLDefaultTimeout,
py::call_guard<py::gil_scoped_release>())
.def(
"_abort",
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
const c10::optional<std::string>& abortReason) {
return self->abort(abortReason);
"_shutdown",
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
return self->shutdown();
},
py::arg("abort_reason") = py::none(),
py::call_guard<py::gil_scoped_release>())
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9771,7 +9771,7 @@ def test_nccl_init_abort(self):
def abort(device):
pg = _get_default_group()
while running:
pg._get_backend(torch.device(device))._abort()
pg._get_backend(torch.device(device))._shutdown()
time.sleep(1)

if self.rank != 1:
Expand Down

0 comments on commit 0ad91c2

Please sign in to comment.