Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose/add reference for view_as_complex #108005

Closed
41 changes: 19 additions & 22 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,15 +695,15 @@ Tensor sparse_compressed_to_dense(

// Computes the strides for view_dtype output when the view dtype is
// smaller than the original dtype
inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
const int64_t ndim = old_strides.size();

TORCH_CHECK(
old_strides[ndim - 1] == 1,
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
" (different element sizes), but got ", old_strides[ndim - 1]);

DimVector new_strides(ndim);
SymDimVector new_strides(ndim);
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
new_strides[dim_idx] = old_strides[dim_idx] * size_ratio;
}
Expand All @@ -713,14 +713,14 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides

// Computes the strides for view_dtype output when the view dtype is
// larger than the original dtype
inline DimVector compute_strides_for_view_dtype_upsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
const int64_t ndim = old_strides.size();
TORCH_CHECK(
old_strides[ndim - 1] == 1,
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
" (different element sizes), but got ", old_strides[ndim - 1]);

DimVector new_strides(ndim);
SymDimVector new_strides(ndim);
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
TORCH_CHECK(
(old_strides[dim_idx] % size_ratio) == 0,
Expand Down Expand Up @@ -753,8 +753,7 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {
auto* impl = new_tensor.unsafeGetTensorImpl();

if (self_element_size == new_element_size) {
impl->set_storage_offset(self.storage_offset());
impl->set_sizes_and_strides(self.sizes(), self.strides());
impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());

} else if (self.dim() == 0) {
TORCH_CHECK(false,
Expand All @@ -766,47 +765,45 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {

int64_t size_ratio = self_element_size / new_element_size;
auto new_strides = compute_strides_for_view_dtype_downsize(
self.strides(), size_ratio, self.scalar_type(), dtype);
self.sym_strides(), size_ratio, self.scalar_type(), dtype);

auto old_sizes = self.sizes();
DimVector new_sizes(self.dim());
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(self.dim());
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes[self.dim() - 1] *= size_ratio;

auto new_storage_offset = size_ratio * self.storage_offset();
auto new_storage_offset = size_ratio * self.sym_storage_offset();

impl->set_storage_offset(new_storage_offset);
impl->set_sizes_and_strides(new_sizes, new_strides);
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);

} else {
// Upsizing element size

int64_t size_ratio = new_element_size / self_element_size;

TORCH_CHECK(
(self.size(-1) % size_ratio) == 0,
(self.sym_size(-1) % size_ratio) == 0,
"self.size(-1) must be divisible by ", size_ratio, " to view ",
self.scalar_type(), " as ", dtype, " (different element sizes), ",
"but got ", self.size(-1));
"but got ", self.sym_size(-1));

TORCH_CHECK(
(self.storage_offset() % size_ratio) == 0,
(self.sym_storage_offset() % size_ratio) == 0,
"self.storage_offset() must be divisible by ", size_ratio, " to view ",
self.scalar_type(), " as ", dtype, " (different element sizes), but got ",
self.storage_offset());
self.sym_storage_offset());

auto new_strides = compute_strides_for_view_dtype_upsize(
self.strides(), size_ratio, self.scalar_type(), dtype);
self.sym_strides(), size_ratio, self.scalar_type(), dtype);

auto old_sizes = self.sizes();
DimVector new_sizes(self.dim());
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(self.dim());
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes[self.dim() - 1] /= size_ratio;

auto new_storage_offset = self.storage_offset() / size_ratio;
auto new_storage_offset = self.sym_storage_offset() / size_ratio;

impl->set_storage_offset(new_storage_offset);
impl->set_sizes_and_strides(new_sizes, new_strides);
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
}

return new_tensor;
Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.imag',
'_refs.reshape_as',
'_refs.view_as',
'_refs.view_as_complex' # TorchInductor does not support complex at the moment.
}

@parametrize("op", ref_ops_names)
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten._unsafe_index,
aten.upsample_bilinear2d,
aten.upsample_nearest2d_backward,
aten.view_as_complex,
aten.xlogy,
aten.xlogy_,
aten.zero,
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,7 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
make_fallback(aten._scaled_mm.default)

