Skip to content

Conversation

wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jan 23, 2024

Stack from ghstack (oldest at bottom):

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jan 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118080

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3748220 with merge base d59c2d6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Jan 23, 2024
…skPartial"

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
…skPartial"

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

[ghstack-poisoned]
@wanchaol wanchaol added ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (dtensor) release notes category labels Jan 23, 2024
…skPartial"

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
…skPartial"

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Jan 24, 2024
This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

ghstack-source-id: df0a074
Pull Request resolved: #118080
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
This is an elegant way of implementing row-wise embedding in DTensor. The creative use of a buffer variable in _MaskPartial slightly violates the designing principle of making Placement subclasses frozen (e.g. for caching). Nevertheless, this should be justified by the benefits it brings.

Comment on lines +133 to +134
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some remarks for what we discussed offline:

  1. For _MaskPartial produced in the output of sharding propagation (either output_spec or schema_suggestions in OutputSharding), the cached self.mask_buffer.data could be filled and not released (by reductions), but still be returned with cache hit. An extremal example is that two parallel row-wise embeddings are applied on the same (replicated) input. Such cases should be rare, and if happen MaskBuffer.materialize_mask would just throw exceptions, which is OK.
  2. self.mask_buffer.data is almost always not None as input to sharding propagation (because otherwise the _MaskPartial probably has been reduced to Replicate or Shard), so this is effectively forbidding cache hit when _MaskPartial is input, as is noted in a followup PR [dtensor] add comment to clarify MaskPartial cache hit #118330.

pytorchmergebot pushed a commit that referenced this pull request Jan 26, 2024
As titled, this PR enables the rowwise embedding sharding in the
RowwiseParallel style, and add tests to ensure it's working as expected

Pull Request resolved: #118242
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079, #118080
pytorchmergebot added a commit that referenced this pull request Jan 26, 2024
…artial (#118080)"

This reverts commit 8cc02b4.

Reverted #118080 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](#118079 (comment)))
@pytorchmergebot
Copy link
Collaborator

@wanchaol your PR has been successfully reverted.

pytorchmergebot pushed a commit that referenced this pull request Jan 26, 2024
As titled, this PR enables the rowwise embedding sharding in the
RowwiseParallel style, and add tests to ensure it's working as expected

Pull Request resolved: #118242
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079, #118080
@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/430/head branch January 30, 2024 15:23
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Feb 8, 2024
…ytorch#118080)

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

Pull Request resolved: pytorch#118080
Approved by: https://github.com/tianyu-l
ghstack dependencies: pytorch#118079
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Feb 8, 2024
…8242)

As titled, this PR enables the rowwise embedding sharding in the
RowwiseParallel style, and add tests to ensure it's working as expected

Pull Request resolved: pytorch#118242
Approved by: https://github.com/tianyu-l
ghstack dependencies: pytorch#118079, pytorch#118080
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants