Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] allow meta tensors during loading state dict and cpu offloading #126267

Closed
wants to merge 2 commits into from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented May 15, 2024

Stack from ghstack (oldest at bottom):

unit test: pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py

with meta init and cpu offloading, we have meta tensors aftermodel.load_state_dict(assign=True, strict=False). This PR avoided calling .cpu on meta tensors otherwise it's a runtime error

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126267

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5eb2e7d with merge base c1dc8bb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels May 15, 2024
weifengpy added a commit that referenced this pull request May 15, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c077e9e6ae7ae23c855eb0997e5da18120649919
Pull Request resolved: #126267
@weifengpy weifengpy requested a review from awgu May 15, 2024 07:15
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left some nits on the unit test.

@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels May 15, 2024
…pu offloading"


unit test: ``pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py``

with meta init and cpu offloading, we have meta tensors after`model.load_state_dict(assign=True, strict=False)`. This PR avoided calling `.cpu` on meta tensors otherwise it's a runtime error

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request May 15, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 002a35b0abf528b8a7327ce7762416f0e42efa13
Pull Request resolved: #126267
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@@ -245,7 +245,7 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
self.padded_sharded_param_size = padded_sharded_param.size()
if sharded_param.numel() > 0:
padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
if self.offload_to_cpu:
if self.offload_to_cpu and not torch.empty(0).is_meta:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.offload_to_cpu and not torch.empty(0).is_meta:
if self.offload_to_cpu and not padded_sharded_param.is_meta:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolving in #126305

pytorchmergebot pushed a commit that referenced this pull request May 15, 2024
support fully_shard(model_on_meta, cpu_offload) when fully_shard is placed outside of `torch.device("meta")`

Pull Request resolved: #126305
Approved by: https://github.com/awgu
ghstack dependencies: #126267
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
…ng (pytorch#126267)

unit test: ``pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py``

with meta init and cpu offloading, we have meta tensors after`model.load_state_dict(assign=True, strict=False)`. This PR avoided calling `.cpu` on meta tensors otherwise it's a runtime error

Pull Request resolved: pytorch#126267
Approved by: https://github.com/awgu
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
support fully_shard(model_on_meta, cpu_offload) when fully_shard is placed outside of `torch.device("meta")`

Pull Request resolved: pytorch#126305
Approved by: https://github.com/awgu
ghstack dependencies: pytorch#126267
@github-actions github-actions bot deleted the gh/weifengpy/3/head branch June 15, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants