Skip to content

Commit

Permalink
Update on "[FSDP][2/N] Remove params_with_grad"
Browse files Browse the repository at this point in the history
This PR removes the property `params_with_grad` from `FullyShardedDataParallel`. It was introduced when implementing `clip_grad_norm_()` but was not consistently used. Personally, I do not think it makes sense for `FullyShardedDataParallel` to expose this helper because it is not a common paradigm.

This PR is technically BC-breaking. However, I checked that no one internally is using this API.


cc @ezyang @gchanan

[ghstack-poisoned]
  • Loading branch information
awgu committed Oct 24, 2022
2 parents 315aa94 + 13ed8a8 commit d69725e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,7 +2485,7 @@ def state_dict(self, *args, **kwargs):
>>> local_dict.keys()
>>> odict_keys(['flat_param', 'inner.flat_param'])
.. warning:: This needs to be called on all ranks since it calls
.. warning:: This needs to be called on all ranks since it uses
collective communications.
"""
# TODO (rohan-varma): separate these out once a state_dict pre-hook
Expand Down Expand Up @@ -2782,7 +2782,7 @@ def load_state_dict(
>>> local_dict.keys()
>>> odict_keys(['flat_param', 'inner.flat_param'])
.. warning:: This needs to be called on all ranks since it calls
.. warning:: This needs to be called on all ranks since it uses
collective communications.
"""
return super().load_state_dict(state_dict, *args)
Expand Down Expand Up @@ -3875,7 +3875,7 @@ def clip_grad_norm_(
calling it for FSDP models would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks since it calls
.. warning:: This needs to be called on all ranks since it uses
collective communications.
"""
self._lazy_init()
Expand Down Expand Up @@ -3961,7 +3961,7 @@ def full_optim_state_dict(
and ``"param_groups"``. The flattened parameters in ``FSDP`` modules
contained in ``model`` are mapped back to their unflattened parameters.
.. warning:: This needs to be called on all ranks since it calls
.. warning:: This needs to be called on all ranks since it uses
collective communications. However, if ``rank0_only=True``, then
the state dict is only populated on rank 0, and all other ranks
return an empty :class:`dict`.
Expand Down

0 comments on commit d69725e

Please sign in to comment.