Skip to content

Commit

Permalink
Fix pickling for Tensor subclasses (redo) (#47732)
Browse files Browse the repository at this point in the history
Summary:
Fixes #47051
Redo of #47115

Pull Request resolved: #47732

Reviewed By: izdeby

Differential Revision: D25465382

Pulled By: ezyang

fbshipit-source-id: 3a8d57281a2d6f57415d5735d34ad307f3526638
  • Loading branch information
hameerabbasi authored and facebook-github-bot committed Feb 1, 2021
1 parent 508bab4 commit b1907f5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import functools
import pprint
import pickle

from torch.testing._internal.common_utils import TestCase, run_tests
from torch.overrides import (
Expand Down Expand Up @@ -839,6 +840,14 @@ def test_newones(self):
n = t.new_ones((1, 2))
self.assertEqual(type(n), SubTensor2)

class TestPickle(TestCase):
"Regression test for gh-47051"
def test_pickle(self):
t = torch.tensor([1]).as_subclass(SubTensor2)
t.abcd = "e"
t2 = pickle.loads(pickle.dumps(t))
self.assertIs(type(t2), SubTensor2)
self.assertEqual(t2.abcd, "e")

class TestBroadcastAllOverride(TestCase):
""" test for gh-37141 """
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._make_subclass,
Tensor.stride,
Tensor.unflatten,
Tensor._reduce_ex_internal,
}


Expand Down
18 changes: 18 additions & 0 deletions torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def wrapped(*args, **kwargs):
return NotImplemented
return wrapped

def _rebuild_from_type(func, type, args, dict):
if type is Tensor:
return func(*args)

ret = func(*args).as_subclass(type)
ret.__dict__ = dict
return ret


# NB: If you subclass Tensor, and want to share the subclassed class
# across processes, you must also update torch/multiprocessing/reductions.py
Expand Down Expand Up @@ -83,6 +91,16 @@ def __deepcopy__(self, memo):
return new_tensor

def __reduce_ex__(self, proto):
if type(self) is Tensor:
return self._reduce_ex_internal(proto)
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto)
func, args = self._reduce_ex_internal(proto)
return (_rebuild_from_type, (func, type(self), args, self.__dict__))

def _reduce_ex_internal(self, proto):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
check_serializing_named_tensor(self)
Expand Down

0 comments on commit b1907f5

Please sign in to comment.