Skip to content

Commit

Permalink
[FSDP] Support unfreezing params for reshard-only hook
Browse files Browse the repository at this point in the history
ghstack-source-id: acd123f95b2c3fd90555aeef08274292c0c30f9d
Pull Request resolved: #104186
  • Loading branch information
awgu committed Jun 26, 2023
1 parent 58feefa commit 1db9014
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 42 deletions.
52 changes: 37 additions & 15 deletions test/distributed/fsdp/test_fsdp_fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ def _init_seq_module(self) -> nn.Module:
for _ in range(self.NUM_LINEARS):
modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()]
seq = nn.Sequential(*modules)
# Freeze every other linear
self._set_seq_module_requires_grad(seq, False)
return seq

def _set_seq_module_requires_grad(self, seq: nn.Module, requires_grad: bool):
# Assume that the linears are leaf modules, meaning that we pass
# `requires_grad=True` to have this method work for both pre and post
# FSDP wrapping
for i in range(self.NUM_LINEARS):
# Only set for every other linear to test mixing frozen/non-frozen
if i % 2 == 0:
for param in seq[i * 2].parameters(recurse=False):
param.requires_grad = False
return seq
for param in seq[i * 2].parameters(recurse=True):
param.requires_grad = requires_grad

@skip_if_lt_x_gpu(2)
def test_backward_reshard_hooks(self):
Expand All @@ -66,6 +72,7 @@ def test_backward_reshard_hooks(self):
],
"use_orig_params": [False, True],
"inp_requires_grad": [False, True],
"unfreeze_params": [False, True],
},
self._test_backward_reshard_hooks,
)
Expand All @@ -75,6 +82,7 @@ def _test_backward_reshard_hooks(
sharding_strategy: ShardingStrategy,
use_orig_params: bool,
inp_requires_grad: bool,
unfreeze_params: bool,
):
seq = self._init_seq_module()
policy = ModuleWrapPolicy({nn.Linear})
Expand All @@ -98,17 +106,31 @@ def _post_backward_reshard_with_count(*args, **kwargs):
"torch.distributed.fsdp._runtime_utils._post_backward_reshard",
_post_backward_reshard_with_count,
):
inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad)
seq(inp).sum().backward()
# If the input does not require gradient, then the 0th frozen
# linear gets resharded in the catch-all reshard since we cannot
# register an autograd hook on it
expected_post_backward_reshard_count = (
self.NUM_LINEARS if inp_requires_grad else self.NUM_LINEARS - 1
)
self.assertEqual(
post_backward_reshard_count, expected_post_backward_reshard_count
)
num_steps = 2
for step_idx in range(num_steps):
if unfreeze_params and step_idx == num_steps - 1:
# Unfreeze the parameters on the last step to emulate some
# kinds of fine-tuning
self._set_seq_module_requires_grad(seq, True)

inp = torch.randn(
(8, 5), device="cuda", requires_grad=inp_requires_grad
)
seq(inp).sum().backward()
if step_idx < num_steps - 1 or not unfreeze_params:
# If the input does not require gradient, then the 0th
# frozen linear gets resharded in the catch-all reshard
# since we cannot register an autograd hook on it
expected_post_backward_reshard_count = (
self.NUM_LINEARS if inp_requires_grad else self.NUM_LINEARS - 1
)
else:
# This follows the normal post-backward hook path
expected_post_backward_reshard_count = self.NUM_LINEARS
self.assertEqual(
post_backward_reshard_count, expected_post_backward_reshard_count
)
post_backward_reshard_count = 0

