Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor] Turn on foreach implementation of optimizer for DTensor by default #123394

Closed
wants to merge 1 commit into from

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Apr 4, 2024

Append DTensor to the optimizer _foreach_supported_types and turn on foreach implementation of optimizer for DTensor if not specified by the users.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim @rohan-varma

Copy link

pytorch-bot bot commented Apr 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123394

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit f907706 with merge base 1a28f73 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 4, 2024
@wz337 wz337 added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Apr 4, 2024
@wz337 wz337 changed the title [DTensor] Turn on foreach implementation of optimizer for DTensor if not specified by users [DTensor] Turn on foreach implementation of optimizer for DTensor by default Apr 4, 2024
@wz337 wz337 requested a review from wanchaol April 4, 2024 22:50
@wz337 wz337 marked this pull request as ready for review April 4, 2024 22:50
@@ -23,6 +25,11 @@
]


# Append DTensor to the list of supported types for foreach implementation of optimizer
# so that we will try to use foreach over the for-loop implementation on CUDA.
_foreach_supported_types.append(DTensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python question: Are we guaranteed that the __init__.py code will only ever run once? Do we need to check like:

if DTensor not in _foreach_supported_types:
    _foreach_supported_types.append(DTensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I will just add it for a safety check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should only be imported once from python importing prospective, but a guard is safer yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or change _foreach_supported_types to a Set. :)
cc: @janeyx99

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to make it a set if it's easier haha

@wz337
Copy link
Contributor Author

wz337 commented Apr 4, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased turn_on_dtensor_foreach onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout turn_on_dtensor_foreach && git pull --rebase)

@with_comms
def test_adam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

# TODO: add fused_adam support
adam_configs = [
{"lr": 0.1},
{"lr": 0.1, "foreach": False},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why we turn foreach to False for all tests?
iiuc even if we put DTensor to _foreach_supported_types, if we pass foreach=False manually to optimizer, it would disable the foreach optimizer path too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just reverting the config based on whether we have "foreach": True originally.

If we have "foreach": True in the config, then we remove it, as it is turned on by default.
For the config that doesn't have "foreach": True, I am turning it to "foreach": False.
So we still have some tests for both foreach and the for-loop implementation if it makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh i see, make sense

@wz337 wz337 requested a review from wanchaol April 5, 2024 00:47
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, for grad norm clipping I guess it would happen in follow up PRs?

@with_comms
def test_adam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

# TODO: add fused_adam support
adam_configs = [
{"lr": 0.1},
{"lr": 0.1, "foreach": False},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh i see, make sense

@wz337
Copy link
Contributor Author

wz337 commented Apr 19, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased turn_on_dtensor_foreach onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout turn_on_dtensor_foreach && git pull --rebase)

@wz337
Copy link
Contributor Author

wz337 commented May 10, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased turn_on_dtensor_foreach onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout turn_on_dtensor_foreach && git pull --rebase)

@@ -302,6 +302,7 @@ def unwrap_to_op_info(
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
print(f"{mesh=}, {arg.device_mesh=}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried repro'ing the test_dtensor_compile failure locally. It's actually coming from here - the test passes when I remove the print statement.

I mostly could tell by looking at the stacktrace from the FakeTensor erroring, and seeing that it's coming from:

(1) this code is printing arg.device_mesh

(2) DeviceMesh is a tensor, and printing it calls tensor.toList()

(3) printing a tensor is not a very trace-friendly operation... which is why you get a kind-of-obscure error

@wz337 wz337 force-pushed the turn_on_dtensor_foreach branch 4 times, most recently from d8643a1 to a89d0f0 Compare May 15, 2024 00:37
@wz337
Copy link
Contributor Author

wz337 commented May 15, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@wz337
Copy link
Contributor Author

wz337 commented May 15, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 16, 2024
Fixes #121799

We fix DeviceMesh hash such that two mesh are considered equal if they have the same mesh and same parent_mesh.
Examples can be found here: #121799

Also need this to unblock #123394

Pull Request resolved: #123572
Approved by: https://github.com/xunnanxu, https://github.com/wanchaol, https://github.com/yoyoyocmu
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
…default (pytorch#123394)

Append DTensor to the optimizer `_foreach_supported_types` and turn on foreach implementation of optimizer for DTensor if not specified by the users.

Pull Request resolved: pytorch#123394
Approved by: https://github.com/wanchaol
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
Fixes pytorch#121799

We fix DeviceMesh hash such that two mesh are considered equal if they have the same mesh and same parent_mesh.
Examples can be found here: pytorch#121799

Also need this to unblock pytorch#123394

Pull Request resolved: pytorch#123572
Approved by: https://github.com/xunnanxu, https://github.com/wanchaol, https://github.com/yoyoyocmu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants