New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split out reusable CUDAFuture from FutureNCCL #48506
Changes from all commits
90a4f88
ea34f6c
3a1fde3
97516a3
05a4b5b
67eed12
956aef9
74ecf01
0c93853
d432680
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
#pragma once | ||
|
||
#include <functional> | ||
#include <memory> | ||
#include <mutex> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include <ATen/core/ivalue.h> | ||
#include <ATen/core/ivalue_inl.h> | ||
#include <ATen/core/jit_type.h> | ||
#include <ATen/cuda/CUDAEvent.h> | ||
#include <ATen/cuda/CUDAMultiStreamGuard.h> | ||
#include <c10/core/Allocator.h> | ||
#include <c10/core/Device.h> | ||
#include <c10/cuda/CUDACachingAllocator.h> | ||
#include <c10/cuda/CUDAFunctions.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <c10/macros/Export.h> | ||
#include <c10/util/intrusive_ptr.h> | ||
|
||
namespace at { namespace cuda { | ||
|
||
struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { | ||
public: | ||
using at::ivalue::Future::Future; | ||
|
||
void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override { | ||
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_); | ||
dataPtrExtractor_ = std::move(dataPtrExtractor); | ||
} | ||
|
||
protected: | ||
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override { | ||
auto fut = c10::make_intrusive<CUDAFuture>(std::move(type)); | ||
// The new future needs the DataPtr extractor when it gets marked complete | ||
// but this might happen immediately inline or in parallel by another | ||
// thread. In both these cases this would/might happen before the user has | ||
// time to set their own DataPtr extractor, which might lead to failures | ||
// if the default extractor can't handle some of the user's types. | ||
// Therefore we propagate our extractor. | ||
fut->setDataPtrExtractor(dataPtrExtractor_); | ||
return fut; | ||
} | ||
|
||
void postMarkCompletedHook(const at::IValue& value) override { | ||
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false); | ||
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) { | ||
if (data_ptr.device().is_cuda()) { | ||
isCudaDeviceUsed[data_ptr.device().index()] = true; | ||
} | ||
} | ||
|
||
cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>(); | ||
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) { | ||
if (isCudaDeviceUsed[idx]) { | ||
at::cuda::CUDAEvent cudaEvent; | ||
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx)); | ||
(*cudaEvents_).push_back(std::move(cudaEvent)); | ||
} | ||
} | ||
} | ||
|
||
std::function<void(void)> wrapCallback( | ||
std::function<void(void)> callback) override { | ||
return [this, callback{std::move(callback)}]() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the difference of this between capture by reference? are the callbacks gonna destructed sometime but we still want them alive in the lambda? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
// We'd love to get a stream for all devices, even those that are not used | ||
// by the value, because the callback could use those other devices, but | ||
// unfortunately this could cause a deadlock with NCCL. See | ||
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414 | ||
// In general, if some devices haven't been used yet, by getting a stream | ||
// for them we'd initialize them, and in addition to causing NCCL to | ||
// misbehaving this also ends up using memory on those devices, which the | ||
// user might not want. | ||
std::vector<at::cuda::CUDAStream> streams; | ||
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) { | ||
c10::DeviceIndex idx = cudaEvent.device_index(); | ||
// FIXME Should we find a way to allow to change the priority of | ||
// streams? | ||
at::cuda::CUDAStream stream = | ||
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx); | ||
cudaEvent.block(stream); | ||
streams.push_back(stream); | ||
} | ||
|
||
// Use the dedicated callback stream to run callback. | ||
at::cuda::CUDAMultiStreamGuard streamGuard(streams); | ||
|
||
// Do not free the underlying data storage of value_ before its | ||
// usage on the stream finishes. | ||
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) { | ||
if (data_ptr.device().is_cuda()) { | ||
c10::cuda::CUDACachingAllocator::recordStream( | ||
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); | ||
} | ||
} | ||
|
||
callback(); | ||
}; | ||
} | ||
|
||
void postWaitHook(const at::IValue& value) override { | ||
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, this means users cannot call wait on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I follow: the |
||
cudaEvent.block( | ||
at::cuda::getCurrentCUDAStream(cudaEvent.device_index())); | ||
} | ||
|
||
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) { | ||
if (data_ptr.device().is_cuda()) { | ||
c10::cuda::CUDACachingAllocator::recordStream( | ||
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); | ||
} | ||
} | ||
} | ||
|
||
// FIXME This field is protected (rather than private) and wrapped in a | ||
// shared_ptr in order to support the FutureNCCL subclass, which wants to set | ||
// the events on its own in order to use the same ones as its WorkNCCL class. | ||
// Once WorkNCCL is gone (as part of the Future and Work merge) this should be | ||
// fixed. | ||
protected: | ||
// The events that correspond to the completion of the async I/O kernels. They | ||
// are recorded on the appropriate streams when the future is marked completed | ||
// and can then be queried/waited/blocked on. There is one event for each | ||
// distinct device on which the value's tensors reside. | ||
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_; | ||
|
||
private: | ||
DataPtrExtractor dataPtrExtractor_; | ||
std::mutex dataPtrExtractorMutex_; | ||
|
||
// FIXME This too is protected so that it can be used by FutureNCCL. Please | ||
// undo that once FutureNCCL is dropped in favor of a "vanilla" CUDAFuture. | ||
protected: | ||
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs( | ||
const at::IValue& value) { | ||
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_); | ||
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs; | ||
if (dataPtrExtractor_ != nullptr) { | ||
// If a Python communication hook is used, dataPtrExtractor_ will be | ||
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python | ||
// dependency to be imported. | ||
data_ptrs = dataPtrExtractor_(value); | ||
} else { | ||
// If a C++ communication hook is used, use the default extractor. | ||
data_ptrs = at::ivalue::Future::defaultDataPtrExtractor(value); | ||
} | ||
return data_ptrs; | ||
} | ||
}; | ||
|
||
} // namespace cuda | ||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <ATen/cuda/CUDAEvent.h> | ||
#include <ATen/cuda/CUDAFuture.h> | ||
#include <ATen/cuda/CUDAMultiStreamGuard.h> | ||
#include <c10/core/Stream.h> | ||
#include <c10/core/StreamGuard.h> | ||
|
@@ -207,146 +208,38 @@ class ProcessGroupNCCL : public ProcessGroup { | |
// enables synchronizing the appropriate streams and avoids stalling PyTorch's | ||
// default stream while running the callback. In case of multiple then | ||
// callbacks, each will be executed on its own fresh stream. | ||
struct FutureNCCL : at::ivalue::Future { | ||
struct FutureNCCL : at::cuda::CUDAFuture { | ||
public: | ||
explicit FutureNCCL( | ||
FutureNCCL( | ||
at::IValue value, | ||
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents) | ||
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), | ||
cudaEvents_(std::move(cudaEvents)) { | ||
: at::cuda::CUDAFuture(c10::ListType::create(c10::TensorType::get())){ | ||
// Check that the device indices are distinct | ||
std::unordered_set<c10::DeviceIndex> uniqueDeviceIndices; | ||
for (const at::cuda::CUDAEvent& event : *cudaEvents_) { | ||
for (const at::cuda::CUDAEvent& event : *cudaEvents) { | ||
TORCH_INTERNAL_ASSERT(event.isCreated()); | ||
uniqueDeviceIndices.insert(event.device_index()); | ||
} | ||
TORCH_INTERNAL_ASSERT( | ||
cudaEvents_->size() == uniqueDeviceIndices.size(), | ||
"Got ", cudaEvents_->size(), " events, but only ", | ||
cudaEvents->size() == uniqueDeviceIndices.size(), | ||
"Got ", cudaEvents->size(), " events, but only ", | ||
uniqueDeviceIndices.size(), " distinct devices"); | ||
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) { | ||
TORCH_INTERNAL_ASSERT( | ||
std::find_if( | ||
cudaEvents_->begin(), | ||
cudaEvents_->end(), | ||
cudaEvents->begin(), | ||
cudaEvents->end(), | ||
[&](const at::cuda::CUDAEvent& ev) { | ||
return ev.device_index() == data_ptr.device().index(); | ||
}) != cudaEvents_->end()); | ||
}) != cudaEvents->end()); | ||
} | ||
cudaEvents_ = std::move(cudaEvents); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these events used anywhere? Will the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These events will not be overridden when the future is marked complete, because FutureNCCL provides its own However, these events will still be used by the postWaitHook, by the wrapCallback, etc. It's somewhat hacky, but it works, and it's a temporary solution only until we finish merging Future and Work (at which point ProcessGroupNCCL should be able to use CUDAFuture unmodified). |
||
markCompleted(std::move(value)); | ||
} | ||
|
||
using at::ivalue::Future::Future; | ||
|
||
void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override { | ||
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_); | ||
dataPtrExtractor_ = std::move(dataPtrExtractor); | ||
} | ||
|
||
protected: | ||
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override { | ||
auto fut = c10::make_intrusive<FutureNCCL>(std::move(type)); | ||
// The new future needs the DataPtr extractor when it gets marked complete | ||
// but this might happen immediately inline or in parallel by another | ||
// thread. In both these cases this would/might happen before the user has | ||
// time to set their own DataPtr extractor, which might lead to failures | ||
// if the default extractor can't handle some of the user's types. | ||
// Therefore we propagate our extractor. | ||
fut->setDataPtrExtractor(dataPtrExtractor_); | ||
return fut; | ||
} | ||
|
||
void postMarkCompletedHook(const at::IValue& value) override { | ||
// Check whether the first or second constructor created this instance. | ||
if (cudaEvents_ == nullptr) { | ||
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false); | ||
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) { | ||
if (data_ptr.device().is_cuda()) { | ||
isCudaDeviceUsed[data_ptr.device().index()] = true; | ||
} | ||
} | ||
|
||
cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>(); | ||
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) { | ||
if (isCudaDeviceUsed[idx]) { | ||
at::cuda::CUDAEvent cudaEvent; | ||
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx)); | ||
(*cudaEvents_).push_back(std::move(cudaEvent)); | ||
} | ||
} | ||
} | ||
} | ||
|
||
std::function<void(void)> wrapCallback(std::function<void(void)> callback) override { | ||
return [this, callback{std::move(callback)}]() { | ||
// We'd love to get a stream for all devices, even those that are not used | ||
// by the value, because the callback could use those other devices, but | ||
// unfortunately this could cause a deadlock with NCCL. See | ||
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414 | ||
// In general, if some devices haven't been used yet, by getting a stream | ||
// for them we'd initialize them, and in addition to causing NCCL to | ||
// misbehaving this also ends up using memory on those devices, which the | ||
// user might not want. | ||
std::vector<at::cuda::CUDAStream> streams; | ||
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) { | ||
c10::DeviceIndex idx = cudaEvent.device_index(); | ||
// FIXME Should we find a way to allow to change the priority of | ||
// streams? | ||
at::cuda::CUDAStream stream = | ||
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx); | ||
cudaEvent.block(stream); | ||
streams.push_back(stream); | ||
} | ||
|
||
// Use the dedicated callback stream to run callback. | ||
at::cuda::CUDAMultiStreamGuard streamGuard(streams); | ||
|
||
// Do not free the underlying data storage of value_ before its | ||
// usage on the stream finishes. | ||
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) { | ||
if (data_ptr.device().is_cuda()) { | ||
c10::cuda::CUDACachingAllocator::recordStream( | ||
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); | ||
} | ||
} | ||
|
||
callback(); | ||
}; | ||
} | ||
|
||
void postWaitHook(const at::IValue& value) override { | ||
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) { | ||
cudaEvent.block( | ||
at::cuda::getCurrentCUDAStream(cudaEvent.device_index())); | ||
} | ||
|
||
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) { | ||
if (data_ptr.device().is_cuda()) { | ||
c10::cuda::CUDACachingAllocator::recordStream( | ||
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); | ||
} | ||
} | ||
} | ||
|
||
private: | ||
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_; | ||
DataPtrExtractor dataPtrExtractor_; | ||
std::mutex dataPtrExtractorMutex_; | ||
|
||
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs( | ||
const at::IValue& value) { | ||
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_); | ||
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs; | ||
if (dataPtrExtractor_ != nullptr) { | ||
// If a Python communication hook is used, dataPtrExtractor_ will be | ||
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python | ||
// dependency to be imported. | ||
data_ptrs = dataPtrExtractor_(value); | ||
} else { | ||
// If a C++ communication hook is used, use the default extractor. | ||
data_ptrs = at::ivalue::Future::defaultDataPtrExtractor(value); | ||
} | ||
return data_ptrs; | ||
// Do nothing because the constructor already stored the events. | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is the only way we want to create an instance, can we delete the default constructor? also, where's the PR that have the
createInstance
for ivalue::Future?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only way I found to have
at::ivalue::Future::then()
be able to create an instance ofat::cuda::CUDAFuture
. In short, the a method of the superclass needs to create an instance of the subclass. It's not supposed to be used by external users, hence it'sprotected
.The superclass implementation of
createInstance
is here: https://github.com/pytorch/pytorch/pull/48505/files#diff-2f833078d12338d0ac920ab654b5c791cb1219729d6e7c97a08393b32a46d173R491