Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Aug 13, 2024

Stack from ghstack (oldest at bottom):

This PR enables HSDP.

Discussions
1. How does trainer get DP mesh?
Right now, we flatten ["dp_replicate", "dp_shard"] into a flattened dimension. Because DeviceMesh currently does not support slicing a flattened dimension, we need to either a) flatten again, or b) bookkeep the flattened result. Why do we initialize all the device mesh in the beginning? That's a good strategy to avoid possible deadlock and timeout and it is easier for users to debug -- initializing a brand-new device mesh requires all ranks to participate.

What is the alternative solution? If DeviceMesh can support slicing a flattened dimension, then we can just slice out from the world mesh. Moreover, we may have to support slicing a flattened dimension. With HSDP + CP, we need to pass the following DeviceMesh to fully_shard, ["dp_replicate", "dp_shard_cp"] where dp_shard_cp is a flattened dimension.

However, if DeviceMesh supports slicing a flattened dimension, what will be the name? Currently, DeviceMesh implicitly concatenate the dimension names that form the flattened dimension. Is this too implicit? We also need to discuss this issue.

Conclusion: use named flatten + slicing

2. How does TorchTitan expose HSDP to users?
Another UX issue is that how does TorchTitan expose HSDP? There are two ways, one is to expose dp_shard and dp_replicate to user. For DDP, dp_shard==1 and dp_replicate>1. For HSDP, dp_shard>1 and dp_replicate>1. For FSDP, dp_shard>1 and dp_replicate==1.

An alternative, which this PR uses, is to expose dp_type. Users explicitly specify FSDP, HSDP, DDP. So we need another way to express the two degrees. This PR currently expose dp_replicate but another suggestion is to let dp accept both int and Tuple[int, int].

Conclusion: data_parallel_replicate and data_parallel_shard

3. Buffers synchronization
DTensor will implicitly synchronize the RNG status. However, there are buffers that are not DTensor. How do we ensure that these buffers are synchronized? This PR currently uses _sync_module_states_with_mesh to synchronize the module states including parameters and buffers. Another proposal is that users should set the random seed correctly and ensure the buffers are the same.

Conclusion: let users handle the RNG status.
FSDP with 8 GPUs, llama3 8B
Screenshot 2024-09-06 at 10 18 03 AM

HSDP with 8 GPUs (replicate 4, shard2), llama3 8B
Screenshot 2024-09-06 at 10 13 43 AM

[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 13, 2024
This PR enables HSDP.

ghstack-source-id: d2c8521
Pull Request resolved: #518
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 13, 2024
@fegin fegin marked this pull request as draft August 13, 2024 07:14
@fegin fegin requested review from awgu, tianyu-l, wanchaol and wz337 August 13, 2024 16:57
@fegin fegin changed the title Enable HSDP [RFC] Enable HSDP Aug 13, 2024
@fegin fegin marked this pull request as ready for review August 14, 2024 17:10
[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 15, 2024
This PR enables HSDP.

ghstack-source-id: 40a3289
Pull Request resolved: #518
@weifengpy
Copy link
Contributor

mark

tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
This PR enables HSDP.

ghstack-source-id: 40a3289
Pull Request resolved: #518
[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 16, 2024
This PR enables HSDP.

ghstack-source-id: 410e235
Pull Request resolved: #518
[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 16, 2024
This PR enables HSDP.

ghstack-source-id: 1a203e7
Pull Request resolved: #518
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 3, 2024
This PR enables HSDP.

ghstack-source-id: df7167d
Pull Request resolved: #518
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 3, 2024
This PR enables HSDP.

ghstack-source-id: c07ef4b
Pull Request resolved: #518
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 3, 2024
This PR enables HSDP.

ghstack-source-id: d3887c9
Pull Request resolved: #518
self.pp,
)
assert (
dp_replicate != -1 or dp_shard != -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't guard against -2?

if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
Copy link
Contributor

@wconstab wconstab Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe don't change, but it is not obvious if we need to add 'dp'. What is the downside of leaving original names? 'dp_replicate' is clearer than 'dp' if someone is looking at PG names and wondering what parallelism is used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add dp because for loss computation and dataloader, dp is required, whether dp is dp_replicate + dp_shard (HSDP) or dp_shard (FSDP). These two components care only about dp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense that the device mesh will have the axis:

  1. dp when DDP or FSDP is used;
  2. dp_shard and dp_replicate as well as their flattened dp when HSDP is used.

One corner case is self.world_size == tp * pp where two dp will be added to names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I don't think so, if both dp_replicate and dp_shard are 1,line 59, if d > 1 won't be true. So we will never add dp mesh.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make the default value of data_parallel_replicate_degree to 1 rather than -1. This would simplify ParallelDims logic, toml files, test_runner.py. The reason behind is that FSDP is the default choice for most configs.

@fegin
Copy link
Contributor Author

fegin commented Sep 9, 2024

@tianyu-l I changed data_parallel_replicate_degree default to 1. But I don't see why ParallelDims logic would be simplified, are we not allowed data_parallel_replicate_degree to be -1? That will be a different story. It's true that toml files and test_runner.py can be simplified but I would prefer to explicitly for these files now that we have 2 data parallelism degrees. So I just leave them in the files.

This PR enables HSDP.

**Discussions**
**1. How does trainer get DP mesh?**
Right now, we flatten `["dp_replicate", "dp_shard"]` into a flattened dimension. Because DeviceMesh currently does not support slicing a flattened dimension, we need to either a) flatten again, or b) bookkeep the flattened result. Why do we initialize all the device mesh in the beginning? That's a good strategy to avoid possible deadlock and timeout and it is easier for users to debug -- initializing a brand-new device mesh requires all ranks to participate. 

What is the alternative solution? If DeviceMesh can support slicing a flattened dimension, then we can just slice out from the world mesh. Moreover, we may have to support slicing a flattened dimension. With HSDP + CP, we need to pass the following DeviceMesh to `fully_shard`, `["dp_replicate", "dp_shard_cp"]` where `dp_shard_cp` is a flattened dimension.

However, if DeviceMesh supports slicing a flattened dimension, what will be the name? Currently, DeviceMesh implicitly concatenate the dimension names that form the flattened dimension. Is this too implicit? We also need to discuss this issue.

**Conclusion: use named flatten + slicing**

**2. How does TorchTitan expose HSDP to users?**
Another UX issue is that how does TorchTitan expose HSDP? There are two ways, one is to expose `dp_shard` and `dp_replicate` to user. For DDP, `dp_shard==1 and dp_replicate>1`. For HSDP, `dp_shard>1 and dp_replicate>1`. For FSDP, `dp_shard>1 and dp_replicate==1`.

An alternative, which this PR uses, is to expose `dp_type`. Users explicitly specify `FSDP`, `HSDP`, `DDP`.  So we need another way to express the two degrees. This PR currently expose `dp_replicate` but another suggestion is to let `dp` accept both `int` and `Tuple[int, int]`.

**Conclusion: data_parallel_replicate and data_parallel_shard**

**3. Buffers synchronization**
DTensor will implicitly synchronize the RNG status. However, there are buffers that are not DTensor. How do we ensure that these buffers are synchronized? This PR currently uses `_sync_module_states_with_mesh` to synchronize the module states including parameters and buffers. Another proposal is that users should set the random seed correctly and ensure the buffers are the same. 

**Conclusion: let users handle the RNG status.**
FSDP with 8 GPUs, llama3 8B
<img width="413" alt="Screenshot 2024-09-06 at 10 18 03 AM" src="https://github.com/user-attachments/assets/2d33252a-89d1-4c38-96e7-277804e3e8e0">

HSDP with 8 GPUs (replicate 4, shard2), llama3 8B
<img width="411" alt="Screenshot 2024-09-06 at 10 13 43 AM" src="https://github.com/user-attachments/assets/2013ec6d-8b94-4937-978c-a422cb455ed2">


[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 9, 2024
This PR enables HSDP.

ghstack-source-id: 4b80a81
Pull Request resolved: #518
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 10, 2024
This PR enables HSDP.

ghstack-source-id: 6255e9c
Pull Request resolved: #518
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Please address final inline comments and make sure CI passes.

)
for d in (dp_replicate, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_replicate >= 1, " dp_shard must -1 or >=1."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert dp_shard == -1 or dp_replicate >= 1, " dp_shard must -1 or >=1."
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

Comment on lines +73 to +74
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask when and why do we need the flattened "dp" mesh? Is it just for HSDP?

Copy link
Contributor Author

@fegin fegin Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, DP is needed for dataloader and loss computation. It's easier for dataloader and loss computation to only know DP. So I ensure there always exist DP mesh.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh makes sense!

[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 10, 2024
This PR enables HSDP.

ghstack-source-id: c85046a
Pull Request resolved: #518
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some questions and potential nits

if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense that the device mesh will have the axis:

  1. dp when DDP or FSDP is used;
  2. dp_shard and dp_replicate as well as their flattened dp when HSDP is used.

One corner case is self.world_size == tp * pp where two dp will be added to names.

@fegin fegin merged commit c77b0c2 into gh/fegin/6/base Sep 10, 2024
fegin added a commit that referenced this pull request Sep 10, 2024
This PR enables HSDP.

ghstack-source-id: c85046a
Pull Request resolved: #518
@fegin fegin deleted the gh/fegin/6/head branch September 10, 2024 22:43
@awgu awgu mentioned this pull request Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants