-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[FSDP2] Required mesh_dim_names
for HSDP
#137436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137436
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 299973c with merge base d1b87e2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
should add an HSDP test for all-gather extensions |
cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
return mesh | ||
elif mesh.ndim == 2: | ||
assert mesh.mesh_dim_names is not None | ||
return mesh[mesh.mesh_dim_names[-1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this mean 3D parallelsim won't be supported like it is for FSDP1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is specifically the submesh used for FSDP. I think it is always 1D (FSDP) or 2D (HSDP), regardless of if there are other parallelisms involved.
@pytorchbot revert -m "Looks like it broke distributed testing, see https://github.com/pytorch/pytorch/actions/runs/11239761070/job/31249854217" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit 5fb30df. Reverted #137436 on behalf of https://github.com/malfet due to Looks like it broke distributed testing, see https://github.com/pytorch/pytorch/actions/runs/11239761070/job/31249854217 ([comment](#137436 (comment)))
@awgu your PR has been successfully reverted. |
I think I just missed one
let me fix this, run CI, and try to reland tomorrow |
Two changes: 1. Require `mesh_dim_names` if using HSDP 2. Pass only the shard mesh to `fsdp_pre_all_gather` Change 1 is technically BC breaking, but it should not be hard to fix on the user side. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
alright let us try this again |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
mesh_dim_names
for HSDP #137436Two changes:
mesh_dim_names
if using HSDPfsdp_pre_all_gather
Change 1 is technically BC breaking, but it should not be hard to fix on the user side.
cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o