From fb5c01a39ebb907a71da137f253734e4886c7f29 Mon Sep 17 00:00:00 2001 From: Salil Desai Date: Mon, 18 Sep 2023 11:37:19 -0700 Subject: [PATCH] [Decomposition] trunc (#109319) Summary: bypass-github-pytorch-ci-checks Add Decomp for Trunc and add it to core_aten_decompositions Test Plan: Phabricator + OSS Tests Reviewed By: SS-JIA, kirklandsign Differential Revision: D49042033 --- .../HasDecompTest.test_aten_core_operators.expect | 2 -- test/test_ops.py | 1 + torch/_decomp/__init__.py | 1 + torch/_decomp/decompositions.py | 5 +++++ torch/_inductor/decomposition.py | 1 + torch/_refs/__init__.py | 11 ++++++++--- 6 files changed, 16 insertions(+), 5 deletions(-) diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 052d675cb68bd..3dbf4c45e66af 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -504,8 +504,6 @@ aten::tril_indices aten::tril_indices.out aten::triu_indices aten::triu_indices.out -aten::trunc -aten::trunc.out aten::trunc_ aten::unbind.int aten::unfold diff --git a/test/test_ops.py b/test/test_ops.py index 626c4c1457aea..11ef74dfff1b7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1856,6 +1856,7 @@ class TestRefsOpsInfo(TestCase): '_refs.tensor_split', '_refs.to', '_refs.true_divide', + '_refs.trunc', '_refs.trunc_divide', '_refs.vsplit', '_refs.vstack', diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 9e0f1fc0aaaf9..a04cb21f4179b 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -370,6 +370,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.tril_, aten.triu, aten.triu_, + aten.trunc, aten.unfold_backward, aten.unfold_copy, aten._unsafe_index, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 9b4b1c7a637ff..9255570db230a 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4020,6 +4020,11 @@ def scaled_dot_product_flash_attention( ) +@register_decomposition([aten.trunc]) +def trunc(self: Tensor, **kwargs) -> Tensor: + return torch.where(self > 0, torch.floor(self), torch.ceil(self)) + + def register_inplace(aten_op, outplace_op): @register_decomposition(aten_op) def inplace_op(*args, **kwargs): diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 5a4dbcd067f3d..1b7117d6bb264 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -72,6 +72,7 @@ aten._unsafe_index, aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py aten.clamp_min, + aten.trunc, ] remove_decompositions(decompositions, decomps_to_exclude) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index f001c72717c2b..80fba10b2af4a 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -945,9 +945,14 @@ def tanh(a): return prims.tanh(a) -@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) -def trunc(a): - return prims.trunc(a) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def trunc(a: TensorLikeType) -> TensorLikeType: + return handle_noncontiguous_outputs([a], prims.trunc(a)) # TODO: register this as a real ref/decomposition once TorchInductor supports complex!