Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Tensor subclassing support, preserving subclasses on function/method calls #28361

Closed
hameerabbasi opened this issue Oct 21, 2019 · 74 comments
Assignees
Labels
feature A request for a proper, new feature. high priority module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hameerabbasi
Copy link
Collaborator

hameerabbasi commented Oct 21, 2019

🚀 Feature

Related: #22402

This feature proposes passing through Tensor subclasses via __torch_function__.

Desired Behaviour

Example desired behavior would be:

class MyTensor(torch.Tensor):
    _additional_attribute = "Kartoffel"

a = MyTensor([0, 1, 2, 3])
# b should be a MyTensor object, with all class attributes passed through.
b = torch_function(a)

Goals

Quoting #22402

These are potential goals that have been collected from the above referenced PRs, other PyTorch issues (referenced in the relevant sections), as well as from discussions with mainly Edward Yang, and also other PyTorch and NumPy maintainers:

  1. Support subclassing torch.Tensor in Python
  2. Preserve Tensor subclasses when calling torch functions on them
  3. Preserve Tensor subclasses when calling numpy functions on them
  4. Use the NumPy API with PyTorch tensors (i.e. NumPy API calls dispatch to torch functions)
  5. Use the PyTorch API with torch.Tensor-like objects that are not Tensor subclasses
  6. Reuse NumPy ufunc implementations directly from PyTorch
  7. Allow operations on mixed array types, e.g. tensor + ndarray

Additionally, from #28361 (comment):

  • Preserve Tensor subclasses when calling Tensor methods
  • Propagating subclass instances correctly also with operators, using views/slices/etc.

Rough Sketch of Implementation

Anything with a type like a built-in tensor will bypass __torch_function__ via their fast path (although they will have a default implementation) but anything else defined by an external library will have the option to allow it.

The following code snippet shows what the default __torch_function__ on TensorBase would look like.

class Tensor:
    def __torch_function__(self, f, t, a, kw):
        if not all(issubclass(ti, TensorBase) for ti in t):
            return NotImplemented
        result = f._wrapped(*a, **kw)
        return type(self)(result)

cc @ezyang @gchanan @zou3519 @jerryzh168 @jph00 @rgommers

@ezyang
Copy link
Contributor

ezyang commented Oct 21, 2019

Can you explain what exactly the delta from this proposal and #22402 is? Or is this just a pattern that subclasses of Tensor can use to implement extensions?

@hameerabbasi
Copy link
Collaborator Author

@ezyang Please see the update to the issue, I've added more details on how this can be made automatic. I've also added an example use-case.

It basically sets out how (if we allow __torch_function__ on subclasses), we can, by a simple extension, create a default __torch_function__ that will make passing through subclasses automatic.

@rgommers
Copy link
Collaborator

@hameerabbasi I'd suggest editing the description some more. The relevant goals are:

  • Support subclassing torch.Tensor in Python
  • Preserve Tensor subclasses when calling torch functions on them
  • Preserve Tensor subclasses when calling Tensor methods
  • Propagating subclass instances correctly also with operators, using views/slices/etc.

Can you explain what exactly the delta from this proposal and #22402 is?

There's no delta, we just need an issue for this topic for discussion (and reporting) that's not mixed with the multi-topic gh-22402. That issue is basically implementable in three parts: __torch_function__ (close to ready for review), this subclassing topic (just started), and NumPy protocol support (lowest prio, not started).

@ezyang
Copy link
Contributor

ezyang commented Oct 21, 2019

OK, sgtm. @jph00 how does this look to you?

@jph00
Copy link

jph00 commented Oct 21, 2019

Thanks gang. I have no comment on the proposed implementation, but the goals look great. :)

Although it's covered already implicitly by the stated goals, I should mention that we've had trouble getting __getitem__ working correctly in subclasses - so this might be something to make sure you test carefully. E.g. be sure to test passing tensors of various types as indices, including bool mask tensors and subclasses.

@rgommers
Copy link
Collaborator

Thanks @jph00, that's exactly the type of input we need.

@fmassa fmassa added feature A request for a proper, new feature. high priority module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 23, 2019
@rgommers rgommers changed the title Pass through subclasses via __torch_function__. Better Tensor subclassing support Oct 28, 2019
@rgommers rgommers changed the title Better Tensor subclassing support ImprovedTensor subclassing support, preserving subclasses on function/method calls Oct 28, 2019
@hameerabbasi
Copy link
Collaborator Author

It seems like variable and no-argument methods do not parse self in their list of arguments:

https://github.com/prasunanand/pytorch/blob/torch_function/tools/autograd/gen_python_functions.py#L101-L110

and

https://github.com/prasunanand/pytorch/blob/torch_function/tools/autograd/gen_python_functions.py#L73-L89

PythonArgParser also does not take self: https://github.com/prasunanand/pytorch/blob/torch_function/torch/csrc/utils/python_arg_parser.h#L102

It might be good for the purposes of this issue to allow self as an argument to PythonArgParser. However, I'm not sure what the overhead of parsing an argument is.

@hameerabbasi
Copy link
Collaborator Author

Also, would it be better to have an expected method on subclasses for the default __torch_function__ a.la. __array_wrap__ (instead called __torch_wrap__, for mirroring NumPy), or just call the default constructor with the output tensor?

Examples:

class TensorSubclass(Tensor):
    def __init__(self, *a, **kw):
        if len(a) == 1 and len(kw) == 0 and isinstance(a[0], torch.Tensor):
            # Do conversion here

vs

class TensorSubclass(Tensor):
    def __torch_wrap__(self, tensor):
        # Do conversion here

@jph00 Thoughts?

@hameerabbasi hameerabbasi changed the title ImprovedTensor subclassing support, preserving subclasses on function/method calls Improved Tensor subclassing support, preserving subclasses on function/method calls Nov 15, 2019
@ezyang
Copy link
Contributor

ezyang commented Nov 15, 2019

It's hard to talk about a very specific implementation problem without seeing more about the planned implementation. In particular, why does PythonArgParser need self?

@hameerabbasi
Copy link
Collaborator Author

Here's my line of reasoning:

  1. We need a default __torch_function__ on Tensor (see issue description), which has to be applied to self as well.
  2. We need to make it work with methods as well, for this to work.
  3. self needs to be in the list of parsed arguments, because most of the __torch_function__ logic is inside PythonArgParser.

The other option is to refactor/rewrite the logic separately for self, i.e.

if type(self) is not Tensor and hasattr(self, '__torch_function__'):
    # Process things here, separately.

Any opinions on which path to take?

@hameerabbasi
Copy link
Collaborator Author

@ezyang
Copy link
Contributor

ezyang commented Nov 18, 2019

We need to make it work with methods as well, for this to work.

OK, let's talk about this for a moment. In the __torch_function__ PR I mentioned about whether or not it would make sense to have some sort of magic method for overriding both functions and methods, but we decided it was out of scope for this issue. Let's drop the question of default tensor function preserving subclasses for a moment, and ask a simpler question: how exactly does the extension to __torch_function__ to support methods work?

@hameerabbasi
Copy link
Collaborator Author

how exactly does the extension to __tensor_function__ to support methods work?

Okay, so my vision is the following: __torch_function__ has the signature (func, args, kwargs) (from the previous PR). In traditional Python style, if it's called on a Tensor method, then func will be the method itself e.g. Tensor.__add__, and args/kwargs would also contain self, in addition to the other explicitly passed-in arguments. In this example, args will contain both self and other.

@ezyang
Copy link
Contributor

ezyang commented Nov 19, 2019

OK, this sounds plausible to me. However, it sounds like this is different from the way Numpy handles arrays in __array_function__. Can you compare this refinement to the Numpy approach?

Also, we have to be careful about this change because if I define both def __add__ and def __torch_function__, which one "wins"?

@hameerabbasi
Copy link
Collaborator Author

Can you compare this refinement to the Numpy approach?

__array_function__ in NumPy doesn't apply to ndarray methods. NumPy, in order to handle subclassing behaviour, does ret = ret.view(subclass) at the end of every method, and then in addition calls ret.__array_finalize__(self) (assuming it exists).

Also, we have to be careful about this change because if I define both def __add__ and def __torch_function__, which one "wins"?

__add__ wins, because of Python's __mro__, subclasses come before superclasses. NumPy has the same problem and model.

@jph00
Copy link

jph00 commented Nov 19, 2019 via email

@rgommers
Copy link
Collaborator

The option of using __array_finalize__ was discussed in gh-22402, the issue is that it's slow.

NumPy, in order to handle subclassing behaviour, does ret = ret.view(subclass) at the end of every method

This actually doesn't work for PyTorch because Tensor.view behaves quite differently from ndarray.view. We had tests in the __torch_function__ branch that used it (adapted from NumPy) but they didn't work so we changed to only use the regular Python way of creating and instantiating a subclass.

@jph00
Copy link

jph00 commented Nov 19, 2019 via email

@rgommers
Copy link
Collaborator

Maybe it could be called .cast instead?

