🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for aten.trunc
Summary
Add support for aten.trunc in the MLX delegate. This op truncates floating-point values toward zero and is needed for integer conversion operations.
Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.trunc can be decomposed using floor and ceil based on sign:
# trunc(x) = sign(x) * floor(abs(x))
# Or equivalently: where(x >= 0, floor(x), ceil(x))
This uses existing FloorNode, CeilNode, WhereNode, and GreaterEqualNode.
Steps
-
Add handler in backends/mlx/ops.py
@REGISTRY.register(target=[torch.ops.aten.trunc.default])
def _trunc_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.trunc - truncate toward zero.
trunc(x) = where(x >= 0, floor(x), ceil(x))
"""
args = P.args(n)
require_args(args, 1, 1, "aten.trunc")
require_kwargs(P.kwargs(n), set(), "aten.trunc")
x = args[0]
x_meta = n.args[0].meta.get("val")
dtype = x_meta.dtype if x_meta is not None else torch.float32
# Create zero constant for comparison
zero_slot = emit_lifted_constant(P, 0.0, dtype)
# x >= 0
_, ge_zero = P.make_tmp_slot()
P.emit(
GreaterEqualNode(
a=P.slot_to_tid(x),
b=P.slot_to_tid(zero_slot),
out=P.slot_to_tid(ge_zero),
)
)
# floor(x)
_, floor_x = P.make_tmp_slot()
P.emit(
FloorNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(floor_x),
)
)
# ceil(x)
_, ceil_x = P.make_tmp_slot()
P.emit(
CeilNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(ceil_x),
)
)
# where(x >= 0, floor(x), ceil(x))
out = P.make_or_get_slot(n)
P.emit(
WhereNode(
condition=P.slot_to_tid(ge_zero),
x=P.slot_to_tid(floor_x),
y=P.slot_to_tid(ceil_x),
out=P.slot_to_tid(out),
)
)
return out
-
Add test in backends/mlx/test/test_ops.py
Use the existing _make_unary_op_test infrastructure:
# Add to _UNARY_OP_TESTS list:
{"op_name": "trunc", "op_fn": torch.trunc, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)},
The scale=10 ensures values span a range where truncation behavior is visible (not just near zero).
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k trunc
References
- MLX C++: Can use combination of floor/ceil, or check if MLX has
trunc directly
- PyTorch signature:
trunc(Tensor self) -> Tensor
- Mathematical definition: trunc(x) rounds toward zero (unlike floor which rounds toward -inf)
Alternatives
No response
Additional context
No response
RFC (Optional)
No response
🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for
aten.truncSummary
Add support for
aten.truncin the MLX delegate. This op truncates floating-point values toward zero and is needed for integer conversion operations.Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.trunccan be decomposed usingfloorandceilbased on sign:This uses existing
FloorNode,CeilNode,WhereNode, andGreaterEqualNode.Steps
Add handler in
backends/mlx/ops.pyAdd test in
backends/mlx/test/test_ops.pyUse the existing
_make_unary_op_testinfrastructure:The
scale=10ensures values span a range where truncation behavior is visible (not just near zero).Running tests
References
truncdirectlytrunc(Tensor self) -> TensorAlternatives
No response
Additional context
No response
RFC (Optional)
No response