Skip to content

Commit

Permalink
[NCCL] create NCCL communicator for send/recv on demand
Browse files Browse the repository at this point in the history
Pull Request resolved: #44922

For NCCL send/recv operations, we will create NCCL communicator on demand following the same design as how it's currently done for collective operations.
ghstack-source-id: 113592757

Differential Revision: [D23773726](https://our.internmc.facebook.com/intern/diff/D23773726/)
  • Loading branch information
mingzhe0908 committed Oct 5, 2020
1 parent a7ae19a commit da7a42a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 20 deletions.
64 changes: 47 additions & 17 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -107,6 +107,13 @@ std::string getKeyFromDevices(const std::vector<at::Device>& devices) {
return deviceList;
}

std::string getKeySendRecv(int myRank, int peer) {
int lowRank = myRank < peer ? myRank : peer;
int highRank = myRank < peer ? peer : myRank;
std::string sendRecvPair = std::to_string(lowRank) + ":" + std::to_string(highRank);
return sendRecvPair;
}

// Get the list of devices from list of tensors
std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) {
std::vector<at::Device> res;
Expand Down Expand Up @@ -716,7 +723,9 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) {

std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices) {
const std::vector<at::Device>& devices,
NCCLCommType commType,
int p2pRank) {
// Sanity check
if (devicesKey.empty()) {
throw std::runtime_error(
Expand All @@ -743,7 +752,8 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
// Create the unique NCCL ID and broadcast it
ncclUniqueId ncclID;

if (rank_ == 0) {
// For point-to-point communication, lower rank of the two will get unique id.
if (rank_ == 0 || (commType != NCCLCommType::COLL && p2pRank == 0)) {
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID));
}

Expand Down Expand Up @@ -779,8 +789,17 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(

for (size_t i = 0; i < devices.size(); ++i) {
// GPU world size and GPU rank
int numRanks = getSize() * devices.size();
int rank = getRank() * devices.size() + i;
int numRanks, rank;

if (commType == NCCLCommType::COLL) {
numRanks = getSize() * devices.size();
rank = getRank() * devices.size() + i;
} else {
// For point-to-point operation, there are only 2 processes involved so
// the GPU rank is either 0 or 1.
numRanks = 2;
rank = p2pRank;
}
// Get the device index
int deviceIndex = devices[i].index();

Expand Down Expand Up @@ -1038,20 +1057,22 @@ template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::pointToPoint(
std::vector<at::Tensor>& tensors,
Fn fn,
bool isRecv,
int peer,
NCCLCommType commType,
PreProcess pre,
PostProcess post) {
const auto devices = getDeviceList(tensors);
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices);
const auto key = getKeySendRecv(rank_, peer);
int p2pRank = rank_ < peer ? 0 : 1;
auto& ncclComms = getNCCLComm(key, devices, commType, p2pRank);

// First let NCCL streams wait for input tensors allocation streams
syncStreams(devices, ncclEvents_[key], ncclStreams_[key]);

// Work itself will create the CUDA events on all GPUs of tensors
auto work = initWork(devices);

if (isRecv) {
if (commType == NCCLCommType::RECV) {
// Store references to outputs and futureNCCLCallbackStream to be used by
// WorkNCCL::getFuture.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(tensors);
Expand Down Expand Up @@ -1080,8 +1101,11 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::pointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) {
gpuGuard.set_index(devices[i].index());
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
// For point-to-point communication, NCCL ranks can only
// be 0 or 1.
int p2pTargetRank = 1 - p2pRank;
C10D_NCCL_CHECK(
fn(tensors[i], ncclComms[i]->getNcclComm(), ncclStream));
fn(tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank));
}
}

Expand Down Expand Up @@ -1117,11 +1141,13 @@ template <typename Fn>
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
bool isRecv) {
int peer,
NCCLCommType type) {
return pointToPoint(
tensor,
fn,
isRecv,
peer,
type,
[](std::vector<at::cuda::CUDAStream>&) {},
[](std::vector<at::cuda::CUDAStream>&) {});
}
Expand Down Expand Up @@ -1411,16 +1437,18 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::send(
tensors,
[&](at::Tensor& input,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
at::cuda::CUDAStream& stream,
int dst) {
return ncclSend(
input.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
dstRank,
dst,
comm,
stream.stream());
},
/* isRecv */ false);
dstRank,
NCCLCommType::SEND);
return ret;
}

Expand All @@ -1433,16 +1461,18 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::recv(
tensors,
[&](at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
at::cuda::CUDAStream& stream,
int src) {
return ncclRecv(
output.data_ptr(),
output.numel(),
getNcclDataType(output.scalar_type()),
srcRank,
src,
comm,
stream.stream());
},
/* isRecv */ true);
srcRank,
NCCLCommType::RECV);
return ret;
}
#else
Expand Down
26 changes: 23 additions & 3 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -23,6 +23,13 @@ constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
// Handling with NCCL.
constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";

// NCCL Commmunication type
enum class NCCLCommType : std::uint8_t {
SEND = 0,
RECV,
COLL,
};

// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
Expand Down Expand Up @@ -459,7 +466,9 @@ class ProcessGroupNCCL : public ProcessGroup {
// a new set of NCCL communicators as a cache entry
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices);
const std::vector<at::Device>& devices,
NCCLCommType commType = NCCLCommType::COLL,
int p2pRank = 0);

// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
Expand Down Expand Up @@ -495,12 +504,14 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Work> pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
bool isRecv);
int peer,
NCCLCommType commType);
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
bool isRecv,
int peer,
NCCLCommType commType,
PreProcess pre,
PostProcess post);

Expand Down Expand Up @@ -545,6 +556,8 @@ class ProcessGroupNCCL : public ProcessGroup {
uint64_t ncclCommCounter_{0};

// The NCCL communicator that the process group has cached.
//
// For collective operations:
// The key is a list of GPU devices that an operation is operating on
// The GPU devices are stored in a device sequence and the cache NCCL
// communicator is associated with this GPU device sequence
Expand All @@ -563,6 +576,13 @@ class ProcessGroupNCCL : public ProcessGroup {
// "0,4,5,6,7,1,2,3"
//
// Note that the order of the device for the tensor list matters.
//
// For point-to-point operations:
// The key is a string of my current rank and the peer process rank.
// e.g. If process 1 and process 2 are involved in a point-to-point communication,
// the key will be "1:2" on both processes.
// Note: this is for the scenario where there is only 1 GPU per process.
// When it comes to multiple GPUs per process, this part may need to redesigned.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;

Expand Down

0 comments on commit da7a42a

Please sign in to comment.