Skip to content

Commit

Permalink
[inductor] Decompose diagonal_copy and diagonal_scatter
Browse files Browse the repository at this point in the history
This decomposes `diagonal_copy` into `as_strided_copy` and `diagonal_scatter`
into `as_strided_scatter`.

Currently these are decomposed into mutations, which can't be used by inductor
because the `_scatter` and `_copy` variants are output by functionalization to
remove mutations from the graph. So by re-introducing mutations in their
decomposition, we break the assumptions of the graph.

ghstack-source-id: 5b6b8ef2fc3dd6439fc90af744324610bf2b4fa1
Pull Request resolved: #103755
  • Loading branch information
peterbell10 committed Jun 16, 2023
1 parent 2d745b9 commit fca717a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 29 deletions.
2 changes: 2 additions & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.detach,
aten.diag_embed,
aten.diagonal,
aten.diagonal_copy,
aten.diagonal_scatter,
aten.dot,
aten.elu,
aten.elu_backward,
Expand Down
10 changes: 7 additions & 3 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,13 +898,19 @@ def as_strided(x, size, stride, storage_offset=None):
return TensorBox(ir.ReinterpretView(storage, new_layout))


@register_lowering(aten.as_strided_)
@register_lowering(aten.as_strided_, type_promotion_kind=None)
def as_strided_(x, size, stride, storage_offset=None):
assert isinstance(x, TensorBox)
x.data = as_strided(x, size, stride, storage_offset).data
return x


@register_lowering(aten.as_strided_copy, type_promotion_kind=None)
def as_strided_copy(x, size, stride, storage_offset=None):
result = as_strided(x, size, stride, storage_offset)
return clone(result)


@register_lowering(aten.cat)
def cat(inputs, dim=0):
if len(inputs) == 1:
Expand Down Expand Up @@ -1535,8 +1541,6 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.cummax)
make_fallback(aten.cummin)
make_fallback(aten.cumprod, warn=False)
make_fallback(aten.diagonal_copy, warn=False)
make_fallback(aten.diagonal_scatter, warn=False)
make_fallback(aten.digamma, warn=False)
make_fallback(aten._efficientzerotensor)
make_fallback(aten._embedding_bag_per_sample_weights_backward)
Expand Down
78 changes: 52 additions & 26 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3777,36 +3777,12 @@ def diag(
return torch.diagonal_copy(self, offset)


@register_decomposition(aten.diagonal_scatter)
@out_wrapper()
def diagonal_scatter(
input: TensorLikeType,
src: TensorLikeType,
offset: int = 0,
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
out = utils.clone_preserve_strides(input)
diag = out.diagonal(offset, dim1, dim2)
check(
diag.shape == src.shape,
lambda: "expected src to have a size equal to the diagonal of the input."
f"Got {src.shape} for a diagonal of shape {diag.shape}",
)
copy_to(diag, src)
return out


@register_decomposition(aten.diagonal)
def diagonal(
def _diagonal_as_strided_args(
self: TensorLikeType,
offset: int = 0,
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
"""
Reference implementation of torch.diagonal
"""
num_dims = self.dim()
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
Expand Down Expand Up @@ -3834,12 +3810,62 @@ def diagonal(
strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
strides.append(self.stride()[dim1] + self.stride()[dim2])

return sizes, strides, storage_offset


@register_decomposition(aten.diagonal_scatter)
@out_wrapper()
def diagonal_scatter(
input: TensorLikeType,
src: TensorLikeType,
offset: int = 0,
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
out = utils.clone_preserve_strides(input)
diag = out.diagonal(offset, dim1, dim2)
check(
diag.shape == src.shape,
lambda: "expected src to have a size equal to the diagonal of the input."
f"Got {src.shape} for a diagonal of shape {diag.shape}",
)
sizes, strides, storage_offset = _diagonal_as_strided_args(
input, offset, dim1, dim2
)
return aten.as_strided_scatter(
input, src, size=sizes, stride=strides, storage_offset=storage_offset
)


@register_decomposition(aten.diagonal)
def diagonal(
self: TensorLikeType,
offset: int = 0,
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
"""
Reference implementation of torch.diagonal
"""
sizes, strides, storage_offset = _diagonal_as_strided_args(
self, offset, dim1, dim2
)
result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)

return result


diagonal_copy = _make_copy_from_view(diagonal)
@register_decomposition(aten.diagonal_copy)
def diagonal_copy(
self: TensorLikeType,
offset: int = 0,
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
sizes, strides, storage_offset = _diagonal_as_strided_args(
self, offset, dim1, dim2
)
return aten.as_strided_copy(self, size=sizes, stride=strides, storage_offset=storage_offset)


@register_decomposition(aten.diag_embed)
Expand Down

0 comments on commit fca717a

Please sign in to comment.