New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
switch dtensor and functional collective to use optree #110670
Conversation
optree recently landed and provide quite good perf, conditionally import new optree if optree is installed [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110670
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 73754cb with merge base 2aa3064 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
optree recently landed and provide quite good perf, conditionally import new optree if optree is installed ghstack-source-id: 28ac266eab5e4b34690bbf8aae0cfe7619451e12 Pull Request resolved: #110670
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow do we know how much perf we can get from switching to optree?
try: | ||
from torch.utils._cxx_pytree import tree_flatten, tree_unflatten | ||
except ImportError: | ||
tree_flatten = torch.utils._pytree.tree_flatten # type: ignore[assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, why the assignment version here instead of doing import like above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed a similar approach taken in #109684, probably another import also works (in both case we need to ignore type annotation, lmk your preference :)
try: | ||
from torch.utils._cxx_pytree import tree_map_only | ||
except ImportError: | ||
tree_map_only = torch.utils._pytree.tree_map_only # type: ignore[assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we know under what builds/setups this import won't exist? Should we issue a warning if we're in a known slow mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe some certain CI does not have access to pip
channel and it only have conda channel, where optree not yet available in conda. @XuehaiPan might have more details
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it and there's not too much perf gain for large scale 2d, reason probably because we already hide almost all cpu overhead. This could potentially benefits smaller models we can try See some updates on summary, locally testing shows around 10% CPU overhead reduction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
optree recently landed and provide quite good perf, conditionally import new optree if optree is installed Some numbers testing mlp layer with TP + func collective: before this PR: 10.390ms after this PR: 9.189ms so around e2e 10% CPU overhead reduction [ghstack-poisoned]
optree recently landed and provide quite good perf, conditionally import new optree if optree is installed ghstack-source-id: cdaf58eaacd93bf32f2edc2060d0b54b68470d0d Pull Request resolved: #110670
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
optree recently landed and provide quite good perf, conditionally import
new optree if optree is installed
Some numbers testing mlp layer with TP + func collective:
before this PR: 10.390ms
after this PR: 9.189ms
so around e2e 10% CPU overhead reduction