Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c10d] Figure out device to use for object collectives #100954

Closed
wants to merge 10 commits into from
5 changes: 2 additions & 3 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,8 +1610,8 @@ def test_backend_config(self):
# Ensure backend config can be created with the following arguments
backend_config_strings_and_expected_values = [
(dist.Backend.GLOO, "cpu:gloo,cuda:gloo"),
(dist.Backend.NCCL, "cpu:nccl,cuda:nccl"),
(dist.Backend.MPI, "cpu:mpi,cuda:mpi"),
(dist.Backend.NCCL, "cuda:nccl"),
(dist.Backend.MPI, "cpu:mpi"),
(dist.Backend.UCC, "cpu:ucc,cuda:ucc"),
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy"),
("DUMMY", "cpu:dummy,cuda:dummy"),
Expand All @@ -1620,7 +1620,6 @@ def test_backend_config(self):
("cpu:dummy,cuda:nccl", "cpu:dummy,cuda:nccl"),
("cpu:gloo,cuda:dummy", "cpu:gloo,cuda:dummy"),
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
("cPu:gLoO,cuDa:NcCl", "cpu:gloo,cuda:nccl")
]

for config_str, expected_value in backend_config_strings_and_expected_values:
Expand Down
14 changes: 0 additions & 14 deletions test/distributed/test_pg_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,6 @@ def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
tensor=tensor,
)

with self.assertRaisesRegex(RuntimeError, ".*") as cm:
scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
scattered_tensor = torch.empty(4)
if self.rank == 0:
wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
else:
wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
self._validate_error(
exception=cm.exception,
op_type="SCATTER" if self.rank == 0 else "REDUCE_SCATTER",
rank=self.rank,
tensor=scattered_tensor,
)

with self.assertRaisesRegex(RuntimeError, ".*") as cm:
if self.rank == 0:
wrapper_pg.broadcast(tensor, 0)
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::DeviceType deviceType,
BackendType backendType,
const c10::optional<c10::intrusive_ptr<Backend>>& backend) {
// TODO: should we add these entries after the backend setting succeeds?
deviceTypeToBackendType_[deviceType] = backendType;
deviceTypes_.insert(deviceType);
// if the backendType is already set then reuse it for this device
if (backendTypeToBackend_.find(backendType) !=
backendTypeToBackend_.end()) {
Expand Down Expand Up @@ -585,6 +587,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return backendTypeToBackend_.at(backendType);
}

// Return device types supported by this ProcessGroup.
// Note: the return type is `Device` rather than `DeviceType` for the purpose
// of easy comparison at Python level. The `Device` will have default index
// (-1).
std::vector<c10::Device> getDeviceTypes() const {
std::vector<c10::Device> devices;
devices.reserve(deviceTypes_.size());
for (auto& dt : deviceTypes_) {
devices.push_back(c10::Device(dt));
}
return devices;
}

protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
Expand All @@ -603,6 +618,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
DebugLevel dist_debug_level_;

// Backend classes for this ProcessGroup
std::unordered_set<c10::DeviceType> deviceTypes_;
std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_;
std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>>
deviceTypeToBackend_;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,8 @@ that adds a prefix to each key inserted to the store.
py::arg("timeout") = ::c10d::kUnsetTimeout,
py::arg("wait_all_ranks") = false,
py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"_device_types", &::c10d::ProcessGroup::getDeviceTypes)
.def(
"_get_backend_name",
&::c10d::ProcessGroup::getBackendName,
Expand Down
131 changes: 113 additions & 18 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ class Backend:

backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]

backend_capability: Dict[str, List[str]] = {
GLOO : ["cpu", "cuda"],
NCCL : ["cuda"],
UCC : ["cpu", "cuda"],
MPI : ["cpu"],
}

def __new__(cls, name: str):
if not isinstance(name, str):
raise ValueError(f"Backend name must be a string, but got: {name}")
Expand All @@ -214,7 +221,7 @@ def __new__(cls, name: str):
return value

