Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split FutureNCCL's CUDA-specific parts from generic future logic #48504

Closed
wants to merge 9 commits into from
112 changes: 65 additions & 47 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -228,10 +228,7 @@ class ProcessGroupNCCL : public ProcessGroup {
throw *error_;
}

for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}
postWaitHook();
}

// If FutureNCCL was created by FutureNCCL::then, its value would be empty
Expand All @@ -245,23 +242,7 @@ class ProcessGroupNCCL : public ProcessGroup {
"outputs or the return value of the callback.");
value_ = std::move(value);

if (cudaEvents_ == nullptr) {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}
postMarkCompletedHook();
}

// Just returns FutureNCCL's value after wait returns.
Expand All @@ -282,32 +263,9 @@ class ProcessGroupNCCL : public ProcessGroup {
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
// Get a stream for all devices, even those that are not used by the
// value, because the user's callback could use those other devices.
std::vector<at::cuda::CUDAStream> streams;
for (c10::DeviceIndex idx = 0; idx < c10::cuda::device_count(); idx++) {
// FIXME Should we find a way to allow to change the priority of
// streams?
streams.push_back(
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx));
}

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, streams[data_ptr.device().index()]);
}
}

for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(streams[cudaEvent.device_index()]);
}

// Use the dedicated callback stream to run callback.
at::cuda::CUDAMultiStreamGuard streamGuard(streams);
callback();
std::function<void(void)> wrappedCallback =
wrapCallback(std::move(callback));
wrappedCallback();
}

// Adds a callback to FutureNCCL, and returns another FutureNCCL to hold
Expand Down Expand Up @@ -356,6 +314,66 @@ class ProcessGroupNCCL : public ProcessGroup {
}
}

protected:
void postMarkCompletedHook() {
if (cudaEvents_ == nullptr) {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}
}

std::function<void(void)> wrapCallback(std::function<void(void)> callback) {
return [this, callback{std::move(callback)}]() {
// Get a stream for all devices, even those that are not used by the
// value, because the user's callback could use those other devices.
std::vector<at::cuda::CUDAStream> streams;
for (c10::DeviceIndex idx = 0; idx < c10::cuda::device_count(); idx++) {
// FIXME Should we find a way to allow to change the priority of
// streams?
streams.push_back(
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx));
}

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, streams[data_ptr.device().index()]);
}
}

for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(streams[cudaEvent.device_index()]);
}

// Use the dedicated callback stream to run callback.
at::cuda::CUDAMultiStreamGuard streamGuard(streams);

callback();
};
}

void postWaitHook() {
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}
}

private:
at::IValue value_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
Expand Down