Skip to content

Good First Issue: Add MLX Op Handler for aten.flip #18918

@metascroy

Description

@metascroy

Good First Issue: Add MLX Op Handler for aten.flip

Summary

Add support for aten.flip in the MLX delegate. This op reverses tensor elements along specified dimensions and is needed by image augmentation and sequence reversal operations.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.

Approach: Decomposed handler (preferred)

aten.flip can be decomposed using SliceNode with step=-1:

# flip(x, dims=[0, 2]) reverses along dims 0 and 2
# For each dim d in dims:
#   x = slice(x, axis=d, start=size-1, stop=-(size+1), step=-1)

This approach reuses the existing SliceNode which already supports negative step (see topk handler for reference).

Steps

  1. Add handler in backends/mlx/ops.py

    @REGISTRY.register(target=[torch.ops.aten.flip.default])
    def _flip_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        args = P.args(n)
        require_args(args, 2, 2, "aten.flip")
        require_kwargs(P.kwargs(n), set(), "aten.flip")
        x, dims = args
        
        # Get input shape for computing slice bounds
        x_meta = n.args[0].meta.get("val")
        
        out = x  # Start with input, chain slices
        for dim in dims:
            dim_size = x_meta.shape[dim]
            _, tmp = P.make_tmp_slot()
            P.emit(
                SliceNode(
                    x=P.slot_to_tid(out),
                    out=P.slot_to_tid(tmp),
                    axis=P.to_int_or_vid(dim),
                    start=P.to_int_or_vid(dim_size - 1),
                    stop=P.to_int_or_vid(-(dim_size + 1)),
                    step=-1,
                )
            )
            out = tmp
        
        final_out = P.make_or_get_slot(n)
        P.emit(IdCopyNode(x=P.slot_to_tid(out), out=P.slot_to_tid(final_out)))
        return final_out
  2. Add test in backends/mlx/test/test_ops.py

    This op doesn't fit the simple unary pattern (has dims parameter), so create a custom test class following the existing patterns like PermuteTest:

    class FlipModel(nn.Module):
        def __init__(self, dims: List[int]):
            super().__init__()
            self.dims = dims
        
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.flip(x, self.dims)
    
    @register_test
    class FlipTest(OpTestCase):
        name = "flip"
        
        def __init__(self, shape: Tuple[int, ...], dims: List[int]):
            self.shape = shape
            self.dims = dims
            dims_str = "_".join(str(d) for d in dims)
            shape_str = "x".join(str(s) for s in shape)
            self.name = f"flip_{shape_str}_dims{dims_str}"
        
        @classmethod
        def get_test_configs(cls) -> List["FlipTest"]:
            return [
                cls(shape=(4, 5), dims=[0]),
                cls(shape=(4, 5), dims=[1]),
                cls(shape=(4, 5), dims=[0, 1]),
                cls(shape=(3, 4, 5), dims=[-1]),
                cls(shape=(3, 4, 5), dims=[0, 2]),
            ]
        
        def create_model(self) -> nn.Module:
            return FlipModel(self.dims)
        
        def create_inputs(self) -> Tuple[torch.Tensor, ...]:
            return (torch.randn(self.shape),)

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k flip

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions