From cd986f06cb71876f558af2d672d74923ace0970f Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Wed, 22 Mar 2023 16:26:18 -0700 Subject: [PATCH 1/7] Adds ncclCommInitRankConfig features --- test/distributed/test_c10d_nccl.py | 54 ++++++++++++++++--- torch/csrc/distributed/c10d/NCCLUtils.hpp | 21 ++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 4 ++ .../distributed/c10d/ProcessGroupNCCL.hpp | 5 ++ torch/csrc/distributed/c10d/init.cpp | 39 ++++++++++++++ torch/nn/parallel/distributed.py | 8 +++ 6 files changed, 125 insertions(+), 6 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 521158fd2c9ea..cff84dff9be22 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -4,6 +4,7 @@ import math import os import random +import re import signal import sys import tempfile @@ -1291,6 +1292,21 @@ def _test_fp16(self, gradient_as_bucket_view=False): def test_fp16(self): self._test_fp16() + @requires_nccl() + @requires_nccl_version((2,17), "Need NCCL 2.17+ for configuring NCCL Communicators") + @skip_if_lt_x_gpu(2) + def test_ddp_default_cga(self): + nccl_debug_file = tempfile.NamedTemporaryFile() + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name + + self._test_fp16() + + # Tests if default CGA for DDP is 2 + nccl_debug_file_content = nccl_debug_file.read() + cga_cluster_size = re.search(rb'CGA cluster.*(\d+)|$', nccl_debug_file_content).group(1) + self.assertEqual(int(cga_cluster_size), 2) + @requires_nccl() @skip_if_lt_x_gpu(2) def test_fp16_grad_is_view(self): @@ -2713,12 +2729,7 @@ def test_sequence_num_set_nccl_new_group(self): torch.cuda.set_device(self.rank) self._test_sequence_num_set_new_group(backend="nccl") - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_pass_nccl_options_high_priority_stream(self): - pg_opts = c10d.ProcessGroupNCCL.Options() - pg_opts.is_high_priority_stream = True - + def _test_pass_nccl_options(self, pg_opts): store = c10d.FileStore(self.file_name, self.world_size) # Test init_process_group accepts options dist.init_process_group( @@ -2737,6 +2748,37 @@ def test_pass_nccl_options_high_priority_stream(self): expected_tensor = torch.tensor([3] * 10).cuda(self.rank) self.assertEqual(expected_tensor, t) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_pass_nccl_options_high_priority_stream(self): + pg_opts = c10d.ProcessGroupNCCL.Options() + pg_opts.is_high_priority_stream = True + self._test_pass_nccl_options(pg_opts) + + @requires_nccl() + @requires_nccl_version((2,17), "Need NCCL 2.17+ for configuring NCCL Communicators") + @skip_if_lt_x_gpu(2) + def test_pass_nccl_options_config(self): + pg_opts = c10d.ProcessGroupNCCL.Options() + pg_opts.config.max_ctas = 4 + pg_opts.config.min_ctas = 2 + pg_opts.config.cga_cluster_size = 2 + nccl_debug_file = tempfile.NamedTemporaryFile() + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name + + # Tests functionality when passing nccl config + self._test_pass_nccl_options(pg_opts) + + # Tests if comms were configured + nccl_debug_file_content = nccl_debug_file.read() + max_ctas = re.search(rb'Max CTAs.*(\d+)|$', nccl_debug_file_content).group(1) + min_ctas = re.search(rb'Min CTAs.*(\d+)|$', nccl_debug_file_content).group(1) + cga_cluster_size = re.search(rb'CGA cluster.*(\d+)|$', nccl_debug_file_content).group(1) + self.assertEqual(pg_opts.config.max_ctas, int(max_ctas)) + self.assertEqual(pg_opts.config.min_ctas, int(min_ctas)) + self.assertEqual(pg_opts.config.cga_cluster_size, int(cga_cluster_size)) + @requires_nccl() @skip_if_lt_x_gpu(4) def test_nccl_barrier(self): diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 0e0b98cd48702..73dbace47d128 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -52,6 +52,12 @@ #define ENABLE_NCCL_PREMUL_SUM_SUPPORT #endif +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17) +#define ENABLE_NCCL_RANK_CONFIG +#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) +#define ENABLE_NCCL_RANK_CONFIG +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -195,6 +201,21 @@ class NCCLComm { return comm; } +#ifdef ENABLE_NCCL_RANK_CONFIG + static std::shared_ptr create( + int numRanks, + int rank, + ncclUniqueId commId, + ncclConfig_t& config) { + auto comm = std::make_shared(); + C10D_NCCL_CHECK( + ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt); + comm->ncclId_ = commId; + comm->rank_ = rank; + return comm; + } +#endif + ncclUniqueId getNcclId() { return ncclId_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 35ce10ab261ca..84d35b4684d0d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1156,7 +1156,11 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( int deviceIndex = devices[i].index(); gpuGuard.set_index(deviceIndex); +#ifdef ENABLE_NCCL_RANK_CONFIG + ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config); +#else ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID); +#endif // Creates the NCCL streams streamVal.push_back( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 3c4c12500a79f..51ddd3c054e33 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -279,6 +279,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Schedule NCCL operations on high priority CUDA streams bool is_high_priority_stream; + +#ifdef ENABLE_NCCL_RANK_CONFIG + // Configure ranks + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; +#endif }; // If you wish to create multiple process groups, each with a potentially diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e78d5ca87ba53..18dec328f702e 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2134,6 +2134,21 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). .def_property_readonly( "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable); +#ifdef ENABLE_NCCL_RANK_CONFIG + py::class_(processGroupNCCL, "NCCLConfig", + R"( +ncclConfig_t data type for configuring NCCL communicators. +See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t +for details. +)") + .def(py::init<>()) + .def_readwrite("blocking", &ncclConfig_t::blocking) + .def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize) + .def_readwrite("min_ctas", &ncclConfig_t::minCTAs) + .def_readwrite("max_ctas", &ncclConfig_t::maxCTAs) + .def_readwrite("net_name", &ncclConfig_t::netName); +#endif + intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( processGroupNCCL, "Options", @@ -2147,18 +2162,42 @@ ProcessGroup options for the NCCL backend to prioritize NCCL kernels when there are compute kernels waiting. Default is False. +Attributes: + config (NCCLConfig): configures NCCL communicators (only avaiable for + builds using NCCL 2.17+). This can be used to improve + communication-computation overlap for NCCL kernels by tuning + available parameters in the config. See + https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + for details. + Example:: >>> import torch.distributed as dist >>> >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) + >>> # For builds using NCCL 2.17+, configure communicators + >>> nccl_options.config.cga_cluster_size = 2 + >>> nccl_options.config.max_ctas = 4 + >>> nccl_options.config.min_ctas = 2 >>> # initialize a nccl process group with the options just created >>> dist.init_process_group("nccl", pg_options=nccl_options) )") .def(py::init(), py::arg("is_high_priority_stream") = false) +#ifdef ENABLE_NCCL_RANK_CONFIG + .def_readwrite( + "is_high_priority_stream", + &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream) + .def_readwrite( + "config", + &::c10d::ProcessGroupNCCL::Options::config); +#else .def_readwrite( "is_high_priority_stream", &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); #endif + processGroupNCCL.def_static( + "_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); }); + processGroupNCCL.def_static( + "_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); }); #ifdef USE_C10D_MPI auto processGroupMPI = diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 02fa9a13b37ea..2c06f01797a98 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -749,6 +749,14 @@ def __init__( else: self.process_group = process_group + if self.process_group.options.backend == "nccl" and torch.cuda.nccl.version() >= (2, 17): + # Note: NVIDIA recommends using CGA Cluster Size of 2 when using DDP. + default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size + default_pg_nccl = self.process_group._get_backend(torch.device("cuda")) + current_cga = default_pg_nccl.options.config.cga_cluster_size + if current_cga == default_cga: + default_pg_nccl.options.config.cga_cluster_size = 2 + self.static_graph = False self.dim = dim self.module = module From a86d4c7e8ab79e5ff4c216190fa98c97103b2520 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Fri, 24 Mar 2023 12:15:18 -0700 Subject: [PATCH 2/7] Fixes lint and error --- test/distributed/test_c10d_nccl.py | 4 ++-- torch/csrc/distributed/c10d/init.cpp | 10 +++++----- torch/nn/parallel/distributed.py | 7 +++++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index cff84dff9be22..a311724c25304 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1293,7 +1293,7 @@ def test_fp16(self): self._test_fp16() @requires_nccl() - @requires_nccl_version((2,17), "Need NCCL 2.17+ for configuring NCCL Communicators") + @requires_nccl_version((2, 17), "Need NCCL 2.17+ for configuring NCCL communicators") @skip_if_lt_x_gpu(2) def test_ddp_default_cga(self): nccl_debug_file = tempfile.NamedTemporaryFile() @@ -2756,7 +2756,7 @@ def test_pass_nccl_options_high_priority_stream(self): self._test_pass_nccl_options(pg_opts) @requires_nccl() - @requires_nccl_version((2,17), "Need NCCL 2.17+ for configuring NCCL Communicators") + @requires_nccl_version((2, 17), "Need NCCL 2.17+ for configuring NCCL communicators") @skip_if_lt_x_gpu(2) def test_pass_nccl_options_config(self): pg_opts = c10d.ProcessGroupNCCL.Options() diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 18dec328f702e..cdee605850a28 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2135,8 +2135,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable); #ifdef ENABLE_NCCL_RANK_CONFIG - py::class_(processGroupNCCL, "NCCLConfig", - R"( + py::class_( + processGroupNCCL, + "NCCLConfig", + R"( ncclConfig_t data type for configuring NCCL communicators. See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t for details. @@ -2186,9 +2188,7 @@ Example:: .def_readwrite( "is_high_priority_stream", &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream) - .def_readwrite( - "config", - &::c10d::ProcessGroupNCCL::Options::config); + .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config); #else .def_readwrite( "is_high_priority_stream", diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 2c06f01797a98..dc52516a9b1b2 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -749,9 +749,12 @@ def __init__( else: self.process_group = process_group - if self.process_group.options.backend == "nccl" and torch.cuda.nccl.version() >= (2, 17): + if ( + dist.get_backend(self.process_group) == "nccl" + and torch.cuda.nccl.version() >= (2, 17) + ): # Note: NVIDIA recommends using CGA Cluster Size of 2 when using DDP. - default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size + default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size # type: ignore[attr-defined] default_pg_nccl = self.process_group._get_backend(torch.device("cuda")) current_cga = default_pg_nccl.options.config.cga_cluster_size if current_cga == default_cga: From 5536c357ea3e39be2007d01bb9b78aa5e41a8cc2 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Tue, 28 Mar 2023 11:40:06 -0700 Subject: [PATCH 3/7] Fixes lint --- torch/nn/parallel/distributed.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index dc52516a9b1b2..40f99be826dca 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -749,12 +749,11 @@ def __init__( else: self.process_group = process_group - if ( - dist.get_backend(self.process_group) == "nccl" - and torch.cuda.nccl.version() >= (2, 17) - ): + if dist.get_backend( + self.process_group + ) == "nccl" and torch.cuda.nccl.version() >= (2, 17): # Note: NVIDIA recommends using CGA Cluster Size of 2 when using DDP. - default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size # type: ignore[attr-defined] + default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size # type: ignore[attr-defined] default_pg_nccl = self.process_group._get_backend(torch.device("cuda")) current_cga = default_pg_nccl.options.config.cga_cluster_size if current_cga == default_cga: From c45e6a2b0b61d35643c7c404fe02fccca52eebd1 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Tue, 18 Apr 2023 11:55:09 -0700 Subject: [PATCH 4/7] Addresses review --- test/distributed/test_c10d_nccl.py | 15 ------------ torch/csrc/distributed/c10d/NCCLUtils.hpp | 23 ++++++++----------- .../distributed/c10d/ProcessGroupNCCL.cpp | 2 +- .../distributed/c10d/ProcessGroupNCCL.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 4 ++-- torch/nn/parallel/distributed.py | 10 -------- 6 files changed, 14 insertions(+), 42 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index a311724c25304..6f3bee7c5bc50 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1292,21 +1292,6 @@ def _test_fp16(self, gradient_as_bucket_view=False): def test_fp16(self): self._test_fp16() - @requires_nccl() - @requires_nccl_version((2, 17), "Need NCCL 2.17+ for configuring NCCL communicators") - @skip_if_lt_x_gpu(2) - def test_ddp_default_cga(self): - nccl_debug_file = tempfile.NamedTemporaryFile() - os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name - - self._test_fp16() - - # Tests if default CGA for DDP is 2 - nccl_debug_file_content = nccl_debug_file.read() - cga_cluster_size = re.search(rb'CGA cluster.*(\d+)|$', nccl_debug_file_content).group(1) - self.assertEqual(int(cga_cluster_size), 2) - @requires_nccl() @skip_if_lt_x_gpu(2) def test_fp16_grad_is_view(self): diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 73dbace47d128..18dd291bed365 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -53,9 +53,9 @@ #endif #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17) -#define ENABLE_NCCL_RANK_CONFIG +#define NCCL_HAS_COMM_CTA_CGA #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) -#define ENABLE_NCCL_RANK_CONFIG +#define NCCL_HAS_COMM_CTA_CGA #endif // Macro to throw on a non-successful NCCL return value. @@ -185,31 +185,28 @@ class NCCLComm { int rank, ncclUniqueId commId) { auto comm = std::make_shared(); -#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt); -#else - ncclConfig_t config = NCCL_CONFIG_INITIALIZER; - if (nccl_use_nonblocking()) { - config.blocking = 0; - } - C10D_NCCL_CHECK_TIMEOUT( - ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); -#endif comm->ncclId_ = commId; comm->rank_ = rank; return comm; } -#ifdef ENABLE_NCCL_RANK_CONFIG +#ifdef NCCL_HAS_COMM_NONBLOCKING static std::shared_ptr create( int numRanks, int rank, ncclUniqueId commId, ncclConfig_t& config) { auto comm = std::make_shared(); - C10D_NCCL_CHECK( + if (nccl_use_nonblocking()) { + config.blocking = 0; + C10D_NCCL_CHECK_TIMEOUT( + ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); + else { + C10D_NCCL_CHECK( ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt); + } comm->ncclId_ = commId; comm->rank_ = rank; return comm; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 84d35b4684d0d..9ac3f14850da2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1156,7 +1156,7 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( int deviceIndex = devices[i].index(); gpuGuard.set_index(deviceIndex); -#ifdef ENABLE_NCCL_RANK_CONFIG +#ifdef NCCL_HAS_COMM_NONBLOCKING ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config); #else ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 51ddd3c054e33..57a2e540e8be0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -280,7 +280,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Schedule NCCL operations on high priority CUDA streams bool is_high_priority_stream; -#ifdef ENABLE_NCCL_RANK_CONFIG +#ifdef NCCL_HAS_COMM_NONBLOCKING // Configure ranks ncclConfig_t config = NCCL_CONFIG_INITIALIZER; #endif diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index cdee605850a28..cd2447bfecb50 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2134,7 +2134,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). .def_property_readonly( "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable); -#ifdef ENABLE_NCCL_RANK_CONFIG +#ifdef NCCL_HAS_COMM_CTA_CGA py::class_( processGroupNCCL, "NCCLConfig", @@ -2184,7 +2184,7 @@ Example:: >>> dist.init_process_group("nccl", pg_options=nccl_options) )") .def(py::init(), py::arg("is_high_priority_stream") = false) -#ifdef ENABLE_NCCL_RANK_CONFIG +#ifdef NCCL_HAS_COMM_CTA_CGA .def_readwrite( "is_high_priority_stream", &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 40f99be826dca..02fa9a13b37ea 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -749,16 +749,6 @@ def __init__( else: self.process_group = process_group - if dist.get_backend( - self.process_group - ) == "nccl" and torch.cuda.nccl.version() >= (2, 17): - # Note: NVIDIA recommends using CGA Cluster Size of 2 when using DDP. - default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size # type: ignore[attr-defined] - default_pg_nccl = self.process_group._get_backend(torch.device("cuda")) - current_cga = default_pg_nccl.options.config.cga_cluster_size - if current_cga == default_cga: - default_pg_nccl.options.config.cga_cluster_size = 2 - self.static_graph = False self.dim = dim self.module = module From 858f81d21a8fa238956a1eab25f387415e085f44 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Tue, 18 Apr 2023 11:59:59 -0700 Subject: [PATCH 5/7] Cleanup --- torch/csrc/distributed/c10d/init.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index cd2447bfecb50..2869b3bea3da1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2194,10 +2194,6 @@ Example:: "is_high_priority_stream", &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); #endif - processGroupNCCL.def_static( - "_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); }); - processGroupNCCL.def_static( - "_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); }); #ifdef USE_C10D_MPI auto processGroupMPI = From 23959f1767befb906c01d63a1cb22f23a08e7de9 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Tue, 18 Apr 2023 12:01:26 -0700 Subject: [PATCH 6/7] Cleanup --- torch/csrc/distributed/c10d/init.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2869b3bea3da1..4a697a217089c 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2195,6 +2195,8 @@ Example:: &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); #endif +#endif + #ifdef USE_C10D_MPI auto processGroupMPI = intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>( From e6fa35b4927ef1425dc7e41af70020be6e81bc16 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Wed, 19 Apr 2023 10:33:49 -0700 Subject: [PATCH 7/7] Adds missing braces --- torch/csrc/distributed/c10d/NCCLUtils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 18dd291bed365..0ac84c6df52dd 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -203,7 +203,7 @@ class NCCLComm { config.blocking = 0; C10D_NCCL_CHECK_TIMEOUT( ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); - else { + } else { C10D_NCCL_CHECK( ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt); }