Skip to content

Commit

Permalink
[NCCL] Add torch::cuda::nccl::send/recv (#45926)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45926

torch/csrc/cuda/nccl.cpp is compiled as part of torch_cuda library and thus by calling this function from ProcessGroupNCCCL.cpp it avoids linking 2nd instance of libnccl.a into torch_python
Fixes similiar issue as #42517

ghstack-source-id: 113910530

Test Plan: waitforsandcastle

Reviewed By: jiayisuse

Differential Revision: D24147802

fbshipit-source-id: d8901fdb31bdc22ddca2364f8050844639a1beb3
  • Loading branch information
mingzhe09088 authored and facebook-github-bot committed Oct 9, 2020
1 parent b7f7378 commit 8cd3857
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
48 changes: 48 additions & 0 deletions torch/csrc/cuda/nccl.cpp
Expand Up @@ -660,6 +660,54 @@ void all2all(at::Tensor& input,
#endif
}

void send(
const at::Tensor& input,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int dst) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
to_nccl_comm(comm),
stream.stream()));
#else
AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

void recv(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int src) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
to_nccl_comm(comm),
stream.stream()));
#else
AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

} // namespace nccl
} // namespace cuda
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/cuda/nccl.h
Expand Up @@ -143,6 +143,17 @@ TORCH_CUDA_API void all2all(
ncclComm_t comm,
at::cuda::CUDAStream& stream);

TORCH_CUDA_API void send(
const at::Tensor& input,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int dst);

TORCH_CUDA_API void recv(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int src);
} // namespace nccl
} // namespace cuda
} // namespace torch
20 changes: 5 additions & 15 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1437,18 +1437,13 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::send(
int /* unused */) {
check_gpu_tensors(tensors);
auto ret = pointToPoint(
tensors,
tensors,
[&](at::Tensor& input,
ncclComm_t comm,
at::cuda::CUDAStream& stream,
int dst) {
return ncclSend(
input.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
dst,
comm,
stream.stream());
torch::cuda::nccl::send(input, comm, stream, dst);
return ncclSuccess;
},
dstRank,
NCCLCommType::SEND);
Expand All @@ -1466,13 +1461,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::recv(
ncclComm_t comm,
at::cuda::CUDAStream& stream,
int src) {
return ncclRecv(
output.data_ptr(),
output.numel(),
getNcclDataType(output.scalar_type()),
src,
comm,
stream.stream());
torch::cuda::nccl::recv(output, comm, stream, src);
return ncclSuccess;
},
srcRank,
NCCLCommType::RECV);
Expand Down

0 comments on commit 8cd3857

Please sign in to comment.