Skip to content

Commit

Permalink
[c10d] Figure out device to use for object collectives (#100954)
Browse files Browse the repository at this point in the history
Fixes #97938

this pr is clone from #100238, which is important to me. But
@kwen2501 has not resolved the confliction. So, this pr is submitted to resolve the confliction.
the only confliction is `distributed_c10d.py:2653`

Pull Request resolved: #100954
Approved by: https://github.com/kwen2501
  • Loading branch information
kwen2501 authored and pytorchmergebot committed May 11, 2023
1 parent a0e6ae2 commit 0848ed2
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 50 deletions.
5 changes: 2 additions & 3 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,8 +1610,8 @@ def test_backend_config(self):
# Ensure backend config can be created with the following arguments
backend_config_strings_and_expected_values = [
(dist.Backend.GLOO, "cpu:gloo,cuda:gloo"),
(dist.Backend.NCCL, "cpu:nccl,cuda:nccl"),
(dist.Backend.MPI, "cpu:mpi,cuda:mpi"),
(dist.Backend.NCCL, "cuda:nccl"),
(dist.Backend.MPI, "cpu:mpi"),
(dist.Backend.UCC, "cpu:ucc,cuda:ucc"),
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy"),
("DUMMY", "cpu:dummy,cuda:dummy"),
Expand All @@ -1620,7 +1620,6 @@ def test_backend_config(self):
("cpu:dummy,cuda:nccl", "cpu:dummy,cuda:nccl"),
("cpu:gloo,cuda:dummy", "cpu:gloo,cuda:dummy"),
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
("cPu:gLoO,cuDa:NcCl", "cpu:gloo,cuda:nccl")
]

for config_str, expected_value in backend_config_strings_and_expected_values:
Expand Down
14 changes: 0 additions & 14 deletions test/distributed/test_pg_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,6 @@ def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
tensor=tensor,
)

with self.assertRaisesRegex(RuntimeError, ".*") as cm:
scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
scattered_tensor = torch.empty(4)
if self.rank == 0:
wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
else:
wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
self._validate_error(
exception=cm.exception,
op_type="SCATTER" if self.rank == 0 else "REDUCE_SCATTER",
rank=self.rank,
tensor=scattered_tensor,
)

with self.assertRaisesRegex(RuntimeError, ".*") as cm:
if self.rank == 0:
wrapper_pg.broadcast(tensor, 0)
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::DeviceType deviceType,
BackendType backendType,
const c10::optional<c10::intrusive_ptr<Backend>>& backend) {
// TODO: should we add these entries after the backend setting succeeds?
deviceTypeToBackendType_[deviceType] = backendType;
deviceTypes_.insert(deviceType);
// if the backendType is already set then reuse it for this device
if (backendTypeToBackend_.find(backendType) !=
backendTypeToBackend_.end()) {
Expand Down Expand Up @@ -585,6 +587,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return backendTypeToBackend_.at(backendType);
}

// Return device types supported by this ProcessGroup.
// Note: the return type is `Device` rather than `DeviceType` for the purpose
// of easy comparison at Python level. The `Device` will have default index
// (-1).
std::vector<c10::Device> getDeviceTypes() const {
std::vector<c10::Device> devices;
devices.reserve(deviceTypes_.size());
for (auto& dt : deviceTypes_) {
devices.push_back(c10::Device(dt));
}
return devices;
}

protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
Expand All @@ -603,6 +618,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
DebugLevel dist_debug_level_;

// Backend classes for this ProcessGroup
std::unordered_set<c10::DeviceType> deviceTypes_;
std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_;
std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>>
deviceTypeToBackend_;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,8 @@ that adds a prefix to each key inserted to the store.
py::arg("timeout") = ::c10d::kUnsetTimeout,
py::arg("wait_all_ranks") = false,
py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"_device_types", &::c10d::ProcessGroup::getDeviceTypes)
.def(
"_get_backend_name",
&::c10d::ProcessGroup::getBackendName,
Expand Down
131 changes: 113 additions & 18 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ class Backend:

backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]

backend_capability: Dict[str, List[str]] = {
GLOO : ["cpu", "cuda"],
NCCL : ["cuda"],
UCC : ["cpu", "cuda"],
MPI : ["cpu"],
}

def __new__(cls, name: str):
if not isinstance(name, str):
raise ValueError(f"Backend name must be a string, but got: {name}")
Expand All @@ -214,7 +221,7 @@ def __new__(cls, name: str):
return value

