Skip to content

Commit

Permalink
add a c++ interface in processGroup to get its backend name
Browse files Browse the repository at this point in the history
backend name of a processgroup created using distributed_c10d python API is tracked, but there is no good way to track name of a processgroup created using processGroup c++ API. In some cases, knowing backend name of a processGroup is useful, e,g., log the backend name, or write some codes that have dependency on the known backend.

Differential Revision: [D26059769](https://our.internmc.facebook.com/intern/diff/D26059769/)

ghstack-source-id: 120349147
Pull Request resolved: #51066
  • Loading branch information
zhaojuanmao committed Jan 26, 2021
1 parent 8096a50 commit 089a879
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 1 deletion.
1 change: 1 addition & 0 deletions c10/util/Logging.h
Expand Up @@ -319,6 +319,7 @@ struct DDPLoggingData {
int bucket_cap_mb;
bool find_unused_parameters;
bool gradient_as_bucket_view;
std::string backend_name;
};

namespace detail {
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -1191,7 +1191,8 @@ that adds a prefix to each key inserted to the store.
.def_readwrite("broadcast_buffers", &c10::DDPLoggingData::broadcast_buffers)
.def_readwrite("bucket_cap_mb", &c10::DDPLoggingData::bucket_cap_mb)
.def_readwrite("find_unused_parameters", &c10::DDPLoggingData::find_unused_parameters)
.def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view);
.def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view)
.def_readwrite("backend_name", &c10::DDPLoggingData::backend_name);

module.def(
"_compute_bucket_assignment_by_size",
Expand Down
10 changes: 10 additions & 0 deletions torch/lib/c10d/ProcessGroup.hpp
Expand Up @@ -42,6 +42,12 @@ enum class OpType : std::uint8_t {
UNKNOWN = 100,
};

const std::string GLOO_BACKEND_NAME = "gloo";
const std::string NCCL_BACKEND_NAME = "nccl";
const std::string MPI_BACKEND_NAME = "mpi";
const std::string ROUND_ROBIN_BACKEND_NAME = "round_robin";
const std::string UNDEFINED = "undefined";

// Converts OpType to human readable string.
std::string opTypeToString(OpType opType);

Expand Down Expand Up @@ -170,6 +176,10 @@ class ProcessGroup : public torch::CustomClassHolder {
return size_;
}

virtual const std::string getBackendName() const {
return UNDEFINED;
}

virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) = 0;
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.hpp
Expand Up @@ -133,6 +133,10 @@ class ProcessGroupGloo : public ProcessGroup {
int threads;
};

const std::string getBackendName() const {
return GLOO_BACKEND_NAME;
}

// Helper functions to create a new device object.
// They are static functions on this class to keep them logically
// separate from the rest of the code base (e.g. torch/csrc/distributed).
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupMPI.hpp
Expand Up @@ -108,6 +108,10 @@ class ProcessGroupMPI : public ProcessGroup {
// Abort the MPI program, needs to be called when exception is detected
void abort();

const std::string getBackendName() const {
return MPI_BACKEND_NAME;
}

c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -227,6 +227,10 @@ class ProcessGroupNCCL : public ProcessGroup {

virtual ~ProcessGroupNCCL();

const std::string getBackendName() const {
return NCCL_BACKEND_NAME;
}

c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupRoundRobin.hpp
Expand Up @@ -25,6 +25,10 @@ class ProcessGroupRoundRobin final : public ProcessGroup {

~ProcessGroupRoundRobin() override;

const std::string getBackendName() const {
return ROUND_ROBIN_BACKEND_NAME;
}

c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
Expand Down
1 change: 1 addition & 0 deletions torch/lib/c10d/reducer.cpp
Expand Up @@ -1481,6 +1481,7 @@ void Reducer::set_construction_logging_data(
ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024);
ddp_logging_data_->find_unused_parameters = find_unused_parameters_;
ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_;
ddp_logging_data_->backend_name = process_group_->getBackendName();
}

c10::DDPLoggingData Reducer::get_ddp_logging_data() {
Expand Down
13 changes: 13 additions & 0 deletions torch/lib/c10d/test/ProcessGroupGlooTest.cpp
Expand Up @@ -577,4 +577,17 @@ TEST(ProcessGroupGlooTest, testAlltoallCUDA) {
}
}

TEST(ProcessGroupGlooTest, testBackendName) {
{
TemporaryFile file;
const auto size = 2;
auto tests = CollectiveTest::initialize(file.path, size);

for (auto i = 0; i < size; i++) {
EXPECT_EQ(
tests[i].getProcessGroup().getBackendName(), c10d::GLOO_BACKEND_NAME);
}
}
}

#endif
8 changes: 8 additions & 0 deletions torch/lib/c10d/test/ProcessGroupMPITest.cpp
Expand Up @@ -334,6 +334,13 @@ void testSendRecv(bool recvAnysource, int iter = 10000) {
}
}

void testBackendName() {
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
if (pg->getBackendName() != MPI_BACKEND_NAME) {
throw std::runtime_error("BOOM!");
}
}

int main(int argc, char** argv) {
#ifdef MPIEXEC
// If we are within an openmpi mpirun, then skip the exec
Expand All @@ -350,6 +357,7 @@ int main(int argc, char** argv) {
testScatter();
testSendRecv(false);
testSendRecv(true);
testBackendName();

std::cout << "Test successful" << std::endl;
#else
Expand Down
13 changes: 13 additions & 0 deletions torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
Expand Up @@ -470,3 +470,16 @@ TEST_F(ProcessGroupNCCLTest, testReduceScatter) {
testReduceScatter(file.path, rank_, size_);
}
}

TEST_F(ProcessGroupNCCLTest, testBackendName) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
auto test = NCCLTestBase(file.path);
test.initialize(rank_, size_);
EXPECT_EQ(
test.getProcessGroup().getBackendName(), c10d::NCCL_BACKEND_NAME);
}
}
2 changes: 2 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -3156,6 +3156,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
BACKEND == "nccl", "nccl does not support DDP on CPU models"
)
def test_ddp_logging_data(self):
group, group_id, rank = self._init_global_test()
model_DDP = copy.deepcopy(DDP_NET)
model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
ddp_logging_data = model_DDP.get_ddp_logging_data()
Expand All @@ -3170,6 +3171,7 @@ def test_ddp_logging_data(self):
self.assertEqual(ddp_logging_data.bucket_cap_mb, 25)
self.assertEqual(ddp_logging_data.find_unused_parameters, False)
self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False)
self.assertEqual(ddp_logging_data.backend_name, dist.get_backend(group_id))

@skipIfNoTorchVision
def test_SyncBatchNorm_process_group(self):
Expand Down

0 comments on commit 089a879

Please sign in to comment.