-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[dtensor] Fix and improve the sharding cache behavior #109306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
resolves #109101 This PR improves the sharding cache behavior by introducing a RuntimeSchemaInfo, used to record some runtime necessary hashing information during op registration time. This enable us to: * only hash arguments that are tensor or have static_argnum, this is to enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache. as we currently hashing all args which exclude those cases * with the correct cache behavior, optimizers will hit the cache again and resolve the high cpu overhead issue. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109306
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b20036a with merge base dbddf18 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
resolves #109101 This PR improves the sharding cache behavior by introducing a RuntimeSchemaInfo, used to record some runtime necessary hashing information during op registration time. This enable us to: * only hash arguments that are tensor or have static_argnum, this is to enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache. as we currently hashing all args which exclude those cases * with the correct cache behavior, optimizers will hit the cache again and resolve the high cpu overhead issue. ghstack-source-id: 7c6a9a0 Pull Request resolved: #109306
resolves #109101 The problem is essentially because we were hashing all the arguments, including the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might change everytime we call the op, thus cache miss everytime we call the op This PR improves the sharding cache behavior by introducing a RuntimeSchemaInfo, used to record some runtime necessary hashing information during op registration time. This enable us to: * only hash arguments that are tensor or have static_argnum, this is to enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache. as we currently hashing all args which exclude those cases * with the correct cache behavior, optimizers will hit the cache again and resolve the high cpu overhead issue. simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements: <img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3"> Adam optimizer shows aten.div hit sharding cache again <img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9"> [ghstack-poisoned]
resolves #109101 The problem is essentially because we were hashing all the arguments, including the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might change everytime we call the op, thus cache miss everytime we call the op This PR improves the sharding cache behavior by introducing a RuntimeSchemaInfo, used to record some runtime necessary hashing information during op registration time. This enable us to: * only hash arguments that are tensor or have static_argnum, this is to enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache. as we currently hashing all args which exclude those cases * with the correct cache behavior, optimizers will hit the cache again and resolve the high cpu overhead issue. simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements: <img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3"> Adam optimizer shows aten.div hit sharding cache again <img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9"> [ghstack-poisoned]
elem_type = op_arg_type.getElementType() | ||
return isinstance(elem_type, torch.TensorType) or ( | ||
isinstance(elem_type, torch.OptionalType) | ||
and isinstance(elem_type.getElementType(), torch.TensorType) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my learning purpose, what cases will trigger this logic?
when isinstance(elem_type, torch.TensorType)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for argument types like Tensor[]
or Tensor?[]
, this is basically to test these routes and make sure we don't skip hashing those arguments
# This static_argnum records static arg starting index for ops that have non-tensor | ||
# args/kwargs which would affect sharding propagation results. | ||
# only a few ops need this information, e.g. view, transpose, var.dim, etc. | ||
static_argnum: int = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this we made the assumption that the first static_argnum args will affect sharding props, and args more than that is not? For a fix, this is fine as long as all DTensor unit test passes. But I am wondering if this might be a hacky assumption here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually no, the comment says static_argnum
is the starting index that would affect sharding prop, so the arguments on or after static_argnum
would influence whether to run the sharding propagation again. I think this is the right way to handle caching uniformly and there's no violation so far in all the ops I skimmed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the xla.py
expected? Overall, the idea looks make sense to me and accept to unblock for now. We might need to think more to make it less hacky down the road?
oops good catch, the |
resolves #109101 The problem is essentially because we were hashing all the arguments, including the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might change everytime we call the op, thus cache miss everytime we call the op This PR improves the sharding cache behavior by introducing a RuntimeSchemaInfo, used to record some runtime necessary hashing information during op registration time. This enable us to: * only hash arguments that are tensor or have static_argnum, this is to enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache. as we currently hashing all args which exclude those cases * with the correct cache behavior, optimizers will hit the cache again and resolve the high cpu overhead issue. simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements: <img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3"> Adam optimizer shows aten.div hit sharding cache again <img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9"> [ghstack-poisoned]
resolves #109101 This PR improves the sharding cache behavior by introducing a RuntimeSchemaInfo, used to record some runtime necessary hashing information during op registration time. This enable us to: * only hash arguments that are tensor or have static_argnum, this is to enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache. as we currently hashing all args which exclude those cases * with the correct cache behavior, optimizers will hit the cache again and resolve the high cpu overhead issue. ghstack-source-id: 5b4486f Pull Request resolved: #109306
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Looks like the op argument schema type check is not reliable.. for things like aten.div.Tensor(Tensor, Tensor), the second argument can be a float/scalar for some reason, switch to check with the instance type directly [ghstack-poisoned]
Looks like the op argument schema type check is not reliable.. for things like aten.div.Tensor(Tensor, Tensor), the second argument can still be a float/scalar for some reason, switch to check with the instance type directly [ghstack-poisoned]
Looks like the op argument schema type check is not reliable.. for things like aten.div.Tensor(Tensor, Tensor), the second argument can still be a float/scalar for some reason, switch to check with the instance type directly Pull Request resolved: #109428 Approved by: https://github.com/awgu, https://github.com/fegin
Stack from ghstack (oldest at bottom):
resolves #109101
The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op
This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
and resolve the high cpu overhead issue.
simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:

Adam optimizer shows aten.div hit sharding cache again
