-
Notifications
You must be signed in to change notification settings - Fork 25.3k
[TP] Add wildcard support #122968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TP] Add wildcard support #122968
Conversation
[ghstack-poisoned]
I wonder if jq or xpath style spec would improve UX given the hierarchical nature? It comes with a higher complexity though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Could we add some tests to demonstrate the usage as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fold the tests PR to this PR? I think every new feature PR should come with tests in the PR itself, not a separate PR
Adding wildcard support for TP's `parallelize_module` API. Example patterns: `layers.*.linear`: any characters `layers.?.linear`: single character `layers.[1-2]`: digit range, matches `layers.1` and `layers.2` Example use case: A model have multiple layers, and we want to parallelize the linear module `lin` inside each layer. ``` model_tp = parallelize_module( model, device_mesh, { "layers.*.lin": ColwiseParallel(), }, ) ``` cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang [ghstack-poisoned]
@wanchaol Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, thanks for addressing comments!
@@ -78,6 +78,17 @@ def reset_parameters(self): | |||
self.net2.reset_parameters() | |||
|
|||
|
|||
class MLPStacked(nn.Module): | |||
def __init__(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think it would be nice to have a num_layers
arg (can have a default) to control how many MLP layers this stacked MLP construct
model_tp, | ||
device_mesh, | ||
{ | ||
"layers.*.net?": ColwiseParallel(output_layouts=Replicate()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering can we do a e2e col + row test here?
{
"layers.*.net[1]": ColwiseParallel(),
"layers.*.net[2]": RowwiseParallel()
}
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
Improve tests per @wanchaol 's suggestions in #122968 Pull Request resolved: #123199 Approved by: https://github.com/wanchaol
Adding wildcard support for TP's `parallelize_module` API. Example patterns: `layers.*.linear`: any characters `layers.?.linear`: single character `layers.[1-2]`: digit range, matches `layers.1` and `layers.2` Example use case: A model have multiple layers, and we want to parallelize the linear module `lin` inside each layer. ``` model_tp = parallelize_module( model, device_mesh, { "layers.*.lin": ColwiseParallel(), }, ) ``` Pull Request resolved: pytorch#122968 Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/wanchaol ghstack dependencies: pytorch#122919
Improve tests per @wanchaol 's suggestions in pytorch#122968 Pull Request resolved: pytorch#123199 Approved by: https://github.com/wanchaol
this is already merged: 5027ef7 |
Adding wildcard support for TP's
parallelize_module
API.Example patterns:
layers.*.linear
: any characterslayers.?.linear
: single characterlayers.[1-2]
: digit range, matcheslayers.1
andlayers.2
Example use case:
A model have multiple layers, and we want to parallelize the linear module
lin
inside each layer.Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang