Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix wrapped module changing after ctor #87837

Closed
wants to merge 6 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Oct 27, 2022

Stack from ghstack:

Recently, I retired FlattenParamsWrapper, which meant that FSDP registers its FlatParameter on the wrapped module instead of the FlattenParamsWrapper instance. This is only relevant for use_orig_params=False.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the FlatParameter is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the FlatParameter is currently registered as an early return condition for rank0_only=True.

The solution in this PR is to re-establish the wrapped module in _lazy_init(), de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon _lazy_init().

The direct access to the private attribute _parameters from nn.Module is not ideal, but we already rely on it for the dynamic FlatParameter registration. The tradeoff is whether we want an additional nn.Module wrapper (FlattenParamsWrapper) and use delattr plus a singleton list to do the dynamic registration or we want to access _parameters. If this becomes a problem, we can work with Core team on a solution.

Differential Revision: D40799962

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87837

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures, 1 Pending

As of commit 3ac9c0d:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 2346ee6752e3b1930356e6462adda990aebb0636
Pull Request resolved: #87837
awgu added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 4595d556ad345c213c38e92c2dbc8a2d2a0c6bf5
Pull Request resolved: #87837
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 963850d8bc3bcf4a8eeff07594b84f6096c992a7
Pull Request resolved: #87837
test/distributed/fsdp/test_fsdp_misc.py Outdated Show resolved Hide resolved
@@ -220,9 +234,10 @@ def _validate_state_dict_contents(

@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
@parametrize("checkpoint_wrap", ["first", "second", "both"])
@parametrize("checkpoint_wrap", ["source", "dest", "both", "source_after_wrap"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, if we have both_after_wrap, then with rank0_only_and_offload being False, would it be possible to reproduce the loading stuck error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"both_after_wrap" with rank0_only_and_offload=False produces a different error:

RuntimeError: Error(s) in loading state_dict for FullyShardedDataParallel:
	While copying the parameter named "_fsdp_wrapped_module.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param", whose dimensions in the model are torch.Size([100]) and whose dimensions in the checkpoint are torch.Size([100]), an exception occurred : ('CUDA error: invalid argument\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.

However, since it is still an error, I added the "both_after_wrap" option to the unit test. We still need to figure out how to reproduce the load error.

torch/distributed/fsdp/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 44cfbc7cd4c37babdedc151567f91630d55ab903
Pull Request resolved: #87837
@awgu awgu requested a review from fegin October 27, 2022 13:56
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 6f069b9970df34a028a13bfdbe906f58f5292036
Pull Request resolved: #87837
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 27, 2022
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 971ee81a4fba4ffc12cf41d0017932ec1e4c460b
Pull Request resolved: #87837
@zhaojuanmao
Copy link
Contributor

once we decided to move checkpoint wrapper to hook based, we may think about whether we want to support this use case, and maybe just error out and say changing module structures after FSDP wrapping is not supported

@awgu
Copy link
Contributor Author

awgu commented Oct 27, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@awgu awgu added the topic: improvements topic category label Oct 28, 2022
sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.
Pull Request resolved: pytorch#87837
Approved by: https://github.com/zhaojuanmao
@awgu
Copy link
Contributor Author

awgu commented Oct 28, 2022

@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.
Pull Request resolved: pytorch#87837
Approved by: https://github.com/zhaojuanmao
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Recently, I retired `FlattenParamsWrapper`, which meant that FSDP registers its `FlatParameter` on the wrapped module instead of the `FlattenParamsWrapper` instance. This is only relevant for `use_orig_params=False`.

If the user changes an FSDP instance's wrapped module after the FSDP constructor, then the `FlatParameter` is no longer registered on the wrapped module. This can cause issues for full state dict, which checks if the `FlatParameter` is currently registered as an early return condition for `rank0_only=True`.

The solution in this PR is to re-establish the wrapped module in `_lazy_init()`, de-registering from the old wrapped module and re-registering to the new wrapped module, where the assumption is that the user should not modify the module structure upon `_lazy_init()`.

The direct access to the private attribute `_parameters` from `nn.Module` is not ideal, but we already rely on it for the dynamic `FlatParameter` registration. The tradeoff is whether we want an additional `nn.Module` wrapper (`FlattenParamsWrapper`) and use `delattr` plus a singleton list to do the dynamic registration or we want to access `_parameters`. If this becomes a problem, we can work with Core team on a solution.
Pull Request resolved: pytorch#87837
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/awgu/143/head branch June 8, 2023 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants