-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][optim_state_dict][6/N] Refactor the optim_state_dict APIs to support hooks #90798
Conversation
…upport hooks [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90798
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit c08d03d: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…upport hooks ghstack-source-id: 19f16d686b7ed549af2a4b2030ab87a629f5572d Pull Request resolved: #90798
…t APIs to support hooks" [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@@ -1204,6 +1204,7 @@ def _unflatten_process_groups( | |||
def _optim_state_dict( | |||
model: torch.nn.Module, | |||
optim: torch.optim.Optimizer, | |||
optim_state_dict: Dict[str, Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob q: do we expect this to be the vanilla state_dict from optim.state_dict()
or named_optim.state_dict()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both are accepted.
The internal API that is used by all the optim_state_dict implementations. | ||
""" | ||
if full_state_dict: | ||
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: to avoid confusion, might be worth adding a comment here or in doc of things function to clarify the existing surfaces for which optim state checkpointing works (i.e. the product of use_orig_params, rank0_only, sharded checkpoint, etc).
|
||
use_orig_params = False | ||
for module in FullyShardedDataParallel.fsdp_modules(model): | ||
use_orig_params = module._use_orig_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we concerned about potential inconsistency here? should we check to ensure the setting is the same for all modules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, can error out if that's not true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed FSDP to enforce same use_orig_params
for all in the same tree in #90871.
use_orig_params = False | ||
for module in FullyShardedDataParallel.fsdp_modules(model): | ||
use_orig_params = module._use_orig_params | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like we just take the setting of the first module - should we just do
use_orig = next(FSDP.fsdp_modules(model)).use_orig_params
) | ||
|
||
@staticmethod | ||
def _optim_state_dict_to_load_impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just for consistency might be good to add a docstring here similar to above API.
True, | ||
use_orig_params, | ||
) | ||
if is_named_optimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be useful to have a small comment here saying that NamedOptim expects the keys to be FQNs instead of regular optimizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add this in the next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
…t APIs to support hooks" **What does this PR do?** This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
…t APIs to support hooks" **What does this PR do?** This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / cuda11.6-py3.10-gcc7-sm86 / test (default, 3, 4, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "The failing test is not related" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…upport hooks (pytorch#90798) **What does this PR do?** This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`. Pull Request resolved: pytorch#90798 Approved by: https://github.com/rohan-varma, https://github.com/awgu
Stack from ghstack (oldest at bottom):
What does this PR do?
This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add
_optim_state_dict_post_hook
and_load_optim_state_dict_pre_hook
for the integration withNamedOptimzer
.