Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for NCCL alltoall #44374

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d427943
Add support for NCCL all to all
zasdfgbnm Sep 8, 2020
0215034
fix
zasdfgbnm Sep 9, 2020
d03c3b1
fix
zasdfgbnm Sep 9, 2020
03831ec
fix
zasdfgbnm Sep 9, 2020
d0da056
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Sep 9, 2020
ffe67dc
tests
zasdfgbnm Sep 9, 2020
16cec66
fix
zasdfgbnm Sep 9, 2020
c331af0
fix
zasdfgbnm Sep 9, 2020
865f4a8
cleanup
zasdfgbnm Sep 9, 2020
7efc5a3
group
zasdfgbnm Oct 16, 2020
1c39d26
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 16, 2020
1f54c1f
save
zasdfgbnm Oct 16, 2020
0683f63
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 16, 2020
a78d22d
error message
zasdfgbnm Oct 16, 2020
4fb63ac
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 20, 2020
2ab4ff2
OpType::ALLTOALL
zasdfgbnm Oct 20, 2020
62a0d7b
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 20, 2020
213409e
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 21, 2020
76090bb
update all to all
zasdfgbnm Oct 21, 2020
8ba2aa8
more
zasdfgbnm Oct 21, 2020
06e9d9a
fix
zasdfgbnm Oct 21, 2020
e2586cb
fix
zasdfgbnm Oct 21, 2020
3d3e3d8
@skip_if_rocm
zasdfgbnm Oct 21, 2020
81ccdda
Merge branch 'master' into nccl-all2all
zasdfgbnm Oct 23, 2020
de82338
fix
zasdfgbnm Oct 23, 2020
d1590e8
Merge branch 'master' into nccl-all2all
zasdfgbnm Nov 9, 2020
e94b602
Merge branch 'master' into nccl-all2all
zasdfgbnm Nov 14, 2020
81c214b
Update ProcessGroupNCCL.cpp
zasdfgbnm Nov 15, 2020
58d50c5
Update ProcessGroupNCCL.cpp
zasdfgbnm Nov 16, 2020
1e163a6
fix
zasdfgbnm Nov 19, 2020
3e5f29f
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Nov 19, 2020
482368a
fix
zasdfgbnm Nov 19, 2020
7abc38a
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Dec 7, 2020
3b20dd6
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Jan 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 102 additions & 6 deletions torch/csrc/cuda/nccl.cpp
Expand Up @@ -71,11 +71,8 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
}
}

ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
if (!t.is_cuda()) {
throw std::runtime_error("Unconvertible NCCL type");
}
switch (t.scalar_type()) {
ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
switch (type) {
case at::kFloat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today's ProcessGroupNCCL also supports at::kBool, is that the same as at::kByte?

    {at::kChar, ncclInt8},
    {at::kByte, ncclUint8},
    {at::kFloat, ncclFloat},
    {at::kDouble, ncclDouble},
    {at::kInt, ncclInt32},
    {at::kLong, ncclInt64},
    {at::kHalf, ncclHalf},
    {at::kBool, ncclUint8},
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301
    {at::kBFloat16, ncclBfloat16},
#endif

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added kBool. And yes, I think kByte should be ncclUint8 as well, instead of ncclChar as currently in this file. I have updated this.

return ncclDataType_t::ncclFloat;
case at::kHalf:
Expand All @@ -99,6 +96,13 @@ ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
}
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
}

ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
if (!t.is_cuda()) {
throw std::runtime_error("Unconvertible NCCL type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is prior to this PR, shall we be more explicit on the error message? Should this be the following?

f"NCCL only supports CUDA tensors, but got a tensor on {t.device}"

}
return to_nccl_data_type(t.scalar_type());
}

ncclRedOp_t to_nccl_red_op(int var) {
return (ncclRedOp_t)(var);
}
Expand Down Expand Up @@ -625,7 +629,7 @@ void all_gather(
#endif
}

