Skip to content

Commit

Permalink
Avoid raising warnings for internal TypedStorage usage
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Nov 8, 2022
1 parent 5ece8a5 commit 9ec684c
Show file tree
Hide file tree
Showing 36 changed files with 496 additions and 179 deletions.
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6805,8 +6805,8 @@ def pack(x):
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
a = torch.ones(5, requires_grad=True)

warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
y = a * a
# should raise two warnings from a being saved twice
self.assertEqual(len(w), 2)
Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def test_serialization_array_with_storage(self):
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

Expand Down
71 changes: 71 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6470,6 +6470,77 @@ def test_storage_casts(self):
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)

# Test that internal versions of functions related to TypedStorage do not
# produce a deprecation warning
def test_typed_storage_internal_no_warning(self):
s0 = torch.FloatStorage(10)
s0_untyped = s0.untyped()
t0 = torch.randn(10)

funcs = [
lambda: torch.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cpu',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s0_untyped,
dtype=s0.dtype,
_internal=True),
lambda: torch.FloatStorage._dtype,
lambda: s0._resize_(20),
lambda: s0._size(),
lambda: s0._untyped_storage,
lambda: s0._is_shared(),
lambda: s0._share_memory_(),
lambda: s0._pickle_storage_type(),
lambda: s0._setitem(slice(0, s0._size()), 1),
lambda: s0._element_size(),
lambda: s0._deepcopy({}),
lambda: s0._data_ptr(),
lambda: s0._nbytes(),
lambda: t0._typed_storage(),
]

if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
s1_untyped = s1.untyped()
t1 = torch.randn(10, device='cuda')

funcs += [
lambda: torch.cuda.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cuda',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s1_untyped,
dtype=s1.dtype,
_internal=True),
lambda: torch.cuda.FloatStorage._dtype,
lambda: s1._resize_(20),
lambda: s1._size(),
lambda: s1._untyped_storage,
lambda: s1._is_shared(),
lambda: s1._share_memory_(),
lambda: s1._pickle_storage_type(),
lambda: s1._setitem(slice(0, s1._size()), 1),
lambda: s1._element_size(),
lambda: s1._deepcopy({}),
lambda: s1._data_ptr(),
lambda: s1._nbytes(),
lambda: t1._typed_storage(),
]

# Check that each of the TypedStorage internal function calls do not
# produce a deprecation warning
for f in funcs:
with warnings.catch_warnings():
warnings.filterwarnings('error', "TypedStorage is deprecated")
f()

# Test that public functions related to TypedStorage produce a deprecation
# warning
def test_typed_storage_deprecation_warning(self):
s0 = torch.FloatStorage(10)
funcs = [
Expand Down
2 changes: 1 addition & 1 deletion test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def is_view_of(self, base, other):
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'cpu' or base.device.type == 'cuda':
if base.storage().data_ptr() != other.storage().data_ptr():
if base._storage().data_ptr() != other._storage().data_ptr():
return False

return True
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "storage");
return handle_torch_function(self, "_storage");
}
auto& self_ = THPVariable_Unpack(self);
return createPyObject(self_.storage());
Expand Down
68 changes: 68 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,102 +718,170 @@ class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.uint8

class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.double

class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.float

class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.half

class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.long

class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int

class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.short

class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.int8

class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bool

class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.bfloat16

class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cdouble

class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.cfloat

class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint8

class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint8

class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.qint32

class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint4x2

class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype

@classproperty
def _dtype(self):
return torch.quint2x4

_storage_classes = {
Expand Down
2 changes: 1 addition & 1 deletion torch/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def persistent_id(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage = obj._untyped_storage
dtype = obj.dtype
else:
storage = obj
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/optimizations/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs):

def tensor_alias_group(self, value: torch.Tensor):
"""Assign a unique identifier to the storage of a given tensor"""
storage = StorageWeakRef(value.storage())
storage = StorageWeakRef(value._typed_storage())
alias_group = self.storage_to_alias_group.get(storage)
if alias_group is None:
alias_group = next(self.make_alias_group)
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/optimizations/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
for name, p in target.named_parameters():
param = target.get_parameter(name)
if p.requires_grad and not self._ignore_parameter(param):
buckets[0].size += p.storage().nbytes()
buckets[0].size += p._storage().nbytes()
buckets[0].params.append(f"{node.target}_{name}")
buckets[0].param_ids.append(id(param))
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if maybe_param.requires_grad and not self._ignore_parameter(
maybe_param
):
buckets[0].size += maybe_param.storage().nbytes()
buckets[0].size += maybe_param._storage().nbytes()
buckets[0].params.append(node.target)
buckets[0].param_ids.append(id(maybe_param))

Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/optimizations/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def meta_fk(meta):
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx)
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
Expand All @@ -402,7 +402,7 @@ def meta_fk(meta):
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta).storage())
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
Expand Down
4 changes: 3 additions & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,9 @@ def _as_strided_meta(
# as_strided to shapes with no elements are trivially valid, so it's OK
pass
elif isinstance(a, torch.Tensor):
utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset)
utils.check_in_bounds_for_storage(
a._typed_storage(), size, stride, storage_offset
)

return TensorMeta(a, shape=size, strides=stride)

Expand Down
8 changes: 4 additions & 4 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def add_constant_storage_mapping(self, fake_tensor):
# const_tensor.add_(torch.rand([1]))
# all aliases of it must become no longer const
assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
weak_st = StorageWeakRef(fake_tensor.constant.storage())
weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())

# we need a map from a weak storage to all of its corresponding
# constant tensors. python doesn't have the weak value equivalent
Expand All @@ -168,7 +168,7 @@ def add_constant_storage_mapping(self, fake_tensor):
def invalidate_constant_aliases(self, tensor):
assert not isinstance(tensor, FakeTensor)

weak_st = StorageWeakRef(tensor.storage())
weak_st = StorageWeakRef(tensor._typed_storage())
if weak_st not in self.constant_storage_mapping:
return

Expand Down Expand Up @@ -1042,7 +1042,7 @@ def to_real_tensor(e):
for e in tree_flatten((args, kwargs))[0]:
if isinstance(e, torch.Tensor):
if not e.is_sparse:
storages.add(e.storage()._cdata)
storages.add(e._typed_storage()._cdata)

# TODO: also check metadata change on inputs
# proper aliasing/metadata relationship between outputs and inputs will
Expand All @@ -1052,7 +1052,7 @@ def to_real_tensor(e):
if id(e) not in inp_impls and (
isinstance(e, torch.Tensor)
and not e.is_sparse
and e.storage()._cdata in storages
and e._typed_storage()._cdata in storages
):
raise orig_not_implemented_exception

Expand Down
Loading

0 comments on commit 9ec684c

Please sign in to comment.