Skip to content

Commit

Permalink
[C10D] Track pg name in c++. (#108813)
Browse files Browse the repository at this point in the history
Pull Request resolved: #108813
Approved by: https://github.com/wconstab
  • Loading branch information
Rodrigo Kumpera authored and pull[bot] committed Jan 27, 2024
1 parent e130d22 commit 2286256
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
10 changes: 10 additions & 0 deletions torch/csrc/distributed/c10d/Backend.hpp
Expand Up @@ -346,6 +346,15 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return onCompletionHook_ != nullptr;
}

// Do not call this directly, use ProcessGroup::setGroupName instead.
void setGroupName(const std::string& name) {
pg_name_ = name;
}

const std::string& getGroupName() const {
return pg_name_;
}

protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
Expand All @@ -358,6 +367,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.
DebugLevel dist_debug_level_;
std::string pg_name_;

std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
};
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.cpp
Expand Up @@ -153,4 +153,16 @@ void ProcessGroup::init() {
C10_LOG_API_USAGE_ONCE(
fmt::format("c10d.process_group_{}", getBackendName()));
}

const std::string& ProcessGroup::getGroupName() const {
TORCH_CHECK(deviceTypeToBackend_.size(), "ProcessGroup name not set");
return deviceTypeToBackend_.begin()->second->getGroupName();
}

void ProcessGroup::setGroupName(const std::string& name) {
for (auto& kv : deviceTypeToBackend_) {
kv.second->setGroupName(name);
}
}

} // namespace c10d
3 changes: 3 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Expand Up @@ -678,6 +678,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return getDefaultBackend()->hasHooks();
}

const std::string& getGroupName() const;
void setGroupName(const std::string& name);

protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -1756,7 +1756,16 @@ The hook must have the following signature:
.def(
"_has_hooks",
&::c10d::ProcessGroup::hasHooks,
py::call_guard<py::gil_scoped_acquire>());
py::call_guard<py::gil_scoped_acquire>())
.def(
"_set_group_name",
&::c10d::ProcessGroup::setGroupName,
py::call_guard<py::gil_scoped_acquire>(),
"Sets the process group name. This is an internal C10D method, do not use.")
.def_property_readonly(
"group_name",
&::c10d::ProcessGroup::getGroupName,
"(Gets this process group name. It's cluster unique)");

py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType")
.value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
Expand Down
11 changes: 7 additions & 4 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -1130,7 +1130,7 @@ def init_process_group(
)

default_pg, _ = _new_process_group_helper(
-1, -1, [], backend, None, group_name=group_name, timeout=timeout
-1, -1, [], backend, None, group_name, timeout=timeout
)
_update_default_pg(default_pg)
else:
Expand All @@ -1152,8 +1152,8 @@ def init_process_group(
[],
backend,
store,
group_name,
pg_options=pg_options,
group_name=group_name,
timeout=timeout
)
_update_default_pg(default_pg)
Expand Down Expand Up @@ -1190,8 +1190,8 @@ def _new_process_group_helper(
global_ranks_in_group,
backend,
store,
group_name,
pg_options=None,
group_name=None,
timeout=default_pg_timeout,
pg_tag=None
):
Expand Down Expand Up @@ -1352,8 +1352,11 @@ def _new_process_group_helper(
pg._register_backend(torch.device(device), backend_type, backend_class)

# update global state
assert group_name is not None
_world.pg_map[pg] = (backend, prefix_store)
_world.pg_names[pg] = group_name
pg._set_group_name(group_name)

_world.pg_backend_config[pg] = str(backend_config)
# "" is the default tag for user PGs
if pg_tag in [None, ""]:
Expand Down Expand Up @@ -3948,7 +3951,7 @@ def _new_group_with_tag(
ranks,
backend,
default_store,
group_name=group_name,
group_name,
pg_options=pg_options,
timeout=timeout,
pg_tag=pg_tag
Expand Down

0 comments on commit 2286256

Please sign in to comment.