Skip to content

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Mar 29, 2024

Adding wildcard support for TP's parallelize_module API.

Example patterns:
layers.*.linear: any characters
layers.?.linear: single character
layers.[1-2]: digit range, matches layers.1 and layers.2

Example use case:
A model have multiple layers, and we want to parallelize the linear module lin inside each layer.

model_tp = parallelize_module(
    model,
    device_mesh,
    {
        "layers.*.lin": ColwiseParallel(),
    },
)

Stack from ghstack (oldest at bottom):

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Mar 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122968

Note: Links to docs will display an error until the docs builds have been completed.

❌ 40 New Failures, 3 Unrelated Failures

As of commit 40258ef with merge base 4dc09d6 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 29, 2024
kwen2501 added a commit that referenced this pull request Mar 29, 2024
ghstack-source-id: a8fac4d
Pull Request resolved: #122968
@kwen2501 kwen2501 requested review from wanchaol and wz337 March 29, 2024 18:54
@kurman
Copy link
Contributor

kurman commented Apr 1, 2024

I wonder if jq or xpath style spec would improve UX given the hierarchical nature? It comes with a higher complexity though.

@XilunWu
Copy link
Contributor

XilunWu commented Apr 1, 2024

Users need be aware that the patterns in the dict param must be mutually exclusive otherwise the repetitive parallelize_module may cause issue. cc @wanchaol @fduwjj

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@wz337 wz337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Could we add some tests to demonstrate the usage as well?

@kwen2501
Copy link
Contributor Author

kwen2501 commented Apr 2, 2024

@wz337 Thanks for the review! The tests are in a stacked PR: #123101

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fold the tests PR to this PR? I think every new feature PR should come with tests in the PR itself, not a separate PR

Adding wildcard support for TP's `parallelize_module` API.

Example patterns:
`layers.*.linear`: any characters
`layers.?.linear`: single character
`layers.[1-2]`: digit range, matches `layers.1` and `layers.2`

Example use case:
A model have multiple layers, and we want to parallelize the linear module `lin` inside each layer.
```
model_tp = parallelize_module(
    model,
    device_mesh,
    {
        "layers.*.lin": ColwiseParallel(),
    },
)
```




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Apr 2, 2024
[TP] Add tests for wildcard support

ghstack-source-id: abb8e51
Pull Request resolved: #122968
@kwen2501
Copy link
Contributor Author

kwen2501 commented Apr 2, 2024

@wanchaol Done!

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, thanks for addressing comments!

@@ -78,6 +78,17 @@ def reset_parameters(self):
self.net2.reset_parameters()


class MLPStacked(nn.Module):
def __init__(self, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it would be nice to have a num_layers arg (can have a default) to control how many MLP layers this stacked MLP construct

model_tp,
device_mesh,
{
"layers.*.net?": ColwiseParallel(output_layouts=Replicate()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering can we do a e2e col + row test here?

{
    "layers.*.net[1]": ColwiseParallel(),
    "layers.*.net[2]": RowwiseParallel()
}

@kwen2501
Copy link
Contributor Author

kwen2501 commented Apr 2, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 2, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Contributor Author

kwen2501 commented Apr 2, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kwen2501
Copy link
Contributor Author

kwen2501 commented Apr 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x c3682288af1f7c66a3f685fcccba192299c38c3f returned non-zero exit code 1

The previous cherry-pick is now empty, possibly due to conflict resolution.
If you wish to commit it anyway, use:

    git commit --allow-empty

Otherwise, please use 'git cherry-pick --skip'
On branch main
Your branch is up to date with 'origin/main'.

You are currently cherry-picking commit c3682288af1.
  (all conflicts fixed: run "git cherry-pick --continue")
  (use "git cherry-pick --skip" to skip this patch)
  (use "git cherry-pick --abort" to cancel the cherry-pick operation)

nothing to commit, working tree clean
Details for Dev Infra team Raised by workflow job

pytorchmergebot pushed a commit that referenced this pull request Apr 3, 2024
Improve tests per @wanchaol 's suggestions in #122968

Pull Request resolved: #123199
Approved by: https://github.com/wanchaol
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
Adding wildcard support for TP's `parallelize_module` API.

Example patterns:
`layers.*.linear`: any characters
`layers.?.linear`: single character
`layers.[1-2]`: digit range, matches `layers.1` and `layers.2`

Example use case:
A model have multiple layers, and we want to parallelize the linear module `lin` inside each layer.
```
model_tp = parallelize_module(
    model,
    device_mesh,
    {
        "layers.*.lin": ColwiseParallel(),
    },
)
```

Pull Request resolved: pytorch#122968
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/wanchaol
ghstack dependencies: pytorch#122919
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
Improve tests per @wanchaol 's suggestions in pytorch#122968

Pull Request resolved: pytorch#123199
Approved by: https://github.com/wanchaol
@atalman atalman closed this Apr 22, 2024
@atalman atalman reopened this Apr 22, 2024
@atalman
Copy link
Contributor

atalman commented Apr 22, 2024

this is already merged: 5027ef7
hence closing the pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants