Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fused optimizer for DP #98270

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
},
py::call_guard<py::gil_scoped_release>())
.def(
"_set_grads_to_none",
[](::c10d::Reducer& reducer) { reducer.set_grads_to_none(true); },
"_set_optimizer_in_backward",
[](::c10d::Reducer& reducer) { reducer.set_optimizer_in_backward(); },
py::call_guard<py::gil_scoped_release>())
.def(
"_set_mixed_precision_param_dtype",
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,9 +1464,9 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
}

if (!gradient_as_bucket_view_) {
if (set_grads_to_none_) {
if (optim_in_backward_) {
// Return early has optimizer has already run.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe typo?

Suggested change
// Return early has optimizer has already run.
// Return early since optimizer has already run.

runGradCallbackForVariable(variable, [&](auto& grad) {
grad.reset();
return true;
});
} else {
Expand All @@ -1486,8 +1486,8 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
bucket_view_in.copy_(bucket_view_out);
}
runGradCallbackForVariable(variable, [&](auto& grad) {
if (set_grads_to_none_) {
grad.reset();
if (optim_in_backward_) {
// Return early has optimizer has already run.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: maybe typo?

return true;
}
// If a parameter is globally unused, we keep its grad untouched.
Expand Down Expand Up @@ -1806,10 +1806,6 @@ void Reducer::register_builtin_comm_hook(
}
}

void Reducer::set_grads_to_none(bool set_to_none) {
set_grads_to_none_ = set_to_none;
}

void Reducer::ensure_prior_reduction_finished() {
// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/distributed/c10d/reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class TORCH_API Reducer {
// Cannot combine with the call of `register_comm_hook`.
void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);

// If set_to_none=True, reducer will set gradients to None in
// finalize_backward callback.
void set_grads_to_none(bool set_to_none);
// Informs reducer that optimizer is running in backward, so gradients
// don't need to be copied from buckets as the optimizer would've already
// been applied.
void set_optimizer_in_backward() { optim_in_backward_ = true; };

// Runs allreduce or installed communication hook given GradBucket instance.
c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(
Expand Down Expand Up @@ -524,7 +525,7 @@ class TORCH_API Reducer {
// are rebuilt after which this mapping is static.
mutable std::unordered_map<size_t, std::vector<at::Tensor>> cached_variables_for_bucket_;

bool set_grads_to_none_{false};
bool optim_in_backward_{false};
friend class Logger;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def apply_optim_in_backward_hook(
hook_state: Any, bucket: dist.GradBucket, optim_stream_state,
) -> torch.futures.Future[torch.Tensor]:
# Run original hook
reducer_weakref, process_group = hook_state
fut = reducer_weakref()._run_allreduce_hook(bucket)
ddp_weakref = hook_state
ddp_inst = ddp_weakref()
reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
fut = reducer._run_allreduce_hook(bucket)
optimizer_stream = optim_stream_state.optim_stream
with torch.cuda.stream(optimizer_stream):
fut.wait()
Expand All @@ -86,11 +88,16 @@ def apply_optim_in_backward_hook(
ret_fut.set_result(bucket.buffer())

# enqueue a callback to wait for this optimizer stream at the end of
# backward.
# backward and set all DDP managed grads to None.
def wait_for_optim_stream_callback():
torch.cuda.current_stream().wait_stream(
optim_stream_state.optim_stream
)
# Set DDP managed grads to None
for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
if hasattr(param, '_in_backward_optimizers'):
param.grad = None

# reset for the next backwards pass
optim_stream_state.wait_for_optim_stream_enqueued = False

Expand Down
43 changes: 16 additions & 27 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,16 +929,16 @@ def _setup_in_backward_optimizers(self):
# Check if user has used apply_optim_in_backward to overlap optimizer
# step + DDP backward. Current constraints:
# 1. Only allreduce is supported at the moment, no custom communication.
# 2. The reducer by default sets all grads for parameters DDP manages to
# None after they have been applied by the optimizer. There is no support
# for setting only some parameter grads to None, this must be done manually
# by user (and DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0 needs to be set.)
# If your use case requires some DDP managed parameters to run with
# an in-backward optimizer and some with a traditional optimizer, please
# ping https://github.com/pytorch/pytorch/issues/90052.
# 2. For DDP-managed parameters that have their optimizer run in
# backward, their gradients are set to ``None``. If your use case
# requires DDP parameters grad not to be set to ``None`` after their
# in-backward optimizer runs, please ping
# https://github.com/pytorch/pytorch/issues/90052.
# NOTE: we use self._module_parameters instead of .parameters() since
# the former excludes ignored (non-DDP managed) parameters.
if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
if any(
hasattr(p, '_in_backward_optimizers') for p in self._module_parameters
):
# Remove hooks that apply_optim_in_backward had registered because
# DDP customizes how optimizer is overlapped with backward due to
# the allreduce.
Expand All @@ -949,36 +949,22 @@ def _setup_in_backward_optimizers(self):
for handle in param_to_handle_map.get(p, []):
handle.remove()

# Need a weakref to the reducer in order to run all_reduce.
reducer_weakref = weakref.ref(self.reducer)
# Need a weakref to DDP instance to run all_reduce (from reducer)
# and get managed DDP parameters.
ddp_weakref = weakref.ref(self)
# Note: importing in function, otherwise this will cause a circular
# import.
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
_apply_optim_in_backward_hook,
)

self.register_comm_hook(
(reducer_weakref, self.process_group),
ddp_weakref,
_apply_optim_in_backward_hook(
gradient_is_bucket_view=self.gradient_as_bucket_view
),
)

# TODO (rohan-varma): this is a workaround that allows users to
# disable the default behavior of DDP managed parameters with
# optimizer runing in backwards having their gradients all set to None.
# Currently, it is an "all or nothing behavior" where DDP will set
# no grads to None or all of them, relaxing this behavior will be
# done dependent on use cases.
if os.getenv("DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE", "1") != "0":
warnings.warn(
"DDP + apply_optim_in_backward will currently set all "
"parameter gradients to None. If this is not the desired "
"behavior, please set env variable "
"DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0, and manually set"
"gradients to None/zero as desired."
)
self.reducer._set_grads_to_none() # type: ignore[attr-defined]
self.reducer._set_optimizer_in_backward()

def _fire_reducer_autograd_hook(self, idx, *unused):
"""
Expand Down Expand Up @@ -2148,6 +2134,9 @@ def _distributed_rank(self):

@staticmethod
def _get_data_parallel_params(module, named_params=False):
"""
Returns a generator of parameters managed by a given DDP unit.
"""
for param in (
module.parameters() if not named_params else module.named_parameters()
):
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4977,7 +4977,7 @@ def _test_ddp_apply_optim_in_backward(
with torch.backends.cudnn.flags(
enabled=True, deterministic=True, benchmark=False
):
for i in range(100):
for i in range(8):
inp = (
torch.randn(1, 3, 1000, 1000, device="cuda")
if j == 1
Expand Down