Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DSD] Implement broadcast_from_rank0 option for optim state_dict #125339

Closed
wants to merge 6 commits into from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented May 1, 2024

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125339

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit f3bb51d with merge base 196a0b1 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels May 1, 2024
@fegin fegin added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request labels May 1, 2024
@fegin fegin requested review from wz337 and LucasLLC May 1, 2024 21:31
[ghstack-poisoned]
)
if equal:
self.assertEqual(states, fsdp_states)
def check(equal):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check get_optimizer_state_dict as well? in torchtune, we call model and optimizer sd separately

Copy link
Contributor Author

@fegin fegin May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, somehow that was removed during rebasing. Sorry for the confusion, will add it back.

[ghstack-poisoned]
fegin added a commit that referenced this pull request May 2, 2024
Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.

ghstack-source-id: d6131de1ec30b64453b4b8179c48177f2b3b00d7
Pull Request resolved: #125339
@@ -683,11 +693,33 @@ def _load_optim_state_dict(
optim_state_dict = FSDP.optim_state_dict_to_load(
model, optim, optim_state_dict
)
elif info.broadcast_from_rank0:
info.full_state_dict = False
local_state_dict = _get_optim_state_dict(model, (optim,), info)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we using _get_optim_state_dict instead of torch.optim.Optimizer.state_dict since we want to align FQNs between local_state_dict and optim_state_dict ? torch.optim.Optimizer.state_dict only give us ID keys

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is easier to proceed with keys.

[ghstack-poisoned]
fegin added a commit that referenced this pull request May 3, 2024
Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.

ghstack-source-id: 467a730fbe6d461f8228e00615a856a96d581acc
Pull Request resolved: #125339
@@ -653,7 +659,11 @@ def _load_optim_state_dict(
return

for optim in optimizers:
optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info)
_init_optim_state(optim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_init_optim_state seems to update model.parameters() for Adam even though we set grad=0 ?

repro: pytest test_distributed.py P1233005758

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with #125708

[ghstack-poisoned]
[ghstack-poisoned]
fegin added a commit that referenced this pull request May 7, 2024
Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.

ghstack-source-id: 0192056670c5c23de26dd0f0eb09da21cb736f73
Pull Request resolved: #125339
@fegin
Copy link
Contributor Author

fegin commented May 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / win-vs2019-cuda11.8-py3 / test (default, 4, 4, windows.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@fegin
Copy link
Contributor Author

fegin commented May 8, 2024

@pytorchbot merge -f "The failing tests are not related."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/fegin/236/head branch June 8, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants