-
Notifications
You must be signed in to change notification settings - Fork 22k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] Showed 2D MLP with colwise + colwise sharding #126073
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126073
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 62ef78a with merge base afda668 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 7ee04d09f5b4c299ea295680922813e57458b332 Pull Request resolved: #126073
model = parallelize_module( | ||
model, | ||
tp_mesh, | ||
{"w1.weight": ColwiseParallel(), "w2.weight": ColwiseParallel()}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ColwiseParallel, iirc the actual tensor sharding dim would be 0 (this is because there was a transpose before, colwise parallel suppose to shard on 1, but because of the tranpose the weight sharding is on 0).
Also the desired input layouts of ColwiseParallel would be replicated, and rowwise parallel desired input layout would be Shard(-1), so there might be some runtime difference though.
If we want to achieve something that you want: TP shard on dim1, while FSDP2 shard on dim0, I think we can start directly from DTensor APIs (i.e. we copy over Colwise/RowwiseParallel and make adjustments), if this works out we can think of how to expose this option to the default Colwise/RowwiseParallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meaning that one matrix dim is sharded both by TP and FSDP while the other matrix dim is not sharded.
Would it make sense to just support this? Curious if this feels natural to folks (for me it kinda is). If so, how difficult would this be (if possible)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yifuwang That is the current way 2D works for FSDP2 (so it is supported, just missing the strided sharding placement that is being worked on, which only affects resharding).
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
MLPStack
to de-dup code #126070CommDebugMode
in grad acc test #126067What is the point of this?
cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang