Skip to content

Good First Issue: Add MLX Op Handler for aten.roll #18919

@metascroy

Description

@metascroy

Good First Issue: Add MLX Op Handler for aten.roll

Summary

Add support for aten.roll in the MLX delegate. This op shifts tensor elements along specified dimensions with wrap-around and is needed by Swin Transformer's shift-window attention mechanism.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. While aten.roll decomposes into index_select + arange + cat operations, a native MLX implementation using mlx::core::roll would be more efficient (single kernel vs multiple ops).

Approach: New schema node + runtime

This requires a new RollNode because MLX has a dedicated roll function that's more efficient than the decomposed representation.

Steps

  1. Add node to backends/mlx/serialization/schema.fbs

    table RollNode {
      x: Tid;
      out: Tid;
      shift: [IntOrVid];  // shift amounts per axis
      axes: [int];        // axes to roll
    }

    Add RollNode to the OpNode union (append only, do not reorder).

  2. Regenerate serialization code

    python backends/mlx/serialization/generate.py
  3. Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h

    void exec_RollNode(const RollNode& node) {
      auto x = get_tensor(node.x());
      
      std::vector<int> shifts;
      for (auto s : *node.shift()) {
        shifts.push_back(resolve_int_or_vid(s));
      }
      
      std::vector<int> axes;
      for (auto a : *node.axes()) {
        axes.push_back(a);
      }
      
      auto out = mlx::core::roll(x, shifts, axes, stream_);
      set_tensor(node.out(), out);
    }
  4. Add handler in backends/mlx/ops.py

    Since RollNode is a simple unary-style op (input tensor → output tensor), add it to the _UNARY_OPS table:

    # In the _UNARY_OPS list, add:
    (torch.ops.aten.roll.default, RollNode, "aten.roll"),

    Note: This only works if RollNode follows the standard unary signature (x, out). However, roll has additional parameters (shifts, axes), so you may need a custom handler instead:

    @REGISTRY.register(target=[torch.ops.aten.roll.default])
    def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        args = P.args(n)
        require_args(args, 2, 3, "aten.roll")
        require_kwargs(P.kwargs(n), set(), "aten.roll")
        x = args[0]
        shifts = args[1]
        dims = args[2] if len(args) > 2 else []
        
        # Normalize shifts and dims to lists
        if isinstance(shifts, int):
            shifts = [shifts]
        if isinstance(dims, int):
            dims = [dims]
        if not dims:
            dims = list(range(len(n.args[0].meta["val"].shape)))
        
        out = P.make_or_get_slot(n)
        P.emit(
            RollNode(
                x=P.slot_to_tid(x),
                out=P.slot_to_tid(out),
                shift=[P.to_int_or_vid(s) for s in shifts],
                axes=list(dims),
            )
        )
        return out
  5. Add test in backends/mlx/test/test_ops.py

    This op doesn't fit the simple unary pattern, so create a custom test class:

    class RollModel(nn.Module):
        def __init__(self, shifts: int, dims: int):
            super().__init__()
            self.shifts = shifts
            self.dims = dims
        
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.roll(x, shifts=self.shifts, dims=self.dims)
    
    @register_test
    class RollTest(OpTestCase):
        name = "roll"
        
        def __init__(self, shape: Tuple[int, ...], shifts: int, dims: int):
            self.shape = shape
            self.shifts = shifts
            self.dims = dims
            self.name = f"roll_shift{shifts}_dim{dims}"
        
        @classmethod
        def get_test_configs(cls) -> List["RollTest"]:
            return [
                cls(shape=(8,), shifts=2, dims=0),
                cls(shape=(4, 5), shifts=1, dims=0),
                cls(shape=(4, 5), shifts=-2, dims=1),
                cls(shape=(3, 4, 5), shifts=3, dims=2),
            ]
        
        def create_model(self) -> nn.Module:
            return RollModel(self.shifts, self.dims)
        
        def create_inputs(self) -> Tuple[torch.Tensor, ...]:
            return (torch.randn(self.shape),)

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k roll

References

  • MLX C++: array roll(const array &a, int shift, int axis, StreamOrDevice s = {})
  • PyTorch signature: roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor
  • Use case: Swin Transformer shift-window attention

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions