diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index 6d1181514aeb..f9fa2bbd7e8f 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -97,7 +97,7 @@ PythonCommHook::~PythonCommHook() { } c10::intrusive_ptr PythonCommHook::runHook( - const GradBucket& bucket) { + GradBucket& bucket) { py::gil_scoped_acquire acquire; py::object py_fut = hook_(state_, bucket); diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index bb5391a68aaa..0d2978e7221f 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -1,7 +1,5 @@ #pragma once -#include - #include #include #include @@ -47,7 +45,7 @@ class TORCH_PYTHON_API CommHookInterface { // Once the tensors in the bucket are ready, kicks off the hook asynchronously // and returns a future that holds the communication results. virtual c10::intrusive_ptr runHook( - const GradBucket& bucket) = 0; + GradBucket& bucket) = 0; // Returns the resulting tensors once the communication hook result is ready. // The resulting tensors will then be copied to the grads of individual @@ -68,16 +66,39 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { ~PythonCommHook() override; - c10::intrusive_ptr runHook( - const GradBucket& bucket) override; + c10::intrusive_ptr runHook(GradBucket& bucket) override; std::vector parseHookResult(const c10::IValue& result) override; private: // Only needed for stateful communication. py::object state_; - // Indicates an asynchrounous communication of gradients. py::object hook_; }; +// This CppCommHook interface only requires implementing runHook method that +// potentially uses a state. +template +class TORCH_API CppCommHookInterface : public CommHookInterface { + public: + explicit CppCommHookInterface(T& state) : state_(state) {} + + virtual ~CppCommHookInterface() {} + + std::vector parseHookResult(const c10::IValue& result) override { + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList"); + + if (result.isTensor()) { + return {result.toTensor()}; + } + + return result.toTensorVector(); + } + + protected: + T state_; // Not owned. +}; + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index ea1eef082a52..abb26dfaa10c 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -700,7 +700,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { if (comm_hook_ == nullptr) { bucket.work = process_group_->allreduce(tensors); } else { - bucket.future_work = comm_hook_->runHook(GradBucket(tensors)); + GradBucket grad_bucket(tensors); + bucket.future_work = comm_hook_->runHook(grad_bucket); } } }