diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index c8ef529f433cc..410e7470642d7 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -1,9 +1,12 @@ +#include #include #include #include #include "CUDATest.hpp" #include "TestUtils.hpp" +#include "c10d/Types.hpp" +#include "c10d/ProcessGroup.hpp" #include #include @@ -19,7 +22,10 @@ using c10d::ProcessGroup; class NCCLTestBase { public: - NCCLTestBase(const std::string& path) : path_(path) {} + NCCLTestBase( + const std::string& path, + const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout + ) : path_(path), pgTimeout_(pgTimeout) {} NCCLTestBase(NCCLTestBase&& other) { path_ = std::move(other.path_); @@ -33,19 +39,22 @@ class NCCLTestBase { void initialize(int rank, int size) { auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + c10::intrusive_ptr opts = c10::make_intrusive(); + opts->timeout = pgTimeout_; pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( - new ::c10d::ProcessGroupNCCL(store, rank, size)); + new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts))); } protected: std::string path_; std::unique_ptr<::c10d::ProcessGroupNCCL> pg_; + std::chrono::milliseconds pgTimeout_; }; class NCCLTest : public NCCLTestBase { public: - NCCLTest(const std::string& path, int worldSize) - : NCCLTestBase(path), + NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) + : NCCLTestBase(path, pgTimeout), numDevices_(cudaNumDevices()), state_(::at::globalContext().lazyInitCUDA()), worldSize_(worldSize) { @@ -497,10 +506,50 @@ void testReduceScatter(const std::string& path, int rank, int size) { } } +void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) { + // simulate world_size > 1 here via threads. + const int worldSize = 4; + std::mutex m; + std::unordered_set nums; + auto runTest = [&](int i) { + NCCLTest test(path, worldSize, std::chrono::milliseconds(3000)); + // Catch error relating to health check failure + bool error_caught = false; + try { + test.initialize(timeout ? 0 : -1, worldSize); + } catch (const std::exception &e) { + std::string errMsg = e.what(); + const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank"; + const std::string kInvalidRankErr = "Invalid rank"; + std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr; + bool cond = errMsg.find(expectedSubstr) != std::string::npos; + EXPECT_TRUE(cond); + error_caught = true; + } + EXPECT_TRUE(error_caught); + }; + std::vector threads; + threads.reserve(worldSize); + for (const auto r : c10::irange(worldSize)) { + threads.emplace_back(std::thread([=]() { runTest(r); })); + } + for (auto& t : threads) { + t.join(); + } +} + +void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) { + testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false); +} + +void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) { + testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true); +} + void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) { // Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we // simulate world_size > 1 here via threads. - const int worldSize = 4; + const int worldSize = 2; std::mutex m; std::unordered_set nums; auto runTest = [&](int i) { @@ -625,6 +674,26 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) { } } +TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) { + if (skipTest()) { + return; + } + { + TemporaryFile file; + testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_); + } +} + +TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) { + if (skipTest()) { + return; + } + { + TemporaryFile file; + testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_); + } +} + TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) { if (skipTest()) { return; diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index d5b88517de4d9..4cecca9109764 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -657,11 +657,13 @@ def test_nccl_propagate_error_reason(self): # otherwise process will be taken down and we can't check for errors. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["NCCL_BLOCKING_WAIT"] = "1" - timeout = timedelta(seconds=2) + # TODO: smaller timeout can fail since PG NCCl does health check in + # constructor. Look into reducing this test's runtime. store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout) + # provide sufficient timeout to initialize NCCL comm. + pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=15)) pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size) - pg.barrier().wait() + pg.barrier().wait(timedelta(seconds=5)) # Simulate stuckness in rank 0. if self.rank == 0: pg_gloo.barrier().wait() @@ -670,7 +672,7 @@ def test_nccl_propagate_error_reason(self): if self.rank != 0: # Time out due to rank 0 not calling into allreduce. with self.assertRaises(RuntimeError): - pg.allreduce([inp]).wait() + pg.allreduce([inp]).wait(timedelta(seconds=5)) # Now when nonzero rank attempts to use communicator, original failure reason should be logged.j try: @@ -2264,14 +2266,14 @@ def _test_nccl_errors_blocking(self, func): store, self.rank, self.world_size, - timeout=timedelta(seconds=self.op_timeout_sec), + timeout=timedelta(seconds=10), ) process_group.allreduce(torch.rand(10).cuda(self.rank)) if self.rank == 0: work = process_group.allreduce(torch.rand(10).cuda(self.rank)) with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # Operation would time out in blocking mode. - work.wait() + work.wait(timeout=timedelta(seconds=self.op_timeout_sec)) # Run some GPU operations to make sure cuda has not gotten stuck. # It was observed cuda could get stuck if NCCL communicators were # not properly aborted before throwing RuntimeError. @@ -2340,13 +2342,13 @@ def test_nccl_blocking_wait_with_barrier(self): store, self.rank, self.world_size, - timeout=timedelta(seconds=self.op_timeout_sec), + timeout=timedelta(seconds=10), ) process_group.barrier().wait() if self.rank == 0: with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # This should timeout - process_group.barrier().wait() + process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec)) def _run_invalid_nccl_blocking_wait_env(self, val): os.environ["NCCL_BLOCKING_WAIT"] = val @@ -2383,21 +2385,19 @@ def test_nccl_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) # Initialize process_group. - timeout = 1 process_group = c10d.ProcessGroupNCCL( - store, self.rank, self.world_size, timeout=timedelta(seconds=timeout) + store, self.rank, self.world_size, timeout=timedelta(seconds=10) ) - process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() + process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=5)) if self.rank == 0: # This should timeout in about 1 second. - start = time.time() # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out. with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): - process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() + process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1)) else: # Sleep to ensure timeout. - time.sleep(2 * timeout) + time.sleep(10) self._wait_for_comm_abort(process_group) @@ -2547,14 +2547,14 @@ def test_nccl_barrier_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) if self.rank == 0: with self.assertRaisesRegex( - RuntimeError, "Timed out initializing process group" + RuntimeError, "Health check failure" ): c10d.init_process_group( backend="nccl", rank=self.rank, world_size=self.world_size, store=store, - timeout=timedelta(seconds=1), + timeout=timedelta(seconds=10), ) @requires_nccl() @@ -2566,12 +2566,12 @@ def test_nccl_barrier_timeout_new_group(self): rank=self.rank, world_size=self.world_size, store=store, - timeout=timedelta(seconds=1), + timeout=timedelta(seconds=10), ) if self.rank == 0: with self.assertRaisesRegex( - RuntimeError, "Timed out initializing process group" + RuntimeError, "Health check failure" ): c10d.new_group([0, 1], timeout=timedelta(seconds=1)) @@ -2589,12 +2589,12 @@ def test_nccl_barrier_timeout_new_group_non_member(self): rank=self.rank, world_size=self.world_size, store=store, - timeout=timedelta(seconds=1), + timeout=timedelta(seconds=10), ) if self.rank == 1: with self.assertRaisesRegex( - RuntimeError, "Timed out initializing process group" + RuntimeError, "Health check failure" ): c10d.new_group([0, 1], timeout=timedelta(seconds=1)) diff --git a/test/distributed/test_jit_c10d.py b/test/distributed/test_jit_c10d.py index 65d82fb033b7d..324070e02b2ac 100644 --- a/test/distributed/test_jit_c10d.py +++ b/test/distributed/test_jit_c10d.py @@ -39,7 +39,7 @@ def setUp(self): def _create_nccl_pg(self, name_prefix): tcp_store = create_tcp_store(jit_class=True) - opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True) + opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(10000, True) name = unique_process_group_name(name_prefix) @@ -49,7 +49,7 @@ def _create_nccl_pg_as_base_process_group(self, name): tcp_store = create_tcp_store(jit_class=True) return torch.classes.dist_c10d.frontend().new_process_group_helper( - self.world_size, self.rank, [], "nccl", tcp_store, name, 0) + self.world_size, self.rank, [], "nccl", tcp_store, name, 10000) @requires_nccl() @sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") @@ -172,7 +172,7 @@ def test_frontend_singleton(self): pg_name = unique_process_group_name("singleton_test_process_group") ProcessGroupNCCL1 = frontend1.new_process_group_helper( - self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0) + self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 10000) ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name) self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name) @@ -189,7 +189,7 @@ def __init__(self): name = unique_process_group_name("module_member_process_group") self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper( - 1, 0, [], "nccl", tcp_store, name, 0) + 1, 0, [], "nccl", tcp_store, name, 10000) def forward(self, input: torch.Tensor): if self.pg is None: diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 2d31d25658dfe..b1507af53ce53 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -12,15 +12,18 @@ #include #include +#include +#include +#include #include #include +#include #include #include #include -#include + #include -#include namespace c10d { @@ -159,6 +162,14 @@ std::vector getDeviceList(const std::vector& tensors) { return res; } +// Return CUDA device with ordinal given by input rank. +at::Device getDeviceForRank(int rank) { + TORCH_CHECK(rank >= 0, "Invalid rank ", rank); + auto numGPUs = at::cuda::getNumGPUs(); + int16_t deviceIdx = static_cast(rank % numGPUs); + return at::Device(at::DeviceType::CUDA, deviceIdx); +} + // [Sync Streams] Helper that lets the input ncclStreams to wait for the current // stream. NCCL communications run on ncclStreams, but input tensors are // allocated on different streams (i.e., current streams). Communications on @@ -502,6 +513,13 @@ ProcessGroupNCCL::ProcessGroupNCCL( asyncErrorHandling_ = false; } + // Perform health check by initializing dummy communicators and destroying + // them. This will help indicate any NCCL-related issues prior to the first + // collective. + // Run it in a separate thread and wait on CV to handle timeouts, since + // majority of getNCCLComm failures are hangs. + runHealthCheck(); + #ifdef ENABLE_NCCL_ERROR_CHECKING ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); @@ -527,6 +545,64 @@ ProcessGroupNCCL::ProcessGroupNCCL( << "\nNCCL_DEBUG: " << ncclDebugLevel; } +void ProcessGroupNCCL::runHealthCheck() { + // Run health check in a separate thread and wait on CV to handle timeouts, + // since majority of getNCCLComm failures are hangs. + + struct HealthCheckData { + std::mutex healthCheckMutex; + std::condition_variable healthCheckCv; + bool healthCheckSuccess = false; + std::exception_ptr healthCheckException; + }; + + HealthCheckData healthCheckData; + auto t = std::thread([&healthCheckData, this]() { + try { + std::vector rankDevice = {getDeviceForRank(rank_)}; + const auto key = getKeyFromDevices(rankDevice); + // OpType does not matter, only need to set to not go through send/recv + // path. + getNCCLComm(key, rankDevice, OpType::ALLREDUCE); + // Now destroy the communicators and remove them from cache so we don't + // use destroyed communicators. + destroyNCCLComms(key); + // Notify main thread the health check is complete. + { + std::lock_guard lk(healthCheckData.healthCheckMutex); + healthCheckData.healthCheckSuccess = true; + } + healthCheckData.healthCheckCv.notify_one(); + } catch (const std::exception& e) { + // Populate exception ptr. + healthCheckData.healthCheckException = std::current_exception(); + // Unblock waiting main thread which will report exception. + healthCheckData.healthCheckCv.notify_one(); + } // Unknown exceptions will just cause the program to terminate. + }); + // We don't need to join the thread, just need to verify health check via the + // CV. Hence we detach the thread here. + t.detach(); // NOLINT + LOG(INFO) << "[Rank " << rank_ << "]" + << " will wait up to " << options_->timeout.count() + << " msec for NCCL health check to complete."; + std::unique_lock lock(healthCheckData.healthCheckMutex); + healthCheckData.healthCheckCv.wait_for( + lock, options_->timeout, [&healthCheckData]() { + return healthCheckData.healthCheckSuccess; + }); + + if (healthCheckData.healthCheckException) { + std::rethrow_exception(healthCheckData.healthCheckException); + } + // If there is no exception, the likely culprit is a timeout/hang which is how + // most communicator init issues manifest themselves. + TORCH_CHECK( + healthCheckData.healthCheckSuccess, + "ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank ", + rank_); +} + void ProcessGroupNCCL::setSequenceNumberForGroup() { if (rank_ == 0) { // Create and broadcast sequence number @@ -874,6 +950,30 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } } + +void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { + TORCH_INTERNAL_ASSERT( + false, + "Expected to find key ", + devNCCLCommMapKey, + " in NCCL communicator map."); + } + std::vector>& ncclComms = + devNCCLCommMap_[devNCCLCommMapKey]; + // Loop through communicators and call ncclCommAbort. + for (const auto& comm : ncclComms) { + // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being + // destroyed, so using ncclCommAbort here. + comm->ncclCommAbort(); + } + // Remove communicators from the cache. + devNCCLCommMap_.erase(devNCCLCommMapKey); + // Clear used device indices. + usedDeviceIdxs_.clear(); +} + std::vector>& ProcessGroupNCCL::getNCCLComm( const std::string& devicesKey, const std::vector& devices, @@ -1697,7 +1797,7 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( "This can potentially cause a hang if this rank to GPU mapping is incorrect.", "Specify device_ids in barrier() to force use of a particular device." ); - devices.emplace_back(at::DeviceType::CUDA, deviceIdx); + devices.emplace_back(getDeviceForRank(rank_)); } else { for (auto usedDeviceIdx : usedDeviceIdxs_) { devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index a35b4681acf93..12e856f31fdbf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -430,6 +430,18 @@ class TORCH_API ProcessGroupNCCL : public ProcessGroup { void abortTimedOutCollectives( std::unordered_set& abortedCommIds); + // Performs a health check by initializing dummy NCCL communicators and then + // destroying them. This will help indicate and signal any NCCL-related issues + // prior to the first collective. The actual initialization and subsequent + // destruction is ran on a separate thread and the main thread is signalled + // about timeouts/errors to report to the application. + void runHealthCheck(); + + // Destroys initialized NCCL communicators in devNCCLComMap_ given by input + // key. Throws if there are no communicators to destroy. Also removes + // communicators from the cache and clears used device indices. + void destroyNCCLComms(const std::string& devNCCLCommMapKey); + void workCleanupLoop(); protected: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a03a7fa009426..ff07f8ec1e7a0 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7334,7 +7334,9 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): # tests expected behavior when nonzero rank hangs. nccl_pg = dist.new_group( ranks=list(i for i in range(int(self.world_size))), - timeout=timedelta(seconds=2), + # provide sufficient timeout so communicators + # can be initialized in ctor. + timeout=timedelta(seconds=15), backend=dist.Backend.NCCL, ) gloo_pg = dist.new_group( @@ -7345,7 +7347,7 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): # Let all ranks call allreduce first to set up communicators etc. # Directly simulating error here will run into store issue described # in https://github.com/pytorch/pytorch/issues/54524. - nccl_pg.allreduce(tensors).wait() + nccl_pg.allreduce(tensors).wait(timedelta(seconds=5)) # All ranks besides 0 call into allreduce. This is to simulate a # desync across the world, where some ranks call into # monitored_barrier() and others are stuck in collective comm. In @@ -7379,6 +7381,8 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks ) + self._barrier(timeout=30) + @with_nccl_blocking_wait @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"})