Skip to content

Commit

Permalink
Support resize on meta storage (#101988)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #101988
Approved by: https://github.com/albanD, https://github.com/bdhirsh
  • Loading branch information
ezyang authored and pytorchmergebot committed May 23, 2023
1 parent 51ff408 commit 7d1ba0a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ const Tensor& resize_as_(
}


static void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
storage->set_nbytes(std::move(size_bytes));
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/Resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);

TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);

static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
// It does not make sense to try to resize a storage
Expand Down
11 changes: 8 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,6 @@ def test_storage_meta_errors(self, device, dtype):
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
s0.pin_memory()

with self.assertRaisesRegex(RuntimeError, r'got unexpected device type'):
s0.resize_(10)

with self.assertRaisesRegex(RuntimeError, r'only available on CPU'):
s0.share_memory_()

Expand All @@ -369,6 +366,14 @@ def test_storage_meta_errors(self, device, dtype):
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
s1.copy_(s0)

@onlyCPU
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_storage_meta_ok(self, device, dtype):
s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)

# This is OK, it changes the meta storage size without allocating
s0.resize_(10)

@onlyCUDA
def test_module_share_memory(self):
# Test fix for issue #80733
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/StorageMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
const auto size_bytes = static_cast<size_t>(size_bytes_i);
at::native::resize_bytes_cuda(storage.unsafeGetStorageImpl(), size_bytes);
#endif
} else if (device_type == at::kMeta) {
at::native::resize_bytes_meta(storage.unsafeGetStorageImpl(), newsize);
} else if (device_type == at::kPrivateUse1) {
ptrdiff_t size_bytes_i = newsize;
TORCH_CHECK(
Expand Down

1 comment on commit 7d1ba0a

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #101988 on behalf of https://github.com/osalpekar due to Need to revert and rebase this in order to unblock train import (comment)

Please sign in to comment.