Skip to content

Commit

Permalink
Decompose/add reference for view_as_complex (#108005)
Browse files Browse the repository at this point in the history
Aten source: https://github.com/pytorch/pytorch/blob/d4a99631dd589afbb972b3401e56f029192c9b0f/aten/src/ATen/native/ComplexHelper.h#L78

Documentation reference:
https://pytorch.org/docs/stable/generated/torch.view_as_complex.html

Note: this adds a new primitive `view_of_dtype`, which is trivially implemented, as its meta function is already implemented elsewhere.

Finally, this is not registered as a decomposition (yet), because TorchInductor does not yet support complex types. It should be added once we do.

Closes #108020 as well.

Pull Request resolved: #108005
Approved by: https://github.com/peterbell10, https://github.com/ezyang
  • Loading branch information
Fidget-Spinner authored and pytorchmergebot committed Sep 7, 2023
1 parent 366ce58 commit c458fa0
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 28 deletions.
41 changes: 19 additions & 22 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,15 +695,15 @@ Tensor sparse_compressed_to_dense(

// Computes the strides for view_dtype output when the view dtype is
// smaller than the original dtype
inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
const int64_t ndim = old_strides.size();

TORCH_CHECK(
old_strides[ndim - 1] == 1,
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
" (different element sizes), but got ", old_strides[ndim - 1]);

DimVector new_strides(ndim);
SymDimVector new_strides(ndim);
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
new_strides[dim_idx] = old_strides[dim_idx] * size_ratio;
}
Expand All @@ -713,14 +713,14 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides

// Computes the strides for view_dtype output when the view dtype is
// larger than the original dtype
inline DimVector compute_strides_for_view_dtype_upsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
const int64_t ndim = old_strides.size();
TORCH_CHECK(
old_strides[ndim - 1] == 1,
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
" (different element sizes), but got ", old_strides[ndim - 1]);

DimVector new_strides(ndim);
SymDimVector new_strides(ndim);
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
TORCH_CHECK(
(old_strides[dim_idx] % size_ratio) == 0,
Expand Down Expand Up @@ -753,8 +753,7 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {
auto* impl = new_tensor.unsafeGetTensorImpl();

if (self_element_size == new_element_size) {
impl->set_storage_offset(self.storage_offset());
impl->set_sizes_and_strides(self.sizes(), self.strides());
impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());

} else if (self.dim() == 0) {
TORCH_CHECK(false,
Expand All @@ -766,47 +765,45 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {

int64_t size_ratio = self_element_size / new_element_size;
auto new_strides = compute_strides_for_view_dtype_downsize(
self.strides(), size_ratio, self.scalar_type(), dtype);
self.sym_strides(), size_ratio, self.scalar_type(), dtype);

auto old_sizes = self.sizes();
DimVector new_sizes(self.dim());
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(self.dim());
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes[self.dim() - 1] *= size_ratio;

auto new_storage_offset = size_ratio * self.storage_offset();
auto new_storage_offset = size_ratio * self.sym_storage_offset();

impl->set_storage_offset(new_storage_offset);
impl->set_sizes_and_strides(new_sizes, new_strides);
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);

} else {
// Upsizing element size

int64_t size_ratio = new_element_size / self_element_size;

TORCH_CHECK(
(self.size(-1) % size_ratio) == 0,
(self.sym_size(-1) % size_ratio) == 0,
"self.size(-1) must be divisible by ", size_ratio, " to view ",
self.scalar_type(), " as ", dtype, " (different element sizes), ",
"but got ", self.size(-1));
"but got ", self.sym_size(-1));

TORCH_CHECK(
(self.storage_offset() % size_ratio) == 0,
(self.sym_storage_offset() % size_ratio) == 0,
"self.storage_offset() must be divisible by ", size_ratio, " to view ",
self.scalar_type(), " as ", dtype, " (different element sizes), but got ",
self.storage_offset());
self.sym_storage_offset());

auto new_strides = compute_strides_for_view_dtype_upsize(
self.strides(), size_ratio, self.scalar_type(), dtype);
self.sym_strides(), size_ratio, self.scalar_type(), dtype);

auto old_sizes = self.sizes();
DimVector new_sizes(self.dim());
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(self.dim());
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes[self.dim() - 1] /= size_ratio;

auto new_storage_offset = self.storage_offset() / size_ratio;
auto new_storage_offset = self.sym_storage_offset() / size_ratio;

impl->set_storage_offset(new_storage_offset);
impl->set_sizes_and_strides(new_sizes, new_strides);
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
}

return new_tensor;
Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.imag',
'_refs.reshape_as',
'_refs.view_as',
'_refs.view_as_complex' # TorchInductor does not support complex at the moment.
}

@parametrize("op", ref_ops_names)
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten._unsafe_index,
aten.upsample_bilinear2d,
aten.upsample_nearest2d_backward,
aten.view_as_complex,
aten.xlogy,
aten.xlogy_,
aten.zero,
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,7 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
make_fallback(aten._scaled_mm.default)

# TODO: This is done, just need to enable support in TorchInductor for complex types.
make_fallback(aten.view_as_complex, require_contiguous)

# The following were added as a result of https://github.com/pytorch/pytorch/pull/94039 to pass tests
Expand Down
5 changes: 0 additions & 5 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2642,11 +2642,6 @@ def meta_complex(real, imag):
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))


@register_meta(aten.view.dtype)
def view_dtype(self, dtype):
return utils.clone_preserve_strides(self).to(dtype)


@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
def nonzero_static(self, *, size: int, fill_value: int = -1):
return self.new_empty((size, self.dim()), dtype=torch.long)
Expand Down
23 changes: 22 additions & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
"squeeze",
"transpose",
"view_of",
"view_element_type",
#
# Functionalized view mutations
#
Expand All @@ -172,7 +173,6 @@
"item",
"maximum_value",
"minimum_value",
"to_dtype",
"copy_strided",
#
# Inplace prims
Expand Down Expand Up @@ -1780,6 +1780,27 @@ def _view_of_aten(a: Tensor) -> Tensor:
doc=_view_of_doc,
)


def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
return a.view(dtype)


def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
return a.view(dtype)


_view_element_type_doc = """
Creates a view of the tensor with a different dtype.
"""

view_element_type = _make_prim(
schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor",
meta=_view_element_type_meta,
impl_aten=_view_element_type_aten,
return_type=RETURN_TYPE.VIEW,
doc=_view_element_type_doc,
)

#
# Functionalized view mutations
#
Expand Down
38 changes: 38 additions & 0 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@
"view_as",
"vsplit",
"vstack",
"view_as_complex",
"unflatten",
"unbind",
"triu",
Expand Down Expand Up @@ -949,6 +950,43 @@ def trunc(a):
return prims.trunc(a)


# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
def view_as_complex(self: TensorLikeType) -> TensorLikeType:
input_dtype = self.dtype
torch._check(
utils.is_float_dtype(input_dtype),
lambda: f"view_as_complex is only supported for floating point"
f"tensors, but got a tensor of scalar type: {input_dtype}",
)
sizes = self.size()
torch._check(
len(sizes) != 0,
lambda: "Input tensor must have one or more dimensions",
)
torch._check(
sizes[-1] == 2,
lambda: "Tensor must have a last dimension of size 2",
)

old_strides = self.stride()
torch._check(
old_strides[-1] == 1,
lambda: "Tensor must have a last dimension with stride 1",
)
dims = old_strides[:-1]
torch._check(
py_all(stride % 2 == 0 for stride in dims),
lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
)
torch._check(
self.storage_offset() % 2 == 0,
lambda: "Tensor must have a storage_offset divisible by 2",
)
return prims.view_element_type(
self, utils.corresponding_complex_dtype(input_dtype)
).squeeze(-1)


def _make_elementwise_binary_reference(
type_promotion_kind,
aten_op=infer_aten_op,
Expand Down
4 changes: 4 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21576,6 +21576,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
],
),
PythonRefInfo(
"_refs.view_as_complex",
torch_opinfo_name="view_as_complex",
),
]
python_ref_db += opinfo.definitions.python_ref_db

Expand Down

0 comments on commit c458fa0

Please sign in to comment.