Skip to content

Commit

Permalink
[Decomposition] trunc (#109319)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-pytorch-ci-checks


 Add Decomp for Trunc and add it to core_aten_decompositions

Test Plan: Phabricator + OSS Tests

Reviewed By: SS-JIA, kirklandsign

Differential Revision: D49042033
  • Loading branch information
salilsdesai authored and facebook-github-bot committed Sep 18, 2023
1 parent 70ca3ee commit fb5c01a
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 5 deletions.
2 changes: 0 additions & 2 deletions test/expect/HasDecompTest.test_aten_core_operators.expect
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,6 @@ aten::tril_indices
aten::tril_indices.out
aten::triu_indices
aten::triu_indices.out
aten::trunc
aten::trunc.out
aten::trunc_
aten::unbind.int
aten::unfold
Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.tensor_split',
'_refs.to',
'_refs.true_divide',
'_refs.trunc',
'_refs.trunc_divide',
'_refs.vsplit',
'_refs.vstack',
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.tril_,
aten.triu,
aten.triu_,
aten.trunc,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
Expand Down
5 changes: 5 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4020,6 +4020,11 @@ def scaled_dot_product_flash_attention(
)


@register_decomposition([aten.trunc])
def trunc(self: Tensor, **kwargs) -> Tensor:
return torch.where(self > 0, torch.floor(self), torch.ceil(self))


def register_inplace(aten_op, outplace_op):
@register_decomposition(aten_op)
def inplace_op(*args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
aten._unsafe_index,
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
aten.clamp_min,
aten.trunc,
]

remove_decompositions(decompositions, decomps_to_exclude)
Expand Down
11 changes: 8 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,14 @@ def tanh(a):
return prims.tanh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def trunc(a):
return prims.trunc(a)
@out_wrapper()
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def trunc(a: TensorLikeType) -> TensorLikeType:
return handle_noncontiguous_outputs([a], prims.trunc(a))


# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
Expand Down

0 comments on commit fb5c01a

Please sign in to comment.