Skip to content

Commit

Permalink
[Decomposition] Trunc
Browse files Browse the repository at this point in the history
Summary:
Work in progress

This exists in _refs but it uses prims

Differential Revision: D49042033
  • Loading branch information
salilsdesai authored and facebook-github-bot committed Sep 14, 2023
1 parent ea94344 commit 5391170
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 5 deletions.
2 changes: 0 additions & 2 deletions test/expect/HasDecompTest.test_aten_core_operators.expect
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,6 @@ aten::tril_indices
aten::tril_indices.out
aten::triu_indices
aten::triu_indices.out
aten::trunc
aten::trunc.out
aten::trunc_
aten::unbind.int
aten::unfold
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.tril_,
aten.triu,
aten.triu_,
aten.trunc,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
Expand Down
5 changes: 5 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4020,6 +4020,11 @@ def scaled_dot_product_flash_attention(
)


@register_decomposition([aten.trunc])
def trunc(self: Tensor) -> Tensor:
return torch.where(self > 0, torch.floor(self), torch.ceil(self))


def register_inplace(aten_op, outplace_op):
@register_decomposition(aten_op)
def inplace_op(*args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
decomps_to_exclude = [
aten._unsafe_index,
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
aten.trunc,
]

remove_decompositions(decompositions, decomps_to_exclude)
Expand Down
11 changes: 8 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,14 @@ def tanh(a):
return prims.tanh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def trunc(a):
return prims.trunc(a)
@out_wrapper()
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def trunc(a: TensorLikeType) -> TensorLikeType:
return handle_noncontiguous_outputs([a], prims.trunc(a))


# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
Expand Down

0 comments on commit 5391170

Please sign in to comment.