Skip to content

Commit

Permalink
[Gradient Compression] Refactor CommHookInterface and PythonCommHook.
Browse files Browse the repository at this point in the history
Pull Request resolved: #46512

1. Let CommHookInterface::processFuture be a non-virtual method, which can be shared by both Python and C++ implementations.
2. Merge 1-line PythonCommHook constructor into the header for simplicity.
3. Rename processFuture method as parseFromHookResult for readability.
4. Simplify the comments.

Original PR issue: C++ DDP Communication Hook #46348
ghstack-source-id: 114558527

Differential Revision: [D24374282](https://our.internmc.facebook.com/intern/diff/D24374282/)
  • Loading branch information
wayi committed Oct 17, 2020
1 parent 997e672 commit 99e5925
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 45 deletions.
35 changes: 17 additions & 18 deletions torch/csrc/distributed/c10d/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,23 @@ void broadcast_coalesced(
}
}

PythonCommHook::PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)){};
std::vector<at::Tensor> CommHookInterface::parseFromHookResult(
const c10::IValue& result) {
TORCH_INTERNAL_ASSERT(
result.isPyObject() || result.isTensorList(),
"expected the future value is either a PyObject (if PythonCommHook is used) or TensorList (if CppCommHook is used)");

if (result.isPyObject()) {
py::gil_scoped_acquire ag;
py::object obj = torch::jit::toPyObject(result);
auto value = torch::jit::toIValue(
obj, c10::ListType::create(c10::TensorType::get()));

return value.toTensorVector();
}

return result.toTensorVector();
}

c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
const GradBucket& bucket) {
Expand All @@ -109,20 +124,4 @@ c10::intrusive_ptr<torch::jit::Future> PythonCommHook::runHook(
}
}

std::vector<at::Tensor> PythonCommHook::processFuture(
c10::IValue future_value) {
// Since we have a Python hook, future_value can be a PyObject.
if (future_value.isPyObject()) {
// We first convert it to an IValue that contains a TensorVector.
py::gil_scoped_acquire ag;
py::object obj = torch::jit::toPyObject(future_value);
auto value = torch::jit::toIValue(
obj, c10::ListType::create(c10::TensorType::get()));

return value.toTensorVector();
}

return future_value.toTensorVector();
}

} // namespace c10d
40 changes: 14 additions & 26 deletions torch/csrc/distributed/c10d/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,45 +35,33 @@ class GradBucket {
std::vector<at::Tensor> tensors_;
};

// DDP's c10d reducer allows communication hooks defined as a sub class
// of CommHookInterface. CommHookInterface is an abstract class and can
// be used to implement both Python and CPP hooks.
struct TORCH_PYTHON_API CommHookInterface {
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing `runHook` method thhat communicate gradients
// asynchronously.
class TORCH_API CommHookInterface {
public:
virtual ~CommHookInterface() {}

// runHook takes a GradBucket type bucket and passes the tensors of
// this grad bucket to hook's callback. This function is called once
// the bucket is ready. The hook can perform whatever processing is
// needed and return a Future that will hold the new value of the grad
// bucket's tensors once ready.
// Runs the registered communication hook to communicate gradients
// asynchronously, Returns a future that holds the communication results.
virtual c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) = 0;

// Once the grad bucket of Future is ready, c10d reducer will call this
// function to get the resulting tensors of the grad bucket. Then c10d
// reducer will use these tensors and copy grads to the grads of individual
// parameters.
virtual std::vector<at::Tensor> processFuture(c10::IValue future_value) = 0;
// Returns the resulting tensors once the communication hook result is ready.
std::vector<at::Tensor> parseFromHookResult(const c10::IValue& result);
};

// PythonCommHook enables registering a python hook to c10d reducer and is a
// sub class of CommHookInterface.
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
public:
// The constructor takes a state and a callable hook. Inputs are Python
// objects. The state is passed to the hook in runHook function can be used to
// maintain and update any state information that users would like to maintain
// as part of the training process. The hook can perform whatever processing
// user specified and return a Future indicating completion of any async work.
PythonCommHook(py::object state, py::object hook);
PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)) {}

~PythonCommHook() override {
py::gil_scoped_acquire ag;
state_.dec_ref();
hook_.dec_ref();
// explicitly setting PyObject* state_ and hook_ to nullptr to prevent
// py::object's dtor to decref on the PyObject again.
// Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor
// to decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
state_.ptr() = nullptr;
hook_.ptr() = nullptr;
Expand All @@ -82,10 +70,10 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
c10::intrusive_ptr<torch::jit::Future> runHook(
const GradBucket& bucket) override;

std::vector<at::Tensor> processFuture(c10::IValue future_value) override;

private:
// Only needed for stateful communication.
py::object state_;
// Indicates an asynchrounous communication of gradients.
py::object hook_;
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ void Reducer::finalize_backward() {
bucket.future_work->wait();

auto future_result =
comm_hook_->processFuture(bucket.future_work->value());
comm_hook_->parseFromHookResult(bucket.future_work->value());

for (size_t i = 0; i < future_result.size(); i++) {
auto& replica = bucket.replicas[i];
Expand Down

0 comments on commit 99e5925

Please sign in to comment.