Skip to content

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Oct 10, 2023

Stack from ghstack (oldest at bottom):

This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling model.load_state_dict().

cc. @fegin

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110925

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 598cde8 with merge base ad24965 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wz337 added a commit that referenced this pull request Oct 10, 2023
ghstack-source-id: 6fa37cc
Pull Request resolved: #110925
@wz337 wz337 changed the title enable 2D FSDP+TP load_state_dict() [2D] Enable 2D FSDP+TP model.load_state_dict() Oct 10, 2023
@wz337 wz337 added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Oct 10, 2023
This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.

cc. fegin 


[ghstack-poisoned]
)


def _all_gather_dtensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this logic and the one in torch/distributed/tensor/parallel/fsdp.py into one common place?

Copy link
Contributor Author

@wz337 wz337 Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _all_gather_dtensor() is actually internal to FSDP. We are following the extension design here. See this function as an example: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py#L91

Essentially, there are two code paths:
FSDP only --> _all_gather_dtensor()
FSDP + TP -> _extensions.all_gather_dtensor()

This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.

cc. fegin 


[ghstack-poisoned]

def _all_gather_dtensor(
tensor: DTensor,
parent_mesh: DeviceMesh,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This typing should be optional.

This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.

cc. fegin 


[ghstack-poisoned]
@wz337 wz337 requested a review from fegin October 10, 2023 07:45
This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.

cc. fegin 


[ghstack-poisoned]
wz337 added a commit that referenced this pull request Oct 10, 2023
ghstack-source-id: da730a1
Pull Request resolved: #110925
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wz337
Copy link
Contributor Author

wz337 commented Oct 11, 2023

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/wz337/3/head branch October 15, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants