diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 9cdca6021ae17..69e6f91cb7c1d 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -4066,17 +4066,6 @@ default: 0 ]] -[[ - name: reshape_ - cname: resize - cpu_half: True - return: self - arguments: - - THTensor* self - - arg: THSize* size - - arg: THStride* stride -]] - [[ name: _sparse_mask return: argument 0 diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index a18833dd128eb..269ba3e016192 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -132,6 +132,95 @@ Tensor repeat(const Tensor& self, IntList repeats) { return result; } +// Infers the size of a dim with size -1, if it exists. Also checks that new +// shape is compatible with the number of elements. +static std::vector infer_size(IntList shape, int64_t numel) { + auto res = shape.vec(); + int64_t newsize = 1; + auto infer_dim = at::optional(); + for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) { + if (shape[dim] == -1) { + if (infer_dim) { + throw std::runtime_error("only one dimension can be inferred"); + } + infer_dim = dim; + } else if (shape[dim] >= 0) { + newsize *= shape[dim]; + } else { + runtime_error("invalid shape dimension %zd", shape[dim]); + } + } + + if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) { + if (infer_dim) { + res[*infer_dim] = numel / newsize; + } + if (numel == 0) { + // Collapse zero-element shapes into one dimension because TH handles zeros + // in sizes strangely: x.resize_(1, 0) has shape (1,). TODO: remove this + // once we have multi-dimensional empty tensors. + return {0}; + } + return res; + } + + std::ostringstream ss; + ss << "shape '" << shape << "' is invalid for input of size " << numel; + throw std::runtime_error(ss.str()); +} + +static at::optional> +compute_stride(const Tensor& self, IntList newshape) { + auto oldstride = self.strides(); + auto oldshape = self.sizes(); + if (oldshape.empty()) { + return std::vector(newshape.size(), 1); + } + + std::vector newstride(newshape.size()); + int64_t view_d = newshape.size() - 1; + // stride for each subspace in the chunk + int64_t chunk_base_stride = oldstride.back(); + // numel in current chunk + int64_t tensor_numel = 1; + int64_t view_numel = 1; + for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) { + tensor_numel *= oldshape[tensor_d]; + // if end of tensor size chunk, check view + if ((tensor_d == 0) || + (oldshape[tensor_d - 1] != 1 && oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) { + while (view_d >= 0 && (view_numel < tensor_numel || newshape[view_d] == 1)) { + newstride[view_d] = view_numel * chunk_base_stride; + view_numel *= newshape[view_d]; + view_d--; + } + if (view_numel != tensor_numel) { + return {}; + } + if (tensor_d > 0) { + chunk_base_stride = oldstride[tensor_d - 1]; + tensor_numel = 1; + view_numel = 1; + } + } + } + if (view_d != -1) { + return {}; + } + return newstride; +} + +Tensor reshape(const Tensor& self, IntList proposed_shape) { + if (self.type().is_sparse()) { + runtime_error("reshape is not implemented for sparse tensors"); + } + auto shape = infer_size(proposed_shape, self.numel()); + if (auto stride = compute_stride(self, shape)) { + return self.as_strided(shape, *stride); + } + return at::_unsafe_view(self.clone(), shape); +} + Tensor select(const Tensor& self, int64_t dim, int64_t index) { int64_t ndim = self.dim(); AT_ASSERT(ndim > 0, "select() cannot be applied to a 0-dim tensor."); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 083e303141e86..86671255f8856 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -378,6 +378,8 @@ - func: repeat(Tensor self, IntList repeats) -> Tensor variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. +- func: reshape(Tensor self, IntList shape) -> Tensor + - func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor) variants: function dispatch: diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index a0db66675ae5b..8f2d7d3bcd53a 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -252,6 +252,7 @@ view of a storage and defines numeric operations on it. .. automethod:: renorm .. automethod:: renorm_ .. automethod:: repeat + .. automethod:: reshape .. automethod:: resize_ .. automethod:: resize_as_ .. automethod:: round diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 9085f93726095..3e16badc8350c 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -34,6 +34,7 @@ Indexing, Slicing, Joining, Mutating Ops .. autofunction:: index_select .. autofunction:: masked_select .. autofunction:: nonzero +.. autofunction:: reshape .. autofunction:: split .. autofunction:: squeeze .. autofunction:: stack diff --git a/test/test_autograd.py b/test/test_autograd.py index e792ad8aed5a2..2390e85cd4d6c 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2213,6 +2213,11 @@ class dont_convert(tuple): ('view', (S,), (S,), '1d'), ('view', (), (dont_convert(()),), 'scalar_to_scalar'), ('view', (), (1,), 'scalar_to_1d'), + ('reshape', (S, S, S), (S * S, S),), + ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'), + ('reshape', (S,), (S,), '1d'), + ('reshape', (), (dont_convert(()),), 'scalar_to_scalar'), + ('reshape', (), (1,), 'scalar_to_1d'), ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'), ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), @@ -2748,6 +2753,7 @@ def unpack_variables(args): 'addmv_', 'addr', 'addr_', + 'reshape', 'where' # argument order } EXCLUDE_GRADCHECK = { diff --git a/test/test_torch.py b/test/test_torch.py index 1c40e1ddf9d3f..576852651332c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4425,6 +4425,31 @@ def _test_view(self, cast): def test_view(self): TestTorch._test_view(self, lambda x: x) + def test_reshape(self): + x = torch.randn(3, 3) + self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) + self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) + self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + + y = torch.randn(4, 4, 4)[:, 0, :] + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) + self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) + + s = torch.randn(()) + self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) + self.assertEqual(s.reshape(-1).shape, (1,)) + self.assertRaises(RuntimeError, lambda: s.reshape(2)) + + empty = torch.tensor([]) + self.assertEqual(empty, empty.reshape(-1)) + self.assertEqual(empty, empty.reshape([0])) + # TODO: fix these once we have multi-dimensional empty tensors + self.assertEqual(empty.reshape([0, 1]).shape, (0,)) + self.assertEqual(empty.reshape([1, -1]).shape, (0,)) + self.assertRaises(RuntimeError, lambda: empty.reshape(1)) + def test_expand(self): tensor = torch.rand(1, 8, 1) tensor2 = torch.rand(5) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index f41a7657e7c62..d78c6059ed9ce 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1290,6 +1290,19 @@ def callable(a, b) -> number """) +add_docstr_all('reshape', + r""" +reshape(*shape) -> Tensor + +Returns a tensor with the same data and number of elements as :attr:`self`, +but with the specified shape. + +Args: + shape (tuple of ints or int...): the desired shape + +See :func:`torch.reshape` +""") + add_docstr_all('resize_', r""" resize_(*sizes) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 38642e9d36ae6..73ab218c68f6b 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4295,6 +4295,41 @@ """) +add_docstr(torch.reshape, + r""" +reshape(input, shape) -> Tensor + +Returns a tensor with the same data and number of elements as :attr:`input`, +but with the specified shape. When possible, the returned tensor will be a view +of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs +with compatible strides can be reshaped without copying, but you should not +depend on the copying vs. viewing behavior. + +A single dimension may be -1, in which case it's inferred from the remaining +dimensions and the number of elements in :attr:`input`. + +Args: + input (Tensor): the tensor to be reshaped + shape (tuple of ints): the new shape + +Example:: + + >>> a = torch.arange(4) + >>> torch.reshape(a, (2, 2)) + 0 1 + 2 3 + [torch.FloatTensor of size (2,2)] + + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + 0 + 1 + 2 + 3 + [torch.FloatTensor of size (4,)] +""") + + add_docstr(torch.round, r""" round(input, out=None) -> Tensor diff --git a/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp b/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp index f2224370438c3..281dbef11dbfe 100644 --- a/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp +++ b/torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp @@ -123,13 +123,10 @@ static void tensorNewClone(rpc::RPCMessage& raw_message) { static void tensorResize(rpc::RPCMessage& raw_message) { at::Tensor tensor = unpackRetrieveTensor(raw_message); THLongStorage *size = unpackTHLongStorage(raw_message); - THLongStorage *stride = unpackTHLongStorage(raw_message); finalize(raw_message); at::ArrayRef sizeRef(size->data, size->size); - at::ArrayRef strideRef(stride->data, stride->size); - tensor.reshape_(sizeRef, strideRef); + tensor.resize_(sizeRef); THLongStorage_free(size); - THLongStorage_free(stride); } static void tensorResizeAs(rpc::RPCMessage& raw_message) {