Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor subclasses lose type when pickling #47051

Open
jph00 opened this issue Oct 29, 2020 · 7 comments
Open

Tensor subclasses lose type when pickling #47051

jph00 opened this issue Oct 29, 2020 · 7 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jph00
Copy link

jph00 commented Oct 29, 2020

馃悰 Bug

Pickling and then unpickling a subclass of Tensor should result in an object of the same class as what was pickled. However, it doesn't - instead it always creates something of class Tensor.

To Reproduce

The following prints torch.Tensor, however it should print _T.

class _T(torch.Tensor): ...
t = tensor([1]).as_subclass(_T)
print(type(pickle.loads(pickle.dumps(t))))

Expected behavior

It should print _T.

Environment

torch.__version__ == 1.7.0, running on Ubuntu.

Additional context

Placing the following method in the base class resolves the problem:

    def __reduce_ex__(self,proto):
        torch.utils.hooks.warn_if_has_hooks(self)
        args = (type(self), self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
        if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
        f = _fa_rebuild_qtensor if self.is_quantized else  _fa_rebuild_tensor
        return (f, args + (self.requires_grad, OrderedDict()))

Perhaps something like this could be used to fix this issue.

cc @hameerabbasi @rgommers @mruberry

@albanD albanD added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: __torch_function__ module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 29, 2020
@hameerabbasi
Copy link
Collaborator

I guess pickle.loads/pickle.dumps isn't meant to be fast, so the proposed fix should be okay. Thoughts?

@jph00
Copy link
Author

jph00 commented Oct 29, 2020

I originally wrote this for, IIRC, pytorch 1.4, since we added our own subclassing support back then when PyTorch didn't support it. At the time, I largely copied, pasted, and refactored code from PyTorch to do so. So it shouldn't be any slower than PyTorch, unless things have changed since then.

@hameerabbasi
Copy link
Collaborator

Yes, __reduce_ex__ seems to have changed significantly and so the proposed impl is too simple to do what it needs to do.

@hameerabbasi
Copy link
Collaborator

Being worked on in #47115

@jph00
Copy link
Author

jph00 commented Nov 5, 2020

Thank you @hameerabbasi ! :)

facebook-github-bot pushed a commit that referenced this issue Feb 1, 2021
Summary:
Fixes #47051
Redo of #47115

Pull Request resolved: #47732

Reviewed By: izdeby

Differential Revision: D25465382

Pulled By: ezyang

fbshipit-source-id: 3a8d57281a2d6f57415d5735d34ad307f3526638
@aschuh-hf
Copy link

aschuh-hf commented Jul 11, 2021

Hi, should pickling of torch.Tensor subclasses work without subclass code needing to do anything special after this issue was closed? Because I run into this very issue with PyTorch 1.8.1. Only after adding a __reduce_ex__ implementation to the subclass as done by Jeremy in FastAI the pickling worked as expected, i.e.,

class TensorSubclass(torch.Tensor):
    ...

    def __reduce_ex__(self,proto):
        torch.utils.hooks.warn_if_has_hooks(self)
        args = (self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
        if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
        args = args + (self.requires_grad, OrderedDict())
        f = torch._utils._rebuild_qtensor if self.is_quantized else  torch._utils._rebuild_tensor_v2
        return (_rebuild_from_type, (f, type(self), args, self.__dict__))


def _rebuild_from_type(func, type, args, dict):
    ret = func(*args).as_subclass(type)
    ret.__dict__ = dict
    return ret

(cf. https://github.com/fastai/fastai/blob/2af48aa4747c910283f679505bb5ef77a4b8eeec/fastai/torch_core.py#L325-L331)

Is there anything different I would need to do in __torch_function__? As I see that torch.Tensor.__reduce_ex__ seems to call this function in it's own implementation of __reduce_ex__:

return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto)
. If you could post a snippet of a proper __torch_function__ to handle also pickling, that would be great. I also think a section could be added to PyTorch documentation at https://pytorch.org/docs/stable/notes/extending.html#extending-torch, which was very helpful in creating a custom torch.Tensor subclass except for pickling support.

I see that the regression test for this issue uses a tensor subclass type which does not define __torch_function__, in which case it seems to be working fine because has_torch_function() is False. I was looking at the test for an example tensor subclass implementation that would support pickling according to the changes made as part of this issue.

I get the impression I run into this issue because of my __torch_function__ looks as follows:

    def __torch_function__(self, func, types, args=(), kwargs=None):
        args = [arg.as_subclass(Tensor) if isinstance(arg, Tensor) else arg for arg in args]
        if kwargs is None:
            kwargs = {}
        result = func(*args, **kwargs)

i.e., converts all args to torch.Tensor types first? But without, I run into an infinite recursion.

Similarly, changing this to:

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if func == Tensor.__reduce_ex__:
            return func(*args, **kwargs)
        args = [arg.as_subclass(Tensor) if isinstance(arg, Tensor) else arg for arg in args]
        if kwargs is None:
            kwargs = {}
        result = func(*args, **kwargs)

leads to an infinite recursion.

It is also puzzling that both torch.Tensor.__reduce_ex__ and torch.Tensor.__reduce_ex_internal__ check for existence of __torch_function__ and call handle_torch_function? I would otherwise just call __reduce_ex_internal__ from my subclass's __reduce_ex__ function, but this is not possible.

Any suggestion on how to properly implement a subclass that supports pickling appreciated.

Thanks!

@mruberry mruberry reopened this Jul 15, 2021
@mruberry
Copy link
Collaborator

Hi, should pickling of torch.Tensor subclasses work without subclass code needing to do anything special after this issue was closed? Because I run into this very issue with PyTorch 1.8.1. Only after adding a __reduce_ex__ implementation to the subclass as done by Jeremy in FastAI the pickling worked as expected, i.e.,

class TensorSubclass(torch.Tensor):
    ...

    def __reduce_ex__(self,proto):
        torch.utils.hooks.warn_if_has_hooks(self)
        args = (self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
        if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
        args = args + (self.requires_grad, OrderedDict())
        f = torch._utils._rebuild_qtensor if self.is_quantized else  torch._utils._rebuild_tensor_v2
        return (_rebuild_from_type, (f, type(self), args, self.__dict__))


def _rebuild_from_type(func, type, args, dict):
    ret = func(*args).as_subclass(type)
    ret.__dict__ = dict
    return ret

(cf. https://github.com/fastai/fastai/blob/2af48aa4747c910283f679505bb5ef77a4b8eeec/fastai/torch_core.py#L325-L331)

Is there anything different I would need to do in __torch_function__? As I see that torch.Tensor.__reduce_ex__ seems to call this function in it's own implementation of __reduce_ex__:

return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto)

. If you could post a snippet of a proper __torch_function__ to handle also pickling, that would be great. I also think a section could be added to PyTorch documentation at https://pytorch.org/docs/stable/notes/extending.html#extending-torch, which was very helpful in creating a custom torch.Tensor subclass except for pickling support.
I see that the regression test for this issue uses a tensor subclass type which does not define __torch_function__, in which case it seems to be working fine because has_torch_function() is False. I was looking at the test for an example tensor subclass implementation that would support pickling according to the changes made as part of this issue.

I get the impression I run into this issue because of my __torch_function__ looks as follows:

    def __torch_function__(self, func, types, args=(), kwargs=None):
        args = [arg.as_subclass(Tensor) if isinstance(arg, Tensor) else arg for arg in args]
        if kwargs is None:
            kwargs = {}
        result = func(*args, **kwargs)

i.e., converts all args to torch.Tensor types first? But without, I run into an infinite recursion.

Similarly, changing this to:

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if func == Tensor.__reduce_ex__:
            return func(*args, **kwargs)
        args = [arg.as_subclass(Tensor) if isinstance(arg, Tensor) else arg for arg in args]
        if kwargs is None:
            kwargs = {}
        result = func(*args, **kwargs)

leads to an infinite recursion.

It is also puzzling that both torch.Tensor.__reduce_ex__ and torch.Tensor.__reduce_ex_internal__ check for existence of __torch_function__ and call handle_torch_function? I would otherwise just call __reduce_ex_internal__ from my subclass's __reduce_ex__ function, but this is not possible.

Any suggestion on how to properly implement a subclass that supports pickling appreciated.

Thanks!

Reopening this issue until we have a chance to review your question, @aschuh-hf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects module: __torch_function__ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants