New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for NCCL alltoall #44374
Changes from all commits
d427943
0215034
d03c3b1
03831ec
d0da056
ffe67dc
16cec66
c331af0
865f4a8
7efc5a3
1c39d26
1f54c1f
0683f63
a78d22d
4fb63ac
2ab4ff2
62a0d7b
213409e
76090bb
8ba2aa8
06e9d9a
e2586cb
3d3e3d8
81ccdda
de82338
d1590e8
e94b602
81c214b
58d50c5
1e163a6
3e5f29f
482368a
7abc38a
3b20dd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,11 +71,8 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) { | |
} | ||
} | ||
|
||
ncclDataType_t to_nccl_data_type(const at::Tensor& t) { | ||
if (!t.is_cuda()) { | ||
throw std::runtime_error("Unconvertible NCCL type"); | ||
} | ||
switch (t.scalar_type()) { | ||
ncclDataType_t to_nccl_data_type(c10::ScalarType type) { | ||
switch (type) { | ||
case at::kFloat: | ||
return ncclDataType_t::ncclFloat; | ||
case at::kHalf: | ||
|
@@ -89,14 +86,23 @@ ncclDataType_t to_nccl_data_type(const at::Tensor& t) { | |
case at::kChar: | ||
return ncclDataType_t::ncclChar; | ||
case at::kByte: | ||
return ncclDataType_t::ncclChar; | ||
return ncclDataType_t::ncclUint8; | ||
case at::kBool: | ||
return ncclDataType_t::ncclUint8; | ||
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301 | ||
case at::kBFloat16: | ||
return ncclDataType_t::ncclBfloat16; | ||
#endif | ||
default: | ||
throw std::runtime_error("Unconvertible NCCL type"); | ||
TORCH_CHECK(false, "Unconvertible NCCL type ", type); | ||
} | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
ncclDataType_t to_nccl_data_type(const at::Tensor& t) { | ||
if (!t.is_cuda()) { | ||
TORCH_CHECK(false, "NCCL only supports CUDA tensors, but got a tensor on ", t.device()); | ||
} | ||
return to_nccl_data_type(t.scalar_type()); | ||
} | ||
|
||
ncclRedOp_t to_nccl_red_op(int var) { | ||
|
@@ -625,7 +631,7 @@ void all_gather( | |
#endif | ||
} | ||
|
||
void all2all(at::Tensor& input, | ||
void all2all_single_equal_split(at::Tensor& input, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would I be correct if I assume this API is not visible to users? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems so? I can not find anything about nccl at https://pytorch.org/cppdocs/api/library_root.html |
||
at::Tensor& output, | ||
int size, | ||
ncclComm_t _comm, | ||
|
@@ -660,6 +666,98 @@ void all2all(at::Tensor& input, | |
#endif | ||
} | ||
|
||
void all2all_single_unequal_split( | ||
void* sendbuff, | ||
const size_t* sendcounts, | ||
const size_t* senddispls, | ||
void* recvbuff, | ||
const size_t* recvcounts, | ||
const size_t* recvdispls, | ||
size_t size, | ||
c10::ScalarType _type, | ||
ncclComm_t _comm, | ||
at::cuda::CUDAStream& stream) { | ||
#ifdef USE_NCCL | ||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 | ||
using namespace torch::cuda::nccl::detail; | ||
|
||
auto type = to_nccl_data_type(_type); | ||
auto comm = to_nccl_comm(_comm); | ||
int numranks; | ||
NCCL_CHECK(ncclCommCount(comm, &numranks)); | ||
NCCL_CHECK(ncclGroupStart()); | ||
for (int r = 0; r < numranks; r++) { | ||
// NCCL uses 0 byte message for synchronization | ||
// Avoid send/recv when message size is zero | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean even if all send/recv cnts are 0, this would still trigger an zero-byte message to do sync across ranks? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If all send counts are zero, wouldn't this be an empty nccl group? |
||
if (sendcounts[r] != 0) { | ||
NCCL_CHECK(ncclSend( | ||
((char*)sendbuff) + senddispls[r] * size, | ||
sendcounts[r], | ||
type, | ||
r, | ||
comm, | ||
stream)); | ||
} | ||
if (recvcounts[r] != 0) { | ||
NCCL_CHECK(ncclRecv( | ||
((char*)recvbuff) + recvdispls[r] * size, | ||
recvcounts[r], | ||
type, | ||
r, | ||
comm, | ||
stream)); | ||
} | ||
} | ||
NCCL_CHECK(ncclGroupEnd()); | ||
#else | ||
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); | ||
#endif | ||
#else | ||
AT_ERROR("PyTorch built without NCCL support"); | ||
#endif | ||
} | ||
|
||
void all2all(std::vector<at::Tensor>& outputTensors, | ||
std::vector<at::Tensor>& inputTensors, | ||
ncclComm_t _comm, | ||
at::cuda::CUDAStream& stream) { | ||
#ifdef USE_NCCL | ||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 | ||
using namespace torch::cuda::nccl::detail; | ||
auto comm = to_nccl_comm(_comm); | ||
|
||
NCCL_CHECK(ncclGroupStart()); | ||
for (size_t r = 0; r < outputTensors.size(); r++) { | ||
at::Tensor &input = inputTensors[r]; | ||
at::Tensor &output = outputTensors[r]; | ||
if (input.numel() != 0) { | ||
NCCL_CHECK(ncclSend( | ||
input.data_ptr(), | ||
input.numel(), | ||
to_nccl_data_type(input), | ||
r, | ||
comm, | ||
stream.stream())); | ||
} | ||
if (output.numel() != 0) { | ||
NCCL_CHECK(ncclRecv( | ||
output.data_ptr(), | ||
output.numel(), | ||
to_nccl_data_type(output), | ||
r, | ||
comm, | ||
stream.stream())); | ||
} | ||
} | ||
NCCL_CHECK(ncclGroupEnd()); | ||
#else | ||
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); | ||
#endif | ||
#else | ||
AT_ERROR("PyTorch built without NCCL support"); | ||
#endif | ||
} | ||
|
||
void send( | ||
const at::Tensor& input, | ||
ncclComm_t comm, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,49 +166,6 @@ std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { | |
return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; | ||
} | ||
|
||
#ifdef ENABLE_NCCL_P2P_SUPPORT | ||
|
||
ncclResult_t ncclAlltoallv( | ||
void* sendbuff, | ||
const size_t* sendcounts, | ||
const size_t* senddispls, | ||
void* recvbuff, | ||
const size_t* recvcounts, | ||
const size_t* recvdispls, | ||
size_t size, | ||
ncclDataType_t type, | ||
ncclComm_t comm, | ||
cudaStream_t stream) { | ||
int numranks; | ||
C10D_NCCL_CHECK(ncclCommCount(comm, &numranks)); | ||
C10D_NCCL_CHECK(ncclGroupStart()); | ||
for (int r = 0; r < numranks; r++) { | ||
// NCCL uses 0 byte message for synchronization | ||
// Avoid send/recv when message size is zero | ||
if (sendcounts[r] != 0) { | ||
C10D_NCCL_CHECK(ncclSend( | ||
((char*)sendbuff) + senddispls[r] * size, | ||
sendcounts[r], | ||
type, | ||
r, | ||
comm, | ||
stream)); | ||
} | ||
if (recvcounts[r] != 0) { | ||
C10D_NCCL_CHECK(ncclRecv( | ||
((char*)recvbuff) + recvdispls[r] * size, | ||
recvcounts[r], | ||
type, | ||
r, | ||
comm, | ||
stream)); | ||
} | ||
} | ||
C10D_NCCL_CHECK(ncclGroupEnd()); | ||
return ncclSuccess; | ||
} | ||
#endif | ||
|
||
} // namespace | ||
|
||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000; | ||
|
@@ -1468,7 +1425,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base( | |
// See [Sync Streams]. | ||
c10::cuda::CUDACachingAllocator::recordStream( | ||
output.storage().data_ptr(), stream); | ||
torch::cuda::nccl::all2all( | ||
torch::cuda::nccl::all2all_single_equal_split( | ||
input, | ||
output, | ||
this->getSize(), | ||
|
@@ -1501,23 +1458,50 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base( | |
// See [Sync Streams]. | ||
c10::cuda::CUDACachingAllocator::recordStream( | ||
output.storage().data_ptr(), stream); | ||
return ncclAlltoallv( | ||
torch::cuda::nccl::all2all_single_unequal_split( | ||
input.data_ptr(), | ||
send_lengths.data(), | ||
send_offsets.data(), | ||
output.data_ptr(), | ||
recv_lengths.data(), | ||
recv_offsets.data(), | ||
input.element_size(), | ||
getNcclDataType(input.scalar_type()), | ||
input.scalar_type(), | ||
comm, | ||
stream.stream()); | ||
stream); | ||
return ncclSuccess; | ||
}, | ||
OpType::ALLTOALL_BASE, | ||
"nccl:all_to_all"); | ||
} | ||
} | ||
|
||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall( | ||
std::vector<at::Tensor>& outputTensors, | ||
std::vector<at::Tensor>& inputTensors, | ||
const AllToAllOptions& /* unused */) { | ||
auto device = outputTensors[0].device(); | ||
for (size_t r = 0; r < outputTensors.size(); r++) { | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
check_gpu_single_tensor(outputTensors[r]); | ||
check_gpu_single_tensor(inputTensors[r]); | ||
TORCH_CHECK(device == outputTensors[r].device() && device == inputTensors[r].device(), | ||
"Tensors must be on the same device") | ||
} | ||
std::vector<at::Tensor> inputTensor0 = {inputTensors[0]}; | ||
std::vector<at::Tensor> outputTensor0 = {outputTensors[0]}; | ||
return collective( | ||
inputTensor0, | ||
outputTensor0, | ||
[&](at::Tensor& /* unused */, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zasdfgbnm @mrshenli Hello, I'm a bit confused, why outputTensors didn't need to record ncclStream to prevent being freed before the collective finishes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a bug... Thanks for catching it! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You’re welcome :) |
||
at::Tensor& /* unused */, | ||
ncclComm_t comm, | ||
at::cuda::CUDAStream& stream) { | ||
torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream); | ||
return ncclSuccess; | ||
}, | ||
OpType::ALLTOALL); | ||
} | ||
|
||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::send( | ||
std::vector<at::Tensor>& tensors, | ||
int dstRank, | ||
|
@@ -1566,6 +1550,14 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base( | |
"ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); | ||
} | ||
|
||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall( | ||
std::vector<at::Tensor>& /* unused */, | ||
std::vector<at::Tensor>& /* unused */, | ||
const AllToAllOptions& /* unused */) { | ||
throw std::runtime_error( | ||
"ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); | ||
} | ||
|
||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::send( | ||
std::vector<at::Tensor>& /* unused */, | ||
int /* unused */, | ||
|
@@ -1597,13 +1589,6 @@ void ProcessGroupNCCL::groupEnd() { | |
--ncclActiveGroupCounter_; | ||
} | ||
|
||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall( | ||
std::vector<at::Tensor>& /* unused */, | ||
std::vector<at::Tensor>& /* unused */, | ||
const AllToAllOptions& /* unused */) { | ||
throw std::runtime_error("ProcessGroupNCCL does not support alltoall"); | ||
} | ||
|
||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::gather( | ||
std::vector<std::vector<at::Tensor>>& /* unused */, | ||
std::vector<at::Tensor>& /* unused */, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today's
ProcessGroupNCCL
also supportsat::kBool
, is that the same asat::kByte
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added
kBool
. And yes, I think kByte should bencclUint8
as well, instead ofncclChar
as currently in this file. I have updated this.