Skip to content

Commit

Permalink
Deprecate TypedStorage, its derived classes, and all of their public …
Browse files Browse the repository at this point in the history
…methods (#85303)

Part of #85302

Pull Request resolved: #85303
Approved by: https://github.com/ezyang
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Nov 8, 2022
1 parent 53ca5ad commit ee28b86
Show file tree
Hide file tree
Showing 37 changed files with 631 additions and 176 deletions.
4 changes: 4 additions & 0 deletions docs/source/storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ holds the data as an untyped array of bytes.
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.

.. warning::
All storage classes except for :class:`torch.UntypedStorage` will be removed
in the future, and :class:`torch.UntypedStorage` will be used in all cases.

.. autoclass:: torch.TypedStorage
:members:
:undoc-members:
Expand Down
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6805,8 +6805,8 @@ def pack(x):
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
a = torch.ones(5, requires_grad=True)

warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
y = a * a
# should raise two warnings from a being saved twice
self.assertEqual(len(w), 2)
Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def test_serialization_array_with_storage(self):
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

Expand Down
121 changes: 121 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6470,6 +6470,127 @@ def test_storage_casts(self):
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)

# Test that internal versions of functions related to TypedStorage do not
# produce a deprecation warning
def test_typed_storage_internal_no_warning(self):
s0 = torch.FloatStorage(10)
s0_untyped = s0.untyped()
t0 = torch.randn(10)

funcs = [
lambda: torch.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cpu',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s0_untyped,
dtype=s0.dtype,
_internal=True),
lambda: torch.FloatStorage._dtype,
lambda: s0._resize_(20),
lambda: s0._size(),
lambda: s0._untyped_storage,
lambda: s0._is_shared(),
lambda: s0._share_memory_(),
lambda: s0._pickle_storage_type(),
lambda: s0._setitem(slice(0, s0._size()), 1),
lambda: s0._element_size(),
lambda: s0._deepcopy({}),
lambda: s0._data_ptr(),
lambda: s0._nbytes(),
lambda: t0._typed_storage(),
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
s1_untyped = s1.untyped()
t1 = torch.randn(10, device='cuda')

funcs += [
lambda: torch.cuda.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cuda',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s1_untyped,
dtype=s1.dtype,
_internal=True),
lambda: torch.cuda.FloatStorage._dtype,
lambda: s1._resize_(20),
lambda: s1._size(),
lambda: s1._untyped_storage,
lambda: s1._is_shared(),
lambda: s1._share_memory_(),
lambda: s1._pickle_storage_type(),
lambda: s1._setitem(slice(0, s1._size()), 1),
lambda: s1._element_size(),
lambda: s1._deepcopy({}),
lambda: s1._data_ptr(),
lambda: s1._nbytes(),
lambda: t1._typed_storage(),
]

# Check that each of the TypedStorage internal function calls do not
# produce a deprecation warning
for f in funcs:
with warnings.catch_warnings():
warnings.filterwarnings('error', "TypedStorage is deprecated")
f()

# Test that public functions related to TypedStorage produce a deprecation
# warning
def test_typed_storage_deprecation_warning(self):
s0 = torch.FloatStorage(10)
funcs = [
lambda: torch.FloatStorage(),
lambda: torch.FloatStorage.dtype,
lambda: s0.fill_(0),
lambda: s0.is_cuda,
lambda: s0.untyped(),
lambda: len(s0),
lambda: s0[0],
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
funcs += [
lambda: torch.cuda.FloatStorage(),
lambda: torch.cuda.FloatStorage.dtype,
lambda: s1.fill_(0),
lambda: s1.is_cuda,
lambda: s1.untyped(),
lambda: len(s1),
lambda: s1[0],
]

# Check that each of the TypedStorage function calls produce a warning
# if warnings are reset between each
for f in funcs:
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

# Check that only one warning is raised from calling multiple
# TypedStorage functions if warnings are not reset between each
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
for f in funcs:
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

def test_from_file(self):
def assert_with_filename(filename):
size = 10000
Expand Down
2 changes: 1 addition & 1 deletion test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def is_view_of(self, base, other):
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'cpu' or base.device.type == 'cuda':
if base.storage().data_ptr() != other.storage().data_ptr():
if base._storage().data_ptr() != other._storage().data_ptr():
return False

return True
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "storage");
return handle_torch_function(self, "_storage");
}
auto& self_ = THPVariable_Unpack(self);
return createPyObject(self_.storage());
Expand Down
87 changes: 86 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,94 +709,179 @@ def is_warn_always_enabled():
################################################################################

from ._tensor import Tensor
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal

# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage.TypedStorage directly.

class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.uint8

class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.double

class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.float

class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.half

class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.long

class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int

class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.short

class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int8

class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bool

class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bfloat16

class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cdouble

class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cfloat

class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint8

class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint8

class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint32

class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint4x2

class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint2x4

_storage_classes = {
Expand Down
2 changes: 1 addition & 1 deletion torch/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def persistent_id(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage = obj._untyped_storage
dtype = obj.dtype
else:
storage = obj
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/optimizations/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs):

def tensor_alias_group(self, value: torch.Tensor):
"""Assign a unique identifier to the storage of a given tensor"""
storage = StorageWeakRef(value.storage())
storage = StorageWeakRef(value._typed_storage())
alias_group = self.storage_to_alias_group.get(storage)
if alias_group is None:
alias_group = next(self.make_alias_group)
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/optimizations/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
for name, p in target.named_parameters():
param = target.get_parameter(name)
if p.requires_grad and not self._ignore_parameter(param):
buckets[0].size += p.storage().nbytes()
buckets[0].size += p._storage().nbytes()
buckets[0].params.append(f"{node.target}_{name}")
buckets[0].param_ids.append(id(param))
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if maybe_param.requires_grad and not self._ignore_parameter(
maybe_param
):
buckets[0].size += maybe_param.storage().nbytes()
buckets[0].size += maybe_param._storage().nbytes()
buckets[0].params.append(node.target)
buckets[0].param_ids.append(id(maybe_param))

Expand Down
Loading

0 comments on commit ee28b86

Please sign in to comment.