Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch][PR] Record FutureNCCL callback stream on CUDA caching allocator #45318

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -10,6 +10,7 @@
#include <ATen/core/qualified_name.h>
#include <ATen/core/rref_interface.h>
#include <c10/core/Scalar.h>
#include <c10/core/Stream.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/intrusive_ptr.h>
Expand Down Expand Up @@ -138,7 +139,7 @@ inline at::Tensor IValue::toTensor() const& {
inline c10::Stream IValue::toStream() && {
return c10::Stream::unpack(payload.as_int);
}
inline c10::Stream IValue::toStream() const & {
inline c10::Stream IValue::toStream() const& {
return c10::Stream::unpack(payload.as_int);
}
inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() && {
Expand Down Expand Up @@ -411,6 +412,13 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
return fut;
}

// Since this file cannot import CUDA depedency, the type of the seocond arg
// in the callback is c10::Stream instead of at::cuda::CUDAStream, and
// CUDAStream is constructed on the fly. The default implementation
// is a no-op, since it does not deal with any CUDA streams.
virtual void setRecordStreamCallback(
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
std::function<void(const at::IValue&, const c10::Stream&)> record_stream_cb) {}

// Tries to retrieve the error message from std::exception_ptr.
std::string tryRetrieveErrorMessage() {
TORCH_CHECK(hasError(), "No error present on the future.");
Expand Down
117 changes: 67 additions & 50 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -3,8 +3,8 @@
#include <c10d/FileStore.hpp>
#ifndef _WIN32
#include <c10d/HashStore.hpp>
#include <c10d/TCPStore.hpp>
#include <c10d/ProcessGroupRoundRobin.hpp>
#include <c10d/TCPStore.hpp>
#endif
#include <c10d/ProcessGroup.hpp>

Expand Down Expand Up @@ -292,7 +292,8 @@ They are used in specifying strategies for reduction collectives, e.g.,

auto store =
py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>(
module, "Store",
module,
"Store",
R"(
Base class for all store implementations, such as the 3 provided by PyTorch
distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`,
Expand Down Expand Up @@ -494,7 +495,10 @@ Example::
>>> store.wait(["bad_key"], timedelta(seconds=10))
)");

shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store,
shared_ptr_class_<::c10d::FileStore>(
module,
"FileStore",
store,
R"(
A store implementation that uses a file to store the underlying key-value pairs.

Expand All @@ -514,7 +518,10 @@ Example::
.def(py::init<const std::string&, int>());

#ifndef _WIN32
shared_ptr_class_<::c10d::HashStore>(module, "HashStore", store,
shared_ptr_class_<::c10d::HashStore>(
module,
"HashStore",
store,
R"(
A thread-safe store implementation based on an underlying hashmap. This store can be used
within the same process (for example, by other threads), but cannot be used across processes.
Expand All @@ -528,7 +535,10 @@ Example::
)")
.def(py::init<>());

shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store,
shared_ptr_class_<::c10d::TCPStore>(
module,
"TCPStore",
store,
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
R"(
A TCP-based distributed key-value store implementation. The server store holds
the data, while the client stores can connect to the server store over TCP and
Expand Down Expand Up @@ -565,7 +575,10 @@ Example::
std::chrono::milliseconds(::c10d::Store::kDefaultTimeout));
#endif

shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store,
shared_ptr_class_<::c10d::PrefixStore>(
module,
"PrefixStore",
store,
R"(
A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`,
:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`)
Expand Down Expand Up @@ -886,12 +899,13 @@ that adds a prefix to each key inserted to the store.
py::arg("interface") = "");

processGroupGloo
.def(py::init<
const std::shared_ptr<::c10d::Store>&,
int,
int,
::c10d::ProcessGroupGloo::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init<
const std::shared_ptr<::c10d::Store>&,
int,
int,
::c10d::ProcessGroupGloo::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const std::shared_ptr<::c10d::Store>& store,
int rank,
Expand Down Expand Up @@ -927,42 +941,45 @@ that adds a prefix to each key inserted to the store.
#endif

#ifdef USE_C10D_NCCL
auto processGroupNCCL = shared_ptr_class_<::c10d::ProcessGroupNCCL>(
module, "ProcessGroupNCCL", processGroup)
.def(py::init<
const std::shared_ptr<::c10d::Store>&,
int,
int,
::c10d::ProcessGroupNCCL::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const std::shared_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::milliseconds& timeout){
::c10d::ProcessGroupNCCL::Options options;
options.isHighPriorityStream = false;
options.opTimeout = timeout;
return std::make_shared<::c10d::ProcessGroupNCCL>(
store, rank, size, options);
}),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::arg("timeout") = std::chrono::milliseconds(
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis),
py::call_guard<py::gil_scoped_release>());
auto processGroupNCCL =
shared_ptr_class_<::c10d::ProcessGroupNCCL>(
module, "ProcessGroupNCCL", processGroup)
.def(
py::init<
const std::shared_ptr<::c10d::Store>&,
int,
int,
::c10d::ProcessGroupNCCL::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const std::shared_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::milliseconds& timeout) {
::c10d::ProcessGroupNCCL::Options options;
options.isHighPriorityStream = false;
options.opTimeout = timeout;
return std::make_shared<::c10d::ProcessGroupNCCL>(
store, rank, size, options);
}),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::arg("timeout") = std::chrono::milliseconds(
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis),
py::call_guard<py::gil_scoped_release>());

py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options")
.def(py::init<>())
.def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
.def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
processGroupNCCL.def_static("_group_start", []() {
::c10d::ProcessGroupNCCL::groupStart();
});
processGroupNCCL.def_static("_group_end", []() {
::c10d::ProcessGroupNCCL::groupEnd();
});
.def_readwrite(
"is_high_priority",
&::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
.def_readwrite(
"op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
processGroupNCCL.def_static(
"_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); });
processGroupNCCL.def_static(
"_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); });
#endif

#ifdef USE_C10D_MPI
Expand All @@ -973,11 +990,11 @@ that adds a prefix to each key inserted to the store.
// this function may return null. This happens if this process is not
// part of a sub group that is to be created.
processGroupMPI.def_static(
"create",
[](std::vector<int> ranks) {
return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
},
py::call_guard<py::gil_scoped_release>());
"create",
[](std::vector<int> ranks) {
return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
},
py::call_guard<py::gil_scoped_release>());
#endif

shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
Expand Down
33 changes: 33 additions & 0 deletions torch/csrc/jit/python/pybind_utils.h
Expand Up @@ -30,6 +30,11 @@
#endif

#include <ATen/core/function_schema.h>
#include <c10/core/Stream.h>
#ifdef USE_C10D_NCCL
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
#endif
#include <c10/util/Exception.h>

#include <algorithm>
Expand Down Expand Up @@ -113,6 +118,34 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
// Future owns a reference to the py::function in its callback
// vector, but Future does not acquire GIL on destruction.
auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));

#ifdef USE_C10D_NCCL
// This callback is only used by NCCL backend, so skip this code on other
// backends and avoid importing cuda dependency.
// By default, assume that the input value is or can be casted into a tensor
// vector that has exactly one tensor.
auto record_stream_cb = [](const at::IValue& value,
const c10::Stream& stream) {
if (value.isTensorList() || value.isPyObject()) {
std::vector<at::Tensor> tensors;
if (value.isTensorList()) {
tensors = value.toTensorVector();
} else {
pybind11::gil_scoped_acquire gil;
py::object obj = torch::jit::toPyObject(value);
tensors = torch::jit::toIValue(
obj, c10::ListType::create(c10::TensorType::get()))
.toTensorVector();
}
TORCH_INTERNAL_ASSERT(tensors.size() == 1, "expected exactly 1 tensor");
at::cuda::CUDAStream cuda_stream(stream);
c10::cuda::CUDACachingAllocator::recordStream(
tensors[0].storage().data_ptr(), cuda_stream);
}
};
fut->setRecordStreamCallback(record_stream_cb);
#endif

return std::make_shared<jit::PythonFutureWrapper>(fut->then(
// Capture a copy of the ivalue::Future instead of the `this` pointer
// because the PythonFutureWrapper object could have been deleted
Expand Down
22 changes: 18 additions & 4 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -12,6 +12,7 @@

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>

namespace c10d {
Expand Down Expand Up @@ -299,6 +300,11 @@ class ProcessGroupNCCL : public ProcessGroup {
auto fut = c10::make_intrusive<FutureNCCL>(
deviceIndex_, thenFutCudaEvents, futureNCCLCallbackStream_);

// Do not free the underlying data storage of value_ before its
// usage on futureNCCLCallbackStream_ finish.
TORCH_INTERNAL_ASSERT(record_stream_cb_);
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());

// Use the dedicated callback stream to run callback.
// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
Expand Down Expand Up @@ -333,11 +339,19 @@ class ProcessGroupNCCL : public ProcessGroup {
return !value_.isNone();
}

void setRecordStreamCallback(
std::function<void(const at::IValue&, const c10::Stream&)>
record_stream_cb) override {
record_stream_cb_ = std::move(record_stream_cb);
}

private:
at::IValue value_;
c10::DeviceIndex deviceIndex_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
std::shared_ptr<at::cuda::CUDAStream> futureNCCLCallbackStream_;
std::function<void(const at::IValue&, const c10::Stream&)>
record_stream_cb_;
c10::optional<FutureError> error_;
};

Expand Down Expand Up @@ -584,10 +598,10 @@ class ProcessGroupNCCL : public ProcessGroup {
//
// For point-to-point operations:
// The key is a string of my current rank and the peer process rank.
// e.g. If process 1 and process 2 are involved in a point-to-point communication,
// the key will be "1:2" on both processes.
// Note: this is for the scenario where there is only 1 GPU per process.
// When it comes to multiple GPUs per process, this part may need to redesigned.
// e.g. If process 1 and process 2 are involved in a point-to-point
// communication, the key will be "1:2" on both processes. Note: this is for
// the scenario where there is only 1 GPU per process. When it comes to
// multiple GPUs per process, this part may need to redesigned.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;

Expand Down