void all2all(at::Tensor& input,
void all2all_single_equal_split(at::Tensor& input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would I be correct if I assume this API is not visible to users?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems so? I can not find anything about nccl at https://pytorch.org/cppdocs/api/library_root.html

at::Tensor& output,
int size,
ncclComm_t _comm,
Expand Down Expand Up @@ -660,6 +664,98 @@ void all2all(at::Tensor& input,
#endif
}

void all2all_single_unequal_split(
void* sendbuff,
const size_t* sendcounts,
const size_t* senddispls,
void* recvbuff,
const size_t* recvcounts,
const size_t* recvdispls,
size_t size,
c10::ScalarType _type,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;

auto type = to_nccl_data_type(_type);
auto comm = to_nccl_comm(_comm);
int numranks;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (int r = 0; r < numranks; r++) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean even if all send/recv cnts are 0, this would still trigger an zero-byte message to do sync across ranks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all send counts are zero, wouldn't this be an empty nccl group?

if (sendcounts[r] != 0) {
NCCL_CHECK(ncclSend(
((char*)sendbuff) + senddispls[r] * size,
sendcounts[r],
type,
r,
comm,
stream));
}
if (recvcounts[r] != 0) {
NCCL_CHECK(ncclRecv(
((char*)recvbuff) + recvdispls[r] * size,
recvcounts[r],
type,
r,
comm,
stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

void all2all(std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);

NCCL_CHECK(ncclGroupStart());
for (size_t r = 0; r < outputTensors.size(); r++) {
at::Tensor &input = inputTensors[r];
at::Tensor &output = outputTensors[r];
if (input.numel() != 0) {
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
r,
comm,
stream.stream()));
}
if (output.numel() != 0) {
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
r,
comm,
stream.stream()));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}

void send(
const at::Tensor& input,
ncclComm_t comm,
Expand Down
20 changes: 19 additions & 1 deletion torch/csrc/cuda/nccl.h
Expand Up @@ -136,13 +136,31 @@ TORCH_CUDA_API void all_gather(
const stream_list& streams = {},
const comm_list& user_comms = {});

TORCH_CUDA_API void all2all(
TORCH_CUDA_API void all2all_single_equal_split(
at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t comm,
at::cuda::CUDAStream& stream);

TORCH_CUDA_API void all2all_single_unequal_split(
void* sendbuff,
const size_t* sendcounts,
const size_t* senddispls,
void* recvbuff,
const size_t* recvcounts,
const size_t* recvdispls,
size_t size,
c10::ScalarType type,
ncclComm_t comm,
at::cuda::CUDAStream& stream);

TORCH_CUDA_API void all2all(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
ncclComm_t _comm,
at::cuda::CUDAStream& stream);

TORCH_CUDA_API void send(
const at::Tensor& input,
ncclComm_t comm,
Expand Down
86 changes: 39 additions & 47 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -166,49 +166,6 @@ std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) {
return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
}

#ifdef ENABLE_NCCL_P2P_SUPPORT

ncclResult_t ncclAlltoallv(
void* sendbuff,
const size_t* sendcounts,
const size_t* senddispls,
void* recvbuff,
const size_t* recvcounts,
const size_t* recvdispls,
size_t size,
ncclDataType_t type,
ncclComm_t comm,
cudaStream_t stream) {
int numranks;
C10D_NCCL_CHECK(ncclCommCount(comm, &numranks));
C10D_NCCL_CHECK(ncclGroupStart());
for (int r = 0; r < numranks; r++) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (sendcounts[r] != 0) {
C10D_NCCL_CHECK(ncclSend(
((char*)sendbuff) + senddispls[r] * size,
sendcounts[r],
type,
r,
comm,
stream));
}
if (recvcounts[r] != 0) {
C10D_NCCL_CHECK(ncclRecv(
((char*)recvbuff) + recvdispls[r] * size,
recvcounts[r],
type,
r,
comm,
stream));
}
}
C10D_NCCL_CHECK(ncclGroupEnd());
return ncclSuccess;
}
#endif

} // namespace

const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
Expand Down Expand Up @@ -1470,7 +1427,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
// See [Sync Streams].
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
torch::cuda::nccl::all2all(
torch::cuda::nccl::all2all_single_equal_split(
input,
output,
this->getSize(),
Expand Down Expand Up @@ -1503,23 +1460,50 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
// See [Sync Streams].
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
return ncclAlltoallv(
torch::cuda::nccl::all2all_single_unequal_split(
input.data_ptr(),
send_lengths.data(),
send_offsets.data(),
output.data_ptr(),
recv_lengths.data(),
recv_offsets.data(),
input.element_size(),
getNcclDataType(input.scalar_type()),
input.scalar_type(),
comm,
stream.stream());
stream);
return ncclSuccess;
},
OpType::ALLTOALL_BASE,
"nccl:all_to_all");
}
}

std::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing the current failure, so c10::intrusive_ptr<> should be needed here.

std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& /* unused */) {
auto device = outputTensors[0].device();
for (size_t r = 0; r < outputTensors.size(); r++) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
check_gpu_single_tensor(outputTensors[r]);
check_gpu_single_tensor(inputTensors[r]);
TORCH_CHECK(device == outputTensors[r].device() && device == inputTensors[r].device(),
"Tensors must be on the same device")
}
std::vector<at::Tensor> inputTensor0 = {inputTensors[0]};
std::vector<at::Tensor> outputTensor0 = {outputTensors[0]};
return collective(
inputTensor0,
outputTensor0,
[&](at::Tensor& /* unused */,
Copy link
Contributor

@cdzhan cdzhan Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zasdfgbnm @mrshenli Hello, I'm a bit confused, why outputTensors didn't need to record ncclStream to prevent being freed before the collective finishes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a bug... Thanks for catching it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You’re welcome :)

at::Tensor& /* unused */,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream);
return ncclSuccess;
},
OpType::ALLTOALL);
}

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::send(
std::vector<at::Tensor>& tensors,
int dstRank,
Expand Down Expand Up @@ -1568,6 +1552,14 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
"ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0");
}

std::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: c10::intrusive_ptr<>

std::vector<at::Tensor>& /* unused */,
std::vector<at::Tensor>& /* unused */,
const AllToAllOptions& /* unused */) {
throw std::runtime_error(
"ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0");
}

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::send(
std::vector<at::Tensor>& /* unused */,
int /* unused */,
Expand Down