Skip to content

Commit

Permalink
[inductor] use triu ref instead of lowering (#96040) (#96462)
Browse files Browse the repository at this point in the history
Fixes #95958
Generated code is functionally identical with ref and lowering, only minor differences

Pull Request resolved: #96040
Approved by: https://github.com/jansel

Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
  • Loading branch information
atalman and ngimel committed Mar 9, 2023
1 parent c9913cf commit c263bd4
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 24 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_opinfo.py
Expand Up @@ -448,6 +448,7 @@ def wrapper_set_seed(op, *args, **kwargs):
"mT",
"mH",
"rsub",
"triu",
}


Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Expand Up @@ -308,6 +308,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.trace,
aten.transpose.int,
aten.tril.default,
aten.triu.default,
aten.unfold,
aten.unfold_backward,
aten.upsample_bilinear2d,
Expand Down
24 changes: 0 additions & 24 deletions torch/_inductor/lowering.py
Expand Up @@ -1505,30 +1505,6 @@ def fn(index):
)


@register_lowering(aten.triu)
def triu(x, diagonal=0):
x_loader = x.make_loader()
dtype = x.get_dtype()

def inner_fn(index):
*_, i, j = index
return ops.where(
ops.ge(
ops.index_expr(j - i - diagonal, torch.int32),
ops.constant(0, torch.int32),
),
x_loader(index),
ops.constant(0, dtype),
)

return Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=inner_fn,
ranges=list(x.get_size()),
)


@register_lowering(aten.select_scatter, type_promotion_kind=None)
def select_scatter(x, src, dim: int, index: int):
assert x.get_dtype() == src.get_dtype()
Expand Down

0 comments on commit c263bd4

Please sign in to comment.