Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NCCL] Support NCCL Send/Recv #44921

Closed
wants to merge 14 commits into from
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -949,6 +949,12 @@ that adds a prefix to each key inserted to the store.
.def(py::init<>())
.def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
.def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
processGroupNCCL.def_static("_group_start", []() {
::c10d::ProcessGroupNCCL::groupStart();
});
processGroupNCCL.def_static("_group_end", []() {
::c10d::ProcessGroupNCCL::groupEnd();
});
#endif

#ifdef USE_C10D_MPI
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/__init__.py
Expand Up @@ -25,3 +25,7 @@ def is_available():
# this.

from .distributed_c10d import _backend

# TODO: remove this once CI issue is resolved
# https://github.com/pytorch/pytorch/issues/42517
from .distributed_c10d import _P2POp, _batch_isend_irecv
137 changes: 119 additions & 18 deletions torch/distributed/distributed_c10d.py
@@ -1,6 +1,7 @@
import pickle
import torch
import warnings
import contextlib
from torch._six import string_classes
from datetime import timedelta

Expand Down Expand Up @@ -159,8 +160,7 @@ class GroupMember(object):

def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group

Helper that checks if the current process's rank is not in a given group.
"""
if group == GroupMember.WORLD:
return False
Expand All @@ -170,8 +170,7 @@ def _rank_not_in_group(group):
def _get_group_rank(group, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank

rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -188,8 +187,7 @@ def _get_group_rank(group, rank):
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group

group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -204,17 +202,15 @@ def _get_global_rank(group, group_rank):
def _check_default_pg():
"""
Helper that checks if the default ProcessGroup has been initialized, with
assertion

assertion.
"""
assert _default_pg is not None, \
"Default process group is not initialized"


def _get_group_size(group):
"""
Helper that gets a given group's world size

Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD:
_check_default_pg()
Expand All @@ -227,7 +223,6 @@ def _get_group_size(group):
def _check_single_tensor(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a single tensor.

"""
if not isinstance(param, torch.Tensor):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
Expand All @@ -237,50 +232,69 @@ def _check_single_tensor(param, param_name):
def _check_tensor_list(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a list of tensors.

"""
if not isinstance(param, list) or \
not all(isinstance(p, torch.Tensor) for p in param):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
"to be of type List[torch.Tensor].".format(param_name))


def _check_op(op):
"""
Helper to check that the ``op`` is either isend or irecv.
"""
if op not in [isend, irecv]:
raise RuntimeError("Invalid ``op``. Expected ``op`` "
"to be of type ``torch.distributed.isend`` or "
"``torch.distributed.irecv``.")

def _check_p2p_op_list(p2p_op_list):
"""
Helper to check that the ``p2p_op_list`` is a list of _P2POp instances and
all ops use the same backend.
"""
if not isinstance(p2p_op_list, list) or \
not all(isinstance(p2p_op, _P2POp) for p2p_op in p2p_op_list):
raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
"to be of type ``torch.distributed._P2POp``.")


backend = get_backend(p2p_op_list[0].group)
if not all(backend == get_backend(p2p_op.group) for p2p_op in p2p_op_list):
raise RuntimeError("All groups need to use the same backend.")


def is_mpi_available():
"""
Checks if the MPI backend is available.

"""
return _MPI_AVAILABLE


def is_nccl_available():
"""
Checks if the NCCL backend is available.

"""
return _NCCL_AVAILABLE


def is_gloo_available():
"""
Checks if the Gloo backend is available.

"""
return _GLOO_AVAILABLE


def is_initialized():
"""
Checking if the default process group has been initialized

"""
return _default_pg is not None


def _get_default_group():
"""
Getting the default process group created by init_process_group

"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand All @@ -291,7 +305,6 @@ def _get_default_group():
def _get_default_store():
"""
Getting the default store created by init_process_group

"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand Down Expand Up @@ -757,6 +770,94 @@ def recv(tensor,
return src


class _P2POp(object):
"""
A class to build point-to-point operations for ``_batch_isend_irecv``.

This class builds the type of P2P operation, communication buffer, peer rank,
Process Group group, and tag. Instances of this class will be passed to
``_batch_isend_irecv`` for point-to-point communications.

Arguments:
op (callable): A function to send data to or receive data from a peer process.
mingzhe09088 marked this conversation as resolved.
Show resolved Hide resolved
The type of ``op`` is either ``torch.distributed.isend`` or
``torch.distributed.irecv``.
tensor (Tensor): Tensor to send or receive.
peer (int): Destination or source rank.
group (ProcessGroup, optional): The process group to work on.
tag (int, optional): Tag to match send with recv.
"""
def __init__(self, op, tensor, peer, group=group.WORLD, tag=0):
self.op = op
self.tensor = tensor
self.peer = peer
self.group = group
self.tag = tag

def __new__(cls, op, tensor, peer, group=group.WORLD, tag=0):
_check_op(op)
_check_single_tensor(tensor, "tensor")
return object.__new__(cls)


@contextlib.contextmanager
def _batch_p2p_manager(backend):
if backend == Backend.NCCL:
ProcessGroupNCCL._group_start()
try:
yield
finally:
if backend == Backend.NCCL:
ProcessGroupNCCL._group_end()


def _batch_isend_irecv(p2p_op_list):
"""
Send or Receive a batch of tensors asynchronously and return a list of requests.

Process each of the operations in p2p_op_list and return the corresponding
requests. NCCL and Gloo backend are currently supported.

Arguments:
p2p_op_list: A list of point-to-point operations(type of each operator is
``torch.distributed._P2POp``). The order of the isend/irecv in the list
matters and it needs to match with corresponding isend/irecv on the
remote end.

Returns:
A list of distributed request objects returned by calling the corresponding
op in the op_list.

Examples:
>>> send_tensor = torch.arange(2) + 2 * rank
>>> recv_tensor = torch.randn(2)
>>> send_op = dist._P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
>>> recv_op = dist._P2POp(dist.irecv, recv_tensor, (rank + 1)%world_size)
>>> reqs = _batch_isend_irecv([send_op, recv_op])
>>> for req in reqs:
>>> req.wait()
>>> recv_tensor
tensor([2, 3]) # Rank 0
tensor([0, 1]) # Rank 1
"""
_check_p2p_op_list(p2p_op_list)
backend = get_backend(p2p_op_list[0].group)
reqs = []
with _batch_p2p_manager(backend):
for p2p_op in p2p_op_list:
op = p2p_op.op
tensor = p2p_op.tensor
peer = p2p_op.peer
curr_group = p2p_op.group
tag = p2p_op.tag

ret = op(tensor, peer, curr_group, tag)

if ret is not None:
reqs.append(ret)
return reqs


def broadcast_multigpu(tensor_list,
src,
group=group.WORLD,
Expand Down