Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Surface C++ comm hooks to Python API as built-in comm hooks #46959

Closed
wants to merge 7 commits into from
42 changes: 42 additions & 0 deletions test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,21 @@ def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bu

return gpu_model

def _gpu_model_with_builtin_ddp_comm_hook(self, process_group, hook=None, gradient_as_bucket_view=False):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
process_group=process_group,
gradient_as_bucket_view=gradient_as_bucket_view,
)

# Register a built-in DDP communication hook if defined
if hook is not None:
gpu_model._register_builtin_comm_hook(hook)

return gpu_model

def _run_and_verify_hook(self, model, input, expected_grad):
# Run forward
output = model(input, self.rank)
Expand Down Expand Up @@ -3474,18 +3489,45 @@ def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

def _test_builtin_ddp_comm_hook_allreduce_hook_nccl(self, gradient_as_bucket_view=False):
"""
This unit test verifies whether a built-in DDP communication hook that just calls
allreduce gives the same result result with the case of no hook registered.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

# Get GPU model with the built-in allreduce communication hook.
gpu_model = self._gpu_model_with_builtin_ddp_comm_hook(
process_group, dist.BuiltinCommHookType.ALLREDUCE, gradient_as_bucket_view)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

# check whether the grads are equal to what DDP without hook would return.
self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_comm_hook_allreduce_hook_nccl(self):
self._test_ddp_comm_hook_allreduce_hook_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_builtin_ddp_comm_hook_allreduce_hook_nccl(self):
self._test_builtin_ddp_comm_hook_allreduce_hook_nccl()

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_builtin_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
self._test_builtin_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)

@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {

// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
// Still need TORCH_PYTHON_API instead of TORCH_API to support Windows platform.
template <typename T>
class TORCH_API CppCommHookInterface : public CommHookInterface {
class TORCH_PYTHON_API CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T& state) : state_(state) {}

Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/distributed/c10d/default_comm_hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,27 @@

namespace c10d {

enum class BuiltinCommHookType {
NONE = 0,
ALLREDUCE = 1,
FP16_COMPRESS = 2,
};

class AllReduceCommHook : public CppCommHookInterface<ProcessGroup*> {
public:
explicit AllReduceCommHook(ProcessGroup* state)
: CppCommHookInterface<ProcessGroup*>(state) {}

~AllReduceCommHook() override {}

c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;
};

class FP16CompressCommHook : public CppCommHookInterface<ProcessGroup*> {
public:
explicit FP16CompressCommHook(ProcessGroup* state)
: CppCommHookInterface<ProcessGroup*>(state) {}

~FP16CompressCommHook() override {}

c10::intrusive_ptr<torch::jit::Future> runHook(GradBucket& bucket) override;
Expand Down
46 changes: 33 additions & 13 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <torch/csrc/Exceptions.h>
#include <torch/csrc/distributed/c10d/comm.h>
#include <torch/csrc/distributed/c10d/default_comm_hooks.h>
#include <torch/csrc/distributed/c10d/reducer.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/object_ptr.h>
Expand Down Expand Up @@ -123,19 +124,25 @@ class PythonStore : public ::c10d::Store {
}
};

// This method is called from DDP's Python API. Its inputs are
// a c10d reducer object, state, and callable comm_hook. State and
// comm_hook inputs are Python objects and this function creates a
// c10d PythonCommHook object using these inputs. It later calls
// register_comm_hook function of the reducer input to register that
// PythonCommHook object.
// Called from DDP's Python API to create a c10d Python comm hook object.
// The input state and callable comm_hook are Python objects. It later calls
// register_comm_hook function of the reducer input to register the hook.
void _register_comm_hook(
::c10d::Reducer& reducer,
py::object state,
py::object comm_hook) {
reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
std::move(state), std::move(comm_hook)));
};
}

// Called from DDP's Python API to create a c10d C++ comm hook.
// The input is an enum hook type. It later calls register_builtin_comm_hook
// function of the reducer input to set the hook type.
void _register_builtin_comm_hook(
::c10d::Reducer& reducer,
::c10d::BuiltinCommHookType comm_hook_type) {
reducer.register_builtin_comm_hook(comm_hook_type);
}

PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
C10_LOG_API_USAGE_ONCE("c10d.python.import");
Expand All @@ -146,12 +153,19 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {

auto module = py::handle(c10d_module).cast<py::module>();

module.def(
"_register_comm_hook",
&_register_comm_hook,
py::arg("reducer"),
py::arg("state"),
py::arg("comm_hook"));
module
.def(
"_register_comm_hook",
&_register_comm_hook,
py::arg("reducer"),
py::arg("state"),
py::arg("comm_hook"),
py::call_guard<py::gil_scoped_release>())
.def(
"_register_builtin_comm_hook",
&_register_builtin_comm_hook,
py::arg("reducer"),
py::arg("comm_hook_type"));

shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
.def(py::init<std::vector<Tensor>&>(), py::arg("tensors"))
Expand All @@ -167,6 +181,12 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
a single tensor.
)");

py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"(
An enum-like class for built-in communication hooks: ``NONE``, ``ALLREDUCE``, and ``FP16_COMPRESS``.)")
.value("NONE", ::c10d::BuiltinCommHookType::NONE)
.value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE)
.value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS);

shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
.def(
py::init<
Expand Down
33 changes: 32 additions & 1 deletion torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ Reducer::Reducer(
// used for algorithms like Gradient Compression/GossipGrad. This hook can be
// registered from Python API using `register_comm_hook`. `PythonCommHook`
// enables registering a Python hook and is a subclass of `CommHookInterface`.
// `CommHookInterface` can be used to implement CPP hooks in the future.
// Additionally, there are also some built-in C++ hook implementations that can
// be specified by calling `register_builtin_comm_hook` from Python API.

Reducer::~Reducer() noexcept(false) {
// Remove all hooks on variables registered by this Reducer. This is necessary
Expand Down Expand Up @@ -530,6 +531,23 @@ void Reducer::autograd_hook(VariableIndex index) {
// one replica.
push_rebuilt_params(index);

if (comm_hook_ == nullptr &&
builtin_comm_hook_type_ != c10d::BuiltinCommHookType::NONE) {
switch (builtin_comm_hook_type_) {
case c10d::BuiltinCommHookType::ALLREDUCE:
comm_hook_ =
std::make_unique<c10d::AllReduceCommHook>(process_group_.get());
break;
case c10d::BuiltinCommHookType::FP16_COMPRESS:
comm_hook_ =
std::make_unique<c10d::FP16CompressCommHook>(process_group_.get());
break;
default:
TORCH_WARN_ONCE(
"Unknown built-in DDP comm hook type is provided. No comm hook will be used.");
}
}
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

// If `find_unused_parameters_` is true there may be model parameters that
// went unused when computing the model output, they won't be part of the
// autograd graph, and won't receive gradients. These parameters are
Expand Down Expand Up @@ -1367,6 +1385,19 @@ void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
comm_hook_ = std::move(iface);
}

// See Note [DDP Communication Hook]
void Reducer::register_builtin_comm_hook(
c10d::BuiltinCommHookType comm_hook_type) {
TORCH_CHECK(
comm_hook_ == nullptr,
"register_builtin_comm_hook can only be called once.");
TORCH_CHECK(
replicas_.size() == 1,
"Communication hook does not support single-process multiple-device mode.");

builtin_comm_hook_type_ = comm_hook_type;
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
}

void Reducer::ensure_prior_reduction_finished() {
// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/c10d/comm.h>
#include <torch/csrc/distributed/c10d/default_comm_hooks.h>

namespace c10d {

Expand Down Expand Up @@ -58,8 +59,14 @@ class Reducer {
// Registers a hook to the reducer. The hook is `CommHookInterface`
// type to allow both Python and CPP hooks. This function can only
// be called once before calling backward.
// Cannot combine with the call of `register_builtin_comm_hook`.
void register_comm_hook(std::unique_ptr<CommHookInterface> iface);

// Registers a built-in C++ comm hook to the reducer. This function can only
// be called once before calling backward.
// Cannot combine with the call of `register_comm_hook`.
void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);

// Returns a vector of tensors in each bucket in sequential order.
std::vector<std::vector<at::Tensor>> get_bucket_tensors() const;

Expand Down Expand Up @@ -347,6 +354,8 @@ class Reducer {
private:
// comm_hook_ is used to access the DDP communication hook if registered.
std::unique_ptr<CommHookInterface> comm_hook_;

c10d::BuiltinCommHookType builtin_comm_hook_type_;
};

// This is equivalent to take_tensors but returns indices into the
Expand Down
16 changes: 16 additions & 0 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from functools import partial

import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -38,8 +39,23 @@ def register_ddp_comm_hook(
to the DDP model. User can specify the type of hook as an enum
``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
be passed to the model.
Uses Python comm hook implementations.

Example::
>>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
"""
comm_hook_type.value(model=model, state=state)


def register_builtin_ddp_comm_hook(comm_hook_type: dist.BuiltinCommHookType, model: DistributedDataParallel):
"""
Registers the hooks of ``torch/csrc/distributed/c10d/default_comm_hooks.h``
to the DDP model. User can specify the type of hook as an enum
``dist.BuiltinCommHookType`` type using ``comm_hook_type`` input.
Uses C++ comm hook implementations.

Example::
>>> register_builtin_ddp_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS, model)
"""

model.reducer._register_builtin_comm_hook(comm_hook_type)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
32 changes: 28 additions & 4 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>

namespace c10d {

Expand Down Expand Up @@ -302,8 +304,29 @@ class ProcessGroupNCCL : public ProcessGroup {

// Do not free the underlying data storage of value_ before its
// usage on futureNCCLCallbackStream_ finish.
TORCH_INTERNAL_ASSERT(record_stream_cb_);
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
if (record_stream_cb_ != nullptr) {
// If a Python communication hook is used, record_stream_cb_ will be
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python
// dependency to be imported.
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
} else {
// If a C++ communication hook is used, create and set a record stream
// callback.
TORCH_INTERNAL_ASSERT(
value_.isTensorList() || value_.isTensor(),
"the future value must be either a tensor list or a tensor.");
at::Tensor tensor;
if (value_.isTensorList()) {
const auto tensors = value_.toTensorVector();
TORCH_INTERNAL_ASSERT(
tensors.size() == 1, "expected exactly 1 tensor");
tensor = tensors[0];
} else {
tensor = value_.toTensor();
}
c10::cuda::CUDACachingAllocator::recordStream(
tensor.storage().data_ptr(), *futureNCCLCallbackStream_);
}

// Use the dedicated callback stream to run callback.
// Cannot move capture std::function in lambda, because it cannot deduce
Expand Down Expand Up @@ -558,7 +581,8 @@ class ProcessGroupNCCL : public ProcessGroup {
// This function iterates through the list of WorkNCCL objects in the
// workList_ corresponding to incomplete collectives and then aborts NCCL
// communicators associated with timed out collectives.
void abortTimedOutCollectives(std::unordered_set<std::string>& abortedCommIds);
void abortTimedOutCollectives(
std::unordered_set<std::string>& abortedCommIds);

void workCleanupLoop();

Expand Down Expand Up @@ -703,6 +727,6 @@ class ProcessGroupNCCL : public ProcessGroup {
// by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
// is called.
static thread_local uint64_t ncclActiveGroupCounter_;
};
}; // namespace c10d
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

} // namespace c10d