Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Decomposition] Trunc #109319

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/expect/HasDecompTest.test_aten_core_operators.expect
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,6 @@ aten::tril_indices
aten::tril_indices.out
aten::triu_indices
aten::triu_indices.out
aten::trunc
aten::trunc.out
aten::trunc_
aten::unbind.int
aten::unfold
Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.tensor_split',
'_refs.to',
'_refs.true_divide',
'_refs.trunc',
'_refs.trunc_divide',
'_refs.vsplit',
'_refs.vstack',
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.tril_,
aten.triu,
aten.triu_,
aten.trunc,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
Expand Down
5 changes: 5 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4020,6 +4020,11 @@ def scaled_dot_product_flash_attention(
)


@register_decomposition([aten.trunc])
def trunc(self: Tensor, **kwargs) -> Tensor:
return torch.where(self > 0, torch.floor(self), torch.ceil(self))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we decomposing what is likely a hardware intrinsic operation into multiple operations? It seems to me that just because you can doesn't mean you should.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 makes sense. We will remove this decomp and promote this operator to core, and provide similar treatment to other operators that map cleanly to hardware intrinsics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 to be sure I understand. If we were to do this decomp here it results in 1) perf loss and 2) potential numerics mismatch. Perf loss should be recoverable by pattern matching + fusion. 2 is a tough one and for that I agree with you that we should probably make this core aten op. But would you agree with me on 1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also be asking how many consumers do you expect would want to "re-compose" this operator. I suspect since this is a hardware intrinsic, most backends would want it to be specialized but now they have to pattern match it back into trunc first. So in the most common case I expect this to create more work, not save work.

On the other hand if you could show that most backends had a hard time implementing trunc and this would save a lot of maintenance then I could see the performance trade-off being worth it.



def register_inplace(aten_op, outplace_op):
@register_decomposition(aten_op)
def inplace_op(*args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
aten._unsafe_index,
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
aten.clamp_min,
aten.trunc,
]

remove_decompositions(decompositions, decomps_to_exclude)
Expand Down
11 changes: 8 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,14 @@ def tanh(a):
return prims.tanh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def trunc(a):
return prims.trunc(a)
@out_wrapper()
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def trunc(a: TensorLikeType) -> TensorLikeType:
return handle_noncontiguous_outputs([a], prims.trunc(a))


# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
Expand Down