From 06fa6c15c0d12fe037182a9b5b8100e057585a52 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 14 Oct 2021 22:22:08 -0700 Subject: [PATCH] Back out "Revert D31299350: Back out "Revert D31005792: [NCCL] Init dummy NCCL comms in constructor"" (#66393) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66393 Third try! Fixes: - test_nccl_timeout can be flaky because of 1s timeout, bump up the timeout to resolve the flakiness. But in general we should not have been relying on time.sleep for this test, filed https://github.com/pytorch/pytorch/issues/66354 to track that. - ciflow/all did not actually run tests due to a bug causing multigpu tests to not be run. This has since been fixed. ghstack-source-id: 140560113 Test Plan: CI Reviewed By: mrshenli Differential Revision: D31534735 fbshipit-source-id: 8b7e0f4fed3972b7a77cbcda28876c9eefb0c7e2 --- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 79 ++++++++++++- test/distributed/test_c10d_nccl.py | 40 +++---- test/distributed/test_jit_c10d.py | 8 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 106 +++++++++++++++++- .../distributed/c10d/ProcessGroupNCCL.hpp | 12 ++ .../_internal/distributed/distributed_test.py | 8 +- 6 files changed, 219 insertions(+), 34 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index c8ef529f433c..410e7470642d 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -1,9 +1,12 @@ +#include #include #include #include #include "CUDATest.hpp" #include "TestUtils.hpp" +#include "c10d/Types.hpp" +#include "c10d/ProcessGroup.hpp" #include #include @@ -19,7 +22,10 @@ using c10d::ProcessGroup; class NCCLTestBase { public: - NCCLTestBase(const std::string& path) : path_(path) {} + NCCLTestBase( + const std::string& path, + const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout + ) : path_(path), pgTimeout_(pgTimeout) {} NCCLTestBase(NCCLTestBase&& other) { path_ = std::move(other.path_); @@ -33,19 +39,22 @@ class NCCLTestBase { void initialize(int rank, int size) { auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); + c10::intrusive_ptr opts = c10::make_intrusive(); + opts->timeout = pgTimeout_; pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( - new ::c10d::ProcessGroupNCCL(store, rank, size)); + new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts))); } protected: std::string path_; std::unique_ptr<::c10d::ProcessGroupNCCL> pg_; + std::chrono::milliseconds pgTimeout_; }; class NCCLTest : public NCCLTestBase { public: - NCCLTest(const std::string& path, int worldSize) - : NCCLTestBase(path), + NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) + : NCCLTestBase(path, pgTimeout), numDevices_(cudaNumDevices()), state_(::at::globalContext().lazyInitCUDA()), worldSize_(worldSize) { @@ -497,10 +506,50 @@ void testReduceScatter(const std::string& path, int rank, int size) { } } +void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) { + // simulate world_size > 1 here via threads. + const int worldSize = 4; + std::mutex m; + std::unordered_set nums; + auto runTest = [&](int i) { + NCCLTest test(path, worldSize, std::chrono::milliseconds(3000)); + // Catch error relating to health check failure + bool error_caught = false; + try { + test.initialize(timeout ? 0 : -1, worldSize); + } catch (const std::exception &e) { + std::string errMsg = e.what(); + const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank"; + const std::string kInvalidRankErr = "Invalid rank"; + std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr; + bool cond = errMsg.find(expectedSubstr) != std::string::npos; + EXPECT_TRUE(cond); + error_caught = true; + } + EXPECT_TRUE(error_caught); + }; + std::vector threads; + threads.reserve(worldSize); + for (const auto r : c10::irange(worldSize)) { + threads.emplace_back(std::thread([=]() { runTest(r); })); + } + for (auto& t : threads) { + t.join(); + } +} + +void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) { + testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false); +} + +void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) { + testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true); +} + void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) { // Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we // simulate world_size > 1 here via threads. - const int worldSize = 4; + const int worldSize = 2; std::mutex m; std::unordered_set nums; auto runTest = [&](int i) { @@ -625,6 +674,26 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) { } } +TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) { + if (skipTest()) { + return; + } + { + TemporaryFile file; + testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_); + } +} + +TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) { + if (skipTest()) { + return; + } + { + TemporaryFile file; + testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_); + } +} + TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) { if (skipTest()) { return; diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index d5b88517de4d..4cecca910976 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -657,11 +657,13 @@ def test_nccl_propagate_error_reason(self): # otherwise process will be taken down and we can't check for errors. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["NCCL_BLOCKING_WAIT"] = "1" - timeout = timedelta(seconds=2) + # TODO: smaller timeout can fail since PG NCCl does health check in + # constructor. Look into reducing this test's runtime. store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout) + # provide sufficient timeout to initialize NCCL comm. + pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=15)) pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size) - pg.barrier().wait() + pg.barrier().wait(timedelta(seconds=5)) # Simulate stuckness in rank 0. if self.rank == 0: pg_gloo.barrier().wait() @@ -670,7 +672,7 @@ def test_nccl_propagate_error_reason(self): if self.rank != 0: # Time out due to rank 0 not calling into allreduce. with self.assertRaises(RuntimeError): - pg.allreduce([inp]).wait() + pg.allreduce([inp]).wait(timedelta(seconds=5)) # Now when nonzero rank attempts to use communicator, original failure reason should be logged.j try: @@ -2264,14 +2266,14 @@ def _test_nccl_errors_blocking(self, func): store, self.rank, self.world_size, - timeout=timedelta(seconds=self.op_timeout_sec), + timeout=timedelta(seconds=10), ) process_group.allreduce(torch.rand(10).cuda(self.rank)) if self.rank == 0: work = process_group.allreduce(torch.rand(10).cuda(self.rank)) with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # Operation would time out in blocking mode. - work.wait() + work.wait(timeout=timedelta(seconds=self.op_timeout_sec)) # Run some GPU operations to make sure cuda has not gotten stuck. # It was observed cuda could get stuck if NCCL communicators were # not properly aborted before throwing RuntimeError. @@ -2340,13 +2342,13 @@ def test_nccl_blocking_wait_with_barrier(self): store, self.rank, self.world_size, - timeout=timedelta(seconds=self.op_timeout_sec), + timeout=timedelta(seconds=10), ) process_group.barrier().wait() if self.rank == 0: with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # This should timeout - process_group.barrier().wait() + process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec)) def _run_invalid_nccl_blocking_wait_env(self, val): os.environ["NCCL_BLOCKING_WAIT"] = val @@ -2383,21 +2385,19 @@ def test_nccl_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) # Initialize process_group. - timeout = 1 process_group = c10d.ProcessGroupNCCL( - store, self.rank, self.world_size, timeout=timedelta(seconds=timeout) + store, self.rank, self.world_size, timeout=timedelta(seconds=10) ) - process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() + process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=5)) if self.rank == 0: # This should timeout in about 1 second. - start = time.time() # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out. with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): - process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() + process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1)) else: # Sleep to ensure timeout. - time.sleep(2 * timeout) + time.sleep(10) self._wait_for_comm_abort(process_group) @@ -2547,14 +2547,14 @@ def test_nccl_barrier_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) if self.rank == 0: with self.assertRaisesRegex( - RuntimeError, "Timed out initializing process group" + RuntimeError, "Health check failure" ): c10d.init_process_group( backend="nccl", rank=self.rank, world_size=self.world_size, store=store, - timeout=timedelta(seconds=1), + timeout=timedelta(seconds=10), ) @requires_nccl() @@ -2566,12 +2566,12 @@ def test_nccl_barrier_timeout_new_group(self): rank=self.rank, world_size=self.world_size, store=store, - timeout=timedelta(seconds=1), + timeout=timedelta(seconds=10), ) if self.rank == 0: with self.assertRaisesRegex( - RuntimeError, "Timed out initializing process group" + RuntimeError, "Health check failure" ): c10d.new_group([0, 1], timeout=timedelta(seconds=1)) @@ -2589,12 +2589,12 @@ def test_nccl_barrier_timeout_new_group_non_member(self): rank=self.rank, world_size=self.world_size, store=store, - timeout=timedelta(seconds=1), + timeout=timedelta(seconds=10), ) if self.rank == 1: with self.assertRaisesRegex( - RuntimeError, "Timed out initializing process group" + RuntimeError, "Health check failure" ): c10d.new_group([0, 1], timeout=timedelta(seconds=1)) diff --git a/test/distributed/test_jit_c10d.py b/test/distributed/test_jit_c10d.py index 65d82fb033b7..324070e02b2a 100644 --- a/test/distributed/test_jit_c10d.py +++ b/test/distributed/test_jit_c10d.py @@ -39,7 +39,7 @@ def setUp(self): def _create_nccl_pg(self, name_prefix): tcp_store = create_tcp_store(jit_class=True) - opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True) + opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(10000, True) name = unique_process_group_name(name_prefix) @@ -49,7 +49,7 @@ def _create_nccl_pg_as_base_process_group(self, name): tcp_store = create_tcp_store(jit_class=True) return torch.classes.dist_c10d.frontend().new_process_group_helper( - self.world_size, self.rank, [], "nccl", tcp_store, name, 0) + self.world_size, self.rank, [], "nccl", tcp_store, name, 10000) @requires_nccl() @sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") @@ -172,7 +172,7 @@ def test_frontend_singleton(self): pg_name = unique_process_group_name("singleton_test_process_group") ProcessGroupNCCL1 = frontend1.new_process_group_helper( - self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0) + self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 10000) ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name) self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name) @@ -189,7 +189,7 @@ def __init__(self): name = unique_process_group_name("module_member_process_group") self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper( - 1, 0, [], "nccl", tcp_store, name, 0) + 1, 0, [], "nccl", tcp_store, name, 10000) def forward(self, input: torch.Tensor): if self.pg is None: diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 2d31d25658df..b1507af53ce5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -12,15 +12,18 @@ #include #include +#include +#include +#include #include #include +#include #include #include #include -#include + #include -#include namespace c10d { @@ -159,6 +162,14 @@ std::vector getDeviceList(const std::vector& tensors) { return res; } +// Return CUDA device with ordinal given by input rank. +at::Device getDeviceForRank(int rank) { + TORCH_CHECK(rank >= 0, "Invalid rank ", rank); + auto numGPUs = at::cuda::getNumGPUs(); + int16_t deviceIdx = static_cast(rank % numGPUs); + return at::Device(at::DeviceType::CUDA, deviceIdx); +} + // [Sync Streams] Helper that lets the input ncclStreams to wait for the current // stream. NCCL communications run on ncclStreams, but input tensors are // allocated on different streams (i.e., current streams). Communications on @@ -502,6 +513,13 @@ ProcessGroupNCCL::ProcessGroupNCCL( asyncErrorHandling_ = false; } + // Perform health check by initializing dummy communicators and destroying + // them. This will help indicate any NCCL-related issues prior to the first + // collective. + // Run it in a separate thread and wait on CV to handle timeouts, since + // majority of getNCCLComm failures are hangs. + runHealthCheck(); + #ifdef ENABLE_NCCL_ERROR_CHECKING ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); @@ -527,6 +545,64 @@ ProcessGroupNCCL::ProcessGroupNCCL( << "\nNCCL_DEBUG: " << ncclDebugLevel; } +void ProcessGroupNCCL::runHealthCheck() { + // Run health check in a separate thread and wait on CV to handle timeouts, + // since majority of getNCCLComm failures are hangs. + + struct HealthCheckData { + std::mutex healthCheckMutex; + std::condition_variable healthCheckCv; + bool healthCheckSuccess = false; + std::exception_ptr healthCheckException; + }; + + HealthCheckData healthCheckData; + auto t = std::thread([&healthCheckData, this]() { + try { + std::vector rankDevice = {getDeviceForRank(rank_)}; + const auto key = getKeyFromDevices(rankDevice); + // OpType does not matter, only need to set to not go through send/recv + // path. + getNCCLComm(key, rankDevice, OpType::ALLREDUCE); + // Now destroy the communicators and remove them from cache so we don't + // use destroyed communicators. + destroyNCCLComms(key); + // Notify main thread the health check is complete. + { + std::lock_guard lk(healthCheckData.healthCheckMutex); + healthCheckData.healthCheckSuccess = true; + } + healthCheckData.healthCheckCv.notify_one(); + } catch (const std::exception& e) { + // Populate exception ptr. + healthCheckData.healthCheckException = std::current_exception(); + // Unblock waiting main thread which will report exception. + healthCheckData.healthCheckCv.notify_one(); + } // Unknown exceptions will just cause the program to terminate. + }); + // We don't need to join the thread, just need to verify health check via the + // CV. Hence we detach the thread here. + t.detach(); // NOLINT + LOG(INFO) << "[Rank " << rank_ << "]" + << " will wait up to " << options_->timeout.count() + << " msec for NCCL health check to complete."; + std::unique_lock lock(healthCheckData.healthCheckMutex); + healthCheckData.healthCheckCv.wait_for( + lock, options_->timeout, [&healthCheckData]() { + return healthCheckData.healthCheckSuccess; + }); + + if (healthCheckData.healthCheckException) { + std::rethrow_exception(healthCheckData.healthCheckException); + } + // If there is no exception, the likely culprit is a timeout/hang which is how + // most communicator init issues manifest themselves. + TORCH_CHECK( + healthCheckData.healthCheckSuccess, + "ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank ", + rank_); +} + void ProcessGroupNCCL::setSequenceNumberForGroup() { if (rank_ == 0) { // Create and broadcast sequence number @@ -874,6 +950,30 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } } + +void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { + std::lock_guard lock(mutex_); + if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { + TORCH_INTERNAL_ASSERT( + false, + "Expected to find key ", + devNCCLCommMapKey, + " in NCCL communicator map."); + } + std::vector>& ncclComms = + devNCCLCommMap_[devNCCLCommMapKey]; + // Loop through communicators and call ncclCommAbort. + for (const auto& comm : ncclComms) { + // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being + // destroyed, so using ncclCommAbort here. + comm->ncclCommAbort(); + } + // Remove communicators from the cache. + devNCCLCommMap_.erase(devNCCLCommMapKey); + // Clear used device indices. + usedDeviceIdxs_.clear(); +} + std::vector>& ProcessGroupNCCL::getNCCLComm( const std::string& devicesKey, const std::vector& devices, @@ -1697,7 +1797,7 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( "This can potentially cause a hang if this rank to GPU mapping is incorrect.", "Specify device_ids in barrier() to force use of a particular device." ); - devices.emplace_back(at::DeviceType::CUDA, deviceIdx); + devices.emplace_back(getDeviceForRank(rank_)); } else { for (auto usedDeviceIdx : usedDeviceIdxs_) { devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index a35b4681acf9..12e856f31fdb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -430,6 +430,18 @@ class TORCH_API ProcessGroupNCCL : public ProcessGroup { void abortTimedOutCollectives( std::unordered_set& abortedCommIds); + // Performs a health check by initializing dummy NCCL communicators and then + // destroying them. This will help indicate and signal any NCCL-related issues + // prior to the first collective. The actual initialization and subsequent + // destruction is ran on a separate thread and the main thread is signalled + // about timeouts/errors to report to the application. + void runHealthCheck(); + + // Destroys initialized NCCL communicators in devNCCLComMap_ given by input + // key. Throws if there are no communicators to destroy. Also removes + // communicators from the cache and clears used device indices. + void destroyNCCLComms(const std::string& devNCCLCommMapKey); + void workCleanupLoop(); protected: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a03a7fa00942..ff07f8ec1e7a 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7334,7 +7334,9 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): # tests expected behavior when nonzero rank hangs. nccl_pg = dist.new_group( ranks=list(i for i in range(int(self.world_size))), - timeout=timedelta(seconds=2), + # provide sufficient timeout so communicators + # can be initialized in ctor. + timeout=timedelta(seconds=15), backend=dist.Backend.NCCL, ) gloo_pg = dist.new_group( @@ -7345,7 +7347,7 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): # Let all ranks call allreduce first to set up communicators etc. # Directly simulating error here will run into store issue described # in https://github.com/pytorch/pytorch/issues/54524. - nccl_pg.allreduce(tensors).wait() + nccl_pg.allreduce(tensors).wait(timedelta(seconds=5)) # All ranks besides 0 call into allreduce. This is to simulate a # desync across the world, where some ranks call into # monitored_barrier() and others are stuck in collective comm. In @@ -7379,6 +7381,8 @@ def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks ) + self._barrier(timeout=30) + @with_nccl_blocking_wait @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"})