Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Add CppCommHook subclass for supporting the C++ API of communication hook. #46566

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion torch/csrc/distributed/c10d/comm.cpp
Expand Up @@ -97,7 +97,7 @@ PythonCommHook::~PythonCommHook() {
}

c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
const GradBucket& bucket) {
GradBucket& bucket) {
py::gil_scoped_acquire acquire;

py::object py_fut = hook_(state_, bucket);
Expand Down Expand Up @@ -135,4 +135,16 @@ std::vector<at::Tensor> PythonCommHook::parseHookResult(
return result.toTensorVector();
}

std::vector<at::Tensor> CppCommHook::parseHookResult(const c10::IValue& result) {
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList");

if (result.isTensor()) {
return {result.toTensor()};
}

return result.toTensorVector();
}

} // namespace c10d
33 changes: 28 additions & 5 deletions torch/csrc/distributed/c10d/comm.h
@@ -1,6 +1,6 @@
#pragma once

#include <memory>
#include <functional>

#include <ATen/ATen.h>
#include <c10d/ProcessGroup.hpp>
Expand Down Expand Up @@ -47,7 +47,7 @@ class TORCH_API CommHookInterface {
// Once the tensors in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) = 0;
GradBucket& bucket) = 0;
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

// Returns the resulting tensors once the communication hook result is ready.
// The resulting tensors will then be copied to the grads of individual
Expand All @@ -68,16 +68,39 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {

~PythonCommHook() override;

c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) override;
c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;

std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override;

private:
// Only needed for stateful communication.
py::object state_;
// Indicates an asynchrounous communication of gradients.
py::object hook_;
};

class TORCH_API CppCommHook : public CommHookInterface {
public:
explicit CppCommHook(
std::function<c10::intrusive_ptr<
torch::jit::Future>(ProcessGroup*, GradBucket&)>& hook,
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
ProcessGroup* process_group = nullptr)
: process_group_(process_group), hook_(std::move(hook)) {}

c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override {
return hook_(process_group_, bucket);
}

std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override;

private:
// This can be a more generic state if needed.
// Note that std::optional<ProcessGroup> cannot be used, since ProcessGroup is
// an abstract class.
ProcessGroup* process_group_; // Not owned.
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
std::function<c10::intrusive_ptr<torch::jit::Future>(
ProcessGroup* process_group,
GradBucket&)>
hook_;
};

} // namespace c10d
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -700,7 +700,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
if (comm_hook_ == nullptr) {
bucket.work = process_group_->allreduce(tensors);
} else {
bucket.future_work = comm_hook_->runHook(GradBucket(tensors));
GradBucket grad_bucket(tensors);
bucket.future_work = comm_hook_->runHook(grad_bucket);
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down