@classmethod
def register_backend(cls, name, func, extended_api=False):
def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None):
"""
Registers a new backend with the given name and instantiating function.
Expand All @@ -232,6 +239,9 @@ def register_backend(cls, name, func, extended_api=False):
Default: ``False``. If set to ``True``, the backend
will get an instance of ``c10d::DistributedBackendOptions``, and
a process group options object as defined by the backend implementation.
device (str or list of str, optional): device type this backend
supports, e.g. "cpu", "cuda", etc. If `None`,
assuming both "cpu" and "cuda"
.. note:: This support of 3rd party backend is experimental and subject to change.
Expand All @@ -248,28 +258,47 @@ def register_backend(cls, name, func, extended_api=False):

setattr(Backend, name.upper(), name.upper())
Backend.backend_list.append(name.lower())

# Update device capability matrix in Backend class
if devices is None:
# This is more of a backward support for groups like `threaded`:
# assume default devices "cpu" and "cuda", but warn
warnings.warn(
f"Device capability of {name} unspecified, assuming `cpu` and "
"`cuda`. Please specify it via the `devices` argument of "
"`register_backend`."
)
Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
elif isinstance(devices, str):
# Single device string specified. Simply convert to list.
Backend.backend_capability[name.lower()] = [devices]
else:
Backend.backend_capability[name.lower()] = devices

Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)

class BackendConfig:

def __init__(self, backend: Union[str, Backend]):
self.device_backend_map: Dict[torch.device, Backend] = {}

# Cases for when backend is a single string (without device types)
if backend == Backend.UNDEFINED:
# default config when backend is not specified
# supported since PyTorch 2.0
self.device_backend_map = {
"cpu": Backend.GLOO,
"cuda": Backend.NCCL,
}
elif backend.lower() in Backend.backend_list:
# backend applies to all devices (e.g. "NCCL", "GLOO", "UCC", "MPI", "custom_backend")
# Cases for when backend is a single string (without device types)
# e.g. "nccl", "gloo", "ucc", "mpi"
supported_devices = Backend.backend_capability[backend.lower()]
backend_val = Backend(backend)
self.device_backend_map = {
"cpu": backend_val,
"cuda": backend_val,
device : backend_val for device in supported_devices
}
else:
elif ":" in backend.lower():
# Backend specified in "device:backend" format
# make sure the backend string is in the correct format
# "{device_type1}:{backend1},{device_type2}:{backend2}"
# e.g. "cpu:gloo,cuda:nccl"
Expand All @@ -288,6 +317,24 @@ def __init__(self, backend: Union[str, Backend]):
raise ValueError(f"Duplicate device type {device} \
in backend string: {backend}. {backend_str_error_message}")
self.device_backend_map[device] = Backend(backend)
else:
# User specified a single backend name whose device capability is
# unknown, assuming it can support the default devices of PyTorch
# (cpu and cuda)
warnings.warn(
f"Device capability of {backend} unknown, assuming `cpu` and "
"`cuda`. You can specify it in `device:backend` format in "
"`init_process_group` call."
)
backend_val = Backend(backend)
self.device_backend_map = {
"cpu" : backend_val,
"cuda" : backend_val,
}

logger.info(
f"Using backend config: {self.device_backend_map}" # noqa: G004
)

def __repr__(self):
# string with all the device:backend pairs separated by commas
Expand Down Expand Up @@ -406,6 +453,7 @@ class _World:
def __init__(self):
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_object_coll_device: Dict[ProcessGroup, torch.device] = {}

@property
def default_pg(self):
Expand Down Expand Up @@ -491,6 +539,10 @@ def pg_to_tag(self) -> Dict[ProcessGroup, str]:
def pg_coalesce_state(self) -> Dict[ProcessGroup, List[Union[_CollOp, P2POp]]]:
return self._pg_coalesce_state

@property
def pg_object_coll_device(self) -> Dict[ProcessGroup, torch.device]:
return self._pg_object_coll_device

_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""

Expand Down Expand Up @@ -521,14 +573,54 @@ class GroupMember(metaclass=_WorldMeta):
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"


def _get_pg_device(group: ProcessGroup):
"""
Returns the device to use with ``group``.
This is cuda for NCCL and CPU for everything else
"""
if _check_for_nccl_backend(group):
return torch.device("cuda", torch.cuda.current_device())
return torch.device("cpu")
def _get_object_coll_device(group: Optional[ProcessGroup] = None):
group = group or _get_default_group()
if group in _world.pg_object_coll_device:
# Previously searched and cached; just return
return _world.pg_object_coll_device[group]

if not isinstance(group, ProcessGroup):
# Provide backward compatibility to cases where `group` passed in is
# actually a Backend (like `ProcessGroupGloo`) rather than a
# `ProcessGroup` in PT 2.0 sense
warnings.warn(
f"You are using a Backend {type(group)} as a ProcessGroup. "
"This usage is deprecated since PyTorch 2.0. Please use a public API "
"of PyTorch Distributed instead."
)
# Most users create Gloo with private API for object collectives
_world.pg_object_coll_device[group] = torch.device("cpu")
return _world.pg_object_coll_device[group]

"""
``group._device_types`` is a property pybind that returns the devices
("cpu", "cuda", etc) supported by ``group``. Can be multiple if the
``group`` supports multiple devices.
"""
devices = group._device_types

if len(devices) == 1:
# User fixed exactly one backend in `init_process_group`
_world.pg_object_coll_device[group] = devices[0]
elif len(devices) == 0:
# No backend has been registered with this PG (maybe because no
# collective has been run?) We pick cpu as the default and hopefully
# this would lazily init Gloo or other available cpu backend.
_world.pg_object_coll_device[group] = torch.device("cpu")
elif torch.device("cpu") in devices:
# There are multiple backends in this PG and cpu is among them.
# cpu is preferred as the object is in cpu memory. No need for device
# copy.
_world.pg_object_coll_device[group] = torch.device("cpu")
else:
# No cpu in the backend list. Randomly pick the first backend
_world.pg_object_coll_device[group] = devices[0]

logger.info(
f"Using device {_world.pg_object_coll_device[group]} for object " # noqa: G004
"collectives."
)
return _world.pg_object_coll_device[group]


# Environment variable to control whether we do a barrier after process group
Expand Down Expand Up @@ -1271,6 +1363,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
_world.pg_to_tag.clear()
_world.tags_to_pg.clear()
_world.pg_coalesce_state.clear()
_world.pg_object_coll_device.clear()

# when process group doesn't have an explicit name (only WORLD (default)
# process group can have an explicit name), we use global _world.group_count
Expand All @@ -1286,6 +1379,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]
del _world.pg_backend_config[pg]
if pg in _world.pg_object_coll_device:
del _world.pg_object_coll_device[pg]
if pg in _world.pg_coalesce_state.keys():
warnings.warn(
"Some coalesced collectives haven't been launched when "
Expand Down Expand Up @@ -2236,7 +2331,7 @@ def all_gather_object(object_list, obj, group=None):
_warn_not_in_group("all_gather_object")
return

current_device = _get_pg_device(group)
current_device = _get_object_coll_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)

# Gather all local sizes. This is so that we can find the max size, and index
Expand Down Expand Up @@ -2337,7 +2432,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# Ensure object_gather_list is specified appropriately.
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
current_device = _get_pg_device(group)
current_device = _get_object_coll_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)

# Gather all local sizes. This is so that we can find the max size, and index
Expand Down Expand Up @@ -2451,7 +2546,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
# case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
current_device = device or _get_pg_device(group)
current_device = device or _get_object_coll_device(group)
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
Expand Down Expand Up @@ -2556,7 +2651,7 @@ def scatter_object_list(
)

my_rank = get_rank()
pg_device = _get_pg_device(group)
pg_device = _get_object_coll_device(group)
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
Expand Down
14 changes: 0 additions & 14 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,20 +1488,6 @@ def test_batch_isend_irecv_gloo_tags(self):

self._barrier()

# NCCL Batch SEND RECV Tensor Error
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
@requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
def test_batch_isend_irecv_tensor_err(self):
self._barrier()
rank = dist.get_rank()
if rank == 0:
with self.assertRaisesRegex(
RuntimeError, "Tensors must be CUDA and dense"
):
send_tensor = _build_tensor(rank + 1)
send_op = dist.P2POp(dist.isend, send_tensor, 1)
dist.batch_isend_irecv([send_op])

# NCCL Batch SEND RECV Op Error
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
@requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
Expand Down
7 changes: 6 additions & 1 deletion torch/testing/_internal/distributed/multi_threaded_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,15 @@ class WorldData:
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
pg_to_tag: Dict[dist.ProcessGroup, str]
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
pg_object_coll_device: Dict[dist.ProcessGroup, torch.device]


class ThreadLocalWorld:
_world = threading.local()

def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {})
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
return ThreadLocalWorld._world.world

@property
Expand Down Expand Up @@ -435,6 +436,10 @@ def pg_to_tag(self):
def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]:
return self._get_world().pg_coalesce_state

@property
def pg_object_coll_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_object_coll_device


_old_pg_world = None

Expand Down

0 comments on commit 0848ed2

Please sign in to comment.