Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook. #46566

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/comm.cpp
Expand Up @@ -97,7 +97,7 @@ PythonCommHook::~PythonCommHook() {
}

c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
const GradBucket& bucket) {
GradBucket& bucket) {
py::gil_scoped_acquire acquire;

py::object py_fut = hook_(state_, bucket);
Expand Down
33 changes: 27 additions & 6 deletions torch/csrc/distributed/c10d/comm.h
@@ -1,7 +1,5 @@
#pragma once

#include <memory>

#include <ATen/ATen.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
Expand Down Expand Up @@ -47,7 +45,7 @@ class TORCH_PYTHON_API CommHookInterface {
// Once the tensors in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) = 0;
GradBucket& bucket) = 0;
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

// Returns the resulting tensors once the communication hook result is ready.
// The resulting tensors will then be copied to the grads of individual
Expand All @@ -68,16 +66,39 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {

~PythonCommHook() override;

c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) override;
c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;

std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override;

private:
// Only needed for stateful communication.
py::object state_;
// Indicates an asynchrounous communication of gradients.
py::object hook_;
};

// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
template <typename T>
class TORCH_API CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T& state) : state_(state) {}

virtual ~CppCommHookInterface() {}

std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override {
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList");

if (result.isTensor()) {
return {result.toTensor()};
}

return result.toTensorVector();
}

protected:
T state_; // Not owned.
};

} // namespace c10d
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -700,7 +700,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
if (comm_hook_ == nullptr) {
bucket.work = process_group_->allreduce(tensors);
} else {
bucket.future_work = comm_hook_->runHook(GradBucket(tensors));
GradBucket grad_bucket(tensors);
bucket.future_work = comm_hook_->runHook(grad_bucket);
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down