Skip to content

Support gradient clipping by norm with FSDP #19235

@awaelchli

Description

@awaelchli
Contributor

Description & Motivation

Our current implementation of gradient clipping for FSDP is limited to clipping by value only. Norm is not supported:

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

The reason is that clipping by norm needs to be called through the FSDP API and this wasn't realized in Lightning yet, because it can't be done directly through the optimizer (the FSDP module reference is required): https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

Pitch

Support clipping by norm.

Change the API from

class Precision:
    ...
    def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
        ...

to

class Precision:
    ...
    def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
        ...

to take the module as input. The implementation in FSDPPrecision would then call module.clip_grad_norm() instead of torch.nn.utils.clip_grad_norm_.

The LightningModule.clip_gradients() method should then pass self.trainer.model to self.trainer.precision_plugin.clip_gradients().

Alternatives

There is not much else we can do. I believe the proposal above will lead to the least amount of breaking changes (only affects the signature of the precision plugin methods.

Additional context

In Fabric's precision plugins, this is already done. We would need to do this on the Trainer side anyway sooner or later, if we want to unify the precision/strategy implementations.

cc @Borda @awaelchli @carmocca

Activity

added
featureIs an improvement or enhancement
needs triageWaiting to be triaged by maintainers
and removed
needs triageWaiting to be triaged by maintainers
on Jan 4, 2024
added this to the 2.2 milestone on Jan 4, 2024
modified the milestones: 2.2, 2.3 on Feb 3, 2024
xin-w8023

xin-w8023 commented on Apr 10, 2024

@xin-w8023

Any updates about this?

modified the milestones: 2.3, future on Jun 2, 2024
amorehead

amorehead commented on Sep 12, 2024

@amorehead
Contributor

Agreed, any updates?

linked a pull request that will close this issue on May 3, 2025
amorehead

amorehead commented on May 3, 2025

@amorehead
Contributor

I've taken a shot at implementing a PR for this issue here. Open to any and all feedback on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementstrategy: fsdpFully Sharded Data Parallel

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

      Development

      Participants

      @awaelchli@amorehead@xin-w8023

      Issue actions

        Support gradient clipping by norm with FSDP · Issue #19235 · Lightning-AI/pytorch-lightning