# TODO: This is done, just need to enable support in TorchInductor for complex types.
make_fallback(aten.view_as_complex, require_contiguous)

# The following were added as a result of https://github.com/pytorch/pytorch/pull/94039 to pass tests
Expand Down
5 changes: 0 additions & 5 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2642,11 +2642,6 @@ def meta_complex(real, imag):
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))


@register_meta(aten.view.dtype)
def view_dtype(self, dtype):
return utils.clone_preserve_strides(self).to(dtype)


@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
def nonzero_static(self, *, size: int, fill_value: int = -1):
return self.new_empty((size, self.dim()), dtype=torch.long)
Expand Down
23 changes: 22 additions & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
"squeeze",
"transpose",
"view_of",
"view_element_type",
#
# Functionalized view mutations
#
Expand All @@ -172,7 +173,6 @@
"item",
"maximum_value",
"minimum_value",
"to_dtype",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this as this function doesn't exist.

"copy_strided",
#
# Inplace prims
Expand Down Expand Up @@ -1780,6 +1780,27 @@ def _view_of_aten(a: Tensor) -> Tensor:
doc=_view_of_doc,
)


def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
return a.view(dtype)


def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
return a.view(dtype)


_view_element_type_doc = """
Creates a view of the tensor with a different dtype.
"""

view_element_type = _make_prim(
schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor",
meta=_view_element_type_meta,
impl_aten=_view_element_type_aten,
return_type=RETURN_TYPE.VIEW,
doc=_view_element_type_doc,
)

#
# Functionalized view mutations
#
Expand Down
38 changes: 38 additions & 0 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@
"view_as",
"vsplit",
"vstack",
"view_as_complex",
"unflatten",
"unbind",
"triu",
Expand Down Expand Up @@ -949,6 +950,43 @@ def trunc(a):
return prims.trunc(a)


# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine to register the decomposition and inductor won't try to use it by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I registered this as a decomposition, tests started failing and it seems Inductor generated code with complex64 and like.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it is being picked up because you have it in the core_aten_decompositions list which is used by inductor here:

decompositions = {**core_aten_decompositions(), **inductor_decompositions}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is composite(explicit) and it's already symintified in core, so there's not much point in registering it as a decomposition, not now nor when we support complex.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezcano CompositeExplicitAutograd operators are not traced through by FakeTensor, so you still need a decomposition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we wouldn't want to trace through those? Just because we are assuming that there may be some in-place ops or smth?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well in the case of view.dtype it calls raw TensorImpl methods and wouldn't materialize anything meaningful in the fx graph if we traced through it. Although I would question if that should be legal in composite methods. It certainly wouldn't work for any tensor subclasses so maybe should be registered as CPU, CUDA, Meta.

Either way I'm guessing this function isn't alone in being registered as CompositeExplicit but actually doesn't dispatch to anything.

def view_as_complex(self: TensorLikeType) -> TensorLikeType:
input_dtype = self.dtype
torch._check(
utils.is_float_dtype(input_dtype),
lambda: f"view_as_complex is only supported for floating point"
f"tensors, but got a tensor of scalar type: {input_dtype}",
)
sizes = self.size()
torch._check(
len(sizes) != 0,
lambda: "Input tensor must have one or more dimensions",
)
torch._check(
sizes[-1] == 2,
lambda: "Tensor must have a last dimension of size 2",
)

old_strides = self.stride()
torch._check(
old_strides[-1] == 1,
lambda: "Tensor must have a last dimension with stride 1",
)
dims = old_strides[:-1]
torch._check(
py_all(stride % 2 == 0 for stride in dims),
lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
)
torch._check(
self.storage_offset() % 2 == 0,
lambda: "Tensor must have a storage_offset divisible by 2",
)
return prims.view_element_type(
self, utils.corresponding_complex_dtype(input_dtype)
).squeeze(-1)


def _make_elementwise_binary_reference(
type_promotion_kind,
aten_op=infer_aten_op,
Expand Down
4 changes: 4 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21576,6 +21576,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
],
),
PythonRefInfo(
"_refs.view_as_complex",
torch_opinfo_name="view_as_complex",
),
]
python_ref_db += opinfo.definitions.python_ref_db

Expand Down