Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}

ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
return reinterpret_cast<ncclComm_t>(var);
}

ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
return reinterpret_cast<ncclUniqueId*>(var);
}
Expand Down Expand Up @@ -107,16 +111,20 @@ using namespace at;

namespace detail {

static inline void NCCL_CHECK(ncclResult_t result) {
NCCL_CHECK(from_nccl_result(result));
}

struct AutoNcclGroup {
AutoNcclGroup() {
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(from_nccl_result(ncclGroupStart()));
NCCL_CHECK(ncclGroupStart());
#endif
}
~AutoNcclGroup() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(from_nccl_result(ncclGroupEnd()));
NCCL_CHECK(ncclGroupEnd());
#endif
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
}
Expand All @@ -133,8 +141,8 @@ struct NcclCommList {
int ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(from_nccl_result(
ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data())));
NCCL_CHECK(
ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
~NcclCommList() {
Expand Down Expand Up @@ -326,7 +334,7 @@ void get_unique_id(ncclUniqueId& id)
{
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(from_nccl_result(ncclGetUniqueId(to_nccl_unique_id(&id))));
NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
Expand All @@ -337,11 +345,11 @@ ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
using namespace torch::cuda::nccl::detail;
ncclComm_t comm;
ncclUniqueId id = comm_id;
NCCL_CHECK(from_nccl_result(ncclCommInitRank(
NCCL_CHECK(ncclCommInitRank(
to_nccl_comm(&comm),
nranks,
*(to_nccl_unique_id(&id)),
rank)));
rank));
return comm;
#else
return nullptr;
Expand All @@ -362,8 +370,7 @@ void comm_destroy(ncclComm_t comm)

#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(from_nccl_result(ncclCommDestroy(
*(to_nccl_comm(&comm)))));
NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
#endif
}

Expand Down Expand Up @@ -420,8 +427,8 @@ void broadcast(
count_max,
")");
ncclComm_t comm = comms[i];
NCCL_CHECK(from_nccl_result(ncclBcast(
tensors[i].data_ptr(), numel, data_type, 0, *(to_nccl_comm(&comm)), stream)));
NCCL_CHECK(ncclBcast(
tensors[i].data_ptr(), numel, data_type, 0, to_nccl_comm(comm), stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -460,15 +467,15 @@ void reduce(
: streams[i]->stream();

ncclComm_t comm = comms_ref[i];
NCCL_CHECK(from_nccl_result(ncclReduce(
NCCL_CHECK(ncclReduce(
inputs[i].data_ptr(),
root == i ? output.data_ptr() : nullptr,
count,
data_type,
to_nccl_red_op(op),
root,
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -512,14 +519,14 @@ void all_reduce(
: streams[i]->stream();

ncclComm_t comm = comms_ref[i];
NCCL_CHECK(from_nccl_result(ncclAllReduce(
NCCL_CHECK(ncclAllReduce(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -554,14 +561,14 @@ void reduce_scatter(
: streams[i]->stream();

ncclComm_t comm = comms_ref[i];
NCCL_CHECK(from_nccl_result(ncclReduceScatter(
NCCL_CHECK(ncclReduceScatter(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
Expand Down Expand Up @@ -596,21 +603,21 @@ void all_gather(

ncclComm_t comm = comms_ref[i];
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(from_nccl_result(ncclAllGather(
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
#else
NCCL_CHECK(from_nccl_result(ncclAllGather(
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
count,
data_type,
outputs[i].data_ptr(),
*(to_nccl_comm(&comm)),
stream)));
to_nccl_comm(comm),
stream));
#endif
}
#else
Expand Down