diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 8b05caea5aba..5efb77ea536a 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -660,6 +660,54 @@ void all2all(at::Tensor& input, #endif } +void send( + const at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int dst) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 7) + using namespace torch::cuda::nccl::detail; + NCCL_CHECK(ncclSend( + input.data_ptr(), + input.numel(), + to_nccl_data_type(input), + dst, + to_nccl_comm(comm), + stream.stream())); +#else + AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + +void recv( + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int src) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 7) + using namespace torch::cuda::nccl::detail; + NCCL_CHECK(ncclRecv( + output.data_ptr(), + output.numel(), + to_nccl_data_type(output), + src, + to_nccl_comm(comm), + stream.stream())); +#else + AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index ecf854ec2009..4cbae2e0208a 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -143,6 +143,17 @@ TORCH_CUDA_API void all2all( ncclComm_t comm, at::cuda::CUDAStream& stream); +TORCH_CUDA_API void send( + const at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int dst); + +TORCH_CUDA_API void recv( + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int src); } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 6e5b72a95b25..c41b60f208bc 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -1437,18 +1437,13 @@ std::shared_ptr ProcessGroupNCCL::send( int /* unused */) { check_gpu_tensors(tensors); auto ret = pointToPoint( - tensors, + tensors, [&](at::Tensor& input, ncclComm_t comm, at::cuda::CUDAStream& stream, int dst) { - return ncclSend( - input.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - dst, - comm, - stream.stream()); + torch::cuda::nccl::send(input, comm, stream, dst); + return ncclSuccess; }, dstRank, NCCLCommType::SEND); @@ -1466,13 +1461,8 @@ std::shared_ptr ProcessGroupNCCL::recv( ncclComm_t comm, at::cuda::CUDAStream& stream, int src) { - return ncclRecv( - output.data_ptr(), - output.numel(), - getNcclDataType(output.scalar_type()), - src, - comm, - stream.stream()); + torch::cuda::nccl::recv(output, comm, stream, src); + return ncclSuccess; }, srcRank, NCCLCommType::RECV);