Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Release/1.7] Enable NCCL A2A on OSS #48857

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/csrc/cuda/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
// When splitting, the view operations will make all Variables broadcast
// together to share a single version counter, because they are all views of the
// large Variable. However, that large Variable is immediately discarded and all
// these Varaibles do not share storage at all.
// these Variables do not share storage at all.
//
// For example, when two buffers are broadcast together in `DataParallel` and
// one of them is modified in-place during `forward` but the other is needed in
Expand Down
95 changes: 69 additions & 26 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}

ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
return reinterpret_cast<ncclComm_t>(var);
}

ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
return reinterpret_cast<ncclUniqueId*>(var);
}
Expand Down Expand Up @@ -107,16 +111,20 @@ using namespace at;

namespace detail {

static inline void NCCL_CHECK(ncclResult_t result) {
NCCL_CHECK(from_nccl_result(result));
}

struct AutoNcclGroup {
AutoNcclGroup() {
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(from_nccl_result(ncclGroupStart()));
NCCL_CHECK(ncclGroupStart());
#endif
}
~AutoNcclGroup() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(from_nccl_result(ncclGroupEnd()));
NCCL_CHECK(ncclGroupEnd());
#endif
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
}
Expand All @@ -133,8 +141,8 @@ struct NcclCommList {
int ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(from_nccl_result(
ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data())));
NCCL_CHECK(
ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
~NcclCommList() {
Expand Down Expand Up @@ -326,7 +334,7 @@ void get_unique_id(ncclUniqueId& id)
{
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(from_nccl_result(ncclGetUniqueId(to_nccl_unique_id(&id))));
NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
Expand All @@ -337,11 +345,11 @@ ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
using namespace torch::cuda::nccl::detail;
ncclComm_t comm;
ncclUniqueId id = comm_id;
NCCL_CHECK(from_nccl_result(ncclCommInitRank(
NCCL_CHECK(ncclCommInitRank(
to_nccl_comm(&comm),
nranks,
*(to_nccl_unique_id(&id)),
rank)));
rank));
return comm;
#else
return nullptr;
Expand All @@ -362,8 +370,7 @@ void comm_destroy(ncclComm_t comm)

#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(from_nccl_result(ncclCommDestroy(
*(to_nccl_comm(&comm)))));
NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
#endif
}

Expand Down Expand Up @@ -420,8 +427,8 @@ void broadcast(
count_max,
")");
ncclComm_t comm = comms[i];
NCCL_CHECK(from_nccl_result(ncclBcast(
tensors[i].data_ptr(), numel, data_type, 0, *(to_nccl_comm(&comm)), stream)));
NCCL_CHECK(ncclBcast(
tensors[i].data_ptr(), numel, data_type, 0, to_nccl_comm(comm), stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -460,15 +467,15 @@ void reduce(
: streams[i]->stream();

ncclComm_t comm = comms_ref[i];
NCCL_CHECK(from_nccl_result(ncclReduce(
NCCL_CHECK(ncclReduce(
inputs[i].data_ptr(),
root == i ? output.data_ptr() : nullptr,
count,
data_type,
to_nccl_red_op(op),
root,
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -512,14 +519,14 @@ void all_reduce(
: streams[i]->stream();

ncclComm_t comm = comms_ref[i];
NCCL_CHECK(from_nccl_result(ncclAllReduce(
NCCL_CHECK(ncclAllReduce(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -554,14 +561,14 @@ void reduce_scatter(
: streams[i]->stream();

ncclComm_t comm = comms_ref[i];
NCCL_CHECK(from_nccl_result(ncclReduceScatter(
NCCL_CHECK(ncclReduceScatter(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -596,27 +603,63 @@ void all_gather(

ncclComm_t comm = comms_ref[i];
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(from_nccl_result(ncclAllGather(
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
#else
NCCL_CHECK(from_nccl_result(ncclAllGather(
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
count,
data_type,
outputs[i].data_ptr(),
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
#endif
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

void all2all(at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;

int numranks;
auto type = to_nccl_data_type(input);
size_t count = input.numel() / size;
size_t rankdiff = input.nbytes() / size;
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
auto* recvbuff = reinterpret_cast<char *>(output.data_ptr());
auto comm = to_nccl_comm(_comm);
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (int r = 0; r < numranks; r++) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (count != 0) {
NCCL_CHECK(ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
NCCL_CHECK(ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

} // namespace nccl
} // namespace cuda
} // namespace torch
7 changes: 7 additions & 0 deletions torch/csrc/cuda/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ TORCH_CUDA_API void all_gather(
const stream_list& streams = {},
const comm_list& user_comms = {});

TORCH_CUDA_API void all2all(
at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t comm,
at::cuda::CUDAStream& stream);

} // namespace nccl
} // namespace cuda
} // namespace torch
3 changes: 0 additions & 3 deletions torch/lib/c10d/NCCLUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#define ENABLE_NCCL_ERROR_CHECKING
#endif

// Fix build issues with NCCL P2P - until then disable NCCL send/recv.
#if defined(ENABLE_NCCL_A2A) && (ENABLE_NCCL_A2A == 1)
// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
// and ncclRecv() are not supported in earlier versions.
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
Expand All @@ -27,7 +25,6 @@
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_P2P_SUPPORT
#endif
#endif

// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd) \
Expand Down
39 changes: 7 additions & 32 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/cuda/nccl.h>

#include <c10d/Utils.hpp>
namespace c10d {
Expand Down Expand Up @@ -158,31 +159,6 @@ std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) {
}

#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclResult_t ncclAlltoall(
void* sendbuff,
void* recvbuff,
size_t count,
size_t size,
ncclDataType_t type,
ncclComm_t comm,
cudaStream_t stream) {
int numranks;
size_t rankdiff = count * size;
C10D_NCCL_CHECK(ncclCommCount(comm, &numranks));
C10D_NCCL_CHECK(ncclGroupStart());
for (int r = 0; r < numranks; r++) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (count != 0) {
C10D_NCCL_CHECK(ncclSend(
((char*)sendbuff) + r * rankdiff, count, type, r, comm, stream));
C10D_NCCL_CHECK(ncclRecv(
((char*)recvbuff) + r * rankdiff, count, type, r, comm, stream));
}
}
C10D_NCCL_CHECK(ncclGroupEnd());
return ncclSuccess;
}

ncclResult_t ncclAlltoallv(
void* sendbuff,
Expand Down Expand Up @@ -1255,14 +1231,13 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
return ncclAlltoall(
input.data_ptr(),
output.data_ptr(),
input.numel() / size_,
input.element_size(),
getNcclDataType(input.scalar_type()),
torch::cuda::nccl::all2all(
input,
output,
this->getSize(),
comm,
stream.stream());
stream);
return ncclSuccess;
});
} else {
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
Expand Down