Skip to content

Commit

Permalink
[NCCL] Support NCCL Send/Recv (#44921)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #44921

This diff adds support for Process Group point-to-point operations on NCCL backend based on ncclSend/ncclRecv. See #43995 for more context.
ghstack-source-id: 113592785

Test Plan: unittest

Reviewed By: jiayisuse

Differential Revision: D23709848

fbshipit-source-id: cdf38050379ecbb10450f3394631317b41163258
  • Loading branch information
mingzhe09088 authored and facebook-github-bot committed Oct 6, 2020
1 parent b04ae95 commit 59083d6
Show file tree
Hide file tree
Showing 6 changed files with 527 additions and 46 deletions.
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -949,6 +949,12 @@ that adds a prefix to each key inserted to the store.
.def(py::init<>())
.def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
.def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
processGroupNCCL.def_static("_group_start", []() {
::c10d::ProcessGroupNCCL::groupStart();
});
processGroupNCCL.def_static("_group_end", []() {
::c10d::ProcessGroupNCCL::groupEnd();
});
#endif

#ifdef USE_C10D_MPI
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/__init__.py
Expand Up @@ -25,3 +25,7 @@ def is_available():
# this.

from .distributed_c10d import _backend

# TODO: remove this once CI issue is resolved
# https://github.com/pytorch/pytorch/issues/42517
from .distributed_c10d import _P2POp, _batch_isend_irecv
137 changes: 119 additions & 18 deletions torch/distributed/distributed_c10d.py
@@ -1,6 +1,7 @@
import pickle
import torch
import warnings
import contextlib
from torch._six import string_classes
from datetime import timedelta

Expand Down Expand Up @@ -159,8 +160,7 @@ class GroupMember(object):

def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group
Helper that checks if the current process's rank is not in a given group.
"""
if group == GroupMember.WORLD:
return False
Expand All @@ -170,8 +170,7 @@ def _rank_not_in_group(group):
def _get_group_rank(group, rank):
"""
Helper that gets a given group's local rank in the group from a given global
rank
rank.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -188,8 +187,7 @@ def _get_group_rank(group, rank):
def _get_global_rank(group, group_rank):
"""
Helper that gets a given group's global rank from a given local rank in the
group
group.
"""
if group is GroupMember.WORLD:
raise RuntimeError("group.WORLD does not have local rank to global "
Expand All @@ -204,17 +202,15 @@ def _get_global_rank(group, group_rank):
def _check_default_pg():
"""
Helper that checks if the default ProcessGroup has been initialized, with
assertion
assertion.
"""
assert _default_pg is not None, \
"Default process group is not initialized"


def _get_group_size(group):
"""
Helper that gets a given group's world size
Helper that gets a given group's world size.
"""
if group is GroupMember.WORLD:
_check_default_pg()
Expand All @@ -227,7 +223,6 @@ def _get_group_size(group):
def _check_single_tensor(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a single tensor.
"""
if not isinstance(param, torch.Tensor):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
Expand All @@ -237,50 +232,69 @@ def _check_single_tensor(param, param_name):
def _check_tensor_list(param, param_name):
"""
Helper to check that the parameter ``param_name`` is a list of tensors.
"""
if not isinstance(param, list) or \
not all(isinstance(p, torch.Tensor) for p in param):
raise RuntimeError("Invalid function argument. Expected parameter `{}` "
"to be of type List[torch.Tensor].".format(param_name))


def _check_op(op):
"""
Helper to check that the ``op`` is either isend or irecv.
"""
if op not in [isend, irecv]:
raise RuntimeError("Invalid ``op``. Expected ``op`` "
"to be of type ``torch.distributed.isend`` or "
"``torch.distributed.irecv``.")

def _check_p2p_op_list(p2p_op_list):
"""
Helper to check that the ``p2p_op_list`` is a list of _P2POp instances and
all ops use the same backend.
"""
if not isinstance(p2p_op_list, list) or \
not all(isinstance(p2p_op, _P2POp) for p2p_op in p2p_op_list):
raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
"to be of type ``torch.distributed._P2POp``.")


backend = get_backend(p2p_op_list[0].group)
if not all(backend == get_backend(p2p_op.group) for p2p_op in p2p_op_list):
raise RuntimeError("All groups need to use the same backend.")


def is_mpi_available():
"""
Checks if the MPI backend is available.
"""
return _MPI_AVAILABLE


def is_nccl_available():
"""
Checks if the NCCL backend is available.
"""
return _NCCL_AVAILABLE


def is_gloo_available():
"""
Checks if the Gloo backend is available.
"""
return _GLOO_AVAILABLE


def is_initialized():
"""
Checking if the default process group has been initialized
"""
return _default_pg is not None


def _get_default_group():
"""
Getting the default process group created by init_process_group
"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand All @@ -291,7 +305,6 @@ def _get_default_group():
def _get_default_store():
"""
Getting the default store created by init_process_group
"""
if not is_initialized():
raise RuntimeError("Default process group has not been initialized, "
Expand Down Expand Up @@ -757,6 +770,94 @@ def recv(tensor,
return src


class _P2POp(object):
"""
A class to build point-to-point operations for ``_batch_isend_irecv``.
This class builds the type of P2P operation, communication buffer, peer rank,
Process Group group, and tag. Instances of this class will be passed to
``_batch_isend_irecv`` for point-to-point communications.
Arguments:
op (callable): A function to send data to or receive data from a peer process.
The type of ``op`` is either ``torch.distributed.isend`` or
``torch.distributed.irecv``.
tensor (Tensor): Tensor to send or receive.
peer (int): Destination or source rank.
group (ProcessGroup, optional): The process group to work on.
tag (int, optional): Tag to match send with recv.
"""
def __init__(self, op, tensor, peer, group=group.WORLD, tag=0):
self.op = op
self.tensor = tensor
self.peer = peer
self.group = group
self.tag = tag

def __new__(cls, op, tensor, peer, group=group.WORLD, tag=0):
_check_op(op)
_check_single_tensor(tensor, "tensor")
return object.__new__(cls)


@contextlib.contextmanager
def _batch_p2p_manager(backend):
if backend == Backend.NCCL:
ProcessGroupNCCL._group_start()
try:
yield
finally:
if backend == Backend.NCCL:
ProcessGroupNCCL._group_end()


def _batch_isend_irecv(p2p_op_list):
"""
Send or Receive a batch of tensors asynchronously and return a list of requests.
Process each of the operations in p2p_op_list and return the corresponding
requests. NCCL and Gloo backend are currently supported.
Arguments:
p2p_op_list: A list of point-to-point operations(type of each operator is
``torch.distributed._P2POp``). The order of the isend/irecv in the list
matters and it needs to match with corresponding isend/irecv on the
remote end.
Returns:
A list of distributed request objects returned by calling the corresponding
op in the op_list.
Examples:
>>> send_tensor = torch.arange(2) + 2 * rank
>>> recv_tensor = torch.randn(2)
>>> send_op = dist._P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
>>> recv_op = dist._P2POp(dist.irecv, recv_tensor, (rank + 1)%world_size)
>>> reqs = _batch_isend_irecv([send_op, recv_op])
>>> for req in reqs:
>>> req.wait()
>>> recv_tensor
tensor([2, 3]) # Rank 0
tensor([0, 1]) # Rank 1
"""
_check_p2p_op_list(p2p_op_list)
backend = get_backend(p2p_op_list[0].group)
reqs = []
with _batch_p2p_manager(backend):
for p2p_op in p2p_op_list:
op = p2p_op.op
tensor = p2p_op.tensor
peer = p2p_op.peer
curr_group = p2p_op.group
tag = p2p_op.tag

ret = op(tensor, peer, curr_group, tag)

if ret is not None:
reqs.append(ret)
return reqs


def broadcast_multigpu(tensor_list,
src,
group=group.WORLD,
Expand Down

0 comments on commit 59083d6

Please sign in to comment.