-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DTensor] Turn on foreach implementation of optimizer for DTensor by default #123394
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123394
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit f907706 with merge base 1a28f73 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -23,6 +25,11 @@ | |||
] | |||
|
|||
|
|||
# Append DTensor to the list of supported types for foreach implementation of optimizer | |||
# so that we will try to use foreach over the for-loop implementation on CUDA. | |||
_foreach_supported_types.append(DTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python question: Are we guaranteed that the __init__.py
code will only ever run once? Do we need to check like:
if DTensor not in _foreach_supported_types:
_foreach_supported_types.append(DTensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I will just add it for a safety check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should only be imported once from python importing prospective, but a guard is safer yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or change _foreach_supported_types
to a Set. :)
cc: @janeyx99
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to make it a set if it's easier haha
51774c9
to
9a107df
Compare
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
9a107df
to
76bbf9c
Compare
@with_comms | ||
def test_adam_1d_sharding(self): | ||
mesh = DeviceMesh(self.device_type, list(range(self.world_size))) | ||
|
||
# TODO: add fused_adam support | ||
adam_configs = [ | ||
{"lr": 0.1}, | ||
{"lr": 0.1, "foreach": False}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering why we turn foreach
to False for all tests?
iiuc even if we put DTensor to _foreach_supported_types
, if we pass foreach=False
manually to optimizer, it would disable the foreach optimizer path too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just reverting the config based on whether we have "foreach": True originally.
If we have "foreach": True
in the config, then we remove it, as it is turned on by default.
For the config that doesn't have "foreach": True
, I am turning it to "foreach": False
.
So we still have some tests for both foreach and the for-loop implementation if it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh i see, make sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, for grad norm clipping I guess it would happen in follow up PRs?
@with_comms | ||
def test_adam_1d_sharding(self): | ||
mesh = DeviceMesh(self.device_type, list(range(self.world_size))) | ||
|
||
# TODO: add fused_adam support | ||
adam_configs = [ | ||
{"lr": 0.1}, | ||
{"lr": 0.1, "foreach": False}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh i see, make sense
76bbf9c
to
0dfb5c0
Compare
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
0dfb5c0
to
199d77a
Compare
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
199d77a
to
240236d
Compare
@@ -302,6 +302,7 @@ def unwrap_to_op_info( | |||
args_schema.append(arg._spec) | |||
local_args.append(arg._local_tensor) | |||
if mesh is not None: | |||
print(f"{mesh=}, {arg.device_mesh=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried repro'ing the test_dtensor_compile failure locally. It's actually coming from here - the test passes when I remove the print statement.
I mostly could tell by looking at the stacktrace from the FakeTensor erroring, and seeing that it's coming from:
(1) this code is printing arg.device_mesh
(2) DeviceMesh
is a tensor, and printing it calls tensor.toList()
(3) printing a tensor is not a very trace-friendly operation... which is why you get a kind-of-obscure error
d8643a1
to
a89d0f0
Compare
a89d0f0
to
f907706
Compare
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #121799 We fix DeviceMesh hash such that two mesh are considered equal if they have the same mesh and same parent_mesh. Examples can be found here: #121799 Also need this to unblock #123394 Pull Request resolved: #123572 Approved by: https://github.com/xunnanxu, https://github.com/wanchaol, https://github.com/yoyoyocmu
…default (pytorch#123394) Append DTensor to the optimizer `_foreach_supported_types` and turn on foreach implementation of optimizer for DTensor if not specified by the users. Pull Request resolved: pytorch#123394 Approved by: https://github.com/wanchaol
Fixes pytorch#121799 We fix DeviceMesh hash such that two mesh are considered equal if they have the same mesh and same parent_mesh. Examples can be found here: pytorch#121799 Also need this to unblock pytorch#123394 Pull Request resolved: pytorch#123572 Approved by: https://github.com/xunnanxu, https://github.com/wanchaol, https://github.com/yoyoyocmu
Append DTensor to the optimizer
_foreach_supported_types
and turn on foreach implementation of optimizer for DTensor if not specified by the users.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim @rohan-varma