Skip to content

Commit

Permalink
Don't store device indices separately on FutureNCCL (#48501)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48501

This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed).

 ---

FutureNCCL stores a set of devices (on which the tensors in the data reside) and a CUDA event for each of those devices. In fact, each event instance also already contains the device it belongs to, which means we can avoid storing that information separately (with the risk that it'll be mismatched and/or inaccurate).
ghstack-source-id: 118180024

Test Plan: Unit tests

Reviewed By: mrshenli

Differential Revision: D25177554

fbshipit-source-id: 64667c176efc2a7dafe99457a1fbba5d142cb06c
  • Loading branch information
lw authored and facebook-github-bot committed Dec 10, 2020
1 parent e294c2d commit 9fe3ac3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 41 deletions.
11 changes: 1 addition & 10 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1008,16 +1008,7 @@ std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() {

c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
getFuture() {
std::vector<c10::DeviceIndex> deviceIndices;
for (const c10::Device& device : devices_) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
deviceIndices.push_back(device.index());
}

return c10::make_intrusive<FutureNCCL>(
at::IValue(*outputs_),
std::move(deviceIndices),
cudaEvents_);
return c10::make_intrusive<FutureNCCL>(at::IValue(*outputs_), cudaEvents_);
}

void ProcessGroupNCCL::workEnqueue(
Expand Down
48 changes: 17 additions & 31 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -211,40 +211,28 @@ class ProcessGroupNCCL : public ProcessGroup {
public:
explicit FutureNCCL(
at::IValue value,
std::vector<c10::DeviceIndex> deviceIndices,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
deviceIndices_(std::move(deviceIndices)),
cudaEvents_(std::move(cudaEvents)) {
// Check that the device indices are distinct
std::unordered_set<c10::DeviceIndex> uniqueDeviceIndices;
for (const auto& deviceIndex : deviceIndices_) {
uniqueDeviceIndices.insert(deviceIndex);
}
TORCH_INTERNAL_ASSERT(
deviceIndices_.size() == uniqueDeviceIndices.size(),
"Got ", deviceIndices_.size(), " devices, but only ",
uniqueDeviceIndices.size(), " distinct ones");
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == deviceIndices_.size(),
"The device indices and the events must be paired up. Got ",
deviceIndices_.size(), " devices and ", cudaEvents_->size(),
" events.");
for (const at::cuda::CUDAEvent& event : *cudaEvents_) {
TORCH_INTERNAL_ASSERT(event.isCreated());
TORCH_INTERNAL_ASSERT(
std::find(
deviceIndices_.begin(),
deviceIndices_.end(),
event.device_index()) != deviceIndices_.end());
uniqueDeviceIndices.insert(event.device_index());
}
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == uniqueDeviceIndices.size(),
"Got ", cudaEvents_->size(), " events, but only ",
uniqueDeviceIndices.size(), " distinct devices");
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
TORCH_INTERNAL_ASSERT(
std::find(
deviceIndices_.begin(),
deviceIndices_.end(),
data_ptr.device().index()) != deviceIndices_.end());
std::find_if(
cudaEvents_->begin(),
cudaEvents_->end(),
[&](const at::cuda::CUDAEvent& ev) {
return ev.device_index() == data_ptr.device().index();
}) != cudaEvents_->end());
}
}

Expand All @@ -266,9 +254,9 @@ class ProcessGroupNCCL : public ProcessGroup {
throw *error_;
}

for (int i = 0; i < deviceIndices_.size(); i++) {
(*cudaEvents_)[i].block(
at::cuda::getCurrentCUDAStream(deviceIndices_[i]));
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}

for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
Expand Down Expand Up @@ -303,7 +291,6 @@ class ProcessGroupNCCL : public ProcessGroup {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
deviceIndices_.push_back(idx);
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
Expand Down Expand Up @@ -336,13 +323,13 @@ class ProcessGroupNCCL : public ProcessGroup {
// misbehaving this also ends up using memory on those devices, which the
// user might not want.
std::vector<at::cuda::CUDAStream> streams;
for (int i = 0; i < deviceIndices_.size(); i++) {
c10::DeviceIndex idx = deviceIndices_[i];
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
c10::DeviceIndex idx = cudaEvent.device_index();
// FIXME Should we find a way to allow to change the priority of
// streams?
at::cuda::CUDAStream stream =
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx);
(*cudaEvents_)[i].block(stream);
cudaEvent.block(stream);
streams.push_back(stream);
}

Expand Down Expand Up @@ -406,7 +393,6 @@ class ProcessGroupNCCL : public ProcessGroup {

private:
at::IValue value_;
std::vector<c10::DeviceIndex> deviceIndices_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
std::mutex dataPtrExtractorMutex_;
Expand Down

0 comments on commit 9fe3ac3

Please sign in to comment.