Skip to content

Commit

Permalink
[WIP][NCCL] Support NCCL Send/Recv
Browse files Browse the repository at this point in the history
Pull Request resolved: #44921

This diff adds support for Process Group point-to-point operations on NCCL backend based on ncclSend/ncclRecv. See #43995 for more context.
ghstack-source-id: 112545024

Differential Revision: [D23709848](https://our.internmc.facebook.com/intern/diff/D23709848/)
  • Loading branch information
mingzhe0908 committed Sep 21, 2020
1 parent 24df3b7 commit 171c4e1
Show file tree
Hide file tree
Showing 5 changed files with 473 additions and 46 deletions.
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -713,6 +713,12 @@ They are used in specifying strategies for reduction collectives, e.g.,
.def(py::init<>())
.def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
.def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
processGroupNCCL.def_static("_group_start", []() {
::c10d::ProcessGroupNCCL::groupStart();
});
processGroupNCCL.def_static("_group_end", []() {
::c10d::ProcessGroupNCCL::groupEnd();
});
#endif

#ifdef USE_C10D_MPI
Expand Down
137 changes: 119 additions & 18 deletions torch/distributed/distributed_c10d.py
@@ -1,6 +1,7 @@
import pickle
import torch
import warnings
import contextlib
from torch._six import string_classes
from datetime import timedelta

Expand Down Expand Up @@ -159,8 +160,7 @@ class GroupMember(object):

def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group
Helper that checks if the current process's rank is not in a given group.
"""
if group == GroupMember.WORLD:
return False
Expand All @@ -170,8 +170,7 @@ def _rank_not_in_group(group):
def _get_group_rank(group, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank
rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -188,8 +187,7 @@ def _get_group_rank(group, rank):
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group
group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -204,17 +202,15 @@ def _get_global_rank(group, group_rank):
def _check_default_pg():
"""
Helper that checks if the default ProcessGroup has been initialized, with
assertion
assertion.
"""
assert _default_pg is not None, \
"Default process group is not initialized"


def _get_group_size(group):
"""
Helper that gets a given group's world size
Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD:
_check_default_pg()
Expand All @@ -227,7 +223,6 @@ def _get_group_size(group):
def _check_single_tensor(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a single tensor.
"""
if not isinstance(param, torch.Tensor):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
Expand All @@ -237,50 +232,81 @@ def _check_single_tensor(param, param_name):
def _check_tensor_list(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a list of tensors.
"""
if not isinstance(param, list) or \
not all(isinstance(p, torch.Tensor) for p in param):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
"to be of type List[torch.Tensor].".format(param_name))


def _check_op_list(op_list, list_name):
"""
Helper to check that the ``op_list`` is a list of functions.
"""
if not isinstance(op_list, list) or \
not all(op in [isend, irecv] for op in op_list):
raise RuntimeError(f"Invalid function. Expected fucntion in `{op_list}` "
f"to be of type torch.distributed.isend or "
f"torch.distributed.irecv.")


def _check_input_length(op_list, tensor_list, peer_list, group_list, tag_list):
"""
Helper to check that all the inputs have the same length.
"""
expected_length = len(op_list)
check_list = [tensor_list, peer_list]
if group_list is not None:
check_list.append(group_list)
if tag_list is not None:
check_list.append(tag_list)

for curr_list in check_list:
if not isinstance(curr_list, list) or expected_length != len(curr_list):
raise RuntimeError(f"Expected parameters to be the same length, "
f"{expected_length} vs {len(curr_list)}, value: {curr_list}")


def _check_group_backend(group_list, backend_name):
"""
Helper to check that all groups in ``group_list`` use the same backend.
"""
if not isinstance(group_list, list) or \
not all(backend_name == get_backend(group) for group in group_list):
raise RuntimeError("All groups need to use the same backend.")


def is_mpi_available():
"""
Checks if the MPI backend is available.
"""
return _MPI_AVAILABLE


def is_nccl_available():
"""
Checks if the NCCL backend is available.
"""
return _NCCL_AVAILABLE


def is_gloo_available():
"""
Checks if the Gloo backend is available.
"""
return _GLOO_AVAILABLE


def is_initialized():
"""
Checking if the default process group has been initialized
"""
return _default_pg is not None


def _get_default_group():
"""
Getting the default process group created by init_process_group
"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand All @@ -291,7 +317,6 @@ def _get_default_group():
def _get_default_store():
"""
Getting the default store created by init_process_group
"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand Down Expand Up @@ -753,6 +778,82 @@ def recv(tensor,
return src


@contextlib.contextmanager
def _batch_p2p_manager(backend):
if backend == Backend.NCCL:
ProcessGroupNCCL._group_start()
try:
yield
finally:
if backend == Backend.NCCL:
ProcessGroupNCCL._group_end()


def batch_isend_irecv(op_list,
tensor_list,
peer_list,
group_list=None,
tag_list=None):
"""
Send or Receive tensors asynchronously and return a list of requests.
Process the first item in each passed parameters, and then the second item
in each passed parameters, etc. Each of these lists should be the same length
as the op_list. The ``ith`` element in ``tensor_list``, ``peer_list``, ``group_list``,
and ``tag_list`` are the communication tensor, peer process, Process Group
group, tag, respectively, for the `ith` element of ``op_list``.
Arguments:
op_list: list of point-to-point operations(type of each operations is either
``torch.distributed.isend`` or ``torch.distributed.irecv``. The order of the
isend/irecv in the list matters and it needs to match with corresponding
isend/irecv on the remote end.
tensor_list (list[Tensor]): list of send or recv tensors.
peer_list (list[int]): list of peer ranks to send to or receive from.
group_list (list[ProcessGroup], Optional): list of groups to operator on. All
groups in ``group_list`` should use the same backend.
tag_list (list[int], Optional): list of tags to match send with recv.
Returns:
A list of distributed request objects returned by calling the corresponding
op in the op_list.
Examples:
>>> send_tensor = torch.arange(2) + 2 * rank
>>> recv_tensor = torch.randn(2)
>>> op_list = [dist.isend, dist.irecv]
>>> tensor_list = [send_tensor, recv_tensor]
>>> peer_list = [(rank + 1) % world_size, (rank + 1) % world_size]
>>> reqs = batch_isend_irecv(op_list, tensor_list, peer_list)
>>> for req in reqs:
>>> req.wait()
>>> recv_tensor
tensor([2, 3]) # Rank 0
tensor([0, 1]) # Rank 1
"""
_check_op_list(op_list, "op_list")
_check_input_length(op_list, tensor_list, peer_list, group_list, tag_list)

backend = get_backend(group.WORLD if group_list is None else group_list[0])
if group_list is not None:
_check_group_backend(group_list, backend)

reqs = []
with _batch_p2p_manager(backend):
for i in range(len(op_list)):
op = op_list[i]
tensor = tensor_list[i]
peer = peer_list[i]
curr_group = group.WORLD if group_list is None else group_list[i]
tag = 0 if tag_list is None else tag_list[i]

ret = op(tensor, peer, curr_group, tag)

if ret is not None:
reqs.append(ret)
return reqs


def broadcast_multigpu(tensor_list,
src,
group=group.WORLD,
Expand Down

0 comments on commit 171c4e1

Please sign in to comment.