@skip_if_lt_x_gpu(2)
def test_parity_with_ddp(self):
Expand Down
7 changes: 5 additions & 2 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
_ext_chunk_dtensor,
_ext_chunk_tensor,
)
from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init
from torch.distributed.fsdp._runtime_utils import (
_lazy_init,
_reset_flat_param_grad_info_if_needed,
)
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle
Expand Down Expand Up @@ -1302,7 +1305,7 @@ def _optim_state_dict(
:meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``,
then nonzero ranks return an empty :class:`dict`.
"""
_clear_grads_if_needed(traversal_utils._get_fsdp_handles(model))
_reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model))
to_save = not rank0_only or (dist.get_rank(group) == 0 or shard_state)
fsdp_osd: Dict[str, Any] = {"state": {}} if to_save else {}
fsdp_osd_state: Dict[str, Any] = fsdp_osd["state"] if to_save else {}
Expand Down
32 changes: 18 additions & 14 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def _root_pre_forward(
state._streams_unshard,
state._streams_pre_unshard,
)
_clear_grads_if_needed(state._all_handles)
_reset_flat_param_grad_info_if_needed(state._all_handles)

# Prepares the forward inputs by moving them to ``compute_device``
# TODO: Do not use the side stream for tensor copies for now; investigate
Expand Down Expand Up @@ -683,7 +683,7 @@ def _pre_backward_hook(
# after all backward calls complete
if state._is_root and not state._post_backward_callback_queued:
_register_post_backward_final_callback(state, module)
_clear_grads_if_needed(state._all_handles)
_reset_flat_param_grad_info_if_needed(state._all_handles)
elif _handles_key:
allowed_states = [TrainingState.IDLE]
if _is_composable(state):
Expand Down Expand Up @@ -1046,6 +1046,10 @@ def _catch_all_reshard(
already_resharded = (
handle.flat_param.data_ptr()
== handle.flat_param._local_shard.data_ptr()
# If FSDP skipped using sharded views, then the flat parameter
# still points to the sharded data, so we need to reshard to
# use sharded views
and not handle._skipped_use_sharded_views
)
if already_resharded:
continue
Expand All @@ -1069,16 +1073,16 @@ def _finalize_params(
"""Finalizes the parameters before the next iteration."""
for handle in state._handles:
flat_param = handle.flat_param
if hasattr(flat_param, "_post_backward_hook_state"):
post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
_p_assert(
post_backward_hook_state_len == expected_post_backward_hook_state_len,
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
)
flat_param._post_backward_hook_state[-1].remove()
delattr(flat_param, "_post_backward_hook_state")
if flat_param.requires_grad:
if hasattr(flat_param, "_post_backward_hook_state"):
post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
_p_assert(
post_backward_hook_state_len == 1
or post_backward_hook_state_len == 2,
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
)
flat_param._post_backward_hook_state[-1].remove()
delattr(flat_param, "_post_backward_hook_state")
if not state._sync_gradients:
# Preserve the gradient accumulation state if not synchronizing
# gradients: `.grad` remains the unsharded gradient from prior
Expand Down Expand Up @@ -1400,7 +1404,7 @@ def _register_post_backward_reshard_only_hooks(
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(_post_backward_reshard, state, handle)
)
handle.flat_param._post_backward_hook_state = hook_handle # type: ignore[attr-defined]
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined]


@no_type_check
Expand Down Expand Up @@ -1442,7 +1446,7 @@ def _wait_for_computation_stream(
pre_unshard_stream.wait_stream(computation_stream)


def _clear_grads_if_needed(
def _reset_flat_param_grad_info_if_needed(
handles: List[FlatParamHandle],
):
"""
Expand All @@ -1452,7 +1456,7 @@ def _clear_grads_if_needed(
"""
for handle in handles:
if handle._use_orig_params:
handle._clear_grads_if_needed()
handle._reset_flat_param_grad_info_if_needed()


@no_type_check
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
)
from torch.distributed.fsdp._runtime_utils import (
_cast_buffers_to_dtype_and_device,
_clear_grads_if_needed,
_get_orig_buffer_dtypes,
_lazy_init,
_reset_flat_param_grad_info_if_needed,
)
from torch.distributed.fsdp.api import (
FullStateDictConfig,
Expand Down Expand Up @@ -142,7 +142,7 @@ def _common_pre_state_dict_hook(
# TODO: need to check if this is always correct for composable FSDP.
_lazy_init(fsdp_state, module)
if fsdp_state._is_root:
_clear_grads_if_needed(fsdp_state._all_handles)
_reset_flat_param_grad_info_if_needed(fsdp_state._all_handles)


def _common_unshard_pre_state_dict_hook(
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/fsdp/_unshard_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
TrainingState,
)
from torch.distributed.fsdp._runtime_utils import (
_clear_grads_if_needed,
_get_fsdp_root_states_with_modules,
_lazy_init,
_reset_flat_param_grad_info_if_needed,
_reshard,
_reshard_grads,
_unshard,
Expand Down Expand Up @@ -190,7 +190,7 @@ def _unshard_fsdp_state_params(
for handle in handles:
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS

_clear_grads_if_needed(handles)
_reset_flat_param_grad_info_if_needed(handles)
free_unsharded_flat_params = [handle.needs_unshard() for handle in handles]
# No need to call `wait_stream()` since we unshard in the computation
# stream directly
Expand Down
25 changes: 18 additions & 7 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,20 +2214,31 @@ def _writeback_tensor(
assert self.flat_param._is_grad_none_mask is not None
self.flat_param._is_grad_none_mask[tensor_index] = True

def _clear_grads_if_needed(self):
"""
When ``use_orig_params=True``, sets the underlying ``flat_param.grad``
to ``None`` if *all* of the original parameters' ``.grad`` are
``None``. This is targeting ``optim.zero_grad(set_to_none=True)``, in
def _reset_flat_param_grad_info_if_needed(self):
"""
When ``use_orig_params=True``:
(1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
original parameters' ``.grad`` are ``None``, and
(2) sets ``flat_param.requires_grad=False`` if *none* of the original
parameters require gradient.
For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
which case we want to free the gradients as soon after the
``zero_grad()`` call as possible.
"""
if not self._use_orig_params:
return
flat_param = self.flat_param
assert flat_param._params is not None
if all(param.grad is None for param in flat_param._params):
assert flat_param._params is not None # mypy
all_grad_none = True
requires_grad = False
for param in flat_param._params:
all_grad_none &= param.grad is None
requires_grad |= param.requires_grad
if all_grad_none:
flat_param.grad = None
# As long as one parameter requires gradient, then the flat parameter
# must require gradient
flat_param.requires_grad = requires_grad

def _deregister_orig_params(self):
for param_info in self.flat_param._param_infos:
Expand Down

0 comments on commit 1db9014

Please sign in to comment.