Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate TypedStorage, its derived classes, and all of their public methods #85303

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ holds the data as an untyped array of bytes.
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.

.. warning::
All storage classes except for :class:`torch.UntypedStorage` will be removed
in the future, and :class:`torch.UntypedStorage` will be used in all cases.

.. autoclass:: torch.TypedStorage
:members:
:undoc-members:
Expand Down
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6805,8 +6805,8 @@ def pack(x):
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
a = torch.ones(5, requires_grad=True)

warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
y = a * a
# should raise two warnings from a being saved twice
self.assertEqual(len(w), 2)
Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def test_serialization_array_with_storage(self):
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

Expand Down
121 changes: 121 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6470,6 +6470,127 @@ def test_storage_casts(self):
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)

# Test that internal versions of functions related to TypedStorage do not
# produce a deprecation warning
def test_typed_storage_internal_no_warning(self):
s0 = torch.FloatStorage(10)
s0_untyped = s0.untyped()
t0 = torch.randn(10)

funcs = [
lambda: torch.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cpu',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s0_untyped,
dtype=s0.dtype,
_internal=True),
lambda: torch.FloatStorage._dtype,
lambda: s0._resize_(20),
lambda: s0._size(),
lambda: s0._untyped_storage,
lambda: s0._is_shared(),
lambda: s0._share_memory_(),
lambda: s0._pickle_storage_type(),
lambda: s0._setitem(slice(0, s0._size()), 1),
lambda: s0._element_size(),
lambda: s0._deepcopy({}),
lambda: s0._data_ptr(),
lambda: s0._nbytes(),
lambda: t0._typed_storage(),
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
s1_untyped = s1.untyped()
t1 = torch.randn(10, device='cuda')

funcs += [
lambda: torch.cuda.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cuda',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s1_untyped,
dtype=s1.dtype,
_internal=True),
lambda: torch.cuda.FloatStorage._dtype,
lambda: s1._resize_(20),
lambda: s1._size(),
lambda: s1._untyped_storage,
lambda: s1._is_shared(),
lambda: s1._share_memory_(),
lambda: s1._pickle_storage_type(),
lambda: s1._setitem(slice(0, s1._size()), 1),
lambda: s1._element_size(),
lambda: s1._deepcopy({}),
lambda: s1._data_ptr(),
lambda: s1._nbytes(),
lambda: t1._typed_storage(),
]

# Check that each of the TypedStorage internal function calls do not
# produce a deprecation warning
for f in funcs:
with warnings.catch_warnings():
warnings.filterwarnings('error', "TypedStorage is deprecated")
f()

# Test that public functions related to TypedStorage produce a deprecation
# warning
def test_typed_storage_deprecation_warning(self):
s0 = torch.FloatStorage(10)
funcs = [
lambda: torch.FloatStorage(),
lambda: torch.FloatStorage.dtype,
lambda: s0.fill_(0),
lambda: s0.is_cuda,
lambda: s0.untyped(),
lambda: len(s0),
lambda: s0[0],
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
funcs += [
lambda: torch.cuda.FloatStorage(),
lambda: torch.cuda.FloatStorage.dtype,
lambda: s1.fill_(0),
lambda: s1.is_cuda,
lambda: s1.untyped(),
lambda: len(s1),
lambda: s1[0],
]

# Check that each of the TypedStorage function calls produce a warning
# if warnings are reset between each
for f in funcs:
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

# Check that only one warning is raised from calling multiple
# TypedStorage functions if warnings are not reset between each
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
for f in funcs:
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

def test_from_file(self):
def assert_with_filename(filename):
size = 10000
Expand Down
2 changes: 1 addition & 1 deletion test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def is_view_of(self, base, other):
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'cpu' or base.device.type == 'cuda':
if base.storage().data_ptr() != other.storage().data_ptr():
if base._storage().data_ptr() != other._storage().data_ptr():
return False

return True
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "storage");
return handle_torch_function(self, "_storage");
}
auto& self_ = THPVariable_Unpack(self);
return createPyObject(self_.storage());
Expand Down
87 changes: 86 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,94 +709,179 @@ def is_warn_always_enabled():
################################################################################

from ._tensor import Tensor
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal

# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage.TypedStorage directly.

class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.uint8

class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.double

class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.float

class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.half

class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.long

class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int

class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.short

class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int8

class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bool

class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bfloat16

class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cdouble

class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cfloat

class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint8

class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint8

class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint32

class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint4x2

class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint2x4

_storage_classes = {
Expand Down
2 changes: 1 addition & 1 deletion torch/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def persistent_id(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage = obj._untyped_storage
dtype = obj.dtype
else:
storage = obj
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/optimizations/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs):

def tensor_alias_group(self, value: torch.Tensor):
"""Assign a unique identifier to the storage of a given tensor"""
storage = StorageWeakRef(value.storage())
storage = StorageWeakRef(value._typed_storage())
alias_group = self.storage_to_alias_group.get(storage)
if alias_group is None:
alias_group = next(self.make_alias_group)
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/optimizations/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
for name, p in target.named_parameters():
param = target.get_parameter(name)
if p.requires_grad and not self._ignore_parameter(param):
buckets[0].size += p.storage().nbytes()
buckets[0].size += p._storage().nbytes()
buckets[0].params.append(f"{node.target}_{name}")
buckets[0].param_ids.append(id(param))
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if maybe_param.requires_grad and not self._ignore_parameter(
maybe_param
):
buckets[0].size += maybe_param.storage().nbytes()
buckets[0].size += maybe_param._storage().nbytes()
buckets[0].params.append(node.target)
buckets[0].param_ids.append(id(maybe_param))

Expand Down
Loading