I find a number of discussions on the behavior of view, so I think it has been considered before and rejected. cast implies a dtype change I'd think rather than a shape change. view is basically equivalent to reshape in NumPy. Or Tensor.reshape, except that that also works for mismatching shapes.

@jph00
Copy link

jph00 commented Nov 20, 2019 via email

@rgommers
Copy link
Collaborator

rgommers commented Nov 20, 2019

There's a bit more to view in NumPy:

>>> import numpy as np                                                                         
>>> class subarray(np.ndarray): 
...     newattr = "I'm here!" 
...                                                                                            
>>> x = np.arange(4)                                                                           
>>> x.view(subarray)                                                                           
subarray([0, 1, 2, 3])
>>> y = x.view(subarray)                                                                       
>>> isinstance(y, subarray)                                                                    
True
>>> y.newattr                                                                                  
"I'm here!"

EDIT: I'd use astype for a dtype change

@ezyang
Copy link
Contributor

ezyang commented Dec 8, 2019

What is meant here by "dictionary-based dispatch"? I am a bit lost now.

@hameerabbasi
Copy link
Collaborator Author

Dictionary-based dispatch is where, inside __torch_function__, one uses a dictionary to look up func and decide the implementation of the function. This will fail as B.__add__ is not Tensor.__add__, and if a class dispatches on the latter, it won't be found in the dict. But I claim this is correct behavior, because using cls.__add__ would produce the correct behaviour. If a class is B or one of its descendants which don't override __add__, then B.__add__ is the correct method to use, and looking up Tensor.__add__ is incorrect anyway.

@rgommers
Copy link
Collaborator

@hameerabbasi I would suggest to add the __add__ interaction with multiple subclasses to the test cases in your branch. The discussion is really hard to follow like this; I'd like to be able to figure out more easily if this is a showstopper or a corner case.

A quick meta point: if we can't think of a good way (not a "way around") to do this, we should stop doing this, or change our underlying constraints until there is a good way.

Each "constraint" should be a separate test case.

To make progress more easily, it may be useful to add a slow mechanism analogous to NumPy's __array_finalize__ that meets all the constraints, and then assess what goes wrong if it's replaced with something faster (whether metaclass or __torch_function__ based or other).

Also, this mechanism is independent of public API changes like as_subclass, so would be useful to be able to look at those as well - they shouldn't need changes after.

@hameerabbasi
Copy link
Collaborator Author

To make progress more easily, it may be useful to add a slow mechanism analogous to NumPy's __array_finalize__ that meets all the constraints, and then assess what goes wrong if it's replaced with something faster (whether metaclass or __torch_function__ based or other).

This will have the same composition issue, unfortunately. I pointed that out here:

Okay, the other alternative here is to use __tensor_wrap__ and __tensor_finalize__. What these two protocols do is essentially pre- and post-processing when "wrapping into a subclass".

However, cautionary note: These have exactly the same problem with super that we just discussed (i.e. things will be processed in the wrong order in your example).

Also, this mechanism is independent of public API changes like as_subclass, so would be useful to be able to look at those as well - they shouldn't need changes after.

@ezyang Would you have an idea of what needs to be done for such a function, what data needs copying and what needs views and so on?

@ezyang
Copy link
Contributor

ezyang commented Dec 10, 2019

@hameerabbasi I'm actually not to sure what the precise semantics of as_subclass are (yes I know it views the tensor as a subclass, but this is frustratingly vague). For starters, does it call the constructor of the subclass?

@hameerabbasi
Copy link
Collaborator Author

hameerabbasi commented Dec 10, 2019

I was hoping there would be a way to do it while keeping the same data pointer, whatever that entails, and also keeping any autograd data attached..

@jph00
Copy link

jph00 commented Dec 10, 2019

I believe the semantics should be exactly the same as replacing __class__ in a regular python object. That is also the behavior of view() in numpy, I believe. Which is to say:

  • All state, including in __dict__, is preserved
  • __init__ is not called
  • type() will return the new type, and method dispatch will use that type's methods in the usual python way (including metaclass dispatch, if a metaclass is defined)

I think it's also helpful to have some special method that's called at this time if it exists - in fastai2, for instance, it's called __after_cast__.

@ezyang
Copy link
Contributor

ezyang commented Dec 11, 2019

I'm not sure how to do it. Let me give some information about how PyObject is implemented in PyTorch and maybe that gives you some information.

The PyObject representing Tensor looks like this:

// Python object that backs torch.autograd.Variable
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THPVariable {
    PyObject_HEAD
    // Payload
    torch::autograd::Variable cdata;
    // Hooks to be run on backwards pass (corresponds to Python attr
    // '_backwards_hooks', set by 'register_hook')
    PyObject* backward_hooks = nullptr;
};

Every Variable also contains a pyobj field which points to the unique PyObject object representing the tensor. This ensures that C++ object identity and Python object identity coincide.

Does that answer your question?

@hameerabbasi
Copy link
Collaborator Author

Does that answer your question?

Somewhat. It seems to me that if in true RAII fashion, cdata is actually copied on assignment/copy constructing, we'll need a way to shallow copy it or change it to a pointer, but that's a pretty invasive change. Other than that, we can just copy all the fields, mostly, as well as shallow-copying __dict__.

@jph00
Copy link

jph00 commented Dec 19, 2019

This is how I updated the fastai2 implementation a couple of weeks ago:

def as_subclass(self:Tensor, typ):
    res = torch.Tensor._make_subclass(typ, self)
    if hasattr(self,'__dict__'): res.__dict__ = self.__dict__
    return res

It seems to be working fine for us - but if we're missing something important, I'd love to know now so we can try to fix it! (And if we're not missing something important, is this a solution that pytorch can use too?)

@hameerabbasi
Copy link
Collaborator Author

I'm just going to go ahead and summarise the issue with __torch_function__ for methods as well as __torch_finalize__, and then talk about my preference and my take on @ezyang's composability problem.

__torch_function__ for methods (and the problem with super)

Consider the following code (__torch_function__ for methods will just pass self as the first argument).

class SubclassA(torch.Tensor):
    def __add__(self, other):
         # Do stuff with self, other
         temp_result = super().__add__(self_transformed, other_transformed)
         # Do stuff with temp_result
         return final_result

class SubclassB(SubclassA):
    def __torch_function__(self, func, args, kwargs):
        # Do stuff with args, kwargs
        temp_result = super().__torch_function__(self, func, args_transformed, kwargs_transformed)
        # Do stuff with temp_result
        return temp_result

Now, consider what happens when we add an instance of SubclassB with another such instance.

Since __add__ is inherited from SubclassA, the flow control goes there first instead of SubclassB's __torch_function__. What happens, concretely, in my current proposal, is:

  • self/other are transformed by SubclassA.__add__. Hopefully, if nothing too weird happens, the transformations preserve the class (SubclassB in this case).
  • Since self_transformed/other_transformed are an instance of SubclassB, the call to super goes to Tensor.__torch_function__, which by default does exactly the same as Tensor.__add__, and returns the result.
  • We then transform temp_result, and pass it back to SubclassA.__add__.
  • SubclassA does the final transformations and then returns the result.

The issue with this is the following: There is an inversion of control. SubclassB.__torch_function__ should be the one controlling the execution flow, but it isn't.

During a previous call, me and @ezyang talked about the following solution: Add a default __add__ to SubclassB (perhaps via metaclasses) that dispatches directly to SubclassB.__torch_function__.

I would like to propose the flip side of this, which has the benefit of making everything behave exactly as Tensor behaves. Possibly we can even make Tensor itself work this way if it weren't for the limitation on performance regression:

Make all implementations of methods on subclasses also go through __torch_function__ by default.

__torch_finalize__ and the problem with super

Here, although less exacerbated, the problem still exists. The inversion of control exists, but since __torch_finalize__ (as the name implies) only finalizes the result (based on one of the inputs of that type), but performs no pre-processing.

as_subclass

I believe @ezyang can talk more about how this is okay or not, but I see at least one problem with it:

def as_subclass(self:Tensor, typ):
    res = torch.Tensor._make_subclass(typ, self)
    if hasattr(self,'__dict__'): res.__dict__ = self.__dict__.copy() ## I added the copy
    return res

Otherwise modifying any attribute on res would also modify it on self (unless that was the intention?)

@jph00
Copy link

jph00 commented Jan 7, 2020 via email

@ezyang
Copy link
Contributor

ezyang commented Jan 8, 2020

Make all implementations of methods on subclasses also go through torch_function by default.

So are you saying, instead of super().__add__ being a valid way to call the parent implementation, you call __torch_function__? Or is this something else? (I apologize if you already described this above but the conversation is pretty long. It might be a good idea to edit the top message with the most up to date proposal for easy access.)

@hameerabbasi
Copy link
Collaborator Author

I mean that all methods that Tensor already has would go through __torch_function__, even for subclasses. Concretely, in the example above, SubclassA.__add__ will be automatically decorated with @torch_function_dispatch, and we will recommend all subclasses do the same. This would have the desired effect of making super().__add__ go through super().__torch_function__.

@ezyang
Copy link
Contributor

ezyang commented Jan 9, 2020

This is the first time you've mentioned torch_function_dispatch in this issue. :)

So, if I understand correctly, what you are proposing is that when you subclass tensor, you are obligated to use a decorator, e.g.,

class SubclassA(Tensor):
  @torch_function_dispatch
  def __add__(self, other):
    ...
    super().__add__(self)

If this is the case, in what order do I end up calling these functions, if I have multiple subclasses, and __torch_function__ and __add__ defined in both cases? I am still not completely understanding your proposal. It would be helpful if you could post more fleshed out example code, and walk me through what happens in these cases.

@hameerabbasi
Copy link
Collaborator Author

hameerabbasi commented Jan 14, 2020

So, for the faulty case, we would replace it with the following code:

def _add_dispatcher(self, other):
    return self, other

class SubclassA(torch.Tensor):
    @torch_function_dispatch(_add_dispatcher)
    def __add__(self, other):
         # Do stuff with self, other
         temp_result = super().__add__(self_transformed, other_transformed)
         # Do stuff with temp_result
         return final_result

class SubclassB(SubclassA):
    def __torch_function__(self, func, args, kwargs):
        # Do stuff with args, kwargs
        temp_result = super().__torch_function__(self, func, args_transformed, kwargs_transformed)
        # Do stuff with temp_result
        return temp_result

What happens is the following:

  1. We would have a default implementation for each class for __torch_function__.
  2. Code would dispatch to an implementation if available otherwise the default.
  3. Suppose x.__add__ is called where type(x) is SubclassB. It'll hit SubclassA.__add__.
  4. Which would realise that there are classes other than superclasses of SubclassA and itself present in the list of arguments, it'll try self.__torch_function__ and then other.__torch_function_.
  5. So code would go through SubclassB.__torch_function__
  6. Transformations would take the form of t.as_subclass(SubclassA)
  7. When super().__torch_function__ is called it would dispatch to SubclassA.__add__, as appropriate.
  8. Control is passed back to SubclassB.
  9. Post-processing happens and result is returned.

@ezyang
Copy link
Contributor

ezyang commented Jan 14, 2020

I feel there is a step missing before

  1. So code would go through SubclassB.torch_function

I called x.__add__() where x is a SubclassB. By normal Python resolution rules I'll hit SubclassA.__add__ when this happens. Are you saying the dispatch decorator will pass control to to SubclassB.__torch_function__? I'm still not sure how this would work.

@hameerabbasi
Copy link
Collaborator Author

So, think of SubclassA.__add__... It will follow the __torch_function__ protocol. When it realises that there are classes other than superclasses of SubclassA and itself present in the list of arguments, it'll try self.__torch_function__ and then other.__torch_function_. Since you mentioned self is SubclassB, it'll hit SubclassB.__torch_function__.

@ezyang
Copy link
Contributor

ezyang commented Jan 16, 2020

One presentational note, we should probably call the code that torch_function_dispatch something distinct from __torch_function__, since it is not the same code at all. I'll call this the "Python dispatcher" for now.

Let me see if I understand what you're saying correctly. Your proposal says:

  1. Whenever a user calls a method on a Tensor class, we always transfer control to the Python dispatcher first. All built-in methods on Tensor have this functionality, and any explicitly overridden methods on Tensor arrange for this transfer of control via a mandatory decorator (what happens if the user forgets to add this decorator?)
  2. Once we are in the Python dispatcher, we need to transfer control to the correct user-defined method or __torch_function__ implementation. Similar to how __torch_function__ operates from ngoldbaum's PR, we make a decision about the most specific class, and then attempt to invoke the corresponding method in the class (if it exists), or the __torch_function__ on that class.

You use super() in your example, but with my recap above I don't see how super can work. A super call will transfer control back to the Python dispatcher, but the Python dispatcher needs to know this time around that we have already "finished" with the most specific class, and we should do something higher in the class hierarchy, but I don't see how you can know that, in the proposal.

@hameerabbasi
Copy link
Collaborator Author

You use super() in your example, but with my recap above I don't see how super can work. A super call will transfer control back to the Python dispatcher, but the Python dispatcher needs to know this time around that we have already "finished" with the most specific class, and we should do something higher in the class hierarchy, but I don't see how you can know that, in the proposal.

The way NumPy handles this is a types argument in __array_function__, subclasses remove "themselves" from types before calling super()

@hameerabbasi
Copy link
Collaborator Author

I'm also writing an RFC as requested.

@hameerabbasi
Copy link
Collaborator Author

hameerabbasi commented Jan 24, 2020

@ezyang @jph00 First draft of the proposal is up. pytorch/rfcs#3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. high priority module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants