-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[dtensor] implement dim-0 (row) embedding sharding with MaskPartial #118080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118080
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3748220 with merge base d59c2d6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…skPartial" This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…skPartial" This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed [ghstack-poisoned]
…skPartial" This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…skPartial" This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed ghstack-source-id: df0a074 Pull Request resolved: #118080
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This is an elegant way of implementing row-wise embedding in DTensor. The creative use of a buffer variable in _MaskPartial
slightly violates the designing principle of making Placement
subclasses frozen (e.g. for caching). Nevertheless, this should be justified by the benefits it brings.
if self.mask_buffer.data is not None or other.mask_buffer.data is not None: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some remarks for what we discussed offline:
- For
_MaskPartial
produced in the output of sharding propagation (eitheroutput_spec
orschema_suggestions
inOutputSharding
), the cachedself.mask_buffer.data
could be filled and not released (by reductions), but still be returned with cache hit. An extremal example is that two parallel row-wise embeddings are applied on the same (replicated) input. Such cases should be rare, and if happenMaskBuffer.materialize_mask
would just throw exceptions, which is OK. self.mask_buffer.data
is almost alwaysnot None
as input to sharding propagation (because otherwise the_MaskPartial
probably has been reduced toReplicate
orShard
), so this is effectively forbidding cache hit when_MaskPartial
is input, as is noted in a followup PR [dtensor] add comment to clarify MaskPartial cache hit #118330.
As titled, this PR enables the rowwise embedding sharding in the RowwiseParallel style, and add tests to ensure it's working as expected Pull Request resolved: #118242 Approved by: https://github.com/tianyu-l ghstack dependencies: #118079, #118080
…artial (#118080)" This reverts commit 8cc02b4. Reverted #118080 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](#118079 (comment)))
@wanchaol your PR has been successfully reverted. |
As titled, this PR enables the rowwise embedding sharding in the RowwiseParallel style, and add tests to ensure it's working as expected Pull Request resolved: #118242 Approved by: https://github.com/tianyu-l ghstack dependencies: #118079, #118080
…ytorch#118080) This PR add support for rowwise sharded embedding by adding a MaskPartial placement that inherits from the default partial placement, and override the Partial constracts to construct the mask and release the mask after the reduction The MaskPartial placement have the potential to support other ops sharding computation that requires a mask for semantic correctness. currently make it live in the embedding ops but we can move it to a common place if needed Pull Request resolved: pytorch#118080 Approved by: https://github.com/tianyu-l ghstack dependencies: pytorch#118079
…8242) As titled, this PR enables the rowwise embedding sharding in the RowwiseParallel style, and add tests to ensure it's working as expected Pull Request resolved: pytorch#118242 Approved by: https://github.com/tianyu-l ghstack dependencies: pytorch#118079, pytorch#118080
Stack from ghstack (oldest at bottom):
This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction
The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225