Skip to content

Commit

Permalink
[inductor] Lower diagonal, diagonal_copy and diagonal_scatter
Browse files Browse the repository at this point in the history
Currently these are decomposed into `as_strided`, which forces a buffer to be
realized. Instead, this lowers them into a native inductor view node and so
doesn't require any buffers to be realized.

ghstack-source-id: b0b2191e196980f1f386047d154ad050fa1f49a0
Pull Request resolved: pytorch#103755
  • Loading branch information
peterbell10 committed Jun 19, 2023
1 parent c441302 commit 1fb6489
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 19 deletions.
3 changes: 3 additions & 0 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def wrapper_set_seed(op, *args, **kwargs):
# Always test with all sample for following ops
inductor_all_samples = {
"arange",
"diagonal",
"diagonal_copy",
"diagonal_scatter",
"softmax.with_dtype",
"index_add",
"index_copy",
Expand Down
2 changes: 1 addition & 1 deletion torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.deg2rad,
aten.detach,
aten.diag_embed,
aten.diagonal,
aten.diagonal_backward,
aten.dot,
aten.elu,
aten.elu_backward,
Expand Down
35 changes: 21 additions & 14 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,23 +1408,13 @@ def __init__(self, data):


@dataclasses.dataclass
class View(BaseView):
class GenericView(BaseView):
size: List[Expr]
reindex: Callable[..., Any]

def make_reindexer(self):
return self.reindex

@staticmethod
def handle_negative_index(idx, size):
idx = sympy.expand(idx)
size = sympy.expand(size)
sizevars = V.graph.sizevars
if sizevars.size_hint(idx) < 0:
sizevars.guard_lt(idx, 0)
idx = idx + size
return idx

def reindex_str(self):
index_old = [sympy_symbol(f"i{n}") for n in range(len(self.size))]
index_new = list(self.reindex(index_old))
Expand All @@ -1437,6 +1427,26 @@ def __str__(self):

__repr__ = __str__

@classmethod
def create(cls, x, new_size, reindex):
return cls(x, list(new_size), reindex)

def get_size(self):
return self.size


@dataclasses.dataclass
class View(GenericView):
@staticmethod
def handle_negative_index(idx, size):
idx = sympy.expand(idx)
size = sympy.expand(size)
sizevars = V.graph.sizevars
if sizevars.size_hint(idx) < 0:
sizevars.guard_lt(idx, 0)
idx = idx + size
return idx

@classmethod
def create(cls, x, new_size):
assert isinstance(new_size, (tuple, list))
Expand Down Expand Up @@ -1559,9 +1569,6 @@ def reindex(index):

return reindex

def get_size(self):
return self.size


@dataclasses.dataclass
class ReinterpretView(BaseView):
Expand Down
72 changes: 68 additions & 4 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import torch.fx
import torch.utils._pytree as pytree
from torch._prims_common import (
canonicalize_dim,
canonicalize_dims,
check,
dtype_to_type,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
Expand Down Expand Up @@ -898,13 +900,19 @@ def as_strided(x, size, stride, storage_offset=None):
return TensorBox(ir.ReinterpretView(storage, new_layout))


@register_lowering(aten.as_strided_)
@register_lowering(aten.as_strided_, type_promotion_kind=None)
def as_strided_(x, size, stride, storage_offset=None):
assert isinstance(x, TensorBox)
x.data = as_strided(x, size, stride, storage_offset).data
return x


@register_lowering(aten.as_strided_copy, type_promotion_kind=None)
def as_strided_copy(x, size, stride, storage_offset=None):
result = as_strided(x, size, stride, storage_offset)
return clone(result)


@register_lowering(aten.cat)
def cat(inputs, dim=0):
if len(inputs) == 1:
Expand All @@ -918,6 +926,65 @@ def cat(inputs, dim=0):
return TensorBox(ir.ConcatKernel.create(inputs, dim))


@register_lowering(aten.diagonal, type_promotion_kind=None)
def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
original_shape = input.get_size()
num_dims = len(original_shape)
dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
dim2 = canonicalize_dim(idx=dim2, rank=num_dims)

check(
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
)

evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
offset_negative = evaluate_expr(sympy.Lt(offset, 0))
if offset_negative:
diag_size = max(min(original_shape[dim1] + offset, original_shape[dim2]), 0)
else:
diag_size = max(min(original_shape[dim1], original_shape[dim2] - offset), 0)

base_idx = (0, 0)
if offset_negative:
base_idx = (-offset, 0)
else:
base_idx = (0, offset)

sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
sizes.append(diag_size)

def reindexer(idx):
diag_idx = idx[-1]
original_idx = [0] * len(original_shape)
cur_dim = 0
for d in range(num_dims):
if d == dim1:
original_idx[d] = diag_idx + base_idx[0]
elif d == dim2:
original_idx[d] = diag_idx + base_idx[1]
else:
original_idx[d] = idx[cur_dim]
cur_dim += 1

assert cur_dim == len(original_shape) - 2
return original_idx

return TensorBox(ir.GenericView.create(input, sizes, reindexer))


@register_lowering(aten.diagonal_copy, type_promotion_kind=None)
def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
return clone(diagonal(input, offset, dim1, dim2))


@register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
output = clone(input)
target = diagonal(output, offset, dim1, dim2)
mutate_to(target, src)
return output


@register_lowering(aten.select, type_promotion_kind=None)
def select(x, dim, idx):
idx = View.handle_negative_index(idx, x.get_size()[dim])
Expand Down Expand Up @@ -1535,8 +1602,6 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.cummax)
make_fallback(aten.cummin)
make_fallback(aten.cumprod, warn=False)
make_fallback(aten.diagonal_copy, warn=False)
make_fallback(aten.diagonal_scatter, warn=False)
make_fallback(aten.digamma, warn=False)
make_fallback(aten._efficientzerotensor)
make_fallback(aten._embedding_bag_per_sample_weights_backward)
Expand Down Expand Up @@ -1626,7 +1691,6 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.adaptive_max_pool3d_backward)
make_fallback(aten.avg_pool3d_backward)
make_fallback(aten._cdist_backward)
make_fallback(aten.diagonal_backward, warn=False)
make_fallback(aten._embedding_bag_dense_backward)
make_fallback(aten.fractional_max_pool2d_backward)
make_fallback(aten.fractional_max_pool3d_backward)
Expand Down

0 comments on commit 1fb6489

Please sign in to comment.