New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved Tensor subclassing support, preserving subclasses on function/method calls #28361
Comments
Can you explain what exactly the delta from this proposal and #22402 is? Or is this just a pattern that subclasses of Tensor can use to implement extensions? |
@ezyang Please see the update to the issue, I've added more details on how this can be made automatic. I've also added an example use-case. It basically sets out how (if we allow |
@hameerabbasi I'd suggest editing the description some more. The relevant goals are:
There's no delta, we just need an issue for this topic for discussion (and reporting) that's not mixed with the multi-topic gh-22402. That issue is basically implementable in three parts: |
OK, sgtm. @jph00 how does this look to you? |
Thanks gang. I have no comment on the proposed implementation, but the goals look great. :) Although it's covered already implicitly by the stated goals, I should mention that we've had trouble getting |
Thanks @jph00, that's exactly the type of input we need. |
It seems like variable and no-argument methods do not parse self in their list of arguments: and
It might be good for the purposes of this issue to allow |
Also, would it be better to have an expected method on subclasses for the default Examples: class TensorSubclass(Tensor):
def __init__(self, *a, **kw):
if len(a) == 1 and len(kw) == 0 and isinstance(a[0], torch.Tensor):
# Do conversion here vs class TensorSubclass(Tensor):
def __torch_wrap__(self, tensor):
# Do conversion here @jph00 Thoughts? |
It's hard to talk about a very specific implementation problem without seeing more about the planned implementation. In particular, why does PythonArgParser need self? |
Here's my line of reasoning:
The other option is to refactor/rewrite the logic separately for self, i.e. if type(self) is not Tensor and hasattr(self, '__torch_function__'):
# Process things here, separately. Any opinions on which path to take? |
I have a WIP branch up at https://github.com/Quansight/pytorch/tree/subclassing, I've added tests at: https://github.com/Quansight/pytorch/blob/a68761ef8942041089d5da4815db07f020667260/test/test_subclassing.py |
OK, let's talk about this for a moment. In the |
Okay, so my vision is the following: |
OK, this sounds plausible to me. However, it sounds like this is different from the way Numpy handles arrays in Also, we have to be careful about this change because if I define both |
|
On Tue, Nov 19, 2019, at 8:43 AM, Hameer Abbasi wrote:
> Can you compare this refinement to the Numpy approach?
`__array_function__` in NumPy doesn't apply to `ndarray` methods. NumPy, in order to handle subclassing behaviour, does `ret = ret.view(subclass)` at the end of every method, and then in addition calls `ret.__array_finalize__(self)` (assuming it exists).
This is also how fastai v2 works BTW - we call `retain_types()` at the end of `Transform.encodes` and various other places (automatically, in most cases).
|
The option of using
This actually doesn't work for PyTorch because |
> This actually doesn't work for PyTorch because `Tensor.view` behaves quite differently from `ndarray.view`. We had tests in the `__torch_function__` branch that used it (adapted from NumPy) but they didn't work so we changed to only use the regular Python way of creating and instantiating a subclass.
Maybe it could be called `.cast` instead?
|
I find a number of discussions on the behavior of |
On Tue, Nov 19, 2019, at 4:43 PM, Ralf Gommers wrote:
> Maybe it could be called `.cast` instead?
I find a number of discussions on the behavior of `view`, so I think it has been considered before and rejected. `cast` implies a dtype change I'd think rather than a shape change. `view` is basically equivalent to `reshape` in NumPy. Or `Tensor.reshape`, except that that also works for mismatching shapes.
Possibly one of us is misunderstanding something (and it could well be me!)
In numpy, `view` does exactly that: it's a dtype change, not a shape change. That's why I suggested `cast` as the name for the equivalent functionality in pytorch.
|
There's a bit more to
EDIT: I'd use |
What is meant here by "dictionary-based dispatch"? I am a bit lost now. |
Dictionary-based dispatch is where, inside |
@hameerabbasi I would suggest to add the
Each "constraint" should be a separate test case. To make progress more easily, it may be useful to add a slow mechanism analogous to NumPy's Also, this mechanism is independent of public API changes like |
This will have the same composition issue, unfortunately. I pointed that out here:
@ezyang Would you have an idea of what needs to be done for such a function, what data needs copying and what needs views and so on? |
@hameerabbasi I'm actually not to sure what the precise semantics of |
I was hoping there would be a way to do it while keeping the same data pointer, whatever that entails, and also keeping any autograd data attached.. |
I believe the semantics should be exactly the same as replacing
I think it's also helpful to have some special method that's called at this time if it exists - in fastai2, for instance, it's called |
I'm not sure how to do it. Let me give some information about how PyObject is implemented in PyTorch and maybe that gives you some information. The PyObject representing Tensor looks like this:
Every Variable also contains a pyobj field which points to the unique Does that answer your question? |
Somewhat. It seems to me that if in true RAII fashion, |
This is how I updated the fastai2 implementation a couple of weeks ago:
It seems to be working fine for us - but if we're missing something important, I'd love to know now so we can try to fix it! (And if we're not missing something important, is this a solution that pytorch can use too?) |
I'm just going to go ahead and summarise the issue with
|
def as_subclass(self:Tensor, typ):
res = torch.Tensor._make_subclass(typ, self)
if hasattr(self,'__dict__'): res.__dict__ = self.__dict__.copy() ## I added the copy
return res
Otherwise modifying any attribute on `res` would also modify it on `self` (unless that was the intention?)
That is absolutely the intention! :) A cast object should be a reference, not a copy. Note that this is *already* the behavior you see in `_make_subclass`:
```
a = tensor([1,2,3])
class T(Tensor): pass
res = torch.Tensor._make_subclass(T, a)
res[1] = 5
print(res)
```
tensor([1, 5, 3])
It would be extremely confusing if cast object acted as a reference when it came to their tensor data, but as a copy when it came to its attributes.
|
So are you saying, instead of |
I mean that all methods that |
This is the first time you've mentioned So, if I understand correctly, what you are proposing is that when you subclass tensor, you are obligated to use a decorator, e.g.,
If this is the case, in what order do I end up calling these functions, if I have multiple subclasses, and |
So, for the faulty case, we would replace it with the following code:
What happens is the following:
|
I feel there is a step missing before
I called |
So, think of |
One presentational note, we should probably call the code that Let me see if I understand what you're saying correctly. Your proposal says:
You use |
The way NumPy handles this is a |
I'm also writing an RFC as requested. |
@ezyang @jph00 First draft of the proposal is up. pytorch/rfcs#3 |
🚀 Feature
Related: #22402
This feature proposes passing through
Tensor
subclasses via__torch_function__
.Desired Behaviour
Example desired behavior would be:
Goals
Quoting #22402
Additionally, from #28361 (comment):
Rough Sketch of Implementation
Anything with a type like a built-in tensor will bypass
__torch_function__
via their fast path (although they will have a default implementation) but anything else defined by an external library will have the option to allow it.The following code snippet shows what the default
__torch_function__
onTensorBase
would look like.cc @ezyang @gchanan @zou3519 @jerryzh168 @jph00 @rgommers
The text was updated successfully, but these errors were encountered: