-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP crashes when submodule calls method that isn't forward()
#109385
Comments
Hi @siddk. This is a known limitation of FSDP. Our design relies on If you have a way to workaround this for now (e.g. monkey patching), then that would be the shortest path for now. |
Awesome - thanks @awgu; I'm lucky in that I found the monkey patching thing to work just as I was writing up the minimal example for the bug report... would be super great to add this to the docs somewhere so others don't fall into the same trap! |
Facing the same problem today. Thanks for the open issue and discussion. |
… fwd methods" FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385). This PR adds a monkey patching API to allow FSDP pre/post-forward hooks to run on the method. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
#125394) FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385). This PR adds a monkey patching API `register_fsdp_forward_method(module: nn.Module, method_name: str)` to allow FSDP pre/post-forward hooks to run on the method. The function is a no-op if the passed-in `module` is not an FSDP module so that the register function can be called even if the FSDP wrapping changes. Pull Request resolved: #125394 Approved by: https://github.com/weifengpy, https://github.com/wanchaol
🐛 Describe the bug
I am getting various runtime errors given an FSDP module that wraps multiple children modules, where in the forward pass, we invoke a submodule's non-forward method. The autowrap policy wraps each submodule separately. The minimal example below should make this more clear:
Run with (at least 2 GPUs): `torchrun --standalone --nnodes 1 --nproc-per-node 2 <script.py>
This results in the following error message:
Further context: I'm working on a project where we take the patch features from a (frozen) Vision Transformer backbone and transform them into a different latent space where they're used to decode other modalities (e.g., depth).
This gist provides an annotated example that reflects our setup a bit better: https://gist.github.com/siddk/db3e8808bed2a9cb90ae62b5338de68d
Some other things I tried (to help speed along debugging) -- all of this is in the linked Gist:
- **Setting
use_orig_params=True
results in a different error at the same Conv2D call (RuntimeError: weight should have at least three dimensions
)- Freezing the ViT (as required in our original setup) results in yet another error at the Conv2D call (
RuntimeError: GET was unable to find an engine to execute this computation
)Interestingly, if we monkey patch the
vit
instance such thatvit.forward = vit.forward_features
and callself.vit(imgs)
inNet.forward()
-- all of these bugs disappear!Versions
cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu
The text was updated successfully, but these errors were encountered: