Skip to content

Commit

Permalink
[Gradient Compression] Refactor CommHookInterface and PythonCommHook.
Browse files Browse the repository at this point in the history
Pull Request resolved: #46512

1. Merge 1-line PythonCommHook constructor into the header for simplicity.
2. Move the implementation of PythonCommHook destructor from the header file to cpp file.
3. Rename processFuture method as parseHookResult for readability.
4. Simplify some comments.

Original PR issue: C++ DDP Communication Hook #46348
ghstack-source-id: 115161086

Differential Revision: [D24374282](https://our.internmc.facebook.com/intern/diff/D24374282/)
  • Loading branch information
wayi committed Oct 26, 2020
1 parent b61671c commit 35fa56e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 42 deletions.
28 changes: 19 additions & 9 deletions torch/csrc/distributed/c10d/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,16 @@ void broadcast_coalesced(
}
}

PythonCommHook::PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)){};
PythonCommHook::~PythonCommHook() {
py::gil_scoped_acquire ag;
state_.dec_ref();
hook_.dec_ref();
// Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor
// to decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
state_.ptr() = nullptr;
hook_.ptr() = nullptr;
}

c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
const GradBucket& bucket) {
Expand All @@ -109,20 +117,22 @@ c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
}
}

std::vector<at::Tensor> PythonCommHook::processFuture(
c10::IValue future_value) {
// Since we have a Python hook, future_value can be a PyObject.
if (future_value.isPyObject()) {
// We first convert it to an IValue that contains a TensorVector.
std::vector<at::Tensor> PythonCommHook::parseHookResult(
const c10::IValue& result) {
TORCH_INTERNAL_ASSERT(
result.isPyObject() || result.isTensorList(),
"expected the hook result is either a PyObject or TensorList");

if (result.isPyObject()) {
py::gil_scoped_acquire ag;
py::object obj = torch::jit::toPyObject(future_value);
py::object obj = torch::jit::toPyObject(result);
auto value = torch::jit::toIValue(
obj, c10::ListType::create(c10::TensorType::get()));

return value.toTensorVector();
}

return future_value.toTensorVector();
return result.toTensorVector();
}

} // namespace c10d
55 changes: 23 additions & 32 deletions torch/csrc/distributed/c10d/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,57 +35,48 @@ class GradBucket {
std::vector<at::Tensor> tensors_;
};

// DDP's c10d reducer allows communication hooks defined as a sub class
// of CommHookInterface. CommHookInterface is an abstract class and can
// be used to implement both Python and CPP hooks.
struct TORCH_PYTHON_API CommHookInterface {
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook result
// into a tensor vector.
class TORCH_PYTHON_API CommHookInterface {
public:
virtual ~CommHookInterface() {}

// runHook takes a GradBucket type bucket and passes the tensors of
// this grad bucket to hook's callback. This function is called once
// the bucket is ready. The hook can perform whatever processing is
// needed and return a Future that will hold the new value of the grad
// bucket's tensors once ready.
// Passes the input grad bucket to the registered communication hook.
// Once the tensors in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) = 0;

// Once the grad bucket of Future is ready, c10d reducer will call this
// function to get the resulting tensors of the grad bucket. Then c10d
// reducer will use these tensors and copy grads to the grads of individual
// Returns the resulting tensors once the communication hook result is ready.
// The resulting tensors will then be copied to the grads of individual
// parameters.
virtual std::vector<at::Tensor> processFuture(c10::IValue future_value) = 0;
virtual std::vector<at::Tensor> parseHookResult(
const c10::IValue& result) = 0;
};

// PythonCommHook enables registering a python hook to c10d reducer and is a
// sub class of CommHookInterface.
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
public:
// The constructor takes a state and a callable hook. Inputs are Python
// objects. The state is passed to the hook in runHook function can be used to
// maintain and update any state information that users would like to maintain
// as part of the training process. The hook can perform whatever processing
// user specified and return a Future indicating completion of any async work.
PythonCommHook(py::object state, py::object hook);
// Takes a state and a callable hook. The inputs are Python objects.
// The state is passed to the hook in runHook method, and it can be used to
// maintain and update any state information during the execution of the hook.
// The hook performs user-specified processing and returns a future indicating
// asychronous communication of gradients.
PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)) {}

~PythonCommHook() override {
py::gil_scoped_acquire ag;
state_.dec_ref();
hook_.dec_ref();
// explicitly setting PyObject* state_ and hook_ to nullptr to prevent
// py::object's dtor to decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
state_.ptr() = nullptr;
hook_.ptr() = nullptr;
}
~PythonCommHook() override;

c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) override;

std::vector<at::Tensor> processFuture(c10::IValue future_value) override;
std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override;

private:
// Only needed for stateful communication.
py::object state_;
// Indicates an asynchrounous communication of gradients.
py::object hook_;
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ void Reducer::finalize_backward() {
bucket.future_work->wait();

auto future_result =
comm_hook_->processFuture(bucket.future_work->value());
comm_hook_->parseHookResult(bucket.future_work->value());

for (size_t i = 0; i < future_result.size(); i++) {
auto& replica = bucket.replicas[i];
Expand Down

0 comments on commit 35fa56e

Please sign in to comment.