Skip to content

Good First Issue: Add MLX Op Handler for aten.bitwise_or #18926

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add MLX Op Handler for aten.bitwise_or

Summary

Add support for aten.bitwise_or in the MLX delegate. This op performs element-wise bitwise OR and is needed for bit manipulation, flags/masks, and low-level data processing.

Background

The MLX delegate currently has no handler for bitwise_or. MLX has native support via mlx::core::bitwise_or which works for both boolean and integer tensors.

Approach: New schema node + runtime

Add a BitwiseOrNode to handle this via MLX's native bitwise_or.

Steps

  1. Add node to backends/mlx/serialization/schema.fbs

    table BitwiseOrNode {
      a: Tid;
      b: Tid;
      out: Tid;
    }

    Add BitwiseOrNode to the OpNode union (append only, do not reorder).

  2. Regenerate serialization code

    python backends/mlx/serialization/generate.py
  3. Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h

    inline void exec_bitwise_or(
        const BitwiseOrNode& n, ExecutionState& st, StreamOrDevice s) {
      auto a = st.get_tensor(n.a());
      auto b = st.get_tensor(n.b());
      auto out = mlx::core::bitwise_or(a, b, s);
      st.set_tensor(n.out(), out);
    }
  4. Add handler in backends/mlx/ops.py

    Add to the _BINARY_OPS table:

    # In _BINARY_OPS list, add:
    ([torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_or.Scalar], BitwiseOrNode, "aten.bitwise_or", True),
  5. Add test in backends/mlx/test/test_ops.py

    Use the _BINARY_OP_TESTS table:

    # Add to _BINARY_OP_TESTS list:
    {"op_name": "bitwise_or_bool", "op_fn": torch.bitwise_or, 
     "shapes": _SHAPES_3, "dtypes": [torch.bool], 
     "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()},
    {"op_name": "bitwise_or_int", "op_fn": torch.bitwise_or, 
     "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], 
     "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)},

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k bitwise_or

References

  • MLX C++: array bitwise_or(const array &a, const array &b, StreamOrDevice s = {})
  • PyTorch signature: bitwise_or(Tensor self, Tensor other) -> Tensor
  • Supported dtypes: int8, int16, int32, int64, uint8, bool

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions