From 45084ff8bf168f12f8d232c10fdd5e7f1f72d239 Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 19 Oct 2020 15:54:03 -0700 Subject: [PATCH 1/8] [Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook. Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 880605e134c3..b8516063c48b 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include @@ -73,8 +75,28 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { private: // Only needed for stateful communication. py::object state_; - // Indicates an asynchrounous communication of gradients. py::object hook_; }; +class TORCH_API CppCommHook : public CommHookInterface { + public: + explicit CppCommHook( + std::function(const GradBucket&, std::any*)>& hook, + std::unique_ptr state = nullptr) + : state_(std::move(state)), hook_(std::move(hook)) {} + + c10::intrusive_ptr runHook( + const GradBucket& bucket) override { + return hook_(bucket, state_.get()); + } + + private: + std::unique_ptr state_; + std::function( + const GradBucket&, + std::any* state)> + hook_; +}; + } // namespace c10d From 2c487dcc0cc3de145ab46f2276ab8edfc25658f6 Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 19 Oct 2020 16:42:45 -0700 Subject: [PATCH 2/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index b8516063c48b..26688e5e0fb5 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -1,8 +1,6 @@ #pragma once -#include #include -#include #include #include @@ -81,21 +79,22 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { class TORCH_API CppCommHook : public CommHookInterface { public: explicit CppCommHook( - std::function(const GradBucket&, std::any*)>& hook, - std::unique_ptr state = nullptr) - : state_(std::move(state)), hook_(std::move(hook)) {} + std::function( + const GradBucket&, ProcessGroup*)>& hook, + ProcessGroup* process_group = nullptr) + : process_group_(process_group), hook_(std::move(hook)) {} c10::intrusive_ptr runHook( const GradBucket& bucket) override { - return hook_(bucket, state_.get()); + return hook_(bucket, process_group_); } private: - std::unique_ptr state_; + // This can be a more generic state if needed. + ProcessGroup* process_group_; // Not owned. std::function( const GradBucket&, - std::any* state)> + ProcessGroup* process_group)> hook_; }; From 69ba4c42a6d7a254dcb3044d6b508306dedd11fe Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 19 Oct 2020 16:59:34 -0700 Subject: [PATCH 3/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.cpp | 2 +- torch/csrc/distributed/c10d/comm.h | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index e3b11d524da4..fe454444004c 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -104,7 +104,7 @@ std::vector CommHookInterface::parseFromHookResult( } c10::intrusive_ptr PythonCommHook::runHook( - const GradBucket& bucket) { + GradBucket& bucket) { py::gil_scoped_acquire acquire; py::object py_fut = hook_(state_, bucket); diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 26688e5e0fb5..3f9d96567b5b 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -45,7 +45,7 @@ class TORCH_API CommHookInterface { // Runs the registered communication hook to communicate gradients // asynchronously, Returns a future that holds the communication results. virtual c10::intrusive_ptr runHook( - const GradBucket& bucket) = 0; + GradBucket& bucket) = 0; // Returns the resulting tensors once the communication hook result is ready. std::vector parseFromHookResult(const c10::IValue& result); @@ -67,8 +67,7 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { hook_.ptr() = nullptr; } - c10::intrusive_ptr runHook( - const GradBucket& bucket) override; + c10::intrusive_ptr runHook(GradBucket& bucket) override; private: // Only needed for stateful communication. @@ -79,22 +78,21 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { class TORCH_API CppCommHook : public CommHookInterface { public: explicit CppCommHook( - std::function( - const GradBucket&, ProcessGroup*)>& hook, + std::function(ProcessGroup*, GradBucket&)>& hook, ProcessGroup* process_group = nullptr) : process_group_(process_group), hook_(std::move(hook)) {} - c10::intrusive_ptr runHook( - const GradBucket& bucket) override { - return hook_(bucket, process_group_); + c10::intrusive_ptr runHook(GradBucket& bucket) override { + return hook_(process_group_, bucket); } private: // This can be a more generic state if needed. - ProcessGroup* process_group_; // Not owned. + ProcessGroup* process_group_; // Not owned. std::function( - const GradBucket&, - ProcessGroup* process_group)> + ProcessGroup* process_group, + GradBucket&)> hook_; }; From 4a28856bfef8170f171603d2ddfba704412479d2 Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 19 Oct 2020 21:49:01 -0700 Subject: [PATCH 4/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 2 ++ torch/csrc/distributed/c10d/reducer.cpp | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 3f9d96567b5b..2f757357fca0 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -89,6 +89,8 @@ class TORCH_API CppCommHook : public CommHookInterface { private: // This can be a more generic state if needed. + // Note that std::optional cannot be used, since ProcessGroup is + // an abstract class. ProcessGroup* process_group_; // Not owned. std::function( ProcessGroup* process_group, diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index a5d52b4f6c21..ba81f6b776e0 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -698,7 +698,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { if (comm_hook_ == nullptr) { bucket.work = process_group_->allreduce(tensors); } else { - bucket.future_work = comm_hook_->runHook(GradBucket(tensors)); + GradBucket grad_bucket(tensors); + bucket.future_work = comm_hook_->runHook(grad_bucket); } } } From 50351bfd33c474299135694d1fd0f3cf1430a52c Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 26 Oct 2020 15:25:11 -0700 Subject: [PATCH 5/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 9c398b7ae5b4..c7493ee5db5b 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -85,6 +85,8 @@ class TORCH_API CppCommHookInterface : public CommHookInterface { public: explicit CppCommHookInterface(T* state = nullptr) : state_(state) {} + ~CppCommHookInterface() override {} + std::vector parseHookResult(const c10::IValue& result) override { TORCH_INTERNAL_ASSERT( result.isTensor() || result.isTensorList(), @@ -97,7 +99,7 @@ class TORCH_API CppCommHookInterface : public CommHookInterface { return result.toTensorVector(); } - private: + protected: T* state_; // Not owned. }; From 7b5b7c01e9cf08bd71ba87e846ddd4ed7345ef80 Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 26 Oct 2020 15:36:45 -0700 Subject: [PATCH 6/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index c7493ee5db5b..5a16ffda2703 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -85,7 +85,7 @@ class TORCH_API CppCommHookInterface : public CommHookInterface { public: explicit CppCommHookInterface(T* state = nullptr) : state_(state) {} - ~CppCommHookInterface() override {} + virtual ~CppCommHookInterface() {} std::vector parseHookResult(const c10::IValue& result) override { TORCH_INTERNAL_ASSERT( From 19f27e7de3a4e482ec9c2b8be3cc65e63da4aebc Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 26 Oct 2020 16:14:04 -0700 Subject: [PATCH 7/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 5a16ffda2703..3b4d3675e6cc 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -83,7 +83,7 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { template class TORCH_API CppCommHookInterface : public CommHookInterface { public: - explicit CppCommHookInterface(T* state = nullptr) : state_(state) {} + explicit CppCommHookInterface(T& state) : state_(state) {} virtual ~CppCommHookInterface() {} @@ -100,7 +100,7 @@ class TORCH_API CppCommHookInterface : public CommHookInterface { } protected: - T* state_; // Not owned. + T state_; // Not owned. }; } // namespace c10d From 2746501fbe889148dd731f5f8ae0a913633fc450 Mon Sep 17 00:00:00 2001 From: wayi Date: Mon, 26 Oct 2020 17:12:35 -0700 Subject: [PATCH 8/8] Update on "[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook." Only provides an interface. Some built-in implementations will be provided in a follow-up commit. riginal PR issue: C++ DDP Communication Hook https://github.com/pytorch/pytorch/issues/46348 Differential Revision: [D24379460](https://our.internmc.facebook.com/intern/diff/D24379460/) [ghstack-poisoned] --- torch/csrc/distributed/c10d/comm.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 3b4d3675e6cc..0d2978e7221f 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -1,7 +1,5 @@ #pragma once -#include - #include #include #include