-
Notifications
You must be signed in to change notification settings - Fork 617
[RFC] Enable HSDP #518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Enable HSDP #518
Conversation
|
mark |
bdd076c to
d6d40f8
Compare
1d4a6a1 to
69c964b
Compare
| self.pp, | ||
| ) | ||
| assert ( | ||
| dp_replicate != -1 or dp_shard != -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we don't guard against -2?
| if (name == "dp_replicate" and self.dp_shard == 1) or ( | ||
| name == "dp_shard" and self.dp_replicate == 1 | ||
| ): | ||
| names.append("dp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe don't change, but it is not obvious if we need to add 'dp'. What is the downside of leaving original names? 'dp_replicate' is clearer than 'dp' if someone is looking at PG names and wondering what parallelism is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add dp because for loss computation and dataloader, dp is required, whether dp is dp_replicate + dp_shard (HSDP) or dp_shard (FSDP). These two components care only about dp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense that the device mesh will have the axis:
dpwhen DDP or FSDP is used;dp_shardanddp_replicateas well as their flatteneddpwhen HSDP is used.
One corner case is self.world_size == tp * pp where two dp will be added to names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I don't think so, if both dp_replicate and dp_shard are 1,line 59, if d > 1 won't be true. So we will never add dp mesh.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make the default value of data_parallel_replicate_degree to 1 rather than -1. This would simplify ParallelDims logic, toml files, test_runner.py. The reason behind is that FSDP is the default choice for most configs.
|
@tianyu-l I changed |
This PR enables HSDP. **Discussions** **1. How does trainer get DP mesh?** Right now, we flatten `["dp_replicate", "dp_shard"]` into a flattened dimension. Because DeviceMesh currently does not support slicing a flattened dimension, we need to either a) flatten again, or b) bookkeep the flattened result. Why do we initialize all the device mesh in the beginning? That's a good strategy to avoid possible deadlock and timeout and it is easier for users to debug -- initializing a brand-new device mesh requires all ranks to participate. What is the alternative solution? If DeviceMesh can support slicing a flattened dimension, then we can just slice out from the world mesh. Moreover, we may have to support slicing a flattened dimension. With HSDP + CP, we need to pass the following DeviceMesh to `fully_shard`, `["dp_replicate", "dp_shard_cp"]` where `dp_shard_cp` is a flattened dimension. However, if DeviceMesh supports slicing a flattened dimension, what will be the name? Currently, DeviceMesh implicitly concatenate the dimension names that form the flattened dimension. Is this too implicit? We also need to discuss this issue. **Conclusion: use named flatten + slicing** **2. How does TorchTitan expose HSDP to users?** Another UX issue is that how does TorchTitan expose HSDP? There are two ways, one is to expose `dp_shard` and `dp_replicate` to user. For DDP, `dp_shard==1 and dp_replicate>1`. For HSDP, `dp_shard>1 and dp_replicate>1`. For FSDP, `dp_shard>1 and dp_replicate==1`. An alternative, which this PR uses, is to expose `dp_type`. Users explicitly specify `FSDP`, `HSDP`, `DDP`. So we need another way to express the two degrees. This PR currently expose `dp_replicate` but another suggestion is to let `dp` accept both `int` and `Tuple[int, int]`. **Conclusion: data_parallel_replicate and data_parallel_shard** **3. Buffers synchronization** DTensor will implicitly synchronize the RNG status. However, there are buffers that are not DTensor. How do we ensure that these buffers are synchronized? This PR currently uses `_sync_module_states_with_mesh` to synchronize the module states including parameters and buffers. Another proposal is that users should set the random seed correctly and ensure the buffers are the same. **Conclusion: let users handle the RNG status.** FSDP with 8 GPUs, llama3 8B <img width="413" alt="Screenshot 2024-09-06 at 10 18 03 AM" src="https://github.com/user-attachments/assets/2d33252a-89d1-4c38-96e7-277804e3e8e0"> HSDP with 8 GPUs (replicate 4, shard2), llama3 8B <img width="411" alt="Screenshot 2024-09-06 at 10 13 43 AM" src="https://github.com/user-attachments/assets/2013ec6d-8b94-4937-978c-a422cb455ed2"> [ghstack-poisoned]
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Please address final inline comments and make sure CI passes.
| ) | ||
| for d in (dp_replicate, tp, pp): | ||
| assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" | ||
| assert dp_shard == -1 or dp_replicate >= 1, " dp_shard must -1 or >=1." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert dp_shard == -1 or dp_replicate >= 1, " dp_shard must -1 or >=1." | |
| assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
| if self.dp_replicate > 1 and self.dp_shard > 1: | ||
| mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask when and why do we need the flattened "dp" mesh? Is it just for HSDP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, DP is needed for dataloader and loss computation. It's easier for dataloader and loss computation to only know DP. So I ensure there always exist DP mesh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh makes sense!
XilunWu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some questions and potential nits
| if (name == "dp_replicate" and self.dp_shard == 1) or ( | ||
| name == "dp_shard" and self.dp_replicate == 1 | ||
| ): | ||
| names.append("dp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense that the device mesh will have the axis:
dpwhen DDP or FSDP is used;dp_shardanddp_replicateas well as their flatteneddpwhen HSDP is used.
One corner case is self.world_size == tp * pp where two dp will be added to names.
Stack from ghstack (oldest at bottom):
This PR enables HSDP.
Discussions
1. How does trainer get DP mesh?
Right now, we flatten
["dp_replicate", "dp_shard"]into a flattened dimension. Because DeviceMesh currently does not support slicing a flattened dimension, we need to either a) flatten again, or b) bookkeep the flattened result. Why do we initialize all the device mesh in the beginning? That's a good strategy to avoid possible deadlock and timeout and it is easier for users to debug -- initializing a brand-new device mesh requires all ranks to participate.What is the alternative solution? If DeviceMesh can support slicing a flattened dimension, then we can just slice out from the world mesh. Moreover, we may have to support slicing a flattened dimension. With HSDP + CP, we need to pass the following DeviceMesh to
fully_shard,["dp_replicate", "dp_shard_cp"]wheredp_shard_cpis a flattened dimension.However, if DeviceMesh supports slicing a flattened dimension, what will be the name? Currently, DeviceMesh implicitly concatenate the dimension names that form the flattened dimension. Is this too implicit? We also need to discuss this issue.
Conclusion: use named flatten + slicing
2. How does TorchTitan expose HSDP to users?
Another UX issue is that how does TorchTitan expose HSDP? There are two ways, one is to expose
dp_shardanddp_replicateto user. For DDP,dp_shard==1 and dp_replicate>1. For HSDP,dp_shard>1 and dp_replicate>1. For FSDP,dp_shard>1 and dp_replicate==1.An alternative, which this PR uses, is to expose
dp_type. Users explicitly specifyFSDP,HSDP,DDP. So we need another way to express the two degrees. This PR currently exposedp_replicatebut another suggestion is to letdpaccept bothintandTuple[int, int].Conclusion: data_parallel_replicate and data_parallel_shard
3. Buffers synchronization
DTensor will implicitly synchronize the RNG status. However, there are buffers that are not DTensor. How do we ensure that these buffers are synchronized? This PR currently uses
_sync_module_states_with_meshto synchronize the module states including parameters and buffers. Another proposal is that users should set the random seed correctly and ensure the buffers are the same.Conclusion: let users handle the RNG status.

FSDP with 8 GPUs, llama3 8B
HSDP with 8 GPUs (replicate 4, shard2), llama3 8B
