diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 53c5cf70d248..77b178f7aa9b 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -509,8 +509,6 @@ aten::tril_indices aten::tril_indices.out aten::triu_indices aten::triu_indices.out -aten::trunc -aten::trunc.out aten::trunc_ aten::unbind.int aten::unfold diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 247f8eabd8e1..fbe739659acc 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -368,6 +368,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.tril_, aten.triu, aten.triu_, + aten.trunc, aten.unfold_backward, aten.unfold_copy, aten._unsafe_index, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 9b4b1c7a637f..26ea67f2d87e 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4020,6 +4020,11 @@ def scaled_dot_product_flash_attention( ) +@register_decomposition([aten.trunc]) +def trunc(self: Tensor) -> Tensor: + return torch.where(self > 0, torch.floor(self), torch.ceil(self)) + + def register_inplace(aten_op, outplace_op): @register_decomposition(aten_op) def inplace_op(*args, **kwargs): diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 3a7da0670d96..c034f840a270 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -71,6 +71,7 @@ decomps_to_exclude = [ aten._unsafe_index, aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py + aten.trunc, ] remove_decompositions(decompositions, decomps_to_exclude) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index f001c72717c2..80fba10b2af4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -945,9 +945,14 @@ def tanh(a): return prims.tanh(a) -@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) -def trunc(a): - return prims.trunc(a) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def trunc(a: TensorLikeType) -> TensorLikeType: + return handle_noncontiguous_outputs([a], prims.trunc(a)) # TODO: register this as a real ref/decomposition once TorchInductor supports complex!