Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta register all foreach ops #112281

Closed
wants to merge 19 commits into from
Closed

Meta register all foreach ops #112281

wants to merge 19 commits into from

Conversation

isuruf
Copy link
Collaborator

@isuruf isuruf commented Oct 27, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112281

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 7d4f4aa with merge base d64bc8f (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as it's a strict improvement, but see the comment

torch/_meta_registrations.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you missed a couple ops here and there tho

Comment on lines 62 to 65
scalar_op = getattr(aten, scalar_op_name).default
scalar_fn = decomposition_table.get(scalar_op, None)
if not scalar_fn:
scalar_fn = meta_table[scalar_op]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it enough to use the scalar_op here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed now.

)


@register_meta_foreach(
[
aten._foreach_abs_.default,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can merge this list and the out-of-place one, right?

@lezcano
Copy link
Collaborator

lezcano commented Nov 3, 2023

Looks good overall, just two comments

@lezcano lezcano self-requested a review November 3, 2023 10:48
@isuruf
Copy link
Collaborator Author

isuruf commented Nov 3, 2023

@lezcano, I changed this PR to a generic implementation for all foreach operations. Let me know what you think.

@lezcano lezcano changed the title Meta register all unary foreach ops Meta register all foreach ops Nov 6, 2023
torch/_meta_registrations.py Outdated Show resolved Hide resolved
def meta__foreach_binop__list(self, other, alpha=1):
_check_foreach_binop_tensor_lists(self, other)

def _check_meta_foreach(*args, _nlists=1, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwarg is not used. Also, why using a _ in a parameter?

Comment on lines 2947 to 2949
if i == 0:
nelem = len(arg)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check + the nelem > 0 should be outside the loop. ATM if _nlists == 1 you are not executing the nelem > 0 check.
Also, try to match the errors from the eager API whenever possible.

def register_meta_foreach(ops):
for op_name, nlists in ops:
for overload in ["default", "Scalar", "Tensor", "List", "ScalarList"]:
op = getattr(getattr(aten, op_name), overload, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpOverloadPacket already has a method that returns the list of overloads if I'm not mistaken.

Comment on lines 3064 to 3066
("_foreach_addcdiv", 2),
("_foreach_addcmul", 2),
("_foreach_lerp", 2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to see that these should have a 2 rather than a 1 from the overloads directly?

@isuruf
Copy link
Collaborator Author

isuruf commented Nov 13, 2023

This is ready for another review.

len(self) == len(tensor1) and len(self) == len(tensor2),
lambda: "All input tensor lists must have the same length",
)
nlists = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we must have nlists > 0?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if we have a few of them, we may need to perform some sort of "all lists have the same length" check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need nlists > 0. Since the first argument is always a list and the dispatcher will dispatch only if the first argument is a list, I didn't add a check.

I added a check to see if all lists have the same length.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert nlists > 0 to be on the safe side.

Also, have a look at the checks that foreach functions execute in eager mode and replicate them here please.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert nlists > 0 to be on the safe side.

Sure.

Also, have a look at the checks that foreach functions execute in eager mode and replicate them here please.

Sorry, I couldn't find one. Do you have an example check?

Copy link
Collaborator

@lezcano lezcano Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find the definition of the _foreach_* functions within ATen and see which checks they implement

torch/_meta_registrations.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this LGTM! Thank you for the big push in this non-trivial issue :)

@lezcano
Copy link
Collaborator

lezcano commented Nov 15, 2023

One of the errors is real though

2023-11-15T04:28:44.5602104Z NotImplementedError: aten::foreach_maximum.List: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered.

aten._foreach_add_.List,
aten._foreach_sub_.List,
aten._foreach_mul_.List,
aten._foreach_div_.List,
aten._foreach_maximum_.List,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the issue with this one, and the others that cannot be registered in a generic way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. foreach_maximum: There is no aten.maximum_ method or any torch.maximum_ method.
  2. foreach_pow: Only foreach_pow has a ScalarAndTensor and needed special handling.
  3. foreach_addc{div, mul}: The foreach_addcmul and addcmul have different signatures. The alpha parameter is a keyword argument that is a list in foreach_addcmul and a scalar in addcmul. There are also different names there. scalars in foreach_addcmul and value in addcmul.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, mind leaving comments?

@isuruf
Copy link
Collaborator Author

isuruf commented Nov 20, 2023

This is ready now.

@lezcano
Copy link
Collaborator

lezcano commented Nov 21, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 21, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@isuruf
Copy link
Collaborator Author

isuruf commented Nov 21, 2023

@pytorchbot label "module: tests" "module: meta tensors"

@pytorch-bot pytorch-bot bot added module: meta tensors module: tests Issues related to tests (not the torch.testing module) labels Nov 21, 2023
@isuruf
Copy link
Collaborator Author

isuruf commented Nov 21, 2023

@pytorchbot label "release notes: Meta API"

@pytorch-bot pytorch-bot bot added the release notes: Meta API release notes category label Nov 21, 2023
@isuruf
Copy link
Collaborator Author

isuruf commented Nov 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: meta tensors module: tests Issues related to tests (not the torch.testing module) open source release notes: Meta API release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants