Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overload names to native_functions.yaml #23532

Closed
wants to merge 8 commits into from

Conversation

smessmer
Copy link
Contributor

@smessmer smessmer commented Jul 29, 2019

Stack from ghstack:

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: https://gist.github.com/dzhulgakov/e64b03ed38c7b530c65992a8318e7332

Differential Revision: D16553437

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75630718

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: internals Related to internal abstractions in c10 and ATen module: operators labels Jul 29, 2019
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the idea of having overload names when we do not actually have more than one implementation of a method. Very few operators are overloaded. If everything were overloaded, we would just mangle the types automatically as this patch semi-manually does. The intention is for the overload name to be semantically meaningful to distinguish it from other overloads. A string like TTiTTTTTTTiiibfbbiTTb doesn't help with understanding.

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75630718

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
smessmer added a commit that referenced this pull request Jul 30, 2019
We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75630718

Pull Request resolved: #23532
ghstack-source-id: 87387510

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@smessmer
Copy link
Contributor Author

smessmer commented Jul 30, 2019

@zdevito Ok, I changed it to only add overload names for the ops that actually have overloads. It's about 1.400 ops, which is >50%. It's not feasible to find semantic names for all of these, so I'd keep with the generated names for now.

@smessmer smessmer requested a review from zdevito July 30, 2019 01:06
@dzhulgakov
Copy link
Collaborator

@smessmer - I think vast majority of >1 overload can be dealt with by diagnosing the _out version of the op. Another big bucket is Tensor vs Scalar argument type. Also, do we allow "default" overload without the overload name?

P.S. Can you remind me why we wanted to do it? I recall the way of registering the ops and verifying changes in schema. That applies less to native_function.yaml and we could probably add the mangling you have directly to the codegen script. But if the names are more meaningful, it's nice to include them here.

@zdevito
Copy link
Contributor

zdevito commented Jul 30, 2019

I agree about the "out" variant. From looking at it almost all the overloads come from that. I would also prefer that the suffix only include the arguments that are different across the overloads. I think with that change it should become feasible to give the rest semantically meaningful names.

@smessmer
Copy link
Contributor Author

Yes, the reason is for verifying schema changes in registration and it is less important for native_functions.yaml. We could generate the mangling in the codegen, but that would mean that we introduce a difference between the function schema in native_functions.yaml and the one used when it's registered in jit. I think that's the wrong direction - we want to unify these schemas, not make them different.

What do you mean about diagnosing the out variant? Give them an "inplace" overload name?

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75659751

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@smessmer
Copy link
Contributor Author

After taking care of the out variants and simple Tensor/Scalar differences, there's now 357 mangled overloads left.

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75659751

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@zdevito
Copy link
Contributor

zdevito commented Jul 30, 2019

I think I was expecting that overload names could have semantic meaning for each overload. This is possible for Tensor/Scalar and Out variants (btw, inplace is the wrong word to use for those add_ is inplace because it modifies self, the out variants do not modify self). Tagging anything with a mangled name is going to be a problem for the future. We can't change the overload name without breaking BC, but we will want to add additional default arguments. Once we do this then the overload name is not even going to match the mangling scheme. So we either need to:

  1. Figure out enough semi-automated rules to give the existing operators semantically meaningful overload names

  2. Give up on naming overloads, because the cost (confusing mangling names strewn about the code) is worse than the benefit.

There are other ways we can catch typos. For instance, if a user registers an op and marks it an overload, we could raise an error if it isn't overloading something loaded from aten.

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75896844

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
smessmer added a commit that referenced this pull request Jul 30, 2019
We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75630718

Pull Request resolved: #23532
ghstack-source-id: 87458034

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@smessmer
Copy link
Contributor Author

ok new approach: The script now (P75976311):

  • annotates all out variants
  • for each op that has overloads with different signatures but the same op_name + overload_name combination:
    • for each of these overloads:
      • find first argument that is not present in all of them
        • extend overload_name by that argument name (if name differs) or type (if type differs)
  • rinse and repeat the above until the overload names are unique

This seems to produce relatively good semantic names for the overloads.

@zdevito @dzhulgakov please take another look

@apaszke
Copy link
Contributor

apaszke commented Jul 31, 2019

Why can't we just do the overload trick that @zdevito suggests? This is exactly the approach taken in many programming languages and it seems enough to catch the errors in most cases. It's also really simple to implement, verify, and does not suffer from BC issues.

@smessmer
Copy link
Contributor Author

smessmer commented Jul 31, 2019

@apaszke We need overload names for several things. Error checking is one of them but not the only one. Another reason is that we don't want mobile to have to resolve overloads, so we need to serialize a model that has overloads already resolved. I think there've been a few other reasons but the decision for overload names happened some time ago and I don't remember everything. The functionality for overload names is already part of the system for some time and there's things relying on it, but we didn't actually add names for overloads of ATen ops before. This PR now catches up with the truth and adds overload names to existing ATen ops.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a lot better now, thanks! I have a few minor things we should fix that I mentioned inline, but the overload names are now much more descriptive about the differences between the operators.

aten/src/ATen/native/native_functions.yaml Outdated Show resolved Hide resolved
aten/src/ATen/native/native_functions.yaml Outdated Show resolved Hide resolved
aten/src/ATen/native/native_functions.yaml Show resolved Hide resolved
aten/src/ATen/native/native_functions.yaml Outdated Show resolved Hide resolved
We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75896844

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@pytorchbot pytorchbot added the module: cpp-extensions Related to torch.utils.cpp_extension label Jul 31, 2019
We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P75896844

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@apaszke
Copy link
Contributor

apaszke commented Aug 1, 2019

Another reason is that we don't want mobile to have to resolve overloads, so we need to serialize a model that has overloads already resolved.

Why can't we use full signatures to uniquely identify those ops then?

We need this to be able to register them with the c10 dispatcher.

The overload names are based on one-letter-per-argument-type.

Script used to change native_functions.yaml and derivatives.yaml: P76270106

