Skip to content

Commit

Permalink
Avoid raising warnings for internal TypedStorage usage
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Oct 4, 2022
1 parent 895f95c commit 9f654a8
Show file tree
Hide file tree
Showing 26 changed files with 403 additions and 140 deletions.
2 changes: 1 addition & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def test_serialization_array_with_storage(self):
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

Expand Down
71 changes: 71 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6430,6 +6430,77 @@ def test_storage_casts(self):
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)

# Test that internal versions of functions related to TypedStorage do not
# produce a deprecation warning
def test_typed_storage_internal_no_warning(self):
s0 = torch.FloatStorage(10)
s0_untyped = s0.untyped()
t0 = torch.randn(10)

funcs = [
lambda: torch.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cpu',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s0_untyped,
dtype=s0.dtype,
_internal=True),
lambda: torch.FloatStorage._dtype,
lambda: s0._resize_(20),
lambda: s0._size(),
lambda: s0._untyped_storage,
lambda: s0._is_shared(),
lambda: s0._share_memory_(),
lambda: s0._pickle_storage_type(),
lambda: s0._setitem(slice(0, s1._size()), 1),
lambda: s0._element_size(),
lambda: s0._deepcopy({}),
lambda: s0._data_ptr(),
lambda: s0._nbytes(),
lambda: t0._typed_storage(),
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
s1_untyped = s1.untyped()
t1 = torch.randn(10, device='cuda')

funcs += [
lambda: torch.cuda.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cuda',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s1_untyped,
dtype=s1.dtype,
_internal=True),
lambda: torch.cuda.FloatStorage._dtype,
lambda: s1._resize_(20),
lambda: s1._size(),
lambda: s1._untyped_storage,
lambda: s1._is_shared(),
lambda: s1._share_memory_(),
lambda: s1._pickle_storage_type(),
lambda: s1._setitem(slice(0, s1._size()), 1),
lambda: s1._element_size(),
lambda: s1._deepcopy({}),
lambda: s1._data_ptr(),
lambda: s1._nbytes(),
lambda: t1._typed_storage(),
]

# Check that each of the TypedStorage internal function calls do not
# produce a deprecation warning
for f in funcs:
with warnings.catch_warnings():
warnings.filterwarnings('error', "TypedStorage is deprecated")
f()

# Test that public functions related to TypedStorage produce a deprecation
# warning
def test_typed_storage_deprecation_warning(self):
s0 = torch.FloatStorage(10)
funcs = [
Expand Down
68 changes: 68 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,102 +656,170 @@ class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.uint8

class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.double

class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.float

class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.half

class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.long

class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int

class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.short

class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int8

class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bool

class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bfloat16

class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cdouble

class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cfloat

class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint8

class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint8

class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint32

class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint4x2

class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint2x4

_storage_classes = {
Expand Down
2 changes: 1 addition & 1 deletion torch/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def persistent_id(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage = obj._untyped_storage
dtype = obj.dtype
else:
storage = obj
Expand Down
2 changes: 1 addition & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def _as_strided_meta(
# as_strided to shapes with no elements are trivially valid, so it's OK
pass
elif isinstance(a, torch.Tensor):
utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset)
utils.check_in_bounds_for_storage(a._typed_storage(), size, stride, storage_offset)

return TensorMeta(a, shape=size, strides=stride)

Expand Down
8 changes: 4 additions & 4 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def add_constant_storage_mapping(self, fake_tensor):
# const_tensor.add_(torch.rand([1]))
# all aliases of it must become no longer const
assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
weak_st = StorageWeakRef(fake_tensor.constant.storage())
weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())

# we need a map from a weak storage to all of its corresponding
# constant tensors. python doesn't have the weak value equivalent
Expand All @@ -156,7 +156,7 @@ def add_constant_storage_mapping(self, fake_tensor):
def invalidate_constant_aliases(self, tensor):
assert not isinstance(tensor, FakeTensor)

weak_st = StorageWeakRef(tensor.storage())
weak_st = StorageWeakRef(tensor._typed_storage())
if weak_st not in self.constant_storage_mapping:
return

Expand Down Expand Up @@ -928,7 +928,7 @@ def to_real_tensor(e):
for e in tree_flatten((args, kwargs))[0]:
if isinstance(e, torch.Tensor):
if not e.is_sparse:
storages.add(e.storage()._cdata)
storages.add(e._typed_storage()._cdata)

# TODO: also check metadata change on inputs
# proper aliasing/metadata relationship between outputs and inputs will
Expand All @@ -938,7 +938,7 @@ def to_real_tensor(e):
if id(e) not in inp_impls and (
isinstance(e, torch.Tensor)
and not e.is_sparse
and e.storage()._cdata in storages
and e._typed_storage()._cdata in storages
):
raise orig_not_implemented_exception

Expand Down
4 changes: 2 additions & 2 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def set_tensor_memo(self, t, v):
if t.is_sparse:
weak_st = None
else:
weak_st = StorageWeakRef(t.storage())
weak_st = StorageWeakRef(t._typed_storage())
tensor_ref_key = WeakTensorRefKey(t)

def del_ten():
Expand Down Expand Up @@ -255,7 +255,7 @@ def is_c_of_r(complex_dtype, real_dtype):
# As long as meta storage is not supported, need to prevent
# redispatching on set_(Storage, ...) which will choke with
# meta storage
s = self.meta_storage(t.storage())
s = self.meta_storage(t._typed_storage())
with no_dispatch():
with torch.no_grad():
r.set_(s, sym_storage_offset(t), sym_size(t), sym_stride(t))
Expand Down
Loading

0 comments on commit 9f654a8

Please sign in to comment.