Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] Showed 2D MLP with colwise + colwise sharding #126073

Closed
wants to merge 1 commit into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented May 13, 2024

Stack from ghstack (oldest at bottom):

What is the point of this?

  • FSDP2 currently only supports rowwise sharding (i.e. dim-0 sharding). This means that when combined with tensor parallelism, some linear weights are rowwise sharded twice, meaning that one matrix dim is sharded both by TP and FSDP while the other matrix dim is not sharded.
  • For possible performance reasons (e.g. quantization block sizes), we may prefer to change those linear weights to have FSDP instead do colwise sharding (i.e. dim-1 sharding). However, implementing that in FSDP2 introduces some complexity. It may be simpler to transpose the weight and change TP to shard colwise. (More verbosely, instead of using TP rowwise sharding and FSDP colwise sharding, we transpose the weight, use TP colwise sharding, and FSDP rowwise sharding.)

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang

Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126073

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 62ef78a with merge base afda668 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels May 13, 2024
awgu added a commit that referenced this pull request May 13, 2024
ghstack-source-id: 7ee04d09f5b4c299ea295680922813e57458b332
Pull Request resolved: #126073
@awgu
Copy link
Contributor Author

awgu commented May 13, 2024

cc: @wanchaol @yifuwang curious to get your thoughts here

In my naive understanding, there should not be any noticeable performance implication.

model = parallelize_module(
model,
tp_mesh,
{"w1.weight": ColwiseParallel(), "w2.weight": ColwiseParallel()},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ColwiseParallel, iirc the actual tensor sharding dim would be 0 (this is because there was a transpose before, colwise parallel suppose to shard on 1, but because of the tranpose the weight sharding is on 0).

Also the desired input layouts of ColwiseParallel would be replicated, and rowwise parallel desired input layout would be Shard(-1), so there might be some runtime difference though.

If we want to achieve something that you want: TP shard on dim1, while FSDP2 shard on dim0, I think we can start directly from DTensor APIs (i.e. we copy over Colwise/RowwiseParallel and make adjustments), if this works out we can think of how to expose this option to the default Colwise/RowwiseParallel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning that one matrix dim is sharded both by TP and FSDP while the other matrix dim is not sharded.

Would it make sense to just support this? Curious if this feels natural to folks (for me it kinda is). If so, how difficult would this be (if possible)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yifuwang That is the current way 2D works for FSDP2 (so it is supported, just missing the strided sharding placement that is being worked on, which only affects resharding).

Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 13, 2024
@awgu awgu closed this Jul 15, 2024
@github-actions github-actions bot deleted the gh/awgu/583/head branch August 15, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue Stale topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants