Skip to content

Commit

Permalink
[ROCm] use ncclAllToAll for rocm (#75128)
Browse files Browse the repository at this point in the history
Summary:
use ncclAllToAll for rocm version > 5.0; ROCm/rccl#503

detail on ncclAllToAll:
ROCm/rccl#503

jithunnair-amd  amathews-amd

Pull Request resolved: #75128
Approved by: https://github.com/wenkaidu, https://github.com/yzygitzh, https://github.com/seemethere

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/97fbe6f0a4d136d34fae851ffee823cb5e73bf71

Reviewed By: seemethere, osalpekar

Differential Revision: D35874469

fbshipit-source-id: 653579837396cdc55bf3a7c2be0e893e16990c9a
  • Loading branch information
KyleCZH authored and facebook-github-bot committed Apr 26, 2022
1 parent eca929d commit ebe160f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ void all2all_single_equal_split(at::Tensor& input,
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
auto* recvbuff = reinterpret_cast<char *>(output.data_ptr());
auto comm = to_nccl_comm(_comm);
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
NCCL_CHECK(ncclAllToAll(sendbuff , recvbuff , count, type, comm, stream));
#else
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for(const auto r : c10::irange(numranks)) {
Expand All @@ -661,6 +664,7 @@ void all2all_single_equal_split(at::Tensor& input,
}
}
NCCL_CHECK(ncclGroupEnd());
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
Expand Down

0 comments on commit ebe160f

Please sign in to comment.