Skip to content

Commit

Permalink
Update on "[RFC][FSDP2] Added register_fsdp_forward_method for user…
Browse files Browse the repository at this point in the history
… fwd methods"


FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385).

This PR adds a monkey patching API to allow FSDP pre/post-forward hooks to run on the method.

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
  • Loading branch information
awgu committed May 2, 2024
1 parent fcad622 commit 4428b8c
Showing 1 changed file with 1 addition and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1147,9 +1147,7 @@ def world_size(self) -> int:

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_register_fsdp_forward_method(self):
"""
Based on https://github.com/pytorch/pytorch/issues/109385
"""
"""Based on https://github.com/pytorch/pytorch/issues/109385"""

class VisionTransformer(nn.Module):
def __init__(self):
Expand Down

0 comments on commit 4428b8c

Please sign in to comment.