New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Meta register all foreach ops #112281
Meta register all foreach ops #112281
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112281
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 7d4f4aa with merge base d64bc8f (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving as it's a strict improvement, but see the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you missed a couple ops here and there tho
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
torch/_meta_registrations.py
Outdated
scalar_op = getattr(aten, scalar_op_name).default | ||
scalar_fn = decomposition_table.get(scalar_op, None) | ||
if not scalar_fn: | ||
scalar_fn = meta_table[scalar_op] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it enough to use the scalar_op
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, fixed now.
torch/_meta_registrations.py
Outdated
) | ||
|
||
|
||
@register_meta_foreach( | ||
[ | ||
aten._foreach_abs_.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we can merge this list and the out-of-place one, right?
Looks good overall, just two comments |
[ghstack-poisoned]
@lezcano, I changed this PR to a generic implementation for all foreach operations. Let me know what you think. |
[ghstack-poisoned]
torch/_meta_registrations.py
Outdated
def meta__foreach_binop__list(self, other, alpha=1): | ||
_check_foreach_binop_tensor_lists(self, other) | ||
|
||
def _check_meta_foreach(*args, _nlists=1, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwarg is not used. Also, why using a _
in a parameter?
torch/_meta_registrations.py
Outdated
if i == 0: | ||
nelem = len(arg) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check + the nelem > 0
should be outside the loop. ATM if _nlists == 1
you are not executing the nelem > 0
check.
Also, try to match the errors from the eager API whenever possible.
torch/_meta_registrations.py
Outdated
def register_meta_foreach(ops): | ||
for op_name, nlists in ops: | ||
for overload in ["default", "Scalar", "Tensor", "List", "ScalarList"]: | ||
op = getattr(getattr(aten, op_name), overload, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpOverloadPacket
already has a method that returns the list of overloads if I'm not mistaken.
torch/_meta_registrations.py
Outdated
("_foreach_addcdiv", 2), | ||
("_foreach_addcmul", 2), | ||
("_foreach_lerp", 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not possible to see that these should have a 2 rather than a 1 from the overloads directly?
[ghstack-poisoned]
This is ready for another review. |
torch/_meta_registrations.py
Outdated
len(self) == len(tensor1) and len(self) == len(tensor2), | ||
lambda: "All input tensor lists must have the same length", | ||
) | ||
nlists = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we must have nlists > 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And if we have a few of them, we may need to perform some sort of "all lists have the same length" check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need nlists > 0
. Since the first argument is always a list and the dispatcher will dispatch only if the first argument is a list, I didn't add a check.
I added a check to see if all lists have the same length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an assert nlists > 0
to be on the safe side.
Also, have a look at the checks that foreach functions execute in eager mode and replicate them here please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an assert nlists > 0 to be on the safe side.
Sure.
Also, have a look at the checks that foreach functions execute in eager mode and replicate them here please.
Sorry, I couldn't find one. Do you have an example check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
find the definition of the _foreach_*
functions within ATen and see which checks they implement
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this LGTM! Thank you for the big push in this non-trivial issue :)
One of the errors is real though
|
[ghstack-poisoned]
torch/_meta_registrations.py
Outdated
aten._foreach_add_.List, | ||
aten._foreach_sub_.List, | ||
aten._foreach_mul_.List, | ||
aten._foreach_div_.List, | ||
aten._foreach_maximum_.List, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the issue with this one, and the others that cannot be registered in a generic way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
foreach_maximum
: There is noaten.maximum_
method or anytorch.maximum_
method.foreach_pow
: Onlyforeach_pow
has a ScalarAndTensor and needed special handling.foreach_addc{div, mul}
: Theforeach_addcmul
andaddcmul
have different signatures. The alpha parameter is a keyword argument that is a list inforeach_addcmul
and a scalar inaddcmul
. There are also different names there.scalars
inforeach_addcmul
andvalue
inaddcmul
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, mind leaving comments?
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
This is ready now. |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot label "module: tests" "module: meta tensors" |
@pytorchbot label "release notes: Meta API" |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @mruberry @ZainRizvi @ezyang @eellison @bdhirsh