diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 7c6c804db533103..7a4f80111be24ca 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1040,21 +1040,43 @@ Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) { return at::_mkldnn_reshape(self, shape); } - auto stride = - at::detail::computeStride(self.sizes(), self.strides(), shape); - // `computeStride` returns the proper strides to use if this - // `reshape` can be just a view. - // - // NB: Even though we have viewable geometry and the target strides here, - // we do not just call `as_strided` on `self` because the backward - // for `as_strided` is not as efficient as that of `view` (since the - // former is meant to handle general cases). + // `computeStride` returns the proper strides to use if this + // `reshape` can be just a view. + auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape); + + // NB: Even though we have viewable geometry and the target strides here, + // we do not just call `as_strided` on `self` because the backward + // for `as_strided` is not as efficient as that of `view` (since the + // former is meant to handle general cases). + // + // Similarly we don't call `view` because it duplicates some of the work + // we've already done, and instead call our internal/private operator + // `_reshape_alias` that essentially does the same thing as `view` and + // `as_strided` without any of the extra overhead. if (stride.has_value()) { - return self.view(shape); + // Temporary check to revert to the old behavior/view in cases where the + // device is not supported (e.g. for XLA the operation is not supported + // so we use `view` instead). + // + // We need to do the checks here instead of in `native_functions.yaml` + // to preserve backwards compatibility. + if (! self.is_xla()) { + return self._reshape_alias(shape, stride.value()); + } else { + return self.view(shape); + } } return at::_unsafe_view(self.clone(at::MemoryFormat::Contiguous), shape); } +Tensor _reshape_alias(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) { + // This is only used by `reshape` in cases where it would otherwise have dispatched + // to `view`. This removes the overhead of calling `view` which duplicates some of + // the work that's already been done (`infer_size_dv` and `computeStride`). + + return alias_with_sizes_and_strides(self, sizes, strides); +} + Tensor reshape_as(const Tensor& self, const Tensor& other) { return self.reshape(other.sizes()); } @@ -2152,11 +2174,13 @@ Tensor numpy_T(const Tensor &self) { return self.permute(transpose_dims); } -Tensor view(const Tensor& self, IntArrayRef size) { +Tensor view(const Tensor& self, + IntArrayRef size) { + at::DimVector inferred_size = at::infer_size_dv(size, self.numel()); auto stride = at::detail::computeStride(self.sizes(), - self.strides(), - inferred_size); + self.strides(), + inferred_size); TORCH_CHECK(stride.has_value(), "view size is " "not compatible with input tensor's size and stride (at least one dimension" " spans across two contiguous subspaces). Use .reshape(...) instead."); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0bc6315e4b94e95..d700d712ef6fd60 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3443,6 +3443,17 @@ device_check: NoCheck device_guard: False +# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape. +# They are not user-facing, hence the leading underscore. Please don't use it +# anywhere else. +- func: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: _reshape_alias + # We don't need to support mkldnn since this is handled explicitly by the reshape operator. + - func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor device_check: NoCheck device_guard: False diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp index a0c227e3c71d8b9..f541d39e3bc4f5d 100644 --- a/test/cpp/api/tensor.cpp +++ b/test/cpp/api/tensor.cpp @@ -1117,3 +1117,28 @@ TEST(TensorTest, StdDimension) { ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3); ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3); } + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(TensorTest, ReshapeAlias) { + // Tests the behavior of the _reshape_alias private operator so + // that it matches the behavior of as_strided and view. + auto x = torch::randn({3, 3}); + ASSERT_TRUE(torch::equal( + torch::_reshape_alias(x, {2, 2}, {1, 2}), + torch::as_strided(x, {2, 2}, {1, 2}) + )); + ASSERT_TRUE(torch::equal( + torch::_reshape_alias(x, {9}, {1}), + x.view({-1}) + )); + + // Test that the backward works fine. + auto y = torch::randn({3, 3}, torch::requires_grad(true)); + auto z = torch::clone(y).detach().requires_grad_(true); + (y * y).view({-1}).mean().backward(); + torch::_reshape_alias((z * z), {9}, {1}).mean().backward(); + ASSERT_TRUE(torch::equal( + y.grad(), + z.grad() + )); +} diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 9a99ec5652c5613..d1e63c8d4dd88d2 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1355,6 +1355,22 @@ def test_view(self, device): self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) + @dtypes(*torch.testing.get_all_dtypes()) + def test_reshape_view_semantics(self, device, dtype): + tensor = make_tensor((15, 4), device, dtype) + target = (20, 3) + + # Cases where the tensor can be returned as a view. + view_tensor = tensor.reshape(target) + self.assertEqual((view_tensor.size()), target) + self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr()) + + # Cases where the tensor must be copied (transpose makes it non-contiguous forcing + # the copy). + copy_tensor = tensor.transpose(0, 1).reshape(target) + self.assertEqual(copy_tensor.size(), target) + self.assertNotEqual(tensor.storage().data_ptr(), copy_tensor.storage().data_ptr()) + def test_contiguous(self, device): x = torch.randn(1, 16, 5, 5, device=device) self.assertTrue(x.is_contiguous()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 2314044ec0d0aca..22481ac5fa46739 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1141,6 +1141,9 @@ # making it impossible (hard) to detect when it is actually a view. # - name: reshape(Tensor self, IntArrayRef shape) +- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a) + self: grad.reshape(self.sizes()) + - name: round(Tensor self) -> Tensor self: zeros_like(grad) diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 33db4b29ba4f64e..6c42bec1e5d1262 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -58,6 +58,7 @@ # discrete anyways. # FIXME: clone indices on construction. 'sparse_coo_tensor_with_dims_and_tensors': 'values', + '_reshape_alias': 'self', } for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE: diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 874e54d74a69d6b..e3ead5a10bcb2e6 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -92,6 +92,7 @@ 'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_', '_fw_primal', 'fake_quantize_per_tensor_affine_cachemask', 'fake_quantize_per_channel_affine_cachemask', + '_reshape_alias', ] SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS)) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 37326aaba38336d..a8a3b3a1e772fca 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -101,7 +101,7 @@ 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub', 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical', - 'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'conj_physical_', '_neg_view' + 'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'conj_physical_', '_neg_view', '_reshape_alias' } GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {