-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix wrapped module changing after ctor #87837
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87837
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Failures, 1 PendingAs of commit 3ac9c0d: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 2346ee6752e3b1930356e6462adda990aebb0636 Pull Request resolved: #87837
[ghstack-poisoned]
ghstack-source-id: 4595d556ad345c213c38e92c2dbc8a2d2a0c6bf5 Pull Request resolved: #87837
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. [ghstack-poisoned]
ghstack-source-id: 963850d8bc3bcf4a8eeff07594b84f6096c992a7 Pull Request resolved: #87837
@@ -220,9 +234,10 @@ def _validate_state_dict_contents( | |||
|
|||
@skip_if_lt_x_gpu(2) | |||
@parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) | |||
@parametrize("checkpoint_wrap", ["first", "second", "both"]) | |||
@parametrize("checkpoint_wrap", ["source", "dest", "both", "source_after_wrap"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, if we have both_after_wrap
, then with rank0_only_and_offload being False, would it be possible to reproduce the loading stuck error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"both_after_wrap"
with rank0_only_and_offload=False
produces a different error:
RuntimeError: Error(s) in loading state_dict for FullyShardedDataParallel:
While copying the parameter named "_fsdp_wrapped_module.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param", whose dimensions in the model are torch.Size([100]) and whose dimensions in the checkpoint are torch.Size([100]), an exception occurred : ('CUDA error: invalid argument\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
However, since it is still an error, I added the "both_after_wrap"
option to the unit test. We still need to figure out how to reproduce the load error.
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. [ghstack-poisoned]
ghstack-source-id: 44cfbc7cd4c37babdedc151567f91630d55ab903 Pull Request resolved: #87837
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. [ghstack-poisoned]
ghstack-source-id: 6f069b9970df34a028a13bfdbe906f58f5292036 Pull Request resolved: #87837
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. [ghstack-poisoned]
ghstack-source-id: 971ee81a4fba4ffc12cf41d0017932ec1e4c460b Pull Request resolved: #87837
once we decided to move checkpoint wrapper to hook based, we may think about whether we want to support this use case, and maybe just error out and say changing module structures after FSDP wrapping is not supported |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @awgu. |
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. Pull Request resolved: pytorch#87837 Approved by: https://github.com/zhaojuanmao
@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. Pull Request resolved: pytorch#87837 Approved by: https://github.com/zhaojuanmao
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`. If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`. The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`. The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution. Pull Request resolved: pytorch#87837 Approved by: https://github.com/zhaojuanmao
Stack from ghstack:
Recently, I retired
FlattenParamsWrapper
, which meant that FSDP registers itsFlatParameter
on the wrapped module instead of theFlattenParamsWrapper
instance. This is only relevant foruse_orig_params=False
.If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the
FlatParameter
is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if theFlatParameter
is currently registered as an early return condition forrank0_only=True
.The solution in this PR is to re-establish the wrapped module in
_lazy_init()
, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon_lazy_init()
.The direct access to the private attribute
_parameters
fromnn.Module
is not ideal, but we already rely on it for the dynamicFlatParameter
registration. The tradeoff is whether we want an additionalnn.Module
wrapper (FlattenParamsWrapper
) and usedelattr
plus a singleton list to do the dynamic registration or we want to access_parameters
. If this becomes a problem, we can work with Core team on a solution.Differential Revision: D40799962