diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index aff2da31c133..0e7ae8e2200f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -720,6 +720,12 @@ They are used in specifying strategies for reduction collectives, e.g., .def(py::init<>()) .def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream) .def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout); + processGroupNCCL.def_static("_group_start", []() { + ::c10d::ProcessGroupNCCL::groupStart(); + }); + processGroupNCCL.def_static("_group_end", []() { + ::c10d::ProcessGroupNCCL::groupEnd(); + }); #endif #ifdef USE_C10D_MPI diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c7d66f322bb1..5fdf8e9169a9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,6 +1,7 @@ import pickle import torch import warnings +import contextlib from torch._six import string_classes from datetime import timedelta @@ -159,8 +160,7 @@ class GroupMember(object): def _rank_not_in_group(group): """ - Helper that checks if the current process's rank is not in a given group - + Helper that checks if the current process's rank is not in a given group. """ if group == GroupMember.WORLD: return False @@ -170,8 +170,7 @@ def _rank_not_in_group(group): def _get_group_rank(group, rank): """ Helper that gets a given group's local rank in the group from a given global - rank - + rank. """ if group is GroupMember.WORLD: raise RuntimeError("group.WORLD does not have local rank to global " @@ -188,8 +187,7 @@ def _get_group_rank(group, rank): def _get_global_rank(group, group_rank): """ Helper that gets a given group's global rank from a given local rank in the - group - + group. """ if group is GroupMember.WORLD: raise RuntimeError("group.WORLD does not have local rank to global " @@ -204,8 +202,7 @@ def _get_global_rank(group, group_rank): def _check_default_pg(): """ Helper that checks if the default ProcessGroup has been initialized, with - assertion - + assertion. """ assert _default_pg is not None, \ "Default process group is not initialized" @@ -213,8 +210,7 @@ def _check_default_pg(): def _get_group_size(group): """ - Helper that gets a given group's world size - + Helper that gets a given group's world size. """ if group is GroupMember.WORLD: _check_default_pg() @@ -227,7 +223,6 @@ def _get_group_size(group): def _check_single_tensor(param, param_name): """ Helper to check that the parameter ``param_name`` is a single tensor. - """ if not isinstance(param, torch.Tensor): raise RuntimeError("Invalid function argument. Expected parameter `{}` " @@ -237,7 +232,6 @@ def _check_single_tensor(param, param_name): def _check_tensor_list(param, param_name): """ Helper to check that the parameter ``param_name`` is a list of tensors. - """ if not isinstance(param, list) or \ not all(isinstance(p, torch.Tensor) for p in param): @@ -245,10 +239,34 @@ def _check_tensor_list(param, param_name): "to be of type List[torch.Tensor].".format(param_name)) +def _check_op(op): + """ + Helper to check that the ``op`` is either isend or irecv. + """ + if op not in [isend, irecv]: + raise RuntimeError("Invalid ``op``. Expected ``op`` " + "to be of type ``torch.distributed.isend`` or " + "``torch.distributed.irecv``.") + +def _check_p2p_op_list(p2p_op_list): + """ + Helper to check that the ``p2p_op_list`` is a list of P2POp instances and + all ops use the same backend. + """ + if not isinstance(p2p_op_list, list) or \ + not all(isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list): + raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``torch.distributed.P2POp``.") + + + backend = get_backend(p2p_op_list[0].group) + if not all(backend == get_backend(p2p_op.group) for p2p_op in p2p_op_list): + raise RuntimeError("All groups need to use the same backend.") + + def is_mpi_available(): """ Checks if the MPI backend is available. - """ return _MPI_AVAILABLE @@ -256,7 +274,6 @@ def is_mpi_available(): def is_nccl_available(): """ Checks if the NCCL backend is available. - """ return _NCCL_AVAILABLE @@ -264,7 +281,6 @@ def is_nccl_available(): def is_gloo_available(): """ Checks if the Gloo backend is available. - """ return _GLOO_AVAILABLE @@ -272,7 +288,6 @@ def is_gloo_available(): def is_initialized(): """ Checking if the default process group has been initialized - """ return _default_pg is not None @@ -280,7 +295,6 @@ def is_initialized(): def _get_default_group(): """ Getting the default process group created by init_process_group - """ if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " @@ -291,7 +305,6 @@ def _get_default_group(): def _get_default_store(): """ Getting the default store created by init_process_group - """ if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " @@ -753,6 +766,94 @@ def recv(tensor, return src +class P2POp(object): + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Arguments: + op (callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``torch.distributed.isend`` or + ``torch.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. + tag (int, optional): Tag to match send with recv. + """ + def __init__(self, op, tensor, peer, group=group.WORLD, tag=0): + self.op = op + self.tensor = tensor + self.peer = peer + self.group = group + self.tag = tag + + def __new__(cls, op, tensor, peer, group=group.WORLD, tag=0): + _check_op(op) + _check_single_tensor(tensor, "tensor") + return object.__new__(cls) + + +@contextlib.contextmanager +def _batch_p2p_manager(backend): + if backend == Backend.NCCL: + ProcessGroupNCCL._group_start() + try: + yield + finally: + if backend == Backend.NCCL: + ProcessGroupNCCL._group_end() + + +def batch_isend_irecv(p2p_op_list): + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the operations in p2p_op_list and return the corresponding + requests. NCCL and Gloo backend are currently supported. + + Arguments: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``torch.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed request objects returned by calling the corresponding + op in the op_list. + + Examples: + >>> send_tensor = torch.arange(2) + 2 * rank + >>> recv_tensor = torch.randn(2) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) + >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank + 1)%world_size) + >>> reqs = batch_isend_irecv([send_op, recv_op]) + >>> for req in reqs: + >>> req.wait() + >>> recv_tensor + tensor([2, 3]) # Rank 0 + tensor([0, 1]) # Rank 1 + """ + _check_p2p_op_list(p2p_op_list) + backend = get_backend(p2p_op_list[0].group) + reqs = [] + with _batch_p2p_manager(backend): + for p2p_op in p2p_op_list: + op = p2p_op.op + tensor = p2p_op.tensor + peer = p2p_op.peer + curr_group = p2p_op.group + tag = p2p_op.tag + + ret = op(tensor, peer, curr_group, tag) + + if ret is not None: + reqs.append(ret) + return reqs + + def broadcast_multigpu(tensor_list, src, group=group.WORLD, diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index a3765670c6b2..d9b0cd75b08b 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -232,6 +232,7 @@ const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000; constexpr int64_t kWaitForAbortCommStoreKey = 1000; constexpr int64_t kSynchronizeBusyWaitMillis = 10; const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000; +thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) : devices_(devices), workStartTime_(std::chrono::steady_clock::now()) { @@ -756,7 +757,26 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( std::vector streamVal; streamVal.reserve(devices.size()); - // Create the NCCL communicators for each GPU + // [Group Start/End Note] This is used to ensure that nccl communicator will be created + // before communication primitives are called. Let's look at this example: + // Using the batch_isend_irecv to send a tensor to a target process. On the sender side, + // the corresponding underlying NCCL calls will look like + // ncclGroupStart() // This is in batch_isend_irecv + // ncclGroupStart() // This is [Note 1] + // ncclCommInitRank() // Inside NCCLComm::create + // ncclSend() + // ncclGroupEnd() // This is [Note 2] + // ncclGroupEnd() // This is in batch_isend_irecv + // With this pattern, the nccl communicator will be created in the last ncclGroupEnd + // which means when ncclSend is processed, the passed communicator argument is NULL which will + // lead to runtime error. So we need to "close" all active nccl groups to ensure + // nccl communicator is actually created before encountering any communication calls. + // This is why we need the following for loop. + for (size_t i = 0; i < ncclActiveGroupCounter_; ++i) { + C10D_NCCL_CHECK(ncclGroupEnd()); + } + + // [Note 1] Create the NCCL communicators for each GPU C10D_NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < devices.size(); ++i) { @@ -781,8 +801,14 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( } } + // [Note 2 ] C10D_NCCL_CHECK(ncclGroupEnd()); + // See [Group Start/End Note] + for (size_t i = 0; i < ncclActiveGroupCounter_; ++i) { + C10D_NCCL_CHECK(ncclGroupStart()); + } + ncclStreams_.emplace(devicesKey, std::move(streamVal)); // Note: these events are created with the (default) cudaEventDisableTiming @@ -1006,6 +1032,72 @@ std::shared_ptr ProcessGroupNCCL::collective( return work; } +template +std::shared_ptr ProcessGroupNCCL::pointToPoint( + std::vector& tensors, + Fn fn, + bool isRecv, + PreProcess pre, + PostProcess post) { + const auto devices = getDeviceList(tensors); + const auto key = getKeyFromDevices(devices); + auto& ncclComms = getNCCLComm(key, devices); + + // First let NCCL streams wait for input tensors allocation streams + syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); + + // Work itself will create the CUDA events on all GPUs of tensors + auto work = initWork(devices); + + if (isRecv) { + // Store references to outputs and futureNCCLCallbackStream to be used by + // WorkNCCL::getFuture. + work->outputs_ = std::make_shared>(tensors); + work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; + } + + at::cuda::OptionalCUDAGuard gpuGuard; + + pre(ncclStreams_[key]); + + for (size_t i = 0; i < tensors.size(); ++i) { + gpuGuard.set_index(devices[i].index()); + at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + + // Both send tensor and recv tensor are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensors[i].storage().data_ptr(), ncclStream); + } + + { + AutoNcclGroup nccl_group_guard; + for (size_t i = 0; i < tensors.size(); ++i) { + gpuGuard.set_index(devices[i].index()); + at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + C10D_NCCL_CHECK( + fn(tensors[i], ncclComms[i]->getNcclComm(), ncclStream)); + } + } + + post(ncclStreams_[key]); + + // Event should only be recorded after the ncclGroupEnd() + for (size_t i = 0; i < tensors.size(); ++i) { + at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + (*work->cudaEvents_)[i].record(ncclStream); + work->ncclComms_[i] = ncclComms[i]; + work->blockingWait_ = blockingWait_; + work->opTimeout_ = opTimeout_; + work->store_ = store_; + } + + return work; +} + template std::shared_ptr ProcessGroupNCCL::collective( std::vector& inputs, @@ -1019,6 +1111,19 @@ std::shared_ptr ProcessGroupNCCL::collective( [](std::vector&) {}); } +template +std::shared_ptr ProcessGroupNCCL::pointToPoint( + std::vector& tensor, + Fn fn, + bool isRecv) { + return pointToPoint( + tensor, + fn, + isRecv, + [](std::vector&) {}, + [](std::vector&) {}); +} + std::shared_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { @@ -1294,6 +1399,50 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( }); } } + +std::shared_ptr ProcessGroupNCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + check_gpu_tensors(tensors); + auto ret = pointToPoint( + tensors, + [&](at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclSend( + input.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + dstRank, + comm, + stream.stream()); + }, + /* isRecv */ false); + return ret; +} + +std::shared_ptr ProcessGroupNCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + check_gpu_tensors(tensors); + auto ret= pointToPoint( + tensors, + [&](at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream) { + return ncclRecv( + output.data_ptr(), + output.numel(), + getNcclDataType(output.scalar_type()), + srcRank, + comm, + stream.stream()); + }, + /* isRecv */ true); + return ret; +} #else std::shared_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& /* unused */, @@ -1304,7 +1453,37 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( throw std::runtime_error( "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } + +std::shared_ptr ProcessGroupNCCL::send( + std::vector& /* unused */, + int /* unused */, + int /* unused */) { + throw std::runtime_error( + "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); +} + +std::shared_ptr ProcessGroupNCCL::recv( + std::vector& /* unused */, + int /* unused */, + int /* unused */) { + throw std::runtime_error( + "ProcessGroupNCCL only supports recv for NCCL lib version >= 2.7.0"); +} +#endif + +void ProcessGroupNCCL::groupStart() { +#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) + C10D_NCCL_CHECK(ncclGroupStart()); +#endif + ++ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEnd() { +#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) + C10D_NCCL_CHECK(ncclGroupEnd()); #endif + --ncclActiveGroupCounter_; +} std::shared_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, @@ -1327,24 +1506,10 @@ std::shared_ptr ProcessGroupNCCL::scatter( throw std::runtime_error("ProcessGroupNCCL does not support scatter"); } -std::shared_ptr ProcessGroupNCCL::send( - std::vector& /* unused */, - int /* unused */, - int /* unused */) { - throw std::runtime_error("ProcessGroupNCCL does not support send"); -} - -std::shared_ptr ProcessGroupNCCL::recv( - std::vector& /* unused */, - int /* unused */, - int /* unused */) { - throw std::runtime_error("ProcessGroupNCCL does not support recv"); -} - std::shared_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { - throw std::runtime_error("ProcessGroupNCCL does not support recv"); + throw std::runtime_error("ProcessGroupNCCL does not support recvAnysource"); } std::shared_ptr ProcessGroupNCCL::allgather_base( diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 5a5e5a718ad8..d69e002cec52 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -415,6 +415,20 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; + std::shared_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + std::shared_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + static void groupStart(); + + static void groupEnd(); + // Unsupported Ops std::shared_ptr gather( std::vector>& outputTensors, @@ -426,16 +440,6 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr send( - std::vector& tensors, - int dstRank, - int tag) override; - - std::shared_ptr recv( - std::vector& tensors, - int srcRank, - int tag) override; - std::shared_ptr recvAnysource( std::vector& tensors, int tag) override; @@ -479,6 +483,22 @@ class ProcessGroupNCCL : public ProcessGroup { PreProcess pre, PostProcess post); + // Helper that encapsulates work shared across point-to-point communication + // primitives. It is the same structure as the helper used for collective + // communicaiton primitives. + template + std::shared_ptr pointToPoint( + std::vector& tensor, + Fn fn, + bool isRecv); + template + std::shared_ptr pointToPoint( + std::vector& tensor, + Fn fn, + bool isRecv, + PreProcess pre, + PostProcess post); + // Checks for NCCL errors on each of the communicators and returns an // appropriate exception_ptr (nullptr if no errors). static std::exception_ptr checkForNCCLErrorsInternal( @@ -634,6 +654,11 @@ class ProcessGroupNCCL : public ProcessGroup { // Schedule NCCL operations on high priority CUDA streams. bool isHighPriorityStream_ = false; + + // The number of active ncclGroupStart() calls. This counter will be increased + // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() + // is called. + static thread_local uint64_t ncclActiveGroupCounter_; }; } // namespace c10d diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 85b1d65a06ec..841990e6dc3c 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -33,6 +33,7 @@ skip_if_lt_x_gpu, skip_if_no_gpu, require_n_gpus_for_nccl_backend, + requires_nccl_version, ) from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT @@ -198,10 +199,13 @@ def _lock(): lf.close() -def _build_tensor(size, value=None, dtype=torch.float): +def _build_tensor(size, value=None, dtype=torch.float, device_id=None): if value is None: value = size - return torch.empty(size, size, size, dtype=dtype).fill_(value) + if device_id is None: + return torch.empty(size, size, size, dtype=dtype).fill_(value) + else: + return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id) def _build_multidim_tensor(dim, dim_size, value=None): @@ -571,6 +575,184 @@ def test_backend_group(self): def test_backend_full_group(self): self._test_group_override_backend(self._init_full_group_test) + # NCCL Batch SEND RECV + @skip_if_no_gpu + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_nccl(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + p2p_op_list = [] + + for val in ["1", "0"]: + os.environ["NCCL_BLOCKING_WAIT"] = val + for src in range(0, dist.get_world_size()): + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(src + 1, value=-1, device_id=device_id) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # GLOO Batch SEND RECV CPU + @skip_if_no_gpu + @unittest.skipIf(BACKEND != "gloo", "GLOO Batch Send Recv CPU") + def test_batch_isend_irecv_gloo(self): + self._barrier() + rank = dist.get_rank() + p2p_op_list = [] + + for src in range(0, dist.get_world_size()): + if src == rank: + continue + send_tensor = _build_tensor(rank + 1) + recv_tensor = _build_tensor(src + 1, value=-1) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # GLOO Batch SEND RECV CPU with provided tags + @skip_if_no_gpu + @unittest.skipIf(BACKEND != "gloo", "GLOO Batch Send Recv CPU") + def test_batch_isend_irecv_gloo_tags(self): + self._barrier() + rank = dist.get_rank() + p2p_op_list = [] + + for src in range(0, dist.get_world_size()): + if src == rank: + continue + send_tensor = _build_tensor(rank + 1) + recv_tensor = _build_tensor(src + 1, value=-1) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # NCCL Batch SEND RECV Tensor Error + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_tensor_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex( + RuntimeError, "Tensors must be CUDA and dense" + ): + send_tensor = _build_tensor(rank + 1) + send_op = dist.P2POp(dist.isend, send_tensor, 1) + req = dist.batch_isend_irecv([send_op]) + req.wait() + + # NCCL Batch SEND RECV Op Error + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_op_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex( + RuntimeError, "^Invalid ``op``" + ): + send_tensor = _build_tensor(rank + 1, device_id=device_id) + send_op = dist.P2POp(dist.broadcast, send_tensor, 1) + req = dist.batch_isend_irecv([send_op]) + req.wait() + + # NCCL Batch SEND RECV p2p_op_list Error + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_op_list_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex( + RuntimeError, "^Invalid ``p2p_op_list``" + ): + send_tensor = _build_tensor(rank + 1) + req = dist.batch_isend_irecv([1, 2]) + req.wait() + + # NCCL Batch SEND RECV Mixed Backend Error + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_mixed_backend_err(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + group_gloo = dist.new_group(ranks=[0, 1], backend="gloo") + group_nccl = dist.new_group(ranks=[0, 1], backend="nccl") + if rank == 0: + with self.assertRaisesRegex( + RuntimeError, "All groups need to use the same backend" + ): + send_tensor = _build_tensor(rank + 1) + send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo) + send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl) + req = dist.batch_isend_irecv([send_op_gloo, send_op_nccl]) + req.wait() + + # NCCL SEND RECV + @unittest.skip("NCCL P2P is not enabled for OSS builds") + @skip_if_no_gpu + @unittest.skipIf(BACKEND != "nccl", "NCCL Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_send_recv_nccl(self): + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + + tensor = _build_tensor(rank + 1, device_id=device_id) + + for src in range(0, dist.get_world_size()): + if src == rank: + # Send mode + for dst in range(0, dist.get_world_size()): + if dst == rank: + continue + dist.send(tensor, dst) + else: + # Recv mode + expected_tensor = _build_tensor(src + 1) + output_tensor = _build_tensor(src + 1, value=-1, device_id=device_id) + dist.recv(output_tensor, src) + self.assertEqual(output_tensor, expected_tensor) + + self._barrier() + # SEND RECV @unittest.skipIf(BACKEND == "nccl", "Nccl does not support send/recv") def test_send_recv(self):