Skip to content

Good First Issue: Add MLX Op Handler for aten.trunc #18923

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add MLX Op Handler for aten.trunc

Summary

Add support for aten.trunc in the MLX delegate. This op truncates floating-point values toward zero and is needed for integer conversion operations.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.

Approach: Decomposed handler (preferred)

aten.trunc can be decomposed using floor and ceil based on sign:

# trunc(x) = sign(x) * floor(abs(x))
# Or equivalently: where(x >= 0, floor(x), ceil(x))

This uses existing FloorNode, CeilNode, WhereNode, and GreaterEqualNode.

Steps

  1. Add handler in backends/mlx/ops.py

    @REGISTRY.register(target=[torch.ops.aten.trunc.default])
    def _trunc_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        """Handle aten.trunc - truncate toward zero.
        
        trunc(x) = where(x >= 0, floor(x), ceil(x))
        """
        args = P.args(n)
        require_args(args, 1, 1, "aten.trunc")
        require_kwargs(P.kwargs(n), set(), "aten.trunc")
        x = args[0]
        
        x_meta = n.args[0].meta.get("val")
        dtype = x_meta.dtype if x_meta is not None else torch.float32
        
        # Create zero constant for comparison
        zero_slot = emit_lifted_constant(P, 0.0, dtype)
        
        # x >= 0
        _, ge_zero = P.make_tmp_slot()
        P.emit(
            GreaterEqualNode(
                a=P.slot_to_tid(x),
                b=P.slot_to_tid(zero_slot),
                out=P.slot_to_tid(ge_zero),
            )
        )
        
        # floor(x)
        _, floor_x = P.make_tmp_slot()
        P.emit(
            FloorNode(
                x=P.slot_to_tid(x),
                out=P.slot_to_tid(floor_x),
            )
        )
        
        # ceil(x)
        _, ceil_x = P.make_tmp_slot()
        P.emit(
            CeilNode(
                x=P.slot_to_tid(x),
                out=P.slot_to_tid(ceil_x),
            )
        )
        
        # where(x >= 0, floor(x), ceil(x))
        out = P.make_or_get_slot(n)
        P.emit(
            WhereNode(
                condition=P.slot_to_tid(ge_zero),
                x=P.slot_to_tid(floor_x),
                y=P.slot_to_tid(ceil_x),
                out=P.slot_to_tid(out),
            )
        )
        return out
  2. Add test in backends/mlx/test/test_ops.py

    Use the existing _make_unary_op_test infrastructure:

    # Add to _UNARY_OP_TESTS list:
    {"op_name": "trunc", "op_fn": torch.trunc, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)},

    The scale=10 ensures values span a range where truncation behavior is visible (not just near zero).

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k trunc

References

  • MLX C++: Can use combination of floor/ceil, or check if MLX has trunc directly
  • PyTorch signature: trunc(Tensor self) -> Tensor
  • Mathematical definition: trunc(x) rounds toward zero (unlike floor which rounds toward -inf)

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions