-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hacks to get PP+TP working #125250
Hacks to get PP+TP working #125250
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125250
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit b34edc8 with merge base 99059af (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: e401d5d85b1ce24a3ce56969c331c0e0a5047358 Pull Request resolved: #125250
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 3ee05685de14cb662fa534a88f1079bb7f62e2ed Pull Request resolved: #125250
continue | ||
submod_name, submod_type = list(meta['nn_module_stack'].values())[-1] | ||
module, classname = submod_type.rsplit('.', 1) | ||
submod_class = getattr(importlib.import_module(module), classname) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if for each node, we do the import, how much overhead would it add?
How about:
if _check_tp_module_type(module, nn.Linear) or \
_check_tp_module_classname(module, "torch.nn.Linear"):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh- the import was cleaned up already iirc, maybe forgot to push.
I think it's easier to have just one function call for now but inside it I already return isinstance() or newcodepath() so the time spent on new code path should be avoided in non PP cases
f"Found TP device_mesh has a parent mesh with dims {parent_mesh.ndim}", | ||
"Currently we only support 2D TP composition with DP.", | ||
) | ||
# if parent_mesh.ndim != 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to remove this constraint, this is mainly because we only have a 2-D checkpoint save/load enabled in dcp at that time, for 3D, we can test and verify the save/load works separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @wz337
f"Found TP device_mesh on the {tp_mesh_dim} dimension of its parent mesh.", | ||
"Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.", | ||
) | ||
# if tp_mesh_dim != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we preserve this to be sth like if tp_mesh_dim != parent_mesh.ndim -1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops almost missed this part
Currently a 3D mesh with a submesh sliced out for TP is going to fail this check. According to wanchaol in [this comment](#125250 (comment)) it should be OK to remove these checks. Though I would appreciate a more careful review here, since I'm not too sure if there are other edge cases where these checks are important. ghstack-source-id: fd6cfe98098d8c6e41110bf64d79c1f6c1894458 Pull Request resolved: #125763
Currently a 3D mesh with a submesh sliced out for TP is going to fail this check. According to wanchaol in [this comment](#125250 (comment)) it should be OK to remove these checks. Though I would appreciate a more careful review here, since I'm not too sure if there are other edge cases where these checks are important. ghstack-source-id: 6f7f3a18cd9e7ed70cc8405dc80a98e3359950d1 Pull Request resolved: #125763
ghstack-source-id: 0e9f93a75891e8c63c61ac5c4995430945e22b61 Pull Request resolved: #125250
ghstack-source-id: b3d8be5c649576f666874b1a378564550b6ef736 Pull Request resolved: #125250
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for solving the composability issue.
Currently a 3D mesh with a submesh sliced out for TP is going to fail this check. According to @wanchaol in [this comment](#125250 (comment)) it should be OK to remove these checks. Though I would appreciate a more careful review here, since I'm not too sure if there are other edge cases where these checks are important. Pull Request resolved: #125763 Approved by: https://github.com/wz337, https://github.com/wanchaol
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k