-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DSD] Implement broadcast_from_rank0 option for optim state_dict #125339
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125339
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit f3bb51d with merge base 196a0b1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
) | ||
if equal: | ||
self.assertEqual(states, fsdp_states) | ||
def check(equal): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to check get_optimizer_state_dict
as well? in torchtune, we call model and optimizer sd separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, somehow that was removed during rebasing. Sorry for the confusion, will add it back.
Summary: This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict. ghstack-source-id: d6131de1ec30b64453b4b8179c48177f2b3b00d7 Pull Request resolved: #125339
@@ -683,11 +693,33 @@ def _load_optim_state_dict( | |||
optim_state_dict = FSDP.optim_state_dict_to_load( | |||
model, optim, optim_state_dict | |||
) | |||
elif info.broadcast_from_rank0: | |||
info.full_state_dict = False | |||
local_state_dict = _get_optim_state_dict(model, (optim,), info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we using _get_optim_state_dict
instead of torch.optim.Optimizer.state_dict
since we want to align FQNs between local_state_dict
and optim_state_dict
? torch.optim.Optimizer.state_dict
only give us ID keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it is easier to proceed with keys.
Summary: This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict. ghstack-source-id: 467a730fbe6d461f8228e00615a856a96d581acc Pull Request resolved: #125339
@@ -653,7 +659,11 @@ def _load_optim_state_dict( | |||
return | |||
|
|||
for optim in optimizers: | |||
optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) | |||
_init_optim_state(optim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_init_optim_state
seems to update model.parameters()
for Adam
even though we set grad=0
?
repro: pytest test_distributed.py
P1233005758
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed with #125708
Summary: This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict. ghstack-source-id: 0192056670c5c23de26dd0f0eb09da21cb736f73 Pull Request resolved: #125339
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: periodic / win-vs2019-cuda11.8-py3 / test (default, 4, 4, windows.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "The failing tests are not related." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC