Skip to content

Commit

Permalink
Add torch::cuda::ncll::all2all (#45900)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45900

Use `torch:cuda::nccl:all2all` from `ProcesGroupNCCL.cpp`

Fixes #42517

Here is a NCCL dependency graph:
```
libnccl.a --> libtorch_cuda.so ---> libtorch_python.so
    |                                   ^
    |                                   |
    --------> libc10d.a -----------------
```
When static library is linked into a dynamic library or an executable, linker is removes all unused/duplicate symbols from that library, unless `-whole-archive` option is used. Before #42514 all nccl call made from `ProcessGroupNCCL.cpp` were also made from `torch/csrc/cuda/nccl.cpp`, which is compiled as part of `libtorch_cuda.so`
But adding `ncclSend`|`ncclRecv` to ProcesGroupNCCL.cpp forced linker to embed those into `libtorch_python.so`, which also resulted in linking other dependent symbols into the library.

This PR adds `nccl[Send|Recv]` call to `torch_cuda.so` by implementing `all2all` in `torch_cuda` and thus avoids double linking the static library.

More involved, but prone solution, would be to use wrappers exported in `torch::cuda::nccl` namespace, instead of making direct NCCL API calls.

Test Plan: Imported from OSS

Reviewed By: mingzhe09088

Differential Revision: D24138011

Pulled By: malfet

fbshipit-source-id: 33305197fc7d8707b7fd3a66b543f7733b9241a1
  • Loading branch information
malfet authored and mingzhe0908 committed Dec 4, 2020
1 parent fea103d commit 445963c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 36 deletions.
2 changes: 1 addition & 1 deletion torch/csrc/cuda/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
// When splitting, the view operations will make all Variables broadcast
// together to share a single version counter, because they are all views of the
// large Variable. However, that large Variable is immediately discarded and all
// these Varaibles do not share storage at all.
// these Variables do not share storage at all.
//
// For example, when two buffers are broadcast together in `DataParallel` and
// one of them is modified in-place during `forward` but the other is needed in
Expand Down
36 changes: 36 additions & 0 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,42 @@ void all_gather(
AT_ERROR("PyTorch built without NCCL support");
#endif
}

void all2all(at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;

int numranks;
auto type = to_nccl_data_type(input);
size_t count = input.numel() / size;
size_t rankdiff = input.nbytes() / size;
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
auto* recvbuff = reinterpret_cast<char *>(output.data_ptr());
auto comm = to_nccl_comm(_comm);
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (int r = 0; r < numranks; r++) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (count != 0) {
NCCL_CHECK(ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
NCCL_CHECK(ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

} // namespace nccl
} // namespace cuda
} // namespace torch
7 changes: 7 additions & 0 deletions torch/csrc/cuda/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ TORCH_CUDA_API void all_gather(
const stream_list& streams = {},
const comm_list& user_comms = {});

TORCH_CUDA_API void all2all(
at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t comm,
at::cuda::CUDAStream& stream);

} // namespace nccl
} // namespace cuda
} // namespace torch
3 changes: 0 additions & 3 deletions torch/lib/c10d/NCCLUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#define ENABLE_NCCL_ERROR_CHECKING
#endif

// Fix build issues with NCCL P2P - until then disable NCCL send/recv.
#if defined(ENABLE_NCCL_A2A) && (ENABLE_NCCL_A2A == 1)
// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
// and ncclRecv() are not supported in earlier versions.
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
Expand All @@ -27,7 +25,6 @@
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_P2P_SUPPORT
#endif
#endif

// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd) \
Expand Down
39 changes: 7 additions & 32 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/cuda/nccl.h>

#include <c10d/Utils.hpp>
namespace c10d {
Expand Down Expand Up @@ -158,31 +159,6 @@ std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) {
}

#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclResult_t ncclAlltoall(
void* sendbuff,
void* recvbuff,
size_t count,
size_t size,
ncclDataType_t type,
ncclComm_t comm,
cudaStream_t stream) {
int numranks;
size_t rankdiff = count * size;
C10D_NCCL_CHECK(ncclCommCount(comm, &numranks));
C10D_NCCL_CHECK(ncclGroupStart());
for (int r = 0; r < numranks; r++) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (count != 0) {
C10D_NCCL_CHECK(ncclSend(
((char*)sendbuff) + r * rankdiff, count, type, r, comm, stream));
C10D_NCCL_CHECK(ncclRecv(
((char*)recvbuff) + r * rankdiff, count, type, r, comm, stream));
}
}
C10D_NCCL_CHECK(ncclGroupEnd());
return ncclSuccess;
}

ncclResult_t ncclAlltoallv(
void* sendbuff,
Expand Down Expand Up @@ -1255,14 +1231,13 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
return ncclAlltoall(
input.data_ptr(),
output.data_ptr(),
input.numel() / size_,
input.element_size(),
getNcclDataType(input.scalar_type()),
torch::cuda::nccl::all2all(
input,
output,
this->getSize(),
comm,
stream.stream());
stream);
return ncclSuccess;
});
} else {
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
Expand Down

0 comments on commit 445963c

Please sign in to comment.