Good First Issue: Add MLX Op Handler for aten.roll
Summary
Add support for aten.roll in the MLX delegate. This op shifts tensor elements along specified dimensions with wrap-around and is needed by Swin Transformer's shift-window attention mechanism.
Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. While aten.roll decomposes into index_select + arange + cat operations, a native MLX implementation using mlx::core::roll would be more efficient (single kernel vs multiple ops).
Approach: New schema node + runtime
This requires a new RollNode because MLX has a dedicated roll function that's more efficient than the decomposed representation.
Steps
-
Add node to backends/mlx/serialization/schema.fbs
table RollNode {
x: Tid;
out: Tid;
shift: [IntOrVid]; // shift amounts per axis
axes: [int]; // axes to roll
}
Add RollNode to the OpNode union (append only, do not reorder).
-
Regenerate serialization code
python backends/mlx/serialization/generate.py
-
Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h
void exec_RollNode(const RollNode& node) {
auto x = get_tensor(node.x());
std::vector<int> shifts;
for (auto s : *node.shift()) {
shifts.push_back(resolve_int_or_vid(s));
}
std::vector<int> axes;
for (auto a : *node.axes()) {
axes.push_back(a);
}
auto out = mlx::core::roll(x, shifts, axes, stream_);
set_tensor(node.out(), out);
}
-
Add handler in backends/mlx/ops.py
Since RollNode is a simple unary-style op (input tensor → output tensor), add it to the _UNARY_OPS table:
# In the _UNARY_OPS list, add:
(torch.ops.aten.roll.default, RollNode, "aten.roll"),
Note: This only works if RollNode follows the standard unary signature (x, out). However, roll has additional parameters (shifts, axes), so you may need a custom handler instead:
@REGISTRY.register(target=[torch.ops.aten.roll.default])
def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 3, "aten.roll")
require_kwargs(P.kwargs(n), set(), "aten.roll")
x = args[0]
shifts = args[1]
dims = args[2] if len(args) > 2 else []
# Normalize shifts and dims to lists
if isinstance(shifts, int):
shifts = [shifts]
if isinstance(dims, int):
dims = [dims]
if not dims:
dims = list(range(len(n.args[0].meta["val"].shape)))
out = P.make_or_get_slot(n)
P.emit(
RollNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
shift=[P.to_int_or_vid(s) for s in shifts],
axes=list(dims),
)
)
return out
-
Add test in backends/mlx/test/test_ops.py
This op doesn't fit the simple unary pattern, so create a custom test class:
class RollModel(nn.Module):
def __init__(self, shifts: int, dims: int):
super().__init__()
self.shifts = shifts
self.dims = dims
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.roll(x, shifts=self.shifts, dims=self.dims)
@register_test
class RollTest(OpTestCase):
name = "roll"
def __init__(self, shape: Tuple[int, ...], shifts: int, dims: int):
self.shape = shape
self.shifts = shifts
self.dims = dims
self.name = f"roll_shift{shifts}_dim{dims}"
@classmethod
def get_test_configs(cls) -> List["RollTest"]:
return [
cls(shape=(8,), shifts=2, dims=0),
cls(shape=(4, 5), shifts=1, dims=0),
cls(shape=(4, 5), shifts=-2, dims=1),
cls(shape=(3, 4, 5), shifts=3, dims=2),
]
def create_model(self) -> nn.Module:
return RollModel(self.shifts, self.dims)
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.randn(self.shape),)
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k roll
References
- MLX C++:
array roll(const array &a, int shift, int axis, StreamOrDevice s = {})
- PyTorch signature:
roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor
- Use case: Swin Transformer shift-window attention
Good First Issue: Add MLX Op Handler for
aten.rollSummary
Add support for
aten.rollin the MLX delegate. This op shifts tensor elements along specified dimensions with wrap-around and is needed by Swin Transformer's shift-window attention mechanism.Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. While
aten.rolldecomposes intoindex_select+arange+catoperations, a native MLX implementation usingmlx::core::rollwould be more efficient (single kernel vs multiple ops).Approach: New schema node + runtime
This requires a new
RollNodebecause MLX has a dedicatedrollfunction that's more efficient than the decomposed representation.Steps
Add node to
backends/mlx/serialization/schema.fbsAdd
RollNodeto theOpNodeunion (append only, do not reorder).Regenerate serialization code
Add C++ runtime exec function in
backends/mlx/runtime/MLXInterpreter.hAdd handler in
backends/mlx/ops.pySince
RollNodeis a simple unary-style op (input tensor → output tensor), add it to the_UNARY_OPStable:Note: This only works if
RollNodefollows the standard unary signature(x, out). However,rollhas additional parameters (shifts,axes), so you may need a custom handler instead:Add test in
backends/mlx/test/test_ops.pyThis op doesn't fit the simple unary pattern, so create a custom test class:
Running tests
References
array roll(const array &a, int shift, int axis, StreamOrDevice s = {})roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor