Skip to content

Commit

Permalink
Enable fused optimizer for DP
Browse files Browse the repository at this point in the history
Pull Request resolved: #98270

Enable DDP optimizer overlap for HPC 10x CMF, providing ~7% QPS gain.


ghstack-source-id: 185745827

Differential Revision: [D42714482](https://our.internmc.facebook.com/intern/diff/D42714482/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D42714482/)!
  • Loading branch information
rohan-varma committed Apr 11, 2023
1 parent fdfd370 commit d060567
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 45 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
},
py::call_guard<py::gil_scoped_release>())
.def(
"_set_grads_to_none",
[](::c10d::Reducer& reducer) { reducer.set_grads_to_none(true); },
"_set_optimizer_in_backward",
[](::c10d::Reducer& reducer) { reducer.set_optimizer_in_backward(); },
py::call_guard<py::gil_scoped_release>())
.def(
"_set_mixed_precision_param_dtype",
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,9 +1464,9 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
}

if (!gradient_as_bucket_view_) {
if (set_grads_to_none_) {
if (optim_in_backward_) {
// Return early has optimizer has already run.
runGradCallbackForVariable(variable, [&](auto& grad) {
grad.reset();
return true;
});
} else {
Expand All @@ -1486,8 +1486,8 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
bucket_view_in.copy_(bucket_view_out);
}
runGradCallbackForVariable(variable, [&](auto& grad) {
if (set_grads_to_none_) {
grad.reset();
if (optim_in_backward_) {
// Return early has optimizer has already run.
return true;
}
// If a parameter is globally unused, we keep its grad untouched.
Expand Down Expand Up @@ -1806,10 +1806,6 @@ void Reducer::register_builtin_comm_hook(
}
}

void Reducer::set_grads_to_none(bool set_to_none) {
set_grads_to_none_ = set_to_none;
}

void Reducer::ensure_prior_reduction_finished() {
// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/distributed/c10d/reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class TORCH_API Reducer {
// Cannot combine with the call of `register_comm_hook`.
void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);

// If set_to_none=True, reducer will set gradients to None in
// finalize_backward callback.
void set_grads_to_none(bool set_to_none);
// Informs reducer that optimizer is running in backward, so gradients
// don't need to be copied from buckets as the optimizer would've already
// been applied.
void set_optimizer_in_backward() { optim_in_backward_ = true; };

// Runs allreduce or installed communication hook given GradBucket instance.
c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(
Expand Down Expand Up @@ -524,7 +525,7 @@ class TORCH_API Reducer {
// are rebuilt after which this mapping is static.
mutable std::unordered_map<size_t, std::vector<at::Tensor>> cached_variables_for_bucket_;

bool set_grads_to_none_{false};
bool optim_in_backward_{false};
friend class Logger;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def apply_optim_in_backward_hook(
hook_state: Any, bucket: dist.GradBucket, optim_stream_state,
) -> torch.futures.Future[torch.Tensor]:
# Run original hook
reducer_weakref, process_group = hook_state
fut = reducer_weakref()._run_allreduce_hook(bucket)
ddp_weakref = hook_state
ddp_inst = ddp_weakref()
reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
fut = reducer._run_allreduce_hook(bucket)
optimizer_stream = optim_stream_state.optim_stream
with torch.cuda.stream(optimizer_stream):
fut.wait()
Expand All @@ -86,11 +88,16 @@ def apply_optim_in_backward_hook(
ret_fut.set_result(bucket.buffer())

# enqueue a callback to wait for this optimizer stream at the end of
# backward.
# backward and set all DDP managed grads to None.
def wait_for_optim_stream_callback():
torch.cuda.current_stream().wait_stream(
optim_stream_state.optim_stream
)
# Set DDP managed grads to None
for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
if hasattr(param, '_in_backward_optimizers'):
param.grad = None

# reset for the next backwards pass
optim_stream_state.wait_for_optim_stream_enqueued = False

Expand Down
43 changes: 16 additions & 27 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,16 +929,16 @@ def _setup_in_backward_optimizers(self):
# Check if user has used apply_optim_in_backward to overlap optimizer
# step + DDP backward. Current constraints:
# 1. Only allreduce is supported at the moment, no custom communication.
# 2. The reducer by default sets all grads for parameters DDP manages to
# None after they have been applied by the optimizer. There is no support
# for setting only some parameter grads to None, this must be done manually
# by user (and DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0 needs to be set.)
# If your use case requires some DDP managed parameters to run with
# an in-backward optimizer and some with a traditional optimizer, please
# ping https://github.com/pytorch/pytorch/issues/90052.
# 2. For DDP-managed parameters that have their optimizer run in
# backward, their gradients are set to ``None``. If your use case
# requires DDP parameters grad not to be set to ``None`` after their
# in-backward optimizer runs, please ping
# https://github.com/pytorch/pytorch/issues/90052.
# NOTE: we use self._module_parameters instead of .parameters() since
# the former excludes ignored (non-DDP managed) parameters.
if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
if any(
hasattr(p, '_in_backward_optimizers') for p in self._module_parameters
):
# Remove hooks that apply_optim_in_backward had registered because
# DDP customizes how optimizer is overlapped with backward due to
# the allreduce.
Expand All @@ -949,36 +949,22 @@ def _setup_in_backward_optimizers(self):
for handle in param_to_handle_map.get(p, []):
handle.remove()

# Need a weakref to the reducer in order to run all_reduce.
reducer_weakref = weakref.ref(self.reducer)
# Need a weakref to DDP instance to run all_reduce (from reducer)
# and get managed DDP parameters.
ddp_weakref = weakref.ref(self)
# Note: importing in function, otherwise this will cause a circular
# import.
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
_apply_optim_in_backward_hook,
)

self.register_comm_hook(
(reducer_weakref, self.process_group),
ddp_weakref,
_apply_optim_in_backward_hook(
gradient_is_bucket_view=self.gradient_as_bucket_view
),
)

# TODO (rohan-varma): this is a workaround that allows users to
# disable the default behavior of DDP managed parameters with
# optimizer runing in backwards having their gradients all set to None.
# Currently, it is an "all or nothing behavior" where DDP will set
# no grads to None or all of them, relaxing this behavior will be
# done dependent on use cases.
if os.getenv("DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE", "1") != "0":
warnings.warn(
"DDP + apply_optim_in_backward will currently set all "
"parameter gradients to None. If this is not the desired "
"behavior, please set env variable "
"DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0, and manually set"
"gradients to None/zero as desired."
)
self.reducer._set_grads_to_none() # type: ignore[attr-defined]
self.reducer._set_optimizer_in_backward()

def _fire_reducer_autograd_hook(self, idx, *unused):
"""
Expand Down Expand Up @@ -2151,6 +2137,9 @@ def _distributed_rank(self):

@staticmethod
def _get_data_parallel_params(module, named_params=False):
"""
Returns a generator of parameters managed by a given DDP unit.
"""
for param in (
module.parameters() if not named_params else module.named_parameters()
):
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4977,7 +4977,7 @@ def _test_ddp_apply_optim_in_backward(
with torch.backends.cudnn.flags(
enabled=True, deterministic=True, benchmark=False
):
for i in range(100):
for i in range(8):
inp = (
torch.randn(1, 3, 1000, 1000, device="cuda")
if j == 1
Expand Down

0 comments on commit d060567

Please sign in to comment.