diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index ad3661d36908114..d91a27684ba1f01 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -448,6 +448,7 @@ def wrapper_set_seed(op, *args, **kwargs): "mT", "mH", "rsub", + "triu", } diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 800eb5180438fd3..cd5db74e89cad03 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -308,6 +308,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]: aten.trace, aten.transpose.int, aten.tril.default, + aten.triu.default, aten.unfold, aten.unfold_backward, aten.upsample_bilinear2d, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 2c0c907d9741d77..4c77ebdf82b0b98 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1505,30 +1505,6 @@ def fn(index): ) -@register_lowering(aten.triu) -def triu(x, diagonal=0): - x_loader = x.make_loader() - dtype = x.get_dtype() - - def inner_fn(index): - *_, i, j = index - return ops.where( - ops.ge( - ops.index_expr(j - i - diagonal, torch.int32), - ops.constant(0, torch.int32), - ), - x_loader(index), - ops.constant(0, dtype), - ) - - return Pointwise.create( - device=x.get_device(), - dtype=dtype, - inner_fn=inner_fn, - ranges=list(x.get_size()), - ) - - @register_lowering(aten.select_scatter, type_promotion_kind=None) def select_scatter(x, src, dim: int, index: int): assert x.get_dtype() == src.get_dtype()