-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[FSDP2] Enable HSDP + TP #133335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] Enable HSDP + TP #133335
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133335
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit d347f6b with merge base cc1cc71 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SPMD placements change looks good to me.
2 <= self._spmd_mesh.ndim <= 3 | ||
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." | ||
self._spmd_placements: Tuple[Placement, ...] | ||
if self._spmd_mesh.ndim == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe factor out to avoid some duplication?
fsdp_shard_placement = (
_StridedShard(0, split_factor=split_factor)
if split_factor > 1
else Shard(0)
)
if self._spmd_mesh.ndim == 2:
self._spmd_placements = (fsdp_shard_placement, self._tp_spec.placements[0])
else:
self._spmd_placements = (
Replicate(), fsdp_shard_placement, self._tp_spec.placements[0])
)
pp_size = 2 if self.world_size > 4 else 1 | ||
return init_device_mesh( | ||
"cuda", | ||
(2, 2, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we need to use the dp_size
and pp_size
above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to remove, but we need 8 GPUs, so it is always 2x2x2. I skip the test if the number of GPUs is less than 8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if you want to make more changes since PR is still draft
but stamp to unblock
Thanks @awgu for the review, I'm going to add one more test to test state_dict before landing. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR enables HSDP + TP Pull Request resolved: pytorch#133335 Approved by: https://github.com/awgu
@awgu thanks a lot for this PR.
this is the device mesh I pass to It takes some DP mesh and reshapes it to work with HSDP. I see following issue: [rank26]: Traceback (most recent call last):
[rank26]: File "<frozen runpy>", line 198, in _run_module_as_main
[rank26]: File "<frozen runpy>", line 88, in _run_code
[rank26]: File "/u/mayank98/scratch/tmp1/dolomite-engine/dolomite_engine/pretrain.py", line 381, in <module>
[rank26]: main()
[rank26]: File "/u/mayank98/scratch/tmp1/dolomite-engine/dolomite_engine/pretrain.py", line 312, in main
[rank26]: model = wrap_model_for_distributed_training(args, model)
[rank26]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank26]: File "/u/mayank98/scratch/tmp1/dolomite-engine/dolomite_engine/distributed/__init__.py", line 210, in wrap_model_for_distributed_training
[rank26]: fully_shard(
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank26]: updated = func(inp_module, *args, **kwargs)
[rank26]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 129, in fully_shard
[rank26]: state._fsdp_param_group = FSDPParamGroup(
[rank26]: ^^^^^^^^^^^^^^^
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 114, in __init__
[rank26]: self.fsdp_params = [
[rank26]: ^
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 115, in <listcomp>
[rank26]: FSDPParam(
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 235, in __init__
[rank26]: self._init_sharded_param(param, device)
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank26]: return func(*args, **kwargs)
[rank26]: ^^^^^^^^^^^^^^^^^^^^^
[rank26]: File "/u/mayank98/miniconda3/envs/ai/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 269, in _init_sharded_param
[rank26]: raise AssertionError(
[rank26]: AssertionError: FSDP requires the DP and TP mesh to have the same parent mesh but got:
[rank26]: DP's global mesh: DeviceMesh('cuda', [[0, 8, 16, 24], [2, 10, 18, 26], [4, 12, 20, 28], [6, 14, 22, 30]])
[rank26]: TP's global mesh: DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17], [18, 19], [20, 21], [22, 23], [24, 25], [26, 27], [28, 29], [30, 31]], mesh_dim_names=('dp', 'tp')) |
@mayank31398 How did you create the device mesh? Can you share the |
@fegin sorry, it was an error on my end. |
Stack from ghstack (oldest at bottom):
This PR enables HSDP + TP
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o