Differential Revision: [D16553437](https://our.internmc.facebook.com/intern/diff/D16553437/)
@zou3519 zou3519 deleted the gh/smessmer/12/head branch August 1, 2019 09:11
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 02f794b.

@gchanan
Copy link
Contributor

gchanan commented Aug 5, 2019

This should not have been merged.

native_functions is something our OSS contributors frequently touch, and this modifies the behavior without any explanation or documentation that someone not looking at this PR can find. See https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/README.md.

@gchanan
Copy link
Contributor

gchanan commented Aug 5, 2019

A few other issues:

  1. The justification is still not clear (see e.g. @apaszke's comment above)). I see comments like:

I think there've been a few other reasons but the decision for overload names happened some time ago and I don't remember everything.

Updating our understanding of why we are doing the changes is necessary; we can't ever get rid of things if we don't remember why we did them in the first place.

  1. Why is there no error checking (that overloads are named)? Presumably we are going to add more overloads in the future, do we expect users to run your script?

@smessmer
Copy link
Contributor Author

smessmer commented Aug 5, 2019

I added documentation: #23844

Backends registering a new kernel for an overload want to use a shorthand syntax that doesn't require them to specify the full schema:

static auto registry = torch::RegisterOperators("my::operator", &kernel_func);

This looks identical to a backend adding a new overload:

static auto registry = torch::RegisterOperators("my::operator", &kernel_func);

The only difference being that in one case the overload was already registered before, in the other it wasn't.

Without overload names, there would be no way for us to know if they intended to add a new overload or intended to add a new kernel to an existing one, so we would have to hope the behavior is correct and couldn't error out.

Marking overloads like @zdevito proposed above with an API like this:

static auto registry = torch::RegisterOperators("my::operator", &kernel_func,
    torch::AddNewOverload());

or

static auto registry = torch::RegisterOperators("my::operator", &kernel_func,
    torch::AddKernelForExisitingOverload());

would be able to differentiate between these, but make the API harder to use and it actually doesn't work because we don't know which static initializers are going to be run first - C++ doesn't guarantee that the registration adding the overload is run before registrations adding kernels for it.

Also, as mentioned, the decision to go with overload names happened when the c10 dispatcher was designed some months ago in design discussions with @zdevito, @dzhulgakov and many other people. Changing that now would require changing how the c10 dispatcher works. Let's avoid boiling the ocean.

Error checking is implemented in the c10 dispatcher. I'm working on a stack of PRs that registers all ATen ops with c10 and c10 will balk if there's an operator with non-unique overload names. This is actually the reason why this PR came now, long after we decided to go with overload names in c10: Before, these ops weren't in c10 and soon they're going to be.

Operators that don't have overloads don't need overload names. Nobody is expected to run my script. If, between now and the time I am able to add the c10 registration, people manage to add overloads and forget about overload names, I will fix them.

smessmer added a commit that referenced this pull request Aug 6, 2019
smessmer added a commit that referenced this pull request Aug 6, 2019
smessmer added a commit that referenced this pull request Aug 8, 2019
@apaszke
Copy link
Contributor

apaszke commented Aug 8, 2019

If you registered both functions and only said that they're an implementation of "my::op", then how would the dispatch system pick one over the other? This seems highly ambiguous to me and I don't understand the example.

C++ doesn't guarantee that the registration adding the overload is run before registrations adding kernels for it.

This doesn't seem like a very strong argument. In particular you don't have to do the checking immediately when the overloads are registered, but you can put it off until a later time (e.g. after all stdlib operators are loaded or when an op is used for the first time).

@zdevito
Copy link
Contributor

zdevito commented Aug 8, 2019

We should address the documentation concerns for having overload names, but generally I am mildly in favor of having semantically named overloads. I think error checking with overload names is really a secondary justification for why having overload names is good.

The primary reason for having them is so that once we have resolved the overload (either dynamically in python, or statically in TorchScript), we can provide a unique name for the result. This makes a few things easier:

  • When doing RPC, we resolve the overload on the sender, and send the fully resolved overload name over the wire. This prevents having to recheck the overload on the other side of the wire. In TorchScript, it is not always possible to completely recover the static type of a value from its IValue, so it is cleaner to resolve the overload at compile time on the client and then just send a message with the full name.

  • For mobile, we are going to keep code size and model size down by only putting the interpreter bytecode in the model. At this point, type information has been erased so we will need to refer to a unique overload unambiguously without knowing the input types.

That said, in both cases, we can instead refer to the full schema, with arguments, to unambiguously name overloads. This seems more fragile to me but not in ways I completely understand yet.

smessmer added a commit that referenced this pull request Aug 8, 2019
smessmer added a commit that referenced this pull request Aug 8, 2019
smessmer added a commit that referenced this pull request Aug 9, 2019
smessmer added a commit that referenced this pull request Aug 10, 2019
smessmer added a commit that referenced this pull request Aug 12, 2019
smessmer added a commit that referenced this pull request Aug 12, 2019
salexspb pushed a commit to salexspb/pytorch that referenced this pull request Aug 13, 2019
Summary:
Pull Request resolved: pytorch#23748

This extends the changes from pytorch#23532

ghstack-source-id: 88157704

Differential Revision: D16629907

fbshipit-source-id: ffcf937ec34a798a971e7d28ad85afb3b646d1fe
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 13, 2019
Summary:
Pull Request resolved: pytorch/pytorch#23748

This extends the changes from pytorch/pytorch#23532

ghstack-source-id: 88157704

Differential Revision: D16629907

fbshipit-source-id: ffcf937ec34a798a971e7d28ad85afb3b646d1fe
@bwasti
Copy link
Contributor

bwasti commented Aug 29, 2019

This is a manual form of type mangling with zero restrictions on symbol choice. It doesn't seem maintainable long term.

Alternate solutions

There might be some less invasive ways to solve some of the problems I found mentioned in this conversation:

  1. Registration ambiguity
static auto registry = torch::RegisterOperators("my::operator", &kernel_func);

totally specifies the schema if implemented to do so. std::string kernel_func(std::string) {}
is, at compile time, not the same as Tensor kernel_func(Tensor) {}. Perhaps the real issue is that the current registry casts all functions to void* instead of deducing and storing the schema? I'd recommend checking out how pybind11 is able to do this (and provide incredible error messages) without the user having to overload function names or even manually specify the schema.

  1. Sending fully resolved names over the wire

This seems like it would be well solved by a mangling scheme, which would be maintained in the same way the variant type IValue is maintained. This would elide the need for introducing changes to schemas and manually writing custom tags for only certain ops.

  1. Mobile

I'm honestly not clear on what the issue is here. It seems like there may be perf concerns doing these lookups? I think we should probably benchmark that and worst comes to worst do something like interning the schema strings to speed stuff up.

Other issues

With respect to this actual change, I think there are some issues that might arise long term.

  1. Won't we just see a ton of schemas that look like this?
schema.first_arg_type(first_arg_type arg_0, ...) -> ret_type
  1. How do collisions work? More generally, it isn't easy to predict expected behavior with this scheme because it doesn't seem a very standard approach.

  2. Will folks who only need the CPU version of PyTorch end up writing "broken" schemas that screw up mobile without realizing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cpp-extensions Related to torch.utils.cpp_extension module: internals Related to internal abstractions in c10 and ATen oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants