diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index b48ab4c4a9ad7..6c22c766e2500 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -695,7 +695,7 @@ Tensor sparse_compressed_to_dense( // Computes the strides for view_dtype output when the view dtype is // smaller than the original dtype -inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) { +inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) { const int64_t ndim = old_strides.size(); TORCH_CHECK( @@ -703,7 +703,7 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides "self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype, " (different element sizes), but got ", old_strides[ndim - 1]); - DimVector new_strides(ndim); + SymDimVector new_strides(ndim); for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) { new_strides[dim_idx] = old_strides[dim_idx] * size_ratio; } @@ -713,14 +713,14 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides // Computes the strides for view_dtype output when the view dtype is // larger than the original dtype -inline DimVector compute_strides_for_view_dtype_upsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) { +inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) { const int64_t ndim = old_strides.size(); TORCH_CHECK( old_strides[ndim - 1] == 1, "self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype, " (different element sizes), but got ", old_strides[ndim - 1]); - DimVector new_strides(ndim); + SymDimVector new_strides(ndim); for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) { TORCH_CHECK( (old_strides[dim_idx] % size_ratio) == 0, @@ -753,8 +753,7 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) { auto* impl = new_tensor.unsafeGetTensorImpl(); if (self_element_size == new_element_size) { - impl->set_storage_offset(self.storage_offset()); - impl->set_sizes_and_strides(self.sizes(), self.strides()); + impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset()); } else if (self.dim() == 0) { TORCH_CHECK(false, @@ -766,17 +765,16 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) { int64_t size_ratio = self_element_size / new_element_size; auto new_strides = compute_strides_for_view_dtype_downsize( - self.strides(), size_ratio, self.scalar_type(), dtype); + self.sym_strides(), size_ratio, self.scalar_type(), dtype); - auto old_sizes = self.sizes(); - DimVector new_sizes(self.dim()); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(self.dim()); std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); new_sizes[self.dim() - 1] *= size_ratio; - auto new_storage_offset = size_ratio * self.storage_offset(); + auto new_storage_offset = size_ratio * self.sym_storage_offset(); - impl->set_storage_offset(new_storage_offset); - impl->set_sizes_and_strides(new_sizes, new_strides); + impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset); } else { // Upsizing element size @@ -784,29 +782,28 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) { int64_t size_ratio = new_element_size / self_element_size; TORCH_CHECK( - (self.size(-1) % size_ratio) == 0, + (self.sym_size(-1) % size_ratio) == 0, "self.size(-1) must be divisible by ", size_ratio, " to view ", self.scalar_type(), " as ", dtype, " (different element sizes), ", - "but got ", self.size(-1)); + "but got ", self.sym_size(-1)); TORCH_CHECK( - (self.storage_offset() % size_ratio) == 0, + (self.sym_storage_offset() % size_ratio) == 0, "self.storage_offset() must be divisible by ", size_ratio, " to view ", self.scalar_type(), " as ", dtype, " (different element sizes), but got ", - self.storage_offset()); + self.sym_storage_offset()); auto new_strides = compute_strides_for_view_dtype_upsize( - self.strides(), size_ratio, self.scalar_type(), dtype); + self.sym_strides(), size_ratio, self.scalar_type(), dtype); - auto old_sizes = self.sizes(); - DimVector new_sizes(self.dim()); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(self.dim()); std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); new_sizes[self.dim() - 1] /= size_ratio; - auto new_storage_offset = self.storage_offset() / size_ratio; + auto new_storage_offset = self.sym_storage_offset() / size_ratio; - impl->set_storage_offset(new_storage_offset); - impl->set_sizes_and_strides(new_sizes, new_strides); + impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset); } return new_tensor; diff --git a/test/test_ops.py b/test/test_ops.py index ae3355f7dff7b..0012ac233f797 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1881,6 +1881,7 @@ class TestRefsOpsInfo(TestCase): '_refs.imag', '_refs.reshape_as', '_refs.view_as', + '_refs.view_as_complex' # TorchInductor does not support complex at the moment. } @parametrize("op", ref_ops_names) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 7436c979cdb96..247f8eabd8e14 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -373,6 +373,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten._unsafe_index, aten.upsample_bilinear2d, aten.upsample_nearest2d_backward, + aten.view_as_complex, aten.xlogy, aten.xlogy_, aten.zero, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a6e0e01b76dad..595b37effb7d2 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1886,6 +1886,7 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) make_fallback(aten._scaled_mm.default) +# TODO: This is done, just need to enable support in TorchInductor for complex types. make_fallback(aten.view_as_complex, require_contiguous) # The following were added as a result of https://github.com/pytorch/pytorch/pull/94039 to pass tests diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 638e046b27365..73909c29d880e 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2642,11 +2642,6 @@ def meta_complex(real, imag): return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) -@register_meta(aten.view.dtype) -def view_dtype(self, dtype): - return utils.clone_preserve_strides(self).to(dtype) - - @register_meta([aten.nonzero_static.default, aten.nonzero_static.out]) def nonzero_static(self, *, size: int, fill_value: int = -1): return self.new_empty((size, self.dim()), dtype=torch.long) diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 8b1a58eef7864..35cd8c196933f 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -148,6 +148,7 @@ "squeeze", "transpose", "view_of", + "view_element_type", # # Functionalized view mutations # @@ -172,7 +173,6 @@ "item", "maximum_value", "minimum_value", - "to_dtype", "copy_strided", # # Inplace prims @@ -1780,6 +1780,27 @@ def _view_of_aten(a: Tensor) -> Tensor: doc=_view_of_doc, ) + +def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + return a.view(dtype) + + +def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: + return a.view(dtype) + + +_view_element_type_doc = """ + Creates a view of the tensor with a different dtype. + """ + +view_element_type = _make_prim( + schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor", + meta=_view_element_type_meta, + impl_aten=_view_element_type_aten, + return_type=RETURN_TYPE.VIEW, + doc=_view_element_type_doc, +) + # # Functionalized view mutations # diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 1decb13c3bff8..66b3cd1476a2f 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -289,6 +289,7 @@ "view_as", "vsplit", "vstack", + "view_as_complex", "unflatten", "unbind", "triu", @@ -949,6 +950,43 @@ def trunc(a): return prims.trunc(a) +# TODO: register this as a real ref/decomposition once TorchInductor supports complex! +def view_as_complex(self: TensorLikeType) -> TensorLikeType: + input_dtype = self.dtype + torch._check( + utils.is_float_dtype(input_dtype), + lambda: f"view_as_complex is only supported for floating point" + f"tensors, but got a tensor of scalar type: {input_dtype}", + ) + sizes = self.size() + torch._check( + len(sizes) != 0, + lambda: "Input tensor must have one or more dimensions", + ) + torch._check( + sizes[-1] == 2, + lambda: "Tensor must have a last dimension of size 2", + ) + + old_strides = self.stride() + torch._check( + old_strides[-1] == 1, + lambda: "Tensor must have a last dimension with stride 1", + ) + dims = old_strides[:-1] + torch._check( + py_all(stride % 2 == 0 for stride in dims), + lambda: "Tensor must have a stride divisible by 2 for all but last dimension", + ) + torch._check( + self.storage_offset() % 2 == 0, + lambda: "Tensor must have a storage_offset divisible by 2", + ) + return prims.view_element_type( + self, utils.corresponding_complex_dtype(input_dtype) + ).squeeze(-1) + + def _make_elementwise_binary_reference( type_promotion_kind, aten_op=infer_aten_op, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7bb36124c9746..7c5d9fcd27b8f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -21576,6 +21576,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), ], ), + PythonRefInfo( + "_refs.view_as_complex", + torch_opinfo_name="view_as_complex", + ), ] python_ref_db += opinfo.definitions.python_ref_db