diff --git a/c10/util/Logging.h b/c10/util/Logging.h index 0435123ea8bd..6b5bc5d12cd8 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -319,6 +319,7 @@ struct DDPLoggingData { int bucket_cap_mb; bool find_unused_parameters; bool gradient_as_bucket_view; + std::string backend_name; }; namespace detail { diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 9502b3236874..d733654250f9 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1191,7 +1191,8 @@ that adds a prefix to each key inserted to the store. .def_readwrite("broadcast_buffers", &c10::DDPLoggingData::broadcast_buffers) .def_readwrite("bucket_cap_mb", &c10::DDPLoggingData::bucket_cap_mb) .def_readwrite("find_unused_parameters", &c10::DDPLoggingData::find_unused_parameters) - .def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view); + .def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view) + .def_readwrite("backend_name", &c10::DDPLoggingData::backend_name); module.def( "_compute_bucket_assignment_by_size", diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index ea4b1428038a..d8ec4b31d36d 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -170,6 +170,10 @@ class ProcessGroup : public torch::CustomClassHolder { return size_; } + virtual const std::string getBackendName() const { + return "undefined"; + } + virtual c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) = 0; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 0508b6f857a1..d0befc95e97f 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -27,6 +27,8 @@ namespace c10d { +constexpr const char* GLOO_BACKEND_NAME = "gloo"; + // ProcessGroupGloo implements Gloo bindings for c10d. // // All functions on this class are expected to be called in the same @@ -133,6 +135,10 @@ class ProcessGroupGloo : public ProcessGroup { int threads; }; + const std::string getBackendName() const override { + return std::string(GLOO_BACKEND_NAME); + } + // Helper functions to create a new device object. // They are static functions on this class to keep them logically // separate from the rest of the code base (e.g. torch/csrc/distributed). diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 420c78ef028a..4e5b7f5e619c 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -16,6 +16,8 @@ namespace c10d { +constexpr const char* MPI_BACKEND_NAME = "mpi"; + // WorkEntry is the state associated with a single MPI run instance. // It include the source Tensor list and destination Tensor list, as well as // The actual run function that will operate either on src or dst or both. @@ -108,6 +110,10 @@ class ProcessGroupMPI : public ProcessGroup { // Abort the MPI program, needs to be called when exception is detected void abort(); + const std::string getBackendName() const override { + return std::string(MPI_BACKEND_NAME); + } + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 4d9dc3bd1ae8..bc47b06245bf 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -32,6 +32,8 @@ constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT"; // Handling with NCCL. constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; +constexpr const char* NCCL_BACKEND_NAME = "nccl"; + // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -227,6 +229,10 @@ class ProcessGroupNCCL : public ProcessGroup { virtual ~ProcessGroupNCCL(); + const std::string getBackendName() const override { + return std::string(NCCL_BACKEND_NAME); + } + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index a8c2eba115a6..6ce7a7804150 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -6,6 +6,8 @@ namespace c10d { +constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin"; + // ProcessGroupRoundRobin implements simple load balancing. // // It is constructed with multiple processes groups. Each call is dispatched to @@ -25,6 +27,10 @@ class ProcessGroupRoundRobin final : public ProcessGroup { ~ProcessGroupRoundRobin() override; + const std::string getBackendName() const override { + return std::string(ROUND_ROBIN_BACKEND_NAME); + } + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index d13c6c3e658a..3d6f949ab3e8 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -1481,6 +1481,7 @@ void Reducer::set_construction_logging_data( ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024); ddp_logging_data_->find_unused_parameters = find_unused_parameters_; ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_; + ddp_logging_data_->backend_name = process_group_->getBackendName(); } c10::DDPLoggingData Reducer::get_ddp_logging_data() { diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index 469cf32a8442..a20d5245e614 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -577,4 +577,17 @@ TEST(ProcessGroupGlooTest, testAlltoallCUDA) { } } +TEST(ProcessGroupGlooTest, testBackendName) { + { + TemporaryFile file; + const auto size = 2; + auto tests = CollectiveTest::initialize(file.path, size); + + for (auto i = 0; i < size; i++) { + EXPECT_EQ( + tests[i].getProcessGroup().getBackendName(), std::string(c10d::GLOO_BACKEND_NAME)); + } + } +} + #endif diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 5503b4cde866..b159bb6fb353 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -334,6 +334,13 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { } } +void testBackendName() { + auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI(); + if (pg->getBackendName() != std::string(c10d::MPI_BACKEND_NAME)) { + throw std::runtime_error("BOOM!"); + } +} + int main(int argc, char** argv) { #ifdef MPIEXEC // If we are within an openmpi mpirun, then skip the exec @@ -350,6 +357,7 @@ int main(int argc, char** argv) { testScatter(); testSendRecv(false); testSendRecv(true); + testBackendName(); std::cout << "Test successful" << std::endl; #else diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 6c8bec2d0f92..4075c9041745 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -470,3 +470,16 @@ TEST_F(ProcessGroupNCCLTest, testReduceScatter) { testReduceScatter(file.path, rank_, size_); } } + +TEST_F(ProcessGroupNCCLTest, testBackendName) { + if (skipTest()) { + return; + } + { + TemporaryFile file; + auto test = NCCLTestBase(file.path); + test.initialize(rank_, size_); + EXPECT_EQ( + test.getProcessGroup().getBackendName(), std::string(c10d::NCCL_BACKEND_NAME)); + } +} diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 462cf8a4d9e0..e65dff8c3989 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -3205,6 +3205,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): BACKEND == "nccl", "nccl does not support DDP on CPU models" ) def test_ddp_logging_data(self): + group, group_id, rank = self._init_global_test() model_DDP = copy.deepcopy(DDP_NET) model_DDP = nn.parallel.DistributedDataParallel(model_DDP) ddp_logging_data = model_DDP.get_ddp_logging_data() @@ -3219,6 +3220,7 @@ def test_ddp_logging_data(self): self.assertEqual(ddp_logging_data.bucket_cap_mb, 25) self.assertEqual(ddp_logging_data.find_unused_parameters, False) self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False) + self.assertEqual(ddp_logging_data.backend_name, dist.get_backend(group_id)) @skipIfNoTorchVision def test_SyncBatchNorm_process_group(self):