Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][optim_state_dict][6/N] Refactor the optim_state_dict APIs to support hooks #90798

Closed
wants to merge 4 commits into from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Dec 13, 2022

Stack from ghstack (oldest at bottom):

What does this PR do?

This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add _optim_state_dict_post_hook and _load_optim_state_dict_pre_hook for the integration with NamedOptimzer.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 13, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90798

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit c08d03d:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

fegin added a commit that referenced this pull request Dec 13, 2022
…upport hooks

ghstack-source-id: 19f16d686b7ed549af2a4b2030ab87a629f5572d
Pull Request resolved: #90798
@fegin fegin marked this pull request as draft December 13, 2022 21:36
@fegin fegin marked this pull request as ready for review December 20, 2022 08:26
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@@ -1204,6 +1204,7 @@ def _unflatten_process_groups(
def _optim_state_dict(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Dict[str, Any],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: do we expect this to be the vanilla state_dict from optim.state_dict() or named_optim.state_dict()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are accepted.

The internal API that is used by all the optim_state_dict implementations.
"""
if full_state_dict:
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to avoid confusion, might be worth adding a comment here or in doc of things function to clarify the existing surfaces for which optim state checkpointing works (i.e. the product of use_orig_params, rank0_only, sharded checkpoint, etc).


use_orig_params = False
for module in FullyShardedDataParallel.fsdp_modules(model):
use_orig_params = module._use_orig_params
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we concerned about potential inconsistency here? should we check to ensure the setting is the same for all modules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can error out if that's not true.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed FSDP to enforce same use_orig_params for all in the same tree in #90871.

use_orig_params = False
for module in FullyShardedDataParallel.fsdp_modules(model):
use_orig_params = module._use_orig_params
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we just take the setting of the first module - should we just do

use_orig  = next(FSDP.fsdp_modules(model)).use_orig_params

)

@staticmethod
def _optim_state_dict_to_load_impl(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just for consistency might be good to add a docstring here similar to above API.

True,
use_orig_params,
)
if is_named_optimizer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be useful to have a small comment here saying that NamedOptim expects the keys to be FQNs instead of regular optimizers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add this in the next PR.

Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

…t APIs to support hooks"


**What does this PR do?**

This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`.



[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Dec 21, 2022

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 21, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 1521bf7dfc3da19cdddddd86d675b3c6932b77e7 returned non-zero exit code 1

Auto-merging torch/distributed/fsdp/_optim_utils.py
Auto-merging torch/distributed/fsdp/fully_sharded_data_parallel.py
CONFLICT (content): Merge conflict in torch/distributed/fsdp/fully_sharded_data_parallel.py
error: could not apply 1521bf7dfc... [FSDP][optim_state_dict][6/N] Refactor the optim_state_dict APIs to support hooks
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
Details for Dev Infra team Raised by workflow job

…t APIs to support hooks"


**What does this PR do?**

This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`.



[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Dec 21, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 additional jobs have failed, first few of them are: trunk ,trunk / cuda11.6-py3.10-gcc7-sm86 / test (default, 3, 4, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@fegin
Copy link
Contributor Author

fegin commented Dec 21, 2022

@pytorchbot merge -f "The failing test is not related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ShisuiUzumaki pushed a commit to ShisuiUzumaki/pytorch that referenced this pull request Dec 23, 2022
…upport hooks (pytorch#90798)

**What does this PR do?**

This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`.

Pull Request resolved: pytorch#90798
Approved by: https://github.com/rohan-varma, https://github.com/awgu
@facebook-github-bot facebook-github-bot deleted the gh/fegin/52/head branch June 8, 2023 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants