Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Surface C++ comm hooks to Python API as built-in comm hooks #46959

Closed
wants to merge 7 commits into from
45 changes: 44 additions & 1 deletion test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,21 @@ def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bu

return gpu_model

def _gpu_model_with_builtin_ddp_comm_hook(self, process_group, hook=None, gradient_as_bucket_view=False):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
process_group=process_group,
gradient_as_bucket_view=gradient_as_bucket_view,
)

# Register a built-in DDP communication hook if defined
if hook is not None:
gpu_model._register_builtin_comm_hook(hook)

return gpu_model

def _run_and_verify_hook(self, model, input, expected_grad):
# Run forward
output = model(input, self.rank)
Expand Down Expand Up @@ -3474,18 +3489,46 @@ def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
"""
This unit test verifies whether built-in DDP communication hooks ALLREDUCE and FP16_COMPRESS
can give the same result result with the case of no hook registered.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

for comm_hook_type in [dist.BuiltinCommHookType.ALLREDUCE, dist.BuiltinCommHookType.FP16_COMPRESS]:
# Get GPU model with the built-in allreduce communication hook.
gpu_model = self._gpu_model_with_builtin_ddp_comm_hook(
process_group, comm_hook_type, gradient_as_bucket_view)

# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_comm_hook_allreduce_hook_nccl(self):
self._test_ddp_comm_hook_allreduce_hook_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_builtin_ddp_comm_hooks_nccl(self):
self._test_builtin_ddp_comm_hooks_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self):
self._test_builtin_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
Expand Down Expand Up @@ -3603,7 +3646,7 @@ def dummy_hook(state, bucket):
model._register_comm_hook(None, dummy_hook)

with self.assertRaisesRegex(
RuntimeError, "register_comm_hook can only be called once."
RuntimeError, "register_comm_hook or register_builtin_comm_hook can only be called once."
):
model._register_comm_hook(None, dummy_hook)

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {

// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
// Still need TORCH_PYTHON_API instead of TORCH_API to support Windows platform.
template <typename T>
class TORCH_API CppCommHookInterface : public CommHookInterface {
class TORCH_PYTHON_API CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T& state) : state_(state) {}

Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/distributed/c10d/default_comm_hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@

namespace c10d {

enum class BuiltinCommHookType {
ALLREDUCE = 1,
FP16_COMPRESS = 2,
};

class AllReduceCommHook : public CppCommHookInterface<ProcessGroup*> {
public:
explicit AllReduceCommHook(ProcessGroup* state)
: CppCommHookInterface<ProcessGroup*>(state) {}

~AllReduceCommHook() override {}

c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;
};

class FP16CompressCommHook : public CppCommHookInterface<ProcessGroup*> {
public:
explicit FP16CompressCommHook(ProcessGroup* state)
: CppCommHookInterface<ProcessGroup*>(state) {}

~FP16CompressCommHook() override {}

c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;
Expand Down
45 changes: 32 additions & 13 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <torch/csrc/Exceptions.h>
#include <torch/csrc/distributed/c10d/comm.h>
#include <torch/csrc/distributed/c10d/default_comm_hooks.h>
#include <torch/csrc/distributed/c10d/reducer.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/object_ptr.h>
Expand Down Expand Up @@ -123,19 +124,25 @@ class PythonStore : public ::c10d::Store {
}
};

// This method is called from DDP's Python API. Its inputs are
// a c10d reducer object, state, and callable comm_hook. State and
// comm_hook inputs are Python objects and this function creates a
// c10d PythonCommHook object using these inputs. It later calls
// register_comm_hook function of the reducer input to register that
// PythonCommHook object.
// Called from DDP's Python API to create a c10d Python comm hook object.
// The input state and callable comm_hook are Python objects. It later calls
// register_comm_hook function of the reducer input to register the hook.
void _register_comm_hook(
::c10d::Reducer& reducer,
py::object state,
py::object comm_hook) {
reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
};
}

// Called from DDP's Python API to create a c10d C++ comm hook.
// The input is an enum hook type. It later calls register_builtin_comm_hook
// function of the reducer input to set the hook type.
void _register_builtin_comm_hook(
::c10d::Reducer& reducer,
::c10d::BuiltinCommHookType comm_hook_type) {
reducer.register_builtin_comm_hook(comm_hook_type);
}

PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
C10_LOG_API_USAGE_ONCE("c10d.python.import");
Expand All @@ -146,12 +153,19 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {

auto module = py::handle(c10d_module).cast<py::module>();

module.def(
"_register_comm_hook",
&_register_comm_hook,
py::arg("reducer"),
py::arg("state"),
py::arg("comm_hook"));
module
.def(
"_register_comm_hook",
&_register_comm_hook,
py::arg("reducer"),
py::arg("state"),
py::arg("comm_hook"),
py::call_guard<py::gil_scoped_release>())
.def(
"_register_builtin_comm_hook",
&_register_builtin_comm_hook,
py::arg("reducer"),
py::arg("comm_hook_type"));

shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
.def(py::init<std::vector<Tensor>&>(), py::arg("tensors"))
Expand All @@ -167,6 +181,11 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
a single tensor.
)");

py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"(
An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)")
.value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE)
.value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS);

shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
.def(
py::init<
Expand Down
33 changes: 31 additions & 2 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ Reducer::Reducer(
// used for algorithms like Gradient Compression/GossipGrad. This hook can be
// registered from Python API using `register_comm_hook`. `PythonCommHook`
// enables registering a Python hook and is a subclass of `CommHookInterface`.
// `CommHookInterface` can be used to implement CPP hooks in the future.
// Additionally, there are also some built-in C++ hook implementations that can
// be specified by calling `register_builtin_comm_hook` from Python API.

Reducer::~Reducer() noexcept(false) {
// Remove all hooks on variables registered by this Reducer. This is necessary
Expand Down Expand Up @@ -1357,7 +1358,8 @@ bool Reducer::rebuild_buckets() {
// See Note [DDP Communication Hook]
void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
TORCH_CHECK(
comm_hook_ == nullptr, "register_comm_hook can only be called once.");
comm_hook_ == nullptr,
"register_comm_hook or register_builtin_comm_hook can only be called once.");
// TODO(@sinannasir): Single-process multiple-device mode support for DDP
// communication hook. Related to GH Issue #42542.
TORCH_CHECK(
Expand All @@ -1367,6 +1369,33 @@ void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
comm_hook_ = std::move(iface);
}

// See Note [DDP Communication Hook]
void Reducer::register_builtin_comm_hook(
c10d::BuiltinCommHookType comm_hook_type) {
TORCH_CHECK(
comm_hook_ == nullptr,
"register_builtin_comm_hook or register_comm_hook can only be called once.");
TORCH_CHECK(
replicas_.size() == 1,
"Communication hook does not support single-process multiple-device mode.");

switch (comm_hook_type) {
case c10d::BuiltinCommHookType::ALLREDUCE:
comm_hook_ =
std::make_unique<c10d::AllReduceCommHook>(process_group_.get());
LOG(INFO) << "Built-in communication hook ALLREDUCE is registered.";
break;
case c10d::BuiltinCommHookType::FP16_COMPRESS:
comm_hook_ =
std::make_unique<c10d::FP16CompressCommHook>(process_group_.get());
LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered.";
break;
default:
TORCH_WARN_ONCE(
"Unknown built-in DDP comm hook type is provided. No comm hook will be used.");
}
}

void Reducer::ensure_prior_reduction_finished() {
// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/c10d/comm.h>
#include <torch/csrc/distributed/c10d/default_comm_hooks.h>

namespace c10d {

Expand Down Expand Up @@ -58,8 +59,14 @@ class Reducer {
// Registers a hook to the reducer. The hook is `CommHookInterface`
// type to allow both Python and CPP hooks. This function can only
// be called once before calling backward.
// Cannot combine with the call of `register_builtin_comm_hook`.
void register_comm_hook(std::unique_ptr<CommHookInterface> iface);

// Registers a built-in C++ comm hook to the reducer. This function can only
// be called once before calling backward.
// Cannot combine with the call of `register_comm_hook`.
void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);

// Returns a vector of tensors in each bucket in sequential order.
std::vector<std::vector<at::Tensor>> get_bucket_tensors() const;

Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from functools import partial

import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -38,6 +39,7 @@ def register_ddp_comm_hook(
to the DDP model. User can specify the type of hook as an enum
``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
be passed to the model.
Uses Python comm hook implementations.

Example::
>>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
Expand Down
30 changes: 27 additions & 3 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>

namespace c10d {

Expand Down Expand Up @@ -302,8 +304,29 @@ class ProcessGroupNCCL : public ProcessGroup {

// Do not free the underlying data storage of value_ before its
// usage on futureNCCLCallbackStream_ finish.
TORCH_INTERNAL_ASSERT(record_stream_cb_);
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
if (record_stream_cb_ != nullptr) {
// If a Python communication hook is used, record_stream_cb_ will be
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python
// dependency to be imported.
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
} else {
// If a C++ communication hook is used, create and set a record stream
// callback.
TORCH_INTERNAL_ASSERT(
value_.isTensorList() || value_.isTensor(),
"the future value must be either a tensor list or a tensor.");
at::Tensor tensor;
if (value_.isTensorList()) {
const auto tensors = value_.toTensorVector();
TORCH_INTERNAL_ASSERT(
tensors.size() == 1, "expected exactly 1 tensor");
tensor = tensors[0];
} else {
tensor = value_.toTensor();
}
c10::cuda::CUDACachingAllocator::recordStream(
tensor.storage().data_ptr(), *futureNCCLCallbackStream_);
}

// Use the dedicated callback stream to run callback.
// Cannot move capture std::function in lambda, because it cannot deduce
Expand Down Expand Up @@ -558,7 +581,8 @@ class ProcessGroupNCCL : public ProcessGroup {
// This function iterates through the list of WorkNCCL objects in the
// workList_ corresponding to incomplete collectives and then aborts NCCL
// communicators associated with timed out collectives.
void abortTimedOutCollectives(std::unordered_set<std::string>& abortedCommIds);
void abortTimedOutCollectives(
std::unordered_set<std::string>& abortedCommIds);

void workCleanupLoop();

Expand Down
36 changes: 35 additions & 1 deletion torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def join(self, divide_by_initial_world_size=True, enable=True):

def _register_comm_hook(self, state: object, hook: callable):
r"""
Register a communication hook which is an enhancement that provides a
Registers a communication hook which is an enhancement that provides a
flexible hook to users where they can specify how DDP aggregates gradients
across multiple workers.

Expand Down Expand Up @@ -1060,6 +1060,40 @@ def _register_comm_hook(self, state: object, hook: callable):
self._check_comm_hook(hook)
dist._register_comm_hook(self.reducer, state, hook)

def _register_builtin_comm_hook(
self, comm_hook_type: dist.BuiltinCommHookType
):
r"""
Registers a built-in communication hook that specifies how DDP
aggregates gradients across multiple workers.
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
The built-in hooks aim to provide efficient C++ implementations for certain hooks,
which might not be as efficient if implemented in Python using a Python communication hook.

Arguments:
comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as
ALLREDUCE, FP16_COMPRESS, etc.

.. warning ::
DDP communication hook can only be registered once and should be registered
before calling backward.

.. warning ::
DDP communication hook does not support single-process multiple-device mode.
Gradbucket tensors should consist of only a single tensor.

.. warning ::
DDP communication hook is experimental and subject to change.

Example::
Below is an example of a FP16 compression where gradients are
compressed into 16-bit floating-point numbers before allreduce, and
then decompressed after allreduce.

>>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)

"""
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)

def _distributed_broadcast_coalesced(
self, tensors, buffer_size, authoritative_rank=0
):
Expand Down