From 9f654a8ce3ebab4cfbb3b0aa7606b9d6c8d5b7ee Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 27 Sep 2022 14:35:15 -0500 Subject: [PATCH] Avoid raising warnings for internal TypedStorage usage --- test/test_cuda.py | 2 +- test/test_torch.py | 71 ++++++++ torch/__init__.py | 68 +++++++ torch/_deploy.py | 2 +- torch/_prims/__init__.py | 2 +- torch/_subclasses/fake_tensor.py | 8 +- torch/_subclasses/meta_utils.py | 4 +- torch/_tensor.py | 32 +++- torch/_utils.py | 4 +- torch/csrc/DynamicTypes.cpp | 2 +- torch/cuda/__init__.py | 48 +++++ torch/cuda/_dynamo_graphs.py | 4 +- .../_shard/checkpoint/filesystem.py | 2 +- torch/distributed/fsdp/_utils.py | 10 +- torch/distributed/fsdp/flat_param.py | 12 +- torch/distributed/fsdp/utils.py | 10 +- .../pipeline/sync/_balance/profile.py | 2 +- torch/distributed/pipeline/sync/stream.py | 2 +- torch/fx/passes/reinplace.py | 24 +-- torch/multiprocessing/reductions.py | 14 +- torch/package/package_exporter.py | 2 +- torch/package/package_importer.py | 6 +- torch/serialization.py | 37 ++-- torch/storage.py | 169 ++++++++++++------ torch/utils/bundled_inputs.py | 4 +- torch/utils/data/_utils/collate.py | 2 +- 26 files changed, 403 insertions(+), 140 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 29556fa140a1f..93c43af115139 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -570,7 +570,7 @@ def test_serialization_array_with_storage(self): self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) - self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage)) + self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) diff --git a/test/test_torch.py b/test/test_torch.py index a39e1936fbed0..5ac9298d7a189 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6430,6 +6430,77 @@ def test_storage_casts(self): self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') self.assertIs(complexdouble_storage.dtype, torch.complex128) + # Test that internal versions of functions related to TypedStorage do not + # produce a deprecation warning + def test_typed_storage_internal_no_warning(self): + s0 = torch.FloatStorage(10) + s0_untyped = s0.untyped() + t0 = torch.randn(10) + + funcs = [ + lambda: torch.FloatStorage(_internal=True), + lambda: torch.TypedStorage( + dtype=torch.float, + device='cpu', + _internal=True), + lambda: torch.TypedStorage( + wrap_storage=s0_untyped, + dtype=s0.dtype, + _internal=True), + lambda: torch.FloatStorage._dtype, + lambda: s0._resize_(20), + lambda: s0._size(), + lambda: s0._untyped_storage, + lambda: s0._is_shared(), + lambda: s0._share_memory_(), + lambda: s0._pickle_storage_type(), + lambda: s0._setitem(slice(0, s1._size()), 1), + lambda: s0._element_size(), + lambda: s0._deepcopy({}), + lambda: s0._data_ptr(), + lambda: s0._nbytes(), + lambda: t0._typed_storage(), + ] + + if torch.cuda.is_available(): + s1 = torch.cuda.FloatStorage(10) + s1_untyped = s1.untyped() + t1 = torch.randn(10, device='cuda') + + funcs += [ + lambda: torch.cuda.FloatStorage(_internal=True), + lambda: torch.TypedStorage( + dtype=torch.float, + device='cuda', + _internal=True), + lambda: torch.TypedStorage( + wrap_storage=s1_untyped, + dtype=s1.dtype, + _internal=True), + lambda: torch.cuda.FloatStorage._dtype, + lambda: s1._resize_(20), + lambda: s1._size(), + lambda: s1._untyped_storage, + lambda: s1._is_shared(), + lambda: s1._share_memory_(), + lambda: s1._pickle_storage_type(), + lambda: s1._setitem(slice(0, s1._size()), 1), + lambda: s1._element_size(), + lambda: s1._deepcopy({}), + lambda: s1._data_ptr(), + lambda: s1._nbytes(), + lambda: t1._typed_storage(), + ] + + # Check that each of the TypedStorage internal function calls do not + # produce a deprecation warning + for f in funcs: + with warnings.catch_warnings(): + warnings.filterwarnings('error', "TypedStorage is deprecated") + f() + + # Test that public functions related to TypedStorage produce a deprecation + # warning def test_typed_storage_deprecation_warning(self): s0 = torch.FloatStorage(10) funcs = [ diff --git a/torch/__init__.py b/torch/__init__.py index 88fb039d8053e..5f2a0eba11ed5 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -656,102 +656,170 @@ class ByteStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.uint8 class DoubleStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.double class FloatStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.float class HalfStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.half class LongStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.long class IntStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int class ShortStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.short class CharStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int8 class BoolStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bool class BFloat16Storage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bfloat16 class ComplexDoubleStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cdouble class ComplexFloatStorage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cfloat class QUInt8Storage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint8 class QInt8Storage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.qint8 class QInt32Storage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.qint32 class QUInt4x2Storage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint4x2 class QUInt2x4Storage(_LegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint2x4 _storage_classes = { diff --git a/torch/_deploy.py b/torch/_deploy.py index 53769538b6c11..30c022eac8793 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -23,7 +23,7 @@ def persistent_id(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - storage = obj._storage + storage = obj._untyped_storage dtype = obj.dtype else: storage = obj diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index a52d1e1388e4b..3db17eec09e43 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1104,7 +1104,7 @@ def _as_strided_meta( # as_strided to shapes with no elements are trivially valid, so it's OK pass elif isinstance(a, torch.Tensor): - utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset) + utils.check_in_bounds_for_storage(a._typed_storage(), size, stride, storage_offset) return TensorMeta(a, shape=size, strides=stride) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index cd14c3b1aac66..21ec0c47ac636 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -144,7 +144,7 @@ def add_constant_storage_mapping(self, fake_tensor): # const_tensor.add_(torch.rand([1])) # all aliases of it must become no longer const assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None - weak_st = StorageWeakRef(fake_tensor.constant.storage()) + weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) # we need a map from a weak storage to all of its corresponding # constant tensors. python doesn't have the weak value equivalent @@ -156,7 +156,7 @@ def add_constant_storage_mapping(self, fake_tensor): def invalidate_constant_aliases(self, tensor): assert not isinstance(tensor, FakeTensor) - weak_st = StorageWeakRef(tensor.storage()) + weak_st = StorageWeakRef(tensor._typed_storage()) if weak_st not in self.constant_storage_mapping: return @@ -928,7 +928,7 @@ def to_real_tensor(e): for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): if not e.is_sparse: - storages.add(e.storage()._cdata) + storages.add(e._typed_storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will @@ -938,7 +938,7 @@ def to_real_tensor(e): if id(e) not in inp_impls and ( isinstance(e, torch.Tensor) and not e.is_sparse - and e.storage()._cdata in storages + and e._typed_storage()._cdata in storages ): raise orig_not_implemented_exception diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 9d641ba458e70..e668da3334057 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -98,7 +98,7 @@ def set_tensor_memo(self, t, v): if t.is_sparse: weak_st = None else: - weak_st = StorageWeakRef(t.storage()) + weak_st = StorageWeakRef(t._typed_storage()) tensor_ref_key = WeakTensorRefKey(t) def del_ten(): @@ -255,7 +255,7 @@ def is_c_of_r(complex_dtype, real_dtype): # As long as meta storage is not supported, need to prevent # redispatching on set_(Storage, ...) which will choke with # meta storage - s = self.meta_storage(t.storage()) + s = self.meta_storage(t._typed_storage()) with no_dispatch(): with torch.no_grad(): r.set_(s, sym_storage_offset(t), sym_size(t), sym_stride(t)) diff --git a/torch/_tensor.py b/torch/_tensor.py index 96293e9732996..976ea36d3ed20 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -128,7 +128,7 @@ def __deepcopy__(self, memo): "different type." ) else: - new_storage = self.storage().__deepcopy__(memo) + new_storage = self._typed_storage()._deepcopy(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute quantizer_params: Union[ @@ -159,7 +159,9 @@ def __deepcopy__(self, memo): # need to wrap with TypedStorage new_tensor = torch._utils._rebuild_qtensor( torch.storage.TypedStorage( - wrap_storage=new_storage.untyped(), dtype=self.dtype + wrap_storage=new_storage._untyped_storage, + dtype=self.dtype, + _internal=True ), self.storage_offset(), self.size(), @@ -253,7 +255,15 @@ def storage(self): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage, (self,), self) - return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype) + torch.storage._warn_typed_storage_removal() + return self._typed_storage() + + # For internal use only, to avoid raising deprecation warning + def _typed_storage(self): + return torch.TypedStorage( + wrap_storage=self._storage(), + dtype=self.dtype, + _internal=True) def _reduce_ex_internal(self, proto): check_serializing_named_tensor(self) @@ -325,7 +335,9 @@ def _reduce_ex_internal(self, proto): # need to wrap with TypedStorage args_qtensor = ( torch.storage.TypedStorage( - wrap_storage=self.storage().untyped(), dtype=self.dtype + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True ), self.storage_offset(), tuple(self.size()), @@ -383,7 +395,9 @@ def _reduce_ex_internal(self, proto): # need to wrap with TypedStorage args = ( torch.storage.TypedStorage( - wrap_storage=self.storage().untyped(), dtype=self.dtype + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True ), self.storage_offset(), tuple(self.size()), @@ -601,7 +615,7 @@ def is_shared(self): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.is_shared, (self,), self) - return self.storage().is_shared() + return self._typed_storage()._is_shared() def share_memory_(self): r"""Moves the underlying storage to shared memory. @@ -611,7 +625,7 @@ def share_memory_(self): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.share_memory_, (self,), self) - self.storage().share_memory_() + self._typed_storage()._share_memory_() return self def __reversed__(self): @@ -1053,7 +1067,9 @@ def storage_type(self): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage_type, (self,), self) - return self.storage()._get_legacy_storage_class() + torch.storage._warn_typed_storage_removal() + + return self._typed_storage()._get_legacy_storage_class() def refine_names(self, *names): r"""Refines the dimension names of :attr:`self` according to :attr:`names`. diff --git a/torch/_utils.py b/torch/_utils.py index 8a539d75f5657..f178cfbaea4ae 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -143,8 +143,8 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): # be a TypedStorage def _rebuild_tensor(storage, storage_offset, size, stride): # first construct a tensor with the correct dtype/device - t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device) - return t.set_(storage.untyped(), storage_offset, size, stride) + t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) + return t.set_(storage._untyped_storage, storage_offset, size, stride) def _rebuild_tensor_v2( diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index b3021ffe0d8d8..93bb37017ce0b 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -135,7 +135,7 @@ at::Storage createStorageGetType( TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj)); scalar_type = reinterpret_cast(dtype_obj)->scalar_type; - untyped_storage_obj = PyObject_GetAttrString(obj, "_storage"); + untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage"); TORCH_INTERNAL_ASSERT(untyped_storage_obj); Py_DECREF(untyped_storage_obj); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index d7c83ebfc739b..6d301357b2012 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -744,72 +744,120 @@ class ByteStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.uint8 class DoubleStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.double class FloatStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.float class HalfStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.half class LongStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.long class IntStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int class ShortStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.short class CharStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int8 class BoolStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bool class BFloat16Storage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bfloat16 class ComplexDoubleStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cdouble class ComplexFloatStorage(_CudaLegacyStorage): @classproperty def dtype(self): _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cfloat del _LegacyStorage diff --git a/torch/cuda/_dynamo_graphs.py b/torch/cuda/_dynamo_graphs.py index 56973e96435fc..c493c584c07b3 100644 --- a/torch/cuda/_dynamo_graphs.py +++ b/torch/cuda/_dynamo_graphs.py @@ -89,7 +89,7 @@ def find_input_mutations(g): mutated_inputs = set() for n in g.nodes: if n.op == 'placeholder': - inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx) + inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx) input_idx += 1 elif n.op == 'call_function': if n.target is operator.getitem: @@ -109,7 +109,7 @@ def find_input_mutations(g): if mut_arg: # TODO: not correct for args that contain tensors in a struct # like list - mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK].storage())] + mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())] # TODO: error on unrecognized nodes return mutated_inputs diff --git a/torch/distributed/_shard/checkpoint/filesystem.py b/torch/distributed/_shard/checkpoint/filesystem.py index ece9000b3ddfb..609c567e1e93e 100644 --- a/torch/distributed/_shard/checkpoint/filesystem.py +++ b/torch/distributed/_shard/checkpoint/filesystem.py @@ -51,7 +51,7 @@ class _StoragePrefix: def _trim(tensor: torch.Tensor) -> torch.Tensor: tensor = tensor.detach().cpu() - if tensor.storage().size() != tensor.numel(): + if tensor._typed_storage().size() != tensor.numel(): tensor = tensor.clone() return tensor diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py index 80688e5dec03d..1a1d5a3e25d26 100644 --- a/torch/distributed/fsdp/_utils.py +++ b/torch/distributed/fsdp/_utils.py @@ -96,14 +96,14 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: bool: ``True`` if this method allocated storage and ``False`` if the storage was already allocated. """ - already_allocated = tensor.storage().size() == size.numel() + already_allocated = tensor._typed_storage().size() == size.numel() if not already_allocated: - tensor_storage_size = tensor.storage().size() + tensor_storage_size = tensor._typed_storage().size() p_assert( tensor_storage_size == 0, f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", ) - tensor.storage().resize_(size.numel()) + tensor._typed_storage().resize_(size.numel()) return not already_allocated @@ -116,13 +116,13 @@ def _free_storage(tensor: torch.Tensor) -> bool: bool: ``True`` if the method freed the storage and ``False`` if the storage was already freed. """ - already_freed = tensor.storage().size() == 0 + already_freed = tensor._typed_storage().size() == 0 if not already_freed: p_assert( tensor.storage_offset() == 0, "Freeing a tensor's storage is unsafe when it is not the sole occupant", ) - tensor.storage().resize_(0) + tensor._typed_storage().resize_(0) return not already_freed diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index 1fd1f277906db..c1960ce4e6605 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -416,7 +416,7 @@ def shard(self, process_group: dist.ProcessGroup): assert ( flat_param.storage_offset() == 0 ), "The `FlatParameter` is not the sole occupant of its storage" - orig_storage = flat_param.storage() + orig_storage = flat_param._typed_storage() local_shard, numel_padded = FlatParamHandle._get_shard( flat_param, self.rank, self.world_size ) @@ -680,7 +680,7 @@ def needs_unshard(self) -> bool: return False unsharded_flat_param = self._get_padded_unsharded_flat_param() already_unsharded = ( - unsharded_flat_param.storage().size() == unsharded_flat_param.numel() + unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel() ) return not already_unsharded @@ -857,9 +857,9 @@ def to_cpu(self): # the padded unsharded flattened parameter as expected # NOTE: This check is not strictly needed for correctness but is a # useful sanity check since the tensor should only be used internally. - unpadded_storage_ptr = self.flat_param.storage().data_ptr() + unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr() padded_storage_ptr = ( - self._get_padded_unsharded_flat_param().storage().data_ptr() + self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr() ) p_assert( unpadded_storage_ptr == padded_storage_ptr, @@ -1067,7 +1067,7 @@ def _check_on_compute_device(self, tensor: Tensor): @staticmethod def _check_storage_freed(tensor: Tensor): - storage_size: int = tensor.storage().size() + storage_size: int = tensor._typed_storage()._size() p_assert( storage_size == 0, f"Expects storage to be freed but got storage with size {storage_size}", @@ -1075,7 +1075,7 @@ def _check_storage_freed(tensor: Tensor): @staticmethod def _check_storage_allocated(tensor: Tensor): - storage_size: int = tensor.storage().size() + storage_size: int = tensor._typed_storage()._size() p_assert(storage_size > 0, "Expects storage to be allocated") def _check_low_precision_shard(self): diff --git a/torch/distributed/fsdp/utils.py b/torch/distributed/fsdp/utils.py index 9f50ec87e2e38..51849cc6de767 100644 --- a/torch/distributed/fsdp/utils.py +++ b/torch/distributed/fsdp/utils.py @@ -96,14 +96,14 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: bool: ``True`` if this method allocated storage and ``False`` if the storage was already allocated. """ - already_allocated = tensor.storage().size() == size.numel() + already_allocated = tensor._typed_storage()._size() == size.numel() if not already_allocated: - tensor_storage_size = tensor.storage().size() + tensor_storage_size = tensor._typed_storage()._size() p_assert( tensor_storage_size == 0, f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", ) - tensor.storage().resize_(size.numel()) + tensor._typed_storage()._resize_(size.numel()) return not already_allocated @@ -116,13 +116,13 @@ def _free_storage(tensor: torch.Tensor) -> bool: bool: ``True`` if the method freed the storage and ``False`` if the storage was already freed. """ - already_freed = tensor.storage().size() == 0 + already_freed = tensor._typed_storage()._size() == 0 if not already_freed: p_assert( tensor.storage_offset() == 0, "Freeing a tensor's storage is unsafe when it is not the sole occupant", ) - tensor.storage().resize_(0) + tensor._typed_storage()._resize_(0) return not already_freed diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py index 6b8a240a2cdfe..1dbd60898eb34 100644 --- a/torch/distributed/pipeline/sync/_balance/profile.py +++ b/torch/distributed/pipeline/sync/_balance/profile.py @@ -105,7 +105,7 @@ def profile_sizes( latent_size = memory_after - memory_before # Analyze size of parameters. - param_size = sum(p.storage().nbytes() for p in layer.parameters()) + param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) # Combine size of parameters and activations with normalize scales. size = latent_size * latent_scale + param_size * param_scale diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py index 41e1591793b6c..0a9e788731041 100644 --- a/torch/distributed/pipeline/sync/stream.py +++ b/torch/distributed/pipeline/sync/stream.py @@ -102,7 +102,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: # # Issue: https://github.com/pytorch/pytorch/issues/27366 # - tensor = tensor.new_empty([0]).set_(tensor.storage()) + tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index ff24ef97f5459..c32b48841437c 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -100,8 +100,8 @@ def run_node(self, node: Node): # Assert here that this is actually the case, and their storages are the same. assert isinstance(node.meta['fake_result'], FakeTensor) assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) - view_storage = StorageWeakRef(node.meta['fake_result'].storage()) - base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage()) + view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) + base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) assert view_storage == base_storage return result @@ -176,7 +176,7 @@ def _maybe_get_inplace_op(op): def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): def _add_if_tensor(x, set_): if isinstance(x, FakeTensor): - set_.add(StorageWeakRef(x.storage())) + set_.add(StorageWeakRef(x._typed_storage())) nodes_used_after = set() for t in tensor_aliases: @@ -452,7 +452,7 @@ def f(x): # Useful debug printing # def _print(x): # if isinstance(x, FakeTensor): - # print(f'fake_result: {StorageWeakRef(x.storage()).cdata}') + # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') # for n in gm.graph.nodes: # print(n.format_node()) @@ -468,7 +468,7 @@ def f(x): # so we know not to re-inplace them. # NOTE: later, we'll need to add an optimization for fully recovering performance # on programs that mutate inputs. - input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder') + input_storages = set(StorageWeakRef(node.meta['fake_result']._typed_storage()) for node in gm.graph.nodes if node.op == 'placeholder') # We also need to know for a given node, what are all of its aliasing nodes. @@ -478,7 +478,7 @@ def f(x): # Tree-mapping because some ops can return lists of tensors. def _add_to_map(x): if isinstance(x, FakeTensor): - storage_to_nodes[StorageWeakRef(x.storage())].add(n) + storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) tree_map(_add_to_map, n.meta['fake_result']) # inplace-ify functional ops, subject to the constraints written below. @@ -529,7 +529,7 @@ def _add_to_map(x): # Step 1b: ensure that the op we're trying to re-inplace isn't a program input self_arg_name = self_arg.name - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) if self_arg_storage in input_storages: # TODO: later, add the optimization for handling `copy_()` calls in the graph. continue @@ -539,7 +539,7 @@ def _add_to_map(x): # so we prevent re-inplacing in this case. continue - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) self_aliases = storage_to_nodes[self_arg_storage] # First, we find all later usages of any of the aliases of self_arg. @@ -594,7 +594,7 @@ def _add_to_map(x): # Hmm... morally I think we also want to keep the `fake_result` metadata # up to date here, but I'm not sure how easy it is to do. # Maybe it's fine to wait until the end of the pass to update it. - curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage()) + curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) @@ -624,8 +624,8 @@ def replace_arg(a): old_flattened_res, _ = tree_flatten(old.meta['fake_result']) node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result']) - old_res_storage = set(StorageWeakRef(x.storage()) for x in old_flattened_res if isinstance(x, FakeTensor)) - node_res_storage = set(StorageWeakRef(x.storage()) for x in node_flattened_res if isinstance(x, FakeTensor)) + old_res_storage = set(StorageWeakRef(x._typed_storage()) for x in old_flattened_res if isinstance(x, FakeTensor)) + node_res_storage = set(StorageWeakRef(x._typed_storage()) for x in node_flattened_res if isinstance(x, FakeTensor)) # This will happen if we're updating a view op, e.g. # e.g. replacing @@ -639,7 +639,7 @@ def replace_arg(a): # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: new_flattened_res, _ = tree_flatten(new.meta['fake_result']) - new_res_storage = set(StorageWeakRef(x.storage()) for x in new_flattened_res if isinstance(x, FakeTensor)) + new_res_storage = set(StorageWeakRef(x._typed_storage()) for x in new_flattened_res if isinstance(x, FakeTensor)) assert len(new_res_storage) == 1 (old_ref,) = old_res_storage (new_ref,) = new_res_storage diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 403b28d6a63c6..82c283e0ea7ca 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -133,7 +133,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) t = torch._utils._rebuild_tensor( - torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype), + torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True), tensor_offset, tensor_size, tensor_stride) if tensor_cls == torch.nn.parameter.Parameter: @@ -147,7 +147,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, def reduce_tensor(tensor): - storage = tensor.storage() + storage = tensor._typed_storage() if tensor.requires_grad and not tensor.is_leaf: raise RuntimeError("Cowardly refusing to serialize non-leaf tensor which requires_grad, " @@ -248,7 +248,7 @@ def reduce_tensor(tensor): # eliminated it so that we could just use tensor views to implement the same # thing. # - if storage.is_cuda: + if storage._untyped_storage.device.type == 'cuda': (device, handle, storage_size_bytes, @@ -334,18 +334,18 @@ def rebuild_storage_empty(cls): return cls() def rebuild_typed_storage(storage, dtype): - return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype) + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) # Use for torch.storage.TypedStorage def reduce_typed_storage(storage): - return (rebuild_typed_storage, (storage._storage, storage.dtype)) + return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) def rebuild_typed_storage_child(storage, storage_type): - return storage_type(wrap_storage=storage) + return storage_type(wrap_storage=storage, _internal=True) # Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage def reduce_typed_storage_child(storage): - return (rebuild_typed_storage_child, (storage._storage, type(storage))) + return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) def reduce_storage(storage): from . import get_sharing_strategy diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 81b5e650b518b..fbe3cc5a73cde 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -887,7 +887,7 @@ def _persistent_id(self, obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - untyped_storage = obj._storage + untyped_storage = obj._untyped_storage storage_type_str = obj.pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj.size() diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 6efa943f11e7e..4b6d8fcfd6919 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -205,7 +205,7 @@ def load_tensor(dtype, size, key, location, restore_location): name = f"{key}.storage" if storage_context.has_storage(name): - storage = storage_context.get_storage(name, dtype).storage() + storage = storage_context.get_storage(name, dtype)._typed_storage() else: tensor = self.zip_reader.get_storage_from_record( ".data/" + name, size, dtype @@ -236,7 +236,9 @@ def persistent_load(saved_id): # TODO: Once we decide to break serialization FC, we can # stop wrapping with TypedStorage return torch.storage.TypedStorage( - wrap_storage=storage.untyped(), dtype=dtype + wrap_storage=storage._untyped_storage, + dtype=dtype, + _internal=True ) elif typename == "reduce_package": # to fix BC breaking change, objects on this load path diff --git a/torch/serialization.py b/torch/serialization.py index 04573a059f029..903b6d1ec5597 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -463,12 +463,12 @@ def persistent_id(obj: Any) -> Optional[Tuple]: if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted - storage = obj._storage + storage = obj._untyped_storage storage_dtype = obj.dtype - storage_type_str = obj.pickle_storage_type() + storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) dtype = obj.dtype - storage_numel = obj.size() + storage_numel = obj._size() elif isinstance(obj, torch.UntypedStorage): storage = obj @@ -591,11 +591,11 @@ def persistent_id(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted - storage = obj._storage + storage = obj._untyped_storage storage_dtype = obj.dtype - storage_type_str = obj.pickle_storage_type() + storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) - storage_numel = obj.size() + storage_numel = obj._size() else: storage = obj @@ -863,7 +863,8 @@ def persistent_load(saved_id): # stop wrapping with TypedStorage deserialized_objects[key] = torch.storage.TypedStorage( wrap_storage=obj, - dtype=dtype) + dtype=dtype, + _internal=True) storage_views = pickle_module.load(f, **pickle_load_args) for target_cdata, root_cdata, offset, numel in storage_views: @@ -873,8 +874,9 @@ def persistent_load(saved_id): # TODO: Once we decide to break serialization FC, we can # stop wrapping with TypedStorage deserialized_objects[target_cdata] = torch.storage.TypedStorage( - wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size], - dtype=root.dtype) + wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], + dtype=root.dtype, + _internal=True) tar.extract('tensors', path=tmpdir) with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: @@ -890,7 +892,7 @@ def persistent_load(saved_id): stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) storage_offset, = struct.unpack(' Union[T, str]: _warn_typed_storage_removal() @@ -578,22 +596,26 @@ def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: return '.'.join([self.__module__, type(self).__name__]) else: - return self._storage.type(dtype, non_blocking) + return self._untyped_storage.type(dtype, non_blocking) def cuda(self, device=None, non_blocking=False, **kwargs) -> T: _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") - cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs) + cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs) return self._new_wrapped_storage(cuda_storage) def element_size(self): _warn_typed_storage_removal() + return self._element_size() + + # For internal use only, to avoid deprecation warning + def _element_size(self): return torch._utils._element_size(self.dtype) def get_device(self) -> int: _warn_typed_storage_removal() - return self._storage.get_device() + return self._untyped_storage.get_device() def __str__(self): _warn_typed_storage_removal() @@ -616,11 +638,15 @@ def __iter__(self): def __copy__(self): _warn_typed_storage_removal() - return self._new_wrapped_storage(copy.copy(self._storage)) + return self._new_wrapped_storage(copy.copy(self._untyped_storage)) def __deepcopy__(self, memo): _warn_typed_storage_removal() - return self._new_wrapped_storage(copy.deepcopy(self._storage, memo)) + return self._deepcopy(memo) + + # For internal use only, to avoid deprecation warning + def _deepcopy(self, memo): + return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) def __sizeof__(self): _warn_typed_storage_removal() @@ -629,7 +655,7 @@ def __sizeof__(self): def clone(self): """Returns a copy of this storage""" _warn_typed_storage_removal() - return self._new_wrapped_storage(self._storage.clone()) + return self._new_wrapped_storage(self._untyped_storage.clone()) def tolist(self): """Returns a list containing the elements of this storage""" @@ -639,12 +665,12 @@ def tolist(self): def cpu(self): """Returns a CPU copy of this storage if it's not already on the CPU""" _warn_typed_storage_removal() - return self._new_wrapped_storage(self._storage.cpu()) + return self._new_wrapped_storage(self._untyped_storage.cpu()) def pin_memory(self): """Coppies the storage to pinned memory, if it's not already pinned.""" _warn_typed_storage_removal() - return self._new_wrapped_storage(self._storage.pin_memory()) + return self._new_wrapped_storage(self._untyped_storage.pin_memory()) def share_memory_(self): """Moves the storage to shared memory. @@ -656,7 +682,11 @@ def share_memory_(self): Returns: self """ _warn_typed_storage_removal() - self._storage.share_memory_() + return self._share_memory_() + + # For internal use only, to avoid deprecation warning + def _share_memory_(self): + self._untyped_storage.share_memory_() return self def _new_shared(self, size, *, device=None): @@ -664,51 +694,67 @@ def _new_shared(self, size, *, device=None): if device is None: device = 'cpu' device = torch.device(device) - untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device) + untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device) return TypedStorage( wrap_storage=untyped_storage, - dtype=self.dtype) + dtype=self.dtype, + _internal=True) @property def _cdata(self): - return self._storage._cdata + return self._untyped_storage._cdata @property def device(self): _warn_typed_storage_removal() - return self._storage.device + return self._untyped_storage.device def size(self): _warn_typed_storage_removal() - return len(self) + return self._size() + + # For internal use only, to avoid deprecation warning + def _size(self): + return self._untyped_storage.nbytes() // self._element_size() def pickle_storage_type(self): _warn_typed_storage_removal() + return self._pickle_storage_type() + + # For internal use only, to avoid deprecation warning + def _pickle_storage_type(self): try: return _dtype_to_storage_type_map()[self.dtype] except KeyError: raise KeyError(f'dtype {self.dtype} is not recognized') def __reduce__(self): - _warn_typed_storage_removal() b = io.BytesIO() torch.save(self, b, _use_new_zipfile_serialization=False) return (_load_from_bytes, (b.getvalue(),)) def data_ptr(self): _warn_typed_storage_removal() - return self._storage.data_ptr() + return self._data_ptr() + + # For internal use only, to avoid deprecation warning + def _data_ptr(self): + return self._untyped_storage.data_ptr() def resize_(self, size): _warn_typed_storage_removal() - self._storage.resize_(size * self.element_size()) + self._resize_(size) + + # For internal use only, to avoid deprecation warning + def _resize_(self, size): + self._untyped_storage.resize_(size * self._element_size()) @classmethod def _free_weak_ref(cls, *args, **kwargs): return UntypedStorage._free_weak_ref(*args, **kwargs) def _weak_ref(self, *args, **kwargs): - return self._storage._weak_ref(*args, **kwargs) + return self._untyped_storage._weak_ref(*args, **kwargs) @classmethod def from_buffer(cls, *args, dtype=None, device=None, **kwargs): @@ -733,12 +779,15 @@ def from_buffer(cls, *args, dtype=None, device=None, **kwargs): dtype = cls.dtype untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) - return TypedStorage(wrap_storage=untyped_storage, dtype=dtype) + return TypedStorage( + wrap_storage=untyped_storage, + dtype=dtype, + _internal=True) def _to(self, dtype): if not isinstance(dtype, torch.dtype): raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") - storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage() + storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage() if storage.data_ptr() == self.data_ptr(): storage = storage.clone() return storage @@ -838,34 +887,38 @@ def _expired(cls, *args, **kwargs): def is_pinned(self): _warn_typed_storage_removal() - return self._storage.is_pinned() + return self._untyped_storage.is_pinned() def _write_file(self, *args, **kwargs): - return self._storage._write_file(*args, **kwargs) + return self._untyped_storage._write_file(*args, **kwargs) def _set_from_file(self, *args, **kwargs): - return self._storage._set_from_file(*args, **kwargs) + return self._untyped_storage._set_from_file(*args, **kwargs) def _set_cdata(self, *args, **kwargs): - return self._storage._set_cdata(*args, **kwargs) + return self._untyped_storage._set_cdata(*args, **kwargs) def _share_cuda_(self, *args, **kwargs): - return self._storage._share_cuda_(*args, **kwargs) + return self._untyped_storage._share_cuda_(*args, **kwargs) def is_shared(self): _warn_typed_storage_removal() - return self._storage.is_shared() + return self._is_shared() + + # For internal use only, to avoid deprecation warning + def _is_shared(self): + return self._untyped_storage.is_shared() @classmethod def _new_shared_cuda(cls, *args, **kwargs): return torch.UntypedStorage._new_shared_cuda(*args, **kwargs) def _share_filename_cpu_(self, *args, **kwargs): - manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs) - return manager_handle, storage_handle, size // self.element_size() + manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs) + return manager_handle, storage_handle, size // self._element_size() def _shared_decref(self): - self._storage._shared_decref() + self._untyped_storage._shared_decref() return self @classmethod @@ -873,11 +926,11 @@ def _release_ipc_counter(cls, *args, device=None, **kwargs): return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) def _shared_incref(self, *args, **kwargs): - return self._storage._shared_incref(*args, **kwargs) + return self._untyped_storage._shared_incref(*args, **kwargs) def _share_fd_cpu_(self, *args, **kwargs): - fd, size = self._storage._share_fd_cpu_(*args, **kwargs) - return fd, size // self.element_size() + fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs) + return fd, size // self._element_size() def _get_legacy_storage_class(self): if self.dtype not in _dtype_to_storage_type_map(): @@ -911,7 +964,7 @@ class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): @classmethod def _new_shared(cls, size): """Creates a new storage in shared memory with the same data type""" - untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size()) + untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size()) return cls(wrap_storage=untyped_storage) @classmethod diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index 1ca2d56616bc2..4ae39733ff2e4 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -391,7 +391,7 @@ def _inflate_expr( if isinstance(arg, torch.Tensor): # Small-storage tensors can just be saved directly. - if arg.storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check: + if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check: return arg, ref, None # Small contiguous tensors can be cloned to have small storage. # TODO: Should we do this even for non-contiguous tensors? @@ -407,7 +407,7 @@ def _inflate_expr( # TODO: Provide more useful diagnostics. raise Exception( f"Bundled input argument at position '{ref}' is " - f"a tensor with storage size {arg.storage().size()}. " + f"a tensor with storage size {arg._typed_storage().size()}. " f"You probably don't want to bundle this as an input. " ) else: diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 0ba9f25c2c9d2..1a00cd4514f58 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -158,7 +158,7 @@ def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[ # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) - storage = elem.storage()._new_shared(numel, device=elem.device) + storage = elem._typed_storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out)