Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 7, 2024

Stack from ghstack (oldest at bottom):

Two changes:

  1. Require mesh_dim_names if using HSDP
  2. Pass only the shard mesh to fsdp_pre_all_gather

Change 1 is technically BC breaking, but it should not be hard to fix on the user side.

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Oct 7, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137436

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 299973c with merge base d1b87e2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu pushed a commit that referenced this pull request Oct 7, 2024
ghstack-source-id: 2a916dd
Pull Request resolved: #137436
@awgu
Copy link
Collaborator Author

awgu commented Oct 7, 2024

should add an HSDP test for all-gather extensions

@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Oct 7, 2024
cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Oct 7, 2024
ghstack-source-id: a615d6b
Pull Request resolved: #137436
@awgu awgu requested review from weifengpy, wz337 and y-sq October 7, 2024 22:14
@awgu awgu marked this pull request as ready for review October 7, 2024 22:14
@awgu
Copy link
Collaborator Author

awgu commented Oct 8, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 8, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

return mesh
elif mesh.ndim == 2:
assert mesh.mesh_dim_names is not None
return mesh[mesh.mesh_dim_names[-1]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mean 3D parallelsim won't be supported like it is for FSDP1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specifically the submesh used for FSDP. I think it is always 1D (FSDP) or 2D (HSDP), regardless of if there are other parallelisms involved.

@malfet
Copy link
Contributor

malfet commented Oct 8, 2024

@pytorchbot revert -m "Looks like it broke distributed testing, see https://github.com/pytorch/pytorch/actions/runs/11239761070/job/31249854217" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Oct 8, 2024
@pytorchmergebot
Copy link
Collaborator

@awgu your PR has been successfully reverted.

@awgu
Copy link
Collaborator Author

awgu commented Oct 9, 2024

I think I just missed one init_device_mesh callsite:

AssertionError: Please init the 2D mesh for HSDP with mesh_dim_names specified

To execute this test, run the following from the base repo dir:
    python test/distributed/_composable/fsdp/test_fully_shard_training.py TestFullyShardGradientAccumulation.test_gradient_accumulation

let me fix this, run CI, and try to reland tomorrow

Two changes:
1. Require `mesh_dim_names` if using HSDP
2. Pass only the shard mesh to `fsdp_pre_all_gather`

Change 1 is technically BC breaking, but it should not be hard to fix on the user side.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Oct 9, 2024
ghstack-source-id: a4b9ea6
Pull Request resolved: #137436
@awgu awgu added ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Oct 9, 2024
@awgu awgu requested review from weifengpy and wz337 October 9, 2024 16:46
@awgu
Copy link
Collaborator Author

awgu commented Oct 9, 2024

alright let us try this again
ran both periodic and inductor CI, so this should cover all distributed tests

@awgu
Copy link
Collaborator Author

awgu commented Oct 9, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

This PR (#137436) was merged in a93ea61 but it is still open, likely due to a Github bug, so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants