Skip to content

Commit

Permalink
[Gradient Compression] Surface C++ comm hooks to Python API as built-…
Browse files Browse the repository at this point in the history
…in comm hooks

Pull Request resolved: #47270

This is almost same as #46959, except that in caffe2/torch/nn/parallel/distributed.py, BuiltinCommHookType should be imported conditionally, only when dist.is_available(). Otherwise, this Python enum type defined in caffe2/torch/scrc/distributed/c10d/init.cpp cannot be imported. See #47153

I tried to follow another enum type enum type ReduceOp defined in the same file, but did not work, because the C++ enum class is defined torch/lib/c10d library, but BuiltinCommHookType is defined in torch/csrc/distributed library. These two libraries are compiled in two different ways.

To avoid adding typing to distributed package, which can be a new project, I simply removed the arg type of BuiltinCommHookType in this file.

To review the diff on top of #46959, compare V1 vs Latest:
https://www.internalfb.com/diff/D24700959?src_version_fbid=270445741055617

Main Changes in V1 (#46959):
1. Implemented the Pybind part.
2. In the reducer, once the builtin_comm_hook_type is set,  a c++ comm hook instance will be created in Reducer::autograd_hook.
3. Added unit tests for the builit-in comm hooks.

Original PR issue: C++ DDP Communication Hook #46348
ghstack-source-id: 115783237

Differential Revision: [D24700959](https://our.internmc.facebook.com/intern/diff/D24700959/)
  • Loading branch information
wayi committed Nov 3, 2020
1 parent 09a5267 commit f8f7c3b
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 21 deletions.
45 changes: 44 additions & 1 deletion test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,21 @@ def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bu

return gpu_model

def _gpu_model_with_builtin_ddp_comm_hook(self, process_group, hook=None, gradient_as_bucket_view=False):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
process_group=process_group,
gradient_as_bucket_view=gradient_as_bucket_view,
)

# Register a built-in DDP communication hook if defined
if hook is not None:
gpu_model._register_builtin_comm_hook(hook)

return gpu_model

def _run_and_verify_hook(self, model, input, expected_grad):
# Run forward
output = model(input, self.rank)
Expand Down Expand Up @@ -3474,18 +3489,46 @@ def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
"""
This unit test verifies whether built-in DDP communication hooks ALLREDUCE and FP16_COMPRESS
can give the same result result with the case of no hook registered.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

for comm_hook_type in [dist.BuiltinCommHookType.ALLREDUCE, dist.BuiltinCommHookType.FP16_COMPRESS]:
# Get GPU model with the built-in allreduce communication hook.
gpu_model = self._gpu_model_with_builtin_ddp_comm_hook(
process_group, comm_hook_type, gradient_as_bucket_view)

# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_comm_hook_allreduce_hook_nccl(self):
self._test_ddp_comm_hook_allreduce_hook_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_builtin_ddp_comm_hooks_nccl(self):
self._test_builtin_ddp_comm_hooks_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self):
self._test_builtin_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
Expand Down Expand Up @@ -3603,7 +3646,7 @@ def dummy_hook(state, bucket):
model._register_comm_hook(None, dummy_hook)

with self.assertRaisesRegex(
RuntimeError, "register_comm_hook can only be called once."
RuntimeError, "register_comm_hook or register_builtin_comm_hook can only be called once."
):
model._register_comm_hook(None, dummy_hook)

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class TORCH_PYTHON_API CommHookInterface {

// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
// Still need TORCH_PYTHON_API instead of TORCH_API to support Windows platform.
template <typename T>
class TORCH_API CppCommHookInterface : public CommHookInterface {
class TORCH_PYTHON_API CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T& state) : state_(state) {}

Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/distributed/c10d/default_comm_hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@

namespace c10d {

enum class BuiltinCommHookType {
ALLREDUCE = 1,
FP16_COMPRESS = 2,
};

class AllReduceCommHook : public CppCommHookInterface<ProcessGroup*> {
public:
explicit AllReduceCommHook(ProcessGroup* state)
: CppCommHookInterface<ProcessGroup*>(state) {}

~AllReduceCommHook() override {}

c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};

class FP16CompressCommHook : public CppCommHookInterface<ProcessGroup*> {
public:
explicit FP16CompressCommHook(ProcessGroup* state)
: CppCommHookInterface<ProcessGroup*>(state) {}

~FP16CompressCommHook() override {}

c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
Expand Down
44 changes: 31 additions & 13 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,25 @@ class PythonStore : public ::c10d::Store {
}
};

// This method is called from DDP's Python API. Its inputs are
// a c10d reducer object, state, and callable comm_hook. State and
// comm_hook inputs are Python objects and this function creates a
// c10d PythonCommHook object using these inputs. It later calls
// register_comm_hook function of the reducer input to register that
// PythonCommHook object.
// Called from DDP's Python API to create a c10d Python comm hook object.
// The input state and callable comm_hook are Python objects. It later calls
// register_comm_hook function of the reducer input to register the hook.
void _register_comm_hook(
::c10d::Reducer& reducer,
py::object state,
py::object comm_hook) {
reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
};
}

// Called from DDP's Python API to create a c10d C++ comm hook.
// The input is an enum hook type. It later calls register_builtin_comm_hook
// function of the reducer input to set the hook type.
void _register_builtin_comm_hook(
::c10d::Reducer& reducer,
::c10d::BuiltinCommHookType comm_hook_type) {
reducer.register_builtin_comm_hook(comm_hook_type);
}

PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
C10_LOG_API_USAGE_ONCE("c10d.python.import");
Expand All @@ -147,12 +153,19 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {

auto module = py::handle(c10d_module).cast<py::module>();

module.def(
"_register_comm_hook",
&_register_comm_hook,
py::arg("reducer"),
py::arg("state"),
py::arg("comm_hook"));
module
.def(
"_register_comm_hook",
&_register_comm_hook,
py::arg("reducer"),
py::arg("state"),
py::arg("comm_hook"),
py::call_guard<py::gil_scoped_release>())
.def(
"_register_builtin_comm_hook",
&_register_builtin_comm_hook,
py::arg("reducer"),
py::arg("comm_hook_type"));

shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
.def(py::init<std::vector<Tensor>&>(), py::arg("tensors"))
Expand All @@ -168,6 +181,11 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
a single tensor.
)");

py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"(
An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)")
.value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE)
.value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS);

shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
.def(
py::init<
Expand Down
33 changes: 31 additions & 2 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ Reducer::Reducer(
// used for algorithms like Gradient Compression/GossipGrad. This hook can be
// registered from Python API using `register_comm_hook`. `PythonCommHook`
// enables registering a Python hook and is a subclass of `CommHookInterface`.
// `CommHookInterface` can be used to implement CPP hooks in the future.
// Additionally, there are also some built-in C++ hook implementations that can
// be specified by calling `register_builtin_comm_hook` from Python API.

Reducer::~Reducer() noexcept(false) {
// Remove all hooks on variables registered by this Reducer. This is necessary
Expand Down Expand Up @@ -1369,7 +1370,8 @@ bool Reducer::rebuild_buckets() {
// See Note [DDP Communication Hook]
void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
TORCH_CHECK(
comm_hook_ == nullptr, "register_comm_hook can only be called once.");
comm_hook_ == nullptr,
"register_comm_hook or register_builtin_comm_hook can only be called once.");
// TODO(@sinannasir): Single-process multiple-device mode support for DDP
// communication hook. Related to GH Issue #42542.
TORCH_CHECK(
Expand All @@ -1379,6 +1381,33 @@ void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
comm_hook_ = std::move(iface);
}

// See Note [DDP Communication Hook]
void Reducer::register_builtin_comm_hook(
c10d::BuiltinCommHookType comm_hook_type) {
TORCH_CHECK(
comm_hook_ == nullptr,
"register_builtin_comm_hook or register_comm_hook can only be called once.");
TORCH_CHECK(
replicas_.size() == 1,
"Communication hook does not support single-process multiple-device mode.");

switch (comm_hook_type) {
case c10d::BuiltinCommHookType::ALLREDUCE:
comm_hook_ =
std::make_unique<c10d::AllReduceCommHook>(process_group_.get());
LOG(INFO) << "Built-in communication hook ALLREDUCE is registered.";
break;
case c10d::BuiltinCommHookType::FP16_COMPRESS:
comm_hook_ =
std::make_unique<c10d::FP16CompressCommHook>(process_group_.get());
LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered.";
break;
default:
TORCH_WARN_ONCE(
"Unknown built-in DDP comm hook type is provided. No comm hook will be used.");
}
}

void Reducer::ensure_prior_reduction_finished() {
// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/c10d/comm.h>
#include <torch/csrc/distributed/c10d/default_comm_hooks.h>

namespace c10d {

Expand Down Expand Up @@ -58,8 +59,14 @@ class Reducer {
// Registers a hook to the reducer. The hook is `CommHookInterface`
// type to allow both Python and CPP hooks. This function can only
// be called once before calling backward.
// Cannot combine with the call of `register_builtin_comm_hook`.
void register_comm_hook(std::unique_ptr<CommHookInterface> iface);

// Registers a built-in C++ comm hook to the reducer. This function can only
// be called once before calling backward.
// Cannot combine with the call of `register_comm_hook`.
void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);

// Returns a vector of tensors in each bucket in sequential order.
std::vector<std::vector<at::Tensor>> get_bucket_tensors() const;

Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from functools import partial

import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -38,6 +39,7 @@ def register_ddp_comm_hook(
to the DDP model. User can specify the type of hook as an enum
``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
be passed to the model.
Uses Python comm hook implementations.
Example::
>>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
Expand Down
30 changes: 27 additions & 3 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>

namespace c10d {

Expand Down Expand Up @@ -302,8 +304,29 @@ class ProcessGroupNCCL : public ProcessGroup {

// Do not free the underlying data storage of value_ before its
// usage on futureNCCLCallbackStream_ finish.
TORCH_INTERNAL_ASSERT(record_stream_cb_);
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
if (record_stream_cb_ != nullptr) {
// If a Python communication hook is used, record_stream_cb_ will be
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python
// dependency to be imported.
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
} else {
// If a C++ communication hook is used, create and set a record stream
// callback.
TORCH_INTERNAL_ASSERT(
value_.isTensorList() || value_.isTensor(),
"the future value must be either a tensor list or a tensor.");
at::Tensor tensor;
if (value_.isTensorList()) {
const auto tensors = value_.toTensorVector();
TORCH_INTERNAL_ASSERT(
tensors.size() == 1, "expected exactly 1 tensor");
tensor = tensors[0];
} else {
tensor = value_.toTensor();
}
c10::cuda::CUDACachingAllocator::recordStream(
tensor.storage().data_ptr(), *futureNCCLCallbackStream_);
}

// Use the dedicated callback stream to run callback.
// Cannot move capture std::function in lambda, because it cannot deduce
Expand Down Expand Up @@ -558,7 +581,8 @@ class ProcessGroupNCCL : public ProcessGroup {
// This function iterates through the list of WorkNCCL objects in the
// workList_ corresponding to incomplete collectives and then aborts NCCL
// communicators associated with timed out collectives.
void abortTimedOutCollectives(std::unordered_set<std::string>& abortedCommIds);
void abortTimedOutCollectives(
std::unordered_set<std::string>& abortedCommIds);

void workCleanupLoop();

Expand Down
36 changes: 35 additions & 1 deletion torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def join(self, divide_by_initial_world_size=True, enable=True):

def _register_comm_hook(self, state: object, hook: callable):
r"""
Register a communication hook which is an enhancement that provides a
Registers a communication hook which is an enhancement that provides a
flexible hook to users where they can specify how DDP aggregates gradients
across multiple workers.
Expand Down Expand Up @@ -1060,6 +1060,40 @@ def _register_comm_hook(self, state: object, hook: callable):
self._check_comm_hook(hook)
dist._register_comm_hook(self.reducer, state, hook)

def _register_builtin_comm_hook(
self, comm_hook_type
):
r"""
Registers a built-in communication hook that specifies how DDP
aggregates gradients across multiple workers.
The built-in hooks aim to provide efficient C++ implementations for certain hooks,
which might not be as efficient if implemented in Python using a Python communication hook.
Arguments:
comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as
ALLREDUCE, FP16_COMPRESS, etc.
.. warning ::
DDP communication hook can only be registered once and should be registered
before calling backward.
.. warning ::
DDP communication hook does not support single-process multiple-device mode.
Gradbucket tensors should consist of only a single tensor.
.. warning ::
DDP communication hook is experimental and subject to change.
Example::
Below is an example of a FP16 compression where gradients are
compressed into 16-bit floating-point numbers before allreduce, and
then decompressed after allreduce.
>>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
"""
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)

def _distributed_broadcast_coalesced(
self, tensors, buffer_size, authoritative_rank=0
):
Expand Down

0 comments on commit f8f7c3b

Please sign in to comment.