Skip to content

Commit

Permalink
Deprecate TypedStorage, its derived classes, and all of their public …
Browse files Browse the repository at this point in the history
…methods
  • Loading branch information
kurtamohler committed Sep 20, 2022
1 parent 62786a0 commit a8a2070
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/source/storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ holds the data as an untyped array of bytes.
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.

.. warning::
All storage classes except for :class:`torch.UntypedStorage` will be removed
in the future, and :class:`torch.UntypedStorage` will be used in all cases.

.. autoclass:: torch.TypedStorage
:members:
:undoc-members:
Expand Down
50 changes: 50 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6430,6 +6430,56 @@ def test_storage_casts(self):
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)

def test_typed_storage_deprecation_warning(self):
s0 = torch.FloatStorage(10)

funcs = [
lambda: torch.FloatStorage(),
lambda: torch.FloatStorage.dtype,
lambda: s0.fill_(0),
lambda: s0.is_cuda,
lambda: s0.untyped(),
lambda: len(s0),
lambda: s0[0],
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)

funcs += [
lambda: torch.cuda.FloatStorage(),
lambda: torch.cuda.FloatStorage.dtype,
lambda: s1.fill_(0),
lambda: s1.is_cuda,
lambda: s1.untyped(),
lambda: len(s1),
lambda: s1[0],
]

# Check that each of the TypedStorage function calls produce a warning
# if warnings are reset between each
for f in funcs:
with warnings.catch_warnings(record=True) as w:
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

# Check that only one warning is raised from calling multiple
# TypedStorage functions if warnings are not reset between each
with warnings.catch_warnings(record=True) as w:
for f in funcs:
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))

def test_from_file(self):
def assert_with_filename(filename):
size = 10000
Expand Down
19 changes: 18 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,94 +647,111 @@ def is_warn_always_enabled():
################################################################################

from ._tensor import Tensor
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal

# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage.TypedStorage directly.

class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.uint8

class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.double

class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.float

class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.half

class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.long

class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.int

class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.short

class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.int8

class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.bool

class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.bfloat16

class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.cdouble

class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.cfloat

class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.quint8

class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.qint8

class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.qint32

class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.quint4x2

class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.quint2x4

_storage_classes = {
Expand Down
15 changes: 14 additions & 1 deletion torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,12 @@ def type(self, *args, **kwargs):

__new__ = _lazy_new

from torch.storage import _LegacyStorage
from torch.storage import _LegacyStorage, _warn_typed_storage_removal

class _CudaLegacyStorage(_LegacyStorage):
@classmethod
def from_buffer(cls, *args, **kwargs):
_warn_typed_storage_removal()
raise RuntimeError('from_buffer: Not available for CUDA storage')

@classmethod
Expand All @@ -742,61 +743,73 @@ def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
class ByteStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.uint8

class DoubleStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.double

class FloatStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.float

class HalfStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.half

class LongStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.long

class IntStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.int

class ShortStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.short

class CharStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.int8

class BoolStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.bool

class BFloat16Storage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.bfloat16

class ComplexDoubleStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.cdouble

class ComplexFloatStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return torch.cfloat

del _LegacyStorage
Expand Down
Loading

0 comments on commit a8a2070

Please sign in to comment.