Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a c++ interface in processGroup to get its backend name #51066

Closed
wants to merge 6 commits into from
1 change: 1 addition & 0 deletions c10/util/Logging.h
Expand Up @@ -319,6 +319,7 @@ struct DDPLoggingData {
int bucket_cap_mb;
bool find_unused_parameters;
bool gradient_as_bucket_view;
std::string backend_name;
};

namespace detail {
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -1191,7 +1191,8 @@ that adds a prefix to each key inserted to the store.
.def_readwrite("broadcast_buffers", &c10::DDPLoggingData::broadcast_buffers)
.def_readwrite("bucket_cap_mb", &c10::DDPLoggingData::bucket_cap_mb)
.def_readwrite("find_unused_parameters", &c10::DDPLoggingData::find_unused_parameters)
.def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view);
.def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view)
.def_readwrite("backend_name", &c10::DDPLoggingData::backend_name);

module.def(
"_compute_bucket_assignment_by_size",
Expand Down
10 changes: 10 additions & 0 deletions torch/lib/c10d/ProcessGroup.hpp
Expand Up @@ -42,6 +42,12 @@ enum class OpType : std::uint8_t {
UNKNOWN = 100,
};

const std::string GLOO_BACKEND_NAME = "gloo";
const std::string NCCL_BACKEND_NAME = "nccl";
const std::string MPI_BACKEND_NAME = "mpi";
const std::string ROUND_ROBIN_BACKEND_NAME = "round_robin";
const std::string UNDEFINED = "undefined";
zhaojuanmao marked this conversation as resolved.
Show resolved Hide resolved

// Converts OpType to human readable string.
std::string opTypeToString(OpType opType);

Expand Down Expand Up @@ -170,6 +176,10 @@ class ProcessGroup : public torch::CustomClassHolder {
return size_;
}

virtual const std::string getBackendName() const {
return UNDEFINED;
}

virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) = 0;
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.hpp
Expand Up @@ -133,6 +133,10 @@ class ProcessGroupGloo : public ProcessGroup {
int threads;
};

const std::string getBackendName() const {
return GLOO_BACKEND_NAME;
}

// Helper functions to create a new device object.
// They are static functions on this class to keep them logically
// separate from the rest of the code base (e.g. torch/csrc/distributed).
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupMPI.hpp
Expand Up @@ -108,6 +108,10 @@ class ProcessGroupMPI : public ProcessGroup {
// Abort the MPI program, needs to be called when exception is detected
void abort();

const std::string getBackendName() const {
return MPI_BACKEND_NAME;
}

c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -227,6 +227,10 @@ class ProcessGroupNCCL : public ProcessGroup {

virtual ~ProcessGroupNCCL();

const std::string getBackendName() const {
return NCCL_BACKEND_NAME;
}

c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupRoundRobin.hpp
Expand Up @@ -25,6 +25,10 @@ class ProcessGroupRoundRobin final : public ProcessGroup {

~ProcessGroupRoundRobin() override;

const std::string getBackendName() const {
return ROUND_ROBIN_BACKEND_NAME;
}

c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
Expand Down
1 change: 1 addition & 0 deletions torch/lib/c10d/reducer.cpp
Expand Up @@ -1481,6 +1481,7 @@ void Reducer::set_construction_logging_data(
ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024);
ddp_logging_data_->find_unused_parameters = find_unused_parameters_;
ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_;
ddp_logging_data_->backend_name = process_group_->getBackendName();
}

c10::DDPLoggingData Reducer::get_ddp_logging_data() {
Expand Down
13 changes: 13 additions & 0 deletions torch/lib/c10d/test/ProcessGroupGlooTest.cpp
Expand Up @@ -577,4 +577,17 @@ TEST(ProcessGroupGlooTest, testAlltoallCUDA) {
}
}

TEST(ProcessGroupGlooTest, testBackendName) {
{
TemporaryFile file;
const auto size = 2;
auto tests = CollectiveTest::initialize(file.path, size);

for (auto i = 0; i < size; i++) {
EXPECT_EQ(
tests[i].getProcessGroup().getBackendName(), c10d::GLOO_BACKEND_NAME);
}
}
}

#endif
8 changes: 8 additions & 0 deletions torch/lib/c10d/test/ProcessGroupMPITest.cpp
Expand Up @@ -334,6 +334,13 @@ void testSendRecv(bool recvAnysource, int iter = 10000) {
}
}

void testBackendName() {
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
if (pg->getBackendName() != MPI_BACKEND_NAME) {
throw std::runtime_error("BOOM!");
}
}

int main(int argc, char** argv) {
#ifdef MPIEXEC
// If we are within an openmpi mpirun, then skip the exec
Expand All @@ -350,6 +357,7 @@ int main(int argc, char** argv) {
testScatter();
testSendRecv(false);
testSendRecv(true);
testBackendName();

std::cout << "Test successful" << std::endl;
#else
Expand Down
13 changes: 13 additions & 0 deletions torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
Expand Up @@ -470,3 +470,16 @@ TEST_F(ProcessGroupNCCLTest, testReduceScatter) {
testReduceScatter(file.path, rank_, size_);
}
}

TEST_F(ProcessGroupNCCLTest, testBackendName) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
auto test = NCCLTestBase(file.path);
test.initialize(rank_, size_);
EXPECT_EQ(
test.getProcessGroup().getBackendName(), c10d::NCCL_BACKEND_NAME);
}
}
2 changes: 2 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -3205,6 +3205,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
BACKEND == "nccl", "nccl does not support DDP on CPU models"
)
def test_ddp_logging_data(self):
group, group_id, rank = self._init_global_test()
model_DDP = copy.deepcopy(DDP_NET)
model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
ddp_logging_data = model_DDP.get_ddp_logging_data()
Expand All @@ -3219,6 +3220,7 @@ def test_ddp_logging_data(self):
self.assertEqual(ddp_logging_data.bucket_cap_mb, 25)
self.assertEqual(ddp_logging_data.find_unused_parameters, False)
self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False)
self.assertEqual(ddp_logging_data.backend_name, dist.get_backend(group_id))

@skipIfNoTorchVision
def test_SyncBatchNorm_process_group(self):
Expand Down