diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 740118daaa6e..2efae01077f3 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -138,7 +139,7 @@ inline at::Tensor IValue::toTensor() const& { inline c10::Stream IValue::toStream() && { return c10::Stream::unpack(payload.as_int); } -inline c10::Stream IValue::toStream() const & { +inline c10::Stream IValue::toStream() const& { return c10::Stream::unpack(payload.as_int); } inline c10::intrusive_ptr IValue::toBlob() && { @@ -411,6 +412,13 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { return fut; } + // Since this file cannot import CUDA depedency, the type of the seocond arg + // in the callback is c10::Stream instead of at::cuda::CUDAStream, and + // CUDAStream is constructed on the fly. The default implementation + // is a no-op, since it does not deal with any CUDA streams. + virtual void setRecordStreamCallback( + std::function record_stream_cb) {} + // Tries to retrieve the error message from std::exception_ptr. std::string tryRetrieveErrorMessage() { TORCH_CHECK(hasError(), "No error present on the future."); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 3ca74069e893..47a3ebabe941 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3,8 +3,8 @@ #include #ifndef _WIN32 #include -#include #include +#include #endif #include @@ -292,7 +292,8 @@ They are used in specifying strategies for reduction collectives, e.g., auto store = py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>( - module, "Store", + module, + "Store", R"( Base class for all store implementations, such as the 3 provided by PyTorch distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`, @@ -494,7 +495,10 @@ Example:: >>> store.wait(["bad_key"], timedelta(seconds=10)) )"); - shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store, + shared_ptr_class_<::c10d::FileStore>( + module, + "FileStore", + store, R"( A store implementation that uses a file to store the underlying key-value pairs. @@ -514,7 +518,10 @@ Example:: .def(py::init()); #ifndef _WIN32 - shared_ptr_class_<::c10d::HashStore>(module, "HashStore", store, + shared_ptr_class_<::c10d::HashStore>( + module, + "HashStore", + store, R"( A thread-safe store implementation based on an underlying hashmap. This store can be used within the same process (for example, by other threads), but cannot be used across processes. @@ -528,7 +535,10 @@ Example:: )") .def(py::init<>()); - shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store, + shared_ptr_class_<::c10d::TCPStore>( + module, + "TCPStore", + store, R"( A TCP-based distributed key-value store implementation. The server store holds the data, while the client stores can connect to the server store over TCP and @@ -565,7 +575,10 @@ Example:: std::chrono::milliseconds(::c10d::Store::kDefaultTimeout)); #endif - shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store, + shared_ptr_class_<::c10d::PrefixStore>( + module, + "PrefixStore", + store, R"( A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`) @@ -886,12 +899,13 @@ that adds a prefix to each key inserted to the store. py::arg("interface") = ""); processGroupGloo - .def(py::init< - const std::shared_ptr<::c10d::Store>&, - int, - int, - ::c10d::ProcessGroupGloo::Options>(), - py::call_guard()) + .def( + py::init< + const std::shared_ptr<::c10d::Store>&, + int, + int, + ::c10d::ProcessGroupGloo::Options>(), + py::call_guard()) .def( py::init([](const std::shared_ptr<::c10d::Store>& store, int rank, @@ -927,42 +941,45 @@ that adds a prefix to each key inserted to the store. #endif #ifdef USE_C10D_NCCL - auto processGroupNCCL = shared_ptr_class_<::c10d::ProcessGroupNCCL>( - module, "ProcessGroupNCCL", processGroup) - .def(py::init< - const std::shared_ptr<::c10d::Store>&, - int, - int, - ::c10d::ProcessGroupNCCL::Options>(), - py::call_guard()) - .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, - int rank, - int size, - const std::chrono::milliseconds& timeout){ - ::c10d::ProcessGroupNCCL::Options options; - options.isHighPriorityStream = false; - options.opTimeout = timeout; - return std::make_shared<::c10d::ProcessGroupNCCL>( - store, rank, size, options); - }), - py::arg("store"), - py::arg("rank"), - py::arg("size"), - py::arg("timeout") = std::chrono::milliseconds( - ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), - py::call_guard()); + auto processGroupNCCL = + shared_ptr_class_<::c10d::ProcessGroupNCCL>( + module, "ProcessGroupNCCL", processGroup) + .def( + py::init< + const std::shared_ptr<::c10d::Store>&, + int, + int, + ::c10d::ProcessGroupNCCL::Options>(), + py::call_guard()) + .def( + py::init([](const std::shared_ptr<::c10d::Store>& store, + int rank, + int size, + const std::chrono::milliseconds& timeout) { + ::c10d::ProcessGroupNCCL::Options options; + options.isHighPriorityStream = false; + options.opTimeout = timeout; + return std::make_shared<::c10d::ProcessGroupNCCL>( + store, rank, size, options); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("timeout") = std::chrono::milliseconds( + ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), + py::call_guard()); py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options") .def(py::init<>()) - .def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream) - .def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout); - processGroupNCCL.def_static("_group_start", []() { - ::c10d::ProcessGroupNCCL::groupStart(); - }); - processGroupNCCL.def_static("_group_end", []() { - ::c10d::ProcessGroupNCCL::groupEnd(); - }); + .def_readwrite( + "is_high_priority", + &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream) + .def_readwrite( + "op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout); + processGroupNCCL.def_static( + "_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); }); + processGroupNCCL.def_static( + "_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); }); #endif #ifdef USE_C10D_MPI @@ -973,11 +990,11 @@ that adds a prefix to each key inserted to the store. // this function may return null. This happens if this process is not // part of a sub group that is to be created. processGroupMPI.def_static( - "create", - [](std::vector ranks) { - return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); - }, - py::call_guard()); + "create", + [](std::vector ranks) { + return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); + }, + py::call_guard()); #endif shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 09a8efa9e813..15c1cdd272b2 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -30,6 +30,11 @@ #endif #include +#include +#ifdef USE_C10D_NCCL +#include +#include +#endif #include #include @@ -113,6 +118,34 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper // Future owns a reference to the py::function in its callback // vector, but Future does not acquire GIL on destruction. auto pf = std::make_shared(std::move(cb)); + +#ifdef USE_C10D_NCCL + // This callback is only used by NCCL backend, so skip this code on other + // backends and avoid importing cuda dependency. + // By default, assume that the input value is or can be casted into a tensor + // vector that has exactly one tensor. + auto record_stream_cb = [](const at::IValue& value, + const c10::Stream& stream) { + if (value.isTensorList() || value.isPyObject()) { + std::vector tensors; + if (value.isTensorList()) { + tensors = value.toTensorVector(); + } else { + pybind11::gil_scoped_acquire gil; + py::object obj = torch::jit::toPyObject(value); + tensors = torch::jit::toIValue( + obj, c10::ListType::create(c10::TensorType::get())) + .toTensorVector(); + } + TORCH_INTERNAL_ASSERT(tensors.size() == 1, "expected exactly 1 tensor"); + at::cuda::CUDAStream cuda_stream(stream); + c10::cuda::CUDACachingAllocator::recordStream( + tensors[0].storage().data_ptr(), cuda_stream); + } + }; + fut->setRecordStreamCallback(record_stream_cb); +#endif + return std::make_shared(fut->then( // Capture a copy of the ivalue::Future instead of the `this` pointer // because the PythonFutureWrapper object could have been deleted diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 24029a4f2f69..3b52616ee3fa 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -12,6 +12,7 @@ #include #include +#include #include namespace c10d { @@ -299,6 +300,11 @@ class ProcessGroupNCCL : public ProcessGroup { auto fut = c10::make_intrusive( deviceIndex_, thenFutCudaEvents, futureNCCLCallbackStream_); + // Do not free the underlying data storage of value_ before its + // usage on futureNCCLCallbackStream_ finish. + TORCH_INTERNAL_ASSERT(record_stream_cb_); + record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap()); + // Use the dedicated callback stream to run callback. // Cannot move capture std::function in lambda, because it cannot deduce // the template type for std::function. Hence use std::bind to explicitly @@ -333,11 +339,19 @@ class ProcessGroupNCCL : public ProcessGroup { return !value_.isNone(); } + void setRecordStreamCallback( + std::function + record_stream_cb) override { + record_stream_cb_ = std::move(record_stream_cb); + } + private: at::IValue value_; c10::DeviceIndex deviceIndex_; std::shared_ptr> cudaEvents_; std::shared_ptr futureNCCLCallbackStream_; + std::function + record_stream_cb_; c10::optional error_; }; @@ -584,10 +598,10 @@ class ProcessGroupNCCL : public ProcessGroup { // // For point-to-point operations: // The key is a string of my current rank and the peer process rank. - // e.g. If process 1 and process 2 are involved in a point-to-point communication, - // the key will be "1:2" on both processes. - // Note: this is for the scenario where there is only 1 GPU per process. - // When it comes to multiple GPUs per process, this part may need to redesigned. + // e.g. If process 1 and process 2 are involved in a point-to-point + // communication, the key will be "1:2" on both processes. Note: this is for + // the scenario where there is only 1 GPU per process. When it comes to + // multiple GPUs per process, this part may need to redesigned. std::unordered_map>> devNCCLCommMap_;