@classmethod
def register_backend(cls, name, func, extended_api=False):
def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None):
"""
Registers a new backend with the given name and instantiating function.

Expand All @@ -232,6 +239,9 @@ def register_backend(cls, name, func, extended_api=False):
Default: ``False``. If set to ``True``, the backend
will get an instance of ``c10d::DistributedBackendOptions``, and
a process group options object as defined by the backend implementation.
device (str or list of str, optional): device type this backend
supports, e.g. "cpu", "cuda", etc. If `None`,
assuming both "cpu" and "cuda"

.. note:: This support of 3rd party backend is experimental and subject to change.

Expand All @@ -248,28 +258,47 @@ def register_backend(cls, name, func, extended_api=False):

setattr(Backend, name.upper(), name.upper())
Backend.backend_list.append(name.lower())

# Update device capability matrix in Backend class
if devices is None:
# This is more of a backward support for groups like `threaded`:
# assume default devices "cpu" and "cuda", but warn
warnings.warn(
f"Device capability of {name} unspecified, assuming `cpu` and "
"`cuda`. Please specify it via the `devices` argument of "
"`register_backend`."
)
Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
elif isinstance(devices, str):
# Single device string specified. Simply convert to list.
Backend.backend_capability[name.lower()] = [devices]
else:
Backend.backend_capability[name.lower()] = devices

Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)

class BackendConfig:

def __init__(self, backend: Union[str, Backend]):
self.device_backend_map: Dict[torch.device, Backend] = {}

# Cases for when backend is a single string (without device types)
if backend == Backend.UNDEFINED:
# default config when backend is not specified
# supported since PyTorch 2.0
self.device_backend_map = {
"cpu": Backend.GLOO,
"cuda": Backend.NCCL,
}
elif backend.lower() in Backend.backend_list:
# backend applies to all devices (e.g. "NCCL", "GLOO", "UCC", "MPI", "custom_backend")
# Cases for when backend is a single string (without device types)
# e.g. "nccl", "gloo", "ucc", "mpi"
supported_devices = Backend.backend_capability[backend.lower()]
backend_val = Backend(backend)
self.device_backend_map = {
"cpu": backend_val,
"cuda": backend_val,
device : backend_val for device in supported_devices
}
else:
elif ":" in backend.lower():
# Backend specified in "device:backend" format
# make sure the backend string is in the correct format
# "{device_type1}:{backend1},{device_type2}:{backend2}"
# e.g. "cpu:gloo,cuda:nccl"
Expand All @@ -288,6 +317,24 @@ def __init__(self, backend: Union[str, Backend]):
raise ValueError(f"Duplicate device type {device} \
in backend string: {backend}. {backend_str_error_message}")
self.device_backend_map[device] = Backend(backend)
else:
# User specified a single backend name whose device capability is
# unknown, assuming it can support the default devices of PyTorch
# (cpu and cuda)
warnings.warn(
f"Device capability of {backend} unknown, assuming `cpu` and "
"`cuda`. You can specify it in `device:backend` format in "
"`init_process_group` call."
)
backend_val = Backend(backend)
self.device_backend_map = {
"cpu" : backend_val,
"cuda" : backend_val,
}

logger.info(
f"Using backend config: {self.device_backend_map}" # noqa: G004
)

def __repr__(self):
# string with all the device:backend pairs separated by commas
Expand Down Expand Up @@ -406,6 +453,7 @@ class _World:
def __init__(self):
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_object_coll_device: Dict[ProcessGroup, torch.device] = {}

@property
def default_pg(self):
Expand Down Expand Up @@ -491,6 +539,10 @@ def pg_to_tag(self) -> Dict[ProcessGroup, str]:
def pg_coalesce_state(self) -> Dict[ProcessGroup, List[Union[_CollOp, P2POp]]]:
return self._pg_coalesce_state

@property
def pg_object_coll_device(self) -> Dict[ProcessGroup, torch.device]:
return self._pg_object_coll_device

_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""

Expand Down Expand Up @@ -521,14 +573,54 @@ class GroupMember(metaclass=_WorldMeta):
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"


def _get_pg_device(group: ProcessGroup):
"""
Returns the device to use with ``group``.
This is cuda for NCCL and CPU for everything else
"""
if _check_for_nccl_backend(group):
return torch.device("cuda", torch.cuda.current_device())
return torch.device("cpu")
def _get_object_coll_device(group: Optional[ProcessGroup] = None):
group = group or _get_default_group()
if group in _world.pg_object_coll_device:
# Previously searched and cached; just return
return _world.pg_object_coll_device[group]

if not isinstance(group, ProcessGroup):
# Provide backward compatibility to cases where `group` passed in is
# actually a Backend (like `ProcessGroupGloo`) rather than a
# `ProcessGroup` in PT 2.0 sense
warnings.warn(
f"You are using a Backend {type(group)} as a ProcessGroup. "
"This usage is deprecated since PyTorch 2.0. Please use a public API "
"of PyTorch Distributed instead."
)
# Most users create Gloo with private API for object collectives
_world.pg_object_coll_device[group] = torch.device("cpu")
return _world.pg_object_coll_device[group]

"""
``group._device_types`` is a property pybind that returns the devices
("cpu", "cuda", etc) supported by ``group``. Can be multiple if the
``group`` supports multiple devices.
"""
devices = group._device_types

if len(devices) == 1:
# User fixed exactly one backend in `init_process_group`
_world.pg_object_coll_device[group] = devices[0]
elif len(devices) == 0:
# No backend has been registered with this PG (maybe because no
# collective has been run?) We pick cpu as the default and hopefully
# this would lazily init Gloo or other available cpu backend.
_world.pg_object_coll_device[group] = torch.device("cpu")
elif torch.device("cpu") in devices:
# There are multiple backends in this PG and cpu is among them.
# cpu is preferred as the object is in cpu memory. No need for device
# copy.
_world.pg_object_coll_device[group] = torch.device("cpu")
else:
# No cpu in the backend list. Randomly pick the first backend
_world.pg_object_coll_device[group] = devices[0]

logger.info(
f"Using device {_world.pg_object_coll_device[group]} for object " # noqa: G004
"collectives."
)
return _world.pg_object_coll_device[group]


# Environment variable to control whether we do a barrier after process group
Expand Down Expand Up @@ -1271,6 +1363,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
_world.pg_to_tag.clear()
_world.tags_to_pg.clear()
_world.pg_coalesce_state.clear()
_world.pg_object_coll_device.clear()

# when process group doesn't have an explicit name (only WORLD (default)
# process group can have an explicit name), we use global _world.group_count
Expand All @@ -1286,6 +1379,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]
del _world.pg_backend_config[pg]
if pg in _world.pg_object_coll_device:
del _world.pg_object_coll_device[pg]
if pg in _world.pg_coalesce_state.keys():
warnings.warn(
"Some coalesced collectives haven't been launched when "
Expand Down Expand Up @@ -2236,7 +2331,7 @@ def all_gather_object(object_list, obj, group=None):
_warn_not_in_group("all_gather_object")
return

current_device = _get_pg_device(group)
current_device = _get_object_coll_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)

# Gather all local sizes. This is so that we can find the max size, and index
Expand Down Expand Up @@ -2337,7 +2432,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# Ensure object_gather_list is specified appropriately.
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
current_device = _get_pg_device(group)
current_device = _get_object_coll_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)

# Gather all local sizes. This is so that we can find the max size, and index
Expand Down Expand Up @@ -2451,7 +2546,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
# case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
current_device = device or _get_pg_device(group)
current_device = device or _get_object_coll_device(group)
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
Expand Down Expand Up @@ -2556,7 +2651,7 @@ def scatter_object_list(
)

my_rank = get_rank()
pg_device = _get_pg_device(group)
pg_device = _get_object_coll_device(group)
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
Expand Down
14 changes: 0 additions & 14 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,20 +1488,6 @@ def test_batch_isend_irecv_gloo_tags(self):

self._barrier()

# NCCL Batch SEND RECV Tensor Error
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
@requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
def test_batch_isend_irecv_tensor_err(self):
self._barrier()
rank = dist.get_rank()
if rank == 0:
with self.assertRaisesRegex(
RuntimeError, "Tensors must be CUDA and dense"
):
send_tensor = _build_tensor(rank + 1)
send_op = dist.P2POp(dist.isend, send_tensor, 1)
dist.batch_isend_irecv([send_op])

# NCCL Batch SEND RECV Op Error
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
@requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
Expand Down
7 changes: 6 additions & 1 deletion torch/testing/_internal/distributed/multi_threaded_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,15 @@ class WorldData:
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
pg_to_tag: Dict[dist.ProcessGroup, str]
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
pg_object_coll_device: Dict[dist.ProcessGroup, torch.device]


class ThreadLocalWorld:
_world = threading.local()

def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {})
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
return ThreadLocalWorld._world.world

@property
Expand Down Expand Up @@ -435,6 +436,10 @@ def pg_to_tag(self):
def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]:
return self._get_world().pg_coalesce_state

@property
def pg_object_coll_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_object_coll_device


_old_pg_world = None

Expand Down