Skip to content

Conversation

wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Sep 14, 2023

Stack from ghstack (oldest at bottom):

resolves #109101

The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:

  • only hash arguments that are tensor or have static_argnum, this is to
    enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
    as we currently hashing all args which exclude those cases
  • with the correct cache behavior, optimizers will hit the cache again
    and resolve the high cpu overhead issue.

simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
Screenshot 2023-09-14 at 11 06 07 AM

Adam optimizer shows aten.div hit sharding cache again
Screenshot 2023-09-14 at 11 02 10 AM

resolves #109101

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109306

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b20036a with merge base dbddf18 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wanchaol added a commit that referenced this pull request Sep 14, 2023
resolves #109101

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

ghstack-source-id: 7c6a9a0
Pull Request resolved: #109306
resolves #109101

The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">


Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">



[ghstack-poisoned]
@wanchaol wanchaol added release notes: distributed (dtensor) release notes category ciflow/trunk Trigger trunk jobs on your pull request labels Sep 15, 2023
resolves #109101

The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">


Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">



[ghstack-poisoned]
@fduwjj fduwjj added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Sep 15, 2023
Comment on lines +203 to +207
elem_type = op_arg_type.getElementType()
return isinstance(elem_type, torch.TensorType) or (
isinstance(elem_type, torch.OptionalType)
and isinstance(elem_type.getElementType(), torch.TensorType)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my learning purpose, what cases will trigger this logic?

when isinstance(elem_type, torch.TensorType)

Copy link
Collaborator Author

@wanchaol wanchaol Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for argument types like Tensor[] or Tensor?[], this is basically to test these routes and make sure we don't skip hashing those arguments

# This static_argnum records static arg starting index for ops that have non-tensor
# args/kwargs which would affect sharding propagation results.
# only a few ops need this information, e.g. view, transpose, var.dim, etc.
static_argnum: int = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this we made the assumption that the first static_argnum args will affect sharding props, and args more than that is not? For a fix, this is fine as long as all DTensor unit test passes. But I am wondering if this might be a hacky assumption here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually no, the comment says static_argnum is the starting index that would affect sharding prop, so the arguments on or after static_argnum would influence whether to run the sharding propagation again. I think this is the right way to handle caching uniformly and there's no violation so far in all the ops I skimmed

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the xla.py expected? Overall, the idea looks make sense to me and accept to unblock for now. We might need to think more to make it less hacky down the road?

@wanchaol
Copy link
Collaborator Author

Is the xla.py expected? Overall, the idea looks make sense to me and accept to unblock for now. We might need to think more to make it less hacky down the road?

oops good catch, the xla.py is not expected.. I accidentally checked that in, will remove. I think this should be the right way to handle sharding cache in a uniform manner, the previous hashing method we had does not work with many corner cases like aten.div.Scalar and a bit bug prone. How custom op registration work with this is not flushed out and I agree we probably need some better way to handle custom ops.

resolves #109101

The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">


Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">



[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Sep 15, 2023
resolves #109101

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

ghstack-source-id: 5b4486f
Pull Request resolved: #109306
@wanchaol
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

wanchaol added a commit that referenced this pull request Sep 16, 2023
Looks like the op argument schema type check is not reliable.. for
things like aten.div.Tensor(Tensor, Tensor), the second argument can be
a float/scalar for some reason, switch to check with the instance type
directly

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Sep 16, 2023
Looks like the op argument schema type check is not reliable.. for
things like aten.div.Tensor(Tensor, Tensor), the second argument can still be
a float/scalar for some reason, switch to check with the instance type
directly

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Sep 16, 2023
Looks like the op argument schema type check is not reliable.. for
things like aten.div.Tensor(Tensor, Tensor), the second argument can be
a float/scalar for some reason, switch to check with the instance type
directly

ghstack-source-id: 720f52a
Pull Request resolved: #109428
pytorchmergebot pushed a commit that referenced this pull request Sep 16, 2023
Looks like the op argument schema type check is not reliable.. for
things like aten.div.Tensor(Tensor, Tensor), the second argument can still be
a float/scalar for some reason, switch to check with the instance type
directly
Pull Request resolved: #109428
Approved by: https://github.com/awgu, https://github.com/fegin
@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/357/head branch September 18, 2023 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants