Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NCCL] Support NCCL Send/Recv #44921

Closed
wants to merge 14 commits into from
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -720,6 +720,12 @@ They are used in specifying strategies for reduction collectives, e.g.,
.def(py::init<>())
.def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
.def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
processGroupNCCL.def_static("_group_start", []() {
::c10d::ProcessGroupNCCL::groupStart();
});
processGroupNCCL.def_static("_group_end", []() {
::c10d::ProcessGroupNCCL::groupEnd();
});
#endif

#ifdef USE_C10D_MPI
Expand Down
137 changes: 119 additions & 18 deletions torch/distributed/distributed_c10d.py
@@ -1,6 +1,7 @@
import pickle
import torch
import warnings
import contextlib
from torch._six import string_classes
from datetime import timedelta

Expand Down Expand Up @@ -159,8 +160,7 @@ class GroupMember(object):

def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group

Helper that checks if the current process's rank is not in a given group.
"""
if group == GroupMember.WORLD:
return False
Expand All @@ -170,8 +170,7 @@ def _rank_not_in_group(group):
def _get_group_rank(group, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank

rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -188,8 +187,7 @@ def _get_group_rank(group, rank):
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group

group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -204,17 +202,15 @@ def _get_global_rank(group, group_rank):
def _check_default_pg():
"""
Helper that checks if the default ProcessGroup has been initialized, with
assertion

assertion.
"""
assert _default_pg is not None, \
"Default process group is not initialized"


def _get_group_size(group):
"""
Helper that gets a given group's world size

Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD:
_check_default_pg()
Expand All @@ -227,7 +223,6 @@ def _get_group_size(group):
def _check_single_tensor(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a single tensor.

"""
if not isinstance(param, torch.Tensor):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
Expand All @@ -237,50 +232,81 @@ def _check_single_tensor(param, param_name):
def _check_tensor_list(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a list of tensors.

"""
if not isinstance(param, list) or \
not all(isinstance(p, torch.Tensor) for p in param):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
"to be of type List[torch.Tensor].".format(param_name))


def _check_op_list(op_list, list_name):
"""
Helper to check that the ``op_list`` is a list of functions.
"""
if not isinstance(op_list, list) or \
not all(op in [isend, irecv] for op in op_list):
raise RuntimeError(f"Invalid function. Expected fucntion in `{op_list}` "
f"to be of type torch.distributed.isend or "
f"torch.distributed.irecv.")


def _check_input_length(op_list, tensor_list, peer_list, group_list, tag_list):
"""
Helper to check that all the inputs have the same length.
"""
expected_length = len(op_list)
check_list = [tensor_list, peer_list]
if group_list is not None:
check_list.append(group_list)
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved
if tag_list is not None:
check_list.append(tag_list)

for curr_list in check_list:
if not isinstance(curr_list, list) or expected_length != len(curr_list):
raise RuntimeError(f"Expected parameters to be the same length, "
f"{expected_length} vs {len(curr_list)}, value: {curr_list}")


def _check_group_backend(group_list, backend_name):
"""
Helper to check that all groups in ``group_list`` use the same backend.
"""
if not isinstance(group_list, list) or \
not all(backend_name == get_backend(group) for group in group_list):
raise RuntimeError("All groups need to use the same backend.")


def is_mpi_available():
"""
Checks if the MPI backend is available.

"""
return _MPI_AVAILABLE


def is_nccl_available():
"""
Checks if the NCCL backend is available.

"""
return _NCCL_AVAILABLE


def is_gloo_available():
"""
Checks if the Gloo backend is available.

"""
return _GLOO_AVAILABLE


def is_initialized():
"""
Checking if the default process group has been initialized

"""
return _default_pg is not None


def _get_default_group():
"""
Getting the default process group created by init_process_group

"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand All @@ -291,7 +317,6 @@ def _get_default_group():
def _get_default_store():
"""
Getting the default store created by init_process_group

"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand Down Expand Up @@ -753,6 +778,82 @@ def recv(tensor,
return src


@contextlib.contextmanager
def _batch_p2p_manager(backend):
if backend == Backend.NCCL:
ProcessGroupNCCL._group_start()
try:
yield
finally:
if backend == Backend.NCCL:
ProcessGroupNCCL._group_end()


def batch_isend_irecv(op_list,
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved
tensor_list,
peer_list,
group_list=None,
tag_list=None):
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved
"""
Send or Receive tensors asynchronously and return a list of requests.
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved

Process the first item in each passed parameters, and then the second item
in each passed parameters, etc. Each of these lists should be the same length
as the op_list. The ``ith`` element in ``tensor_list``, ``peer_list``, ``group_list``,
and ``tag_list`` are the communication tensor, peer process, Process Group
group, tag, respectively, for the `ith` element of ``op_list``.

Arguments:
op_list: list of point-to-point operations(type of each operations is either
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved
``torch.distributed.isend`` or ``torch.distributed.irecv``. The order of the
isend/irecv in the list matters and it needs to match with corresponding
isend/irecv on the remote end.
tensor_list (list[Tensor]): list of send or recv tensors.
peer_list (list[int]): list of peer ranks to send to or receive from.
group_list (list[ProcessGroup], Optional): list of groups to operator on. All
groups in ``group_list`` should use the same backend.
tag_list (list[int], Optional): list of tags to match send with recv.

Returns:
A list of distributed request objects returned by calling the corresponding
op in the op_list.
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> send_tensor = torch.arange(2) + 2 * rank
>>> recv_tensor = torch.randn(2)
>>> op_list = [dist.isend, dist.irecv]
>>> tensor_list = [send_tensor, recv_tensor]
>>> peer_list = [(rank + 1) % world_size, (rank + 1) % world_size]
>>> reqs = batch_isend_irecv(op_list, tensor_list, peer_list)
>>> for req in reqs:
>>> req.wait()
>>> recv_tensor
tensor([2, 3]) # Rank 0
tensor([0, 1]) # Rank 1
"""
_check_op_list(op_list, "op_list")
_check_input_length(op_list, tensor_list, peer_list, group_list, tag_list)

backend = get_backend(group.WORLD if group_list is None else group_list[0])
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved
if group_list is not None:
_check_group_backend(group_list, backend)

reqs = []
with _batch_p2p_manager(backend):
for i in range(len(op_list)):
op = op_list[i]
tensor = tensor_list[i]
peer = peer_list[i]
curr_group = group.WORLD if group_list is None else group_list[i]
tag = 0 if tag_list is None else tag_list[i]

ret = op(tensor, peer, curr_group, tag)

if ret is not None:
reqs.append(ret)
return reqs


def broadcast_multigpu(tensor_list,
src,
group=group.WORLD,
Expand Down