Skip to content

Commit

Permalink
[Gradient Compression] Refactor CommHookInterface and PythonCommHook.
Browse files Browse the repository at this point in the history
Pull Request resolved: #46512

1. Merge 1-line PythonCommHook constructor into the header for simplicity.
2. Rename processFuture method as parseHookResult for readability.
3. Simplify some comments.

Original PR issue: C++ DDP Communication Hook #46348
ghstack-source-id: 115091355

Differential Revision: [D24374282](https://our.internmc.facebook.com/intern/diff/D24374282/)
  • Loading branch information
wayi committed Oct 24, 2020
1 parent ccb79f3 commit 80be339
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 42 deletions.
28 changes: 19 additions & 9 deletions torch/csrc/distributed/c10d/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,16 @@ void broadcast_coalesced(
}
}

PythonCommHook::PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)){};
PythonCommHook::~PythonCommHook() {
py::gil_scoped_acquire ag;
state_.dec_ref();
hook_.dec_ref();
// Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor
// to decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
state_.ptr() = nullptr;
hook_.ptr() = nullptr;
}

c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
const GradBucket& bucket) {
Expand All @@ -109,20 +117,22 @@ c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
}
}

std::vector<at::Tensor> PythonCommHook::processFuture(
c10::IValue future_value) {
// Since we have a Python hook, future_value can be a PyObject.
if (future_value.isPyObject()) {
// We first convert it to an IValue that contains a TensorVector.
std::vector<at::Tensor> PythonCommHook::parseHookResult(
const c10::IValue& result) {
TORCH_INTERNAL_ASSERT(
result.isPyObject() || result.isTensorList(),
"expected the hook result is either a PyObject or TensorList");

if (result.isPyObject()) {
py::gil_scoped_acquire ag;
py::object obj = torch::jit::toPyObject(future_value);
py::object obj = torch::jit::toPyObject(result);
auto value = torch::jit::toIValue(
obj, c10::ListType::create(c10::TensorType::get()));

return value.toTensorVector();
}

return future_value.toTensorVector();
return result.toTensorVector();
}

} // namespace c10d
55 changes: 23 additions & 32 deletions torch/csrc/distributed/c10d/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,57 +35,48 @@ class GradBucket {
std::vector<at::Tensor> tensors_;
};

// DDP's c10d reducer allows communication hooks defined as a sub class
// of CommHookInterface. CommHookInterface is an abstract class and can
// be used to implement both Python and CPP hooks.
struct TORCH_PYTHON_API CommHookInterface {
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook result
// into a tensor vector.
class TORCH_PYTHON_API CommHookInterface {
public:
virtual ~CommHookInterface() {}

// runHook takes a GradBucket type bucket and passes the tensors of
// this grad bucket to hook's callback. This function is called once
// the bucket is ready. The hook can perform whatever processing is
// needed and return a Future that will hold the new value of the grad
// bucket's tensors once ready.
// Passes the input grad bucket to the registered communication hook.
// Once the tensors in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) = 0;

// Once the grad bucket of Future is ready, c10d reducer will call this
// function to get the resulting tensors of the grad bucket. Then c10d
// reducer will use these tensors and copy grads to the grads of individual
// Returns the resulting tensors once the communication hook result is ready.
// The resulting tensors will then be copied to the grads of individual
// parameters.
virtual std::vector<at::Tensor> processFuture(c10::IValue future_value) = 0;
virtual std::vector<at::Tensor> parseHookResult(
const c10::IValue& result) = 0;
};

// PythonCommHook enables registering a python hook to c10d reducer and is a
// sub class of CommHookInterface.
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
public:
// The constructor takes a state and a callable hook. Inputs are Python
// objects. The state is passed to the hook in runHook function can be used to
// maintain and update any state information that users would like to maintain
// as part of the training process. The hook can perform whatever processing
// user specified and return a Future indicating completion of any async work.
PythonCommHook(py::object state, py::object hook);
// Takes a state and a callable hook. The inputs are Python objects.
// The state is passed to the hook in runHook method, and it can be used to
// maintain and update any state information during the execution of the hook.
// The hook performs user-specified processing and returns a future indicating
// asychronous communication of gradients.
PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)) {}

~PythonCommHook() override {
py::gil_scoped_acquire ag;
state_.dec_ref();
hook_.dec_ref();
// explicitly setting PyObject* state_ and hook_ to nullptr to prevent
// py::object's dtor to decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
state_.ptr() = nullptr;
hook_.ptr() = nullptr;
}
~PythonCommHook() override;

c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) override;

std::vector<at::Tensor> processFuture(c10::IValue future_value) override;
std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override;

private:
// Only needed for stateful communication.
py::object state_;
// Indicates an asynchrounous communication of gradients.
py::object hook_;
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ void Reducer::finalize_backward() {
bucket.future_work->wait();

auto future_result =
comm_hook_->processFuture(bucket.future_work->value());
comm_hook_->parseHookResult(bucket.future_work->value());

for (size_t i = 0; i < future_result.size(); i++) {
auto& replica = bucket.replicas[i];
Expand Down

0 comments on commit 80be339

Please sign in to comment.