diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index e62e93b3a20..8c079e785e3 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -13,6 +13,7 @@ on: - extension/audio/** - examples/models/parakeet/** - examples/models/voxtral_realtime/** + - examples/models/qwen3_5_moe/** workflow_dispatch: permissions: {} @@ -63,6 +64,61 @@ jobs: ./cmake-out/backends/mlx/test/multi_thread_test_runner echo "::endgroup::" + echo "::group::Run gated_delta_rule op tests" + ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v + echo "::endgroup::" + + test-mlx-qwen35-moe: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-qwen35-moe + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Qwen 3.5 MoE (tiny model)" + ${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.export \ + --tiny-test \ + --backend mlx \ + --qlinear 4w \ + --qlinear-group-size 32 \ + --output-dir /tmp/qwen35_moe_mlx_tiny + echo "::endgroup::" + + echo "::group::Check AsType node count" + ASTYPE_COUNT=$(${CONDA_RUN} python -m executorch.backends.mlx.pte_inspector \ + /tmp/qwen35_moe_mlx_tiny/model.pte --mlx-instructions 2>&1 | grep -c "AsTypeNode" || true) + echo "AsType nodes: ${ASTYPE_COUNT}" + if [ "$ASTYPE_COUNT" -gt 23 ]; then + echo "Failed: expected no more than 23 AsType nodes, got ${ASTYPE_COUNT}" + exit 1 + fi + echo "::endgroup::" + + echo "::group::Run Qwen 3.5 MoE inference" + OUTPUT=$(${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.run \ + --pte /tmp/qwen35_moe_mlx_tiny/model.pte \ + --prompt-len 4 \ + --max-new-tokens 5 2>&1) + echo "$OUTPUT" + if echo "$OUTPUT" | grep -q "Generated token ids: \[167, 167, 81, 167, 81\]"; then + echo "Success: Qwen 3.5 MoE MLX export + inference completed with expected output" + else + echo "Failed: unexpected output (expected [167, 167, 81, 167, 81])" + exit 1 + fi + echo "::endgroup::" + backend-tester: strategy: fail-fast: false diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt index 00e7c497b1c..43968d09b5d 100644 --- a/backends/mlx/CMakeLists.txt +++ b/backends/mlx/CMakeLists.txt @@ -247,6 +247,14 @@ add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx) # Op logging option (for debugging) - OFF by default for performance option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF) +# Custom kernel execution - OFF by default for security. When enabled, +# MetalKernelNode can execute arbitrary Metal shader code embedded in .pte +# files. Only enable for trusted .pte sources. +option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION + "Allow MetalKernelNode to execute custom Metal shaders from .pte files" + ON +) + set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp ) @@ -262,6 +270,13 @@ if(ET_MLX_ENABLE_OP_LOGGING) message(STATUS "MLX delegate op logging ENABLED") endif() +if(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION) + target_compile_definitions( + mlxdelegate PRIVATE ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION + ) + message(STATUS "MLX delegate custom kernel execution ENABLED") +endif() + target_include_directories( mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime ) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 2add4f1b7a3..0892476fedd 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -23,6 +23,7 @@ import traceback from collections import defaultdict +from contextlib import contextmanager from dataclasses import dataclass from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union @@ -172,6 +173,24 @@ def emit_init(self, op: OpNodeUnion) -> None: self._chains.append([]) self._chains[self.init_chain_idx].append(Instruction(op=op)) + @contextmanager + def new_chain(self): + """Context manager that creates a new instruction chain and redirects emit() to it. + + Usage: + with P.new_chain() as chain_idx: + P.emit(MulNode(...)) # goes to the new chain + # P.emit() goes back to the previous chain + """ + chain_idx = len(self._chains) + self._chains.append([]) + prev_chain = self._current_chain + self._current_chain = chain_idx + try: + yield chain_idx + finally: + self._current_chain = prev_chain + def args(self, node: Node) -> Tuple[Any, ...]: return self.slot_map(node.args) @@ -629,9 +648,12 @@ def _verify_build(self): info.handler in (noop_handler, PatternHandler.deferred_handler) or n.users == {} ): - assert ( - self.slot_manager.get_slot(n) is None - ), f"Did not expect node {n} handled by {info.handler} to have a slot" + # Deferred body nodes may or may not have slots — this is fine. + # Pattern handlers absorb nodes into their body and may set + # slots on them (e.g., GatedDeltaRuleHandler sets getitem[0]'s + # slot to the ScanNode output). Dead nodes (no users) also + # skip the slot check. + pass else: assert ( self.slot_manager.get_slot(n) is not None @@ -962,6 +984,11 @@ def get_named_data_store(self) -> NamedDataStore: ``ep.constants`` / ``extra_constants`` (which all use unprefixed keys). The prefix is applied at the exit boundary — the ``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``. + + To reduce peak memory, each constant is deleted from the EP + immediately after its bytes are added to the NamedDataStore. + This avoids holding two full copies of all constants simultaneously + (important for large models where constants can be 20+ GB). """ named_data_store = NamedDataStore() @@ -971,6 +998,17 @@ def get_named_data_store(self) -> NamedDataStore: key=lambda x: self._slot_to_final_tid.get(x[1], 0), ) + # Free EP constants not used by the MLX graph to reduce peak memory. + used = set(self._constant_name_to_slot.keys()) + for ispec in self.ep.graph_signature.input_specs: + if ispec.arg.name in used and ispec.target is not None: + used.add(ispec.target) + + for d in (self.ep._state_dict, self.ep._constants): + for name in list(d.keys()): + if name not in used and isinstance(d[name], torch.Tensor): + del d[name] + logger.debug(f"Adding {len(entries)} constants to NamedDataStore...") for canonical_name, _slot in entries: tensor = self._find_constant_tensor(canonical_name) @@ -983,6 +1021,15 @@ def get_named_data_store(self) -> NamedDataStore: data=t, alignment=16, ) + + # Free the original tensor from the EP immediately. + # The contiguous copy is now serialized as bytes in the + # NamedDataStore — the EP reference is no longer needed. + # (It would be deleted by lowered_backend_module.py after + # preprocess() returns anyway.) + self._delete_constant_tensor(canonical_name) + del tensor, t + logger.debug("Done adding constants to NamedDataStore") return named_data_store @@ -1011,17 +1058,33 @@ def get_mutable_buffer_names(self) -> List[str]: def _find_constant_tensor(self, name: str) -> Optional[torch.Tensor]: """Find a constant tensor by name from various sources.""" - if name in self.ep.state_dict: - return self.ep.state_dict[name] - if name in self.ep.constants: - return self.ep.constants[name] + result = self._resolve_constant(name) + if result is None: + return None + + d, k = result + return d[k] + + def _delete_constant_tensor(self, name: str) -> None: + """Delete a constant from the EP to free memory during serialization.""" + + result = self._resolve_constant(name) + if result: + d, k = result + del d[k] + + def _resolve_constant(self, name): + """Returns (dict, key) or None.""" + if name in self.ep._state_dict: + return self.ep._state_dict, name + if name in self.ep._constants: + return self.ep._constants, name if name in self.extra_constants: - return self.extra_constants[name] - # Look up by target + return self.extra_constants, name for ispec in self.ep.graph_signature.input_specs: if ispec.arg.name == name and ispec.target is not None: - if ispec.target in self.ep.state_dict: - return self.ep.state_dict[ispec.target] - if ispec.target in self.ep.constants: - return self.ep.constants[ispec.target] + if ispec.target in self.ep._state_dict: + return self.ep._state_dict, ispec.target + if ispec.target in self.ep._constants: + return self.ep._constants, ispec.target return None diff --git a/backends/mlx/builder/slot_manager.py b/backends/mlx/builder/slot_manager.py index b1884a76a68..d2c39597cdd 100644 --- a/backends/mlx/builder/slot_manager.py +++ b/backends/mlx/builder/slot_manager.py @@ -30,12 +30,26 @@ class IdSpace(Enum): Temp = auto() -@dataclass(frozen=True) +@dataclass(eq=False, frozen=True) class Slot: + """Represents an allocated tensor or symbolic int slot. + + Uses identity-based equality and hashing (not field-based) so that + two Slots with the same (id_type, id_space, idx) — which can happen + when the delete-as-you-go allocator recycles an idx — remain distinct + in sets and dicts during build(). + """ + id_type: IdType id_space: IdSpace idx: Optional[int] = None + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + class IdManager: def __init__(self): diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index 8ad891e3568..d7d6288ba8f 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -269,3 +269,117 @@ def rope_fake( ) -> Tensor: """Fake implementation for tracing.""" return x.new_empty(x.shape) + + +@torch.library.custom_op("mlx::gather_mm", mutates_args=()) +def gather_mm( + a: Tensor, # [..., M, K] + b: Tensor, # [E, K, N] or [..., K, N] + rhs_indices: Optional[Tensor] = None, # Expert selection indices + lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices + sorted_indices: bool = False, +) -> Tensor: + """ + Gather matrix multiply — matches mlx::core::gather_mm semantics exactly. + + Output shape = broadcast(lhs_indices, rhs_indices).shape + [M, N] + where M = a.shape[-2], N = b.shape[-1]. + + For MoE: a=[N_tokens, 1, K], b=[E, K, out], rhs_indices=[N_tokens] + → output=[N_tokens, 1, out]. Caller squeezes dim -2. + """ + if rhs_indices is not None: + b_sel = b[rhs_indices] + else: + b_sel = b + return torch.matmul(a, b_sel) + + +@torch.library.register_fake("mlx::gather_mm") +def gather_mm_fake( + a: Tensor, + b: Tensor, + rhs_indices: Optional[Tensor] = None, + lhs_indices: Optional[Tensor] = None, + sorted_indices: bool = False, +) -> Tensor: + # Matches MLX: output = indices.shape + [M, N] + # For simplicity, use matmul shape rules after gather + M = a.shape[-2] + N = b.shape[-1] + if rhs_indices is not None: + batch = rhs_indices.shape + else: + batch = b.shape[:-2] + return a.new_empty((*batch, M, N)) + + +@torch.library.custom_op("mlx::gather_qmm", mutates_args=()) +def gather_qmm( + x: Tensor, # [..., M, K] + w: Tensor, # [E, out, in_packed] + scales: Tensor, # [E, out, in//gs] + biases: Optional[Tensor] = None, # [E, out, in//gs] (affine mode) + rhs_indices: Optional[Tensor] = None, # Expert selection indices + lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices + transpose: bool = True, + group_size: int = 32, + bits: int = 4, + mode: str = "affine", + sorted_indices: bool = False, +) -> Tensor: + """ + Gather quantized matrix multiply — matches mlx::core::gather_qmm semantics. + + Output shape = broadcast(lhs_indices, rhs_indices).shape + [M, N] + + For MoE: x=[N_tokens, 1, K], w=[E, out, K_packed], rhs_indices=[N_tokens] + → output=[N_tokens, 1, out]. Caller squeezes dim -2. + """ + # Eager fallback: gather, dequantize, matmul + if rhs_indices is not None: + w_sel = w[rhs_indices] + s_sel = scales[rhs_indices] + b_sel = biases[rhs_indices] if biases is not None else None + else: + w_sel = w + s_sel = scales + b_sel = biases + + # Dequantize + w_float = w_sel.to(x.dtype) + s_expanded = s_sel.repeat_interleave(group_size, dim=-1) + if b_sel is not None: + b_expanded = b_sel.repeat_interleave(group_size, dim=-1) + w_dequant = w_float * s_expanded + b_expanded + else: + w_dequant = w_float * s_expanded + + if transpose: + w_dequant = w_dequant.transpose(-1, -2) + + return torch.matmul(x, w_dequant) + + +@torch.library.register_fake("mlx::gather_qmm") +def gather_qmm_fake( + x: Tensor, + w: Tensor, + scales: Tensor, + biases: Optional[Tensor] = None, + rhs_indices: Optional[Tensor] = None, + lhs_indices: Optional[Tensor] = None, + transpose: bool = True, + group_size: int = 32, + bits: int = 4, + mode: str = "affine", + sorted_indices: bool = False, +) -> Tensor: + # Matches MLX: output = indices.shape + [M, N] + M = x.shape[-2] + N = w.shape[-2] if transpose else w.shape[-1] + if rhs_indices is not None: + batch = rhs_indices.shape + else: + batch = w.shape[:-2] + return x.new_empty((*batch, M, N)) diff --git a/backends/mlx/llm/switch.py b/backends/mlx/llm/switch.py new file mode 100644 index 00000000000..28d408cbd71 --- /dev/null +++ b/backends/mlx/llm/switch.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +SwitchLinear — per-expert linear layer using mlx::gather_mm / mlx::gather_qmm. + +A self-contained expert linear primitive (like nn.Linear but with expert +selection via index-based gather). Mirrors mlx-lm's SwitchLinear / +QuantizedSwitchLinear but uses PyTorch nn.Module with torchao quantization +support. + +Lifecycle: + 1. __init__: creates per-expert nn.Linear in nn.ModuleList + 2. (optional) quantize_model_(): quantizes all expert linears + 3. pack(): stacks weights into 3D buffers, deletes ModuleList + 4. forward(x, indices): unsqueeze → gather_mm/gather_qmm → squeeze + +Usage: + from executorch.backends.mlx.llm.switch import SwitchLinear, pack_all_switch_linears + + gate_proj = SwitchLinear(hidden, inter, num_experts) + up_proj = SwitchLinear(hidden, inter, num_experts) + down_proj = SwitchLinear(inter, hidden, num_experts) + + # After optional quantize_model_(): + pack_all_switch_linears(model) + + # In forward: + for k in range(top_k): + idx = expert_indices[:, k] + gate = gate_proj(x, idx) + up = up_proj(x, idx) + h = F.silu(gate) * up + down = down_proj(h, idx) + output += routing_weights[:, k:k+1] * down +""" + +import logging + +import torch +import torch.nn as nn + +# Import MLX custom ops to register mlx::gather_mm and mlx::gather_qmm +from executorch.backends.mlx import custom_ops as _mlx_custom_ops # noqa: F401 + +logger = logging.getLogger(__name__) + +__all__ = ["SwitchLinear", "SwitchMLP", "pack_all_switch_linears"] + + +class SwitchLinear(nn.Module): + """Per-expert linear layer using mlx::gather_mm / mlx::gather_qmm. + + Stores expert weights as nn.ModuleList of nn.Linear, so quantize_model_() + naturally quantizes them. After quantization (or without it), call pack() + to stack weights into 3D buffers for the MLX gather custom ops. + + Args: + input_dims: Input feature dimension + output_dims: Output feature dimension + num_experts: Number of experts + bias: Whether to use bias (default: False) + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + bias: bool = False, + ): + super().__init__() + self.input_dims = input_dims + self.output_dims = output_dims + self.num_experts = num_experts + self._packed = False + self._is_quantized = False + + self.experts = nn.ModuleList( + [nn.Linear(input_dims, output_dims, bias=bias) for _ in range(num_experts)] + ) + + def pack(self): + """Stack per-expert weights into 3D buffers and delete the ModuleList. + + Must be called after quantization (if any) and before forward/export. + + - Quantized: extracts inner tensors (qdata, scale, zero_point), + stacks into [E, out, in_packed] buffers. Weight layout matches + mlx::gather_qmm's expectations (transpose=True handles transposition). + - Unquantized: stacks weight.data into [E, out, in], then pretransposes + to [E, in, out] so gather_mm receives the correct layout directly + (no runtime transpose needed). + """ + if self._packed: + return + + w0 = self.experts[0].weight + self._is_quantized = hasattr(w0, "qdata") + + if self._is_quantized: + _, metadata = w0.__tensor_flatten__() + self.group_size = metadata["block_size"][-1] + + self.register_buffer( + "qdata", + torch.stack([e.weight.qdata for e in self.experts]), + ) + self.register_buffer( + "scale", + torch.stack([e.weight.scale for e in self.experts]), + ) + self.register_buffer( + "zero_point", + torch.stack([e.weight.zero_point for e in self.experts]), + ) + else: + # Stack [E, out, in] then pretranspose to [E, in, out] + stacked = torch.stack([e.weight.data for e in self.experts]) + self.register_buffer("weight", stacked.transpose(-1, -2).contiguous()) + + del self.experts + self._packed = True + + def forward( + self, + x: torch.Tensor, + indices: torch.Tensor, + sorted_indices: bool = False, + ) -> torch.Tensor: + """Forward without unsqueeze/squeeze — caller manages dimensions. + + Used by UnfusedMoEExperts which passes x as [N, 1, 1, D] + and indices as [N, top_k] to handle all experts at once. + """ + if not self._packed: + raise RuntimeError("SwitchLinear.pack() must be called before forward_raw.") + + if self._is_quantized: + return torch.ops.mlx.gather_qmm( + x, + self.qdata, + self.scale, + biases=self.zero_point, + rhs_indices=indices, + transpose=True, + group_size=self.group_size, + sorted_indices=sorted_indices, + ) + else: + return torch.ops.mlx.gather_mm( + x, + self.weight, + rhs_indices=indices, + sorted_indices=sorted_indices, + ) + + +class SwitchMLP(nn.Module): + """Gated MoE MLP using SwitchLinear for each projection. + + Bundles gate + up + down projections with gated activation and optional + expert sorting into a single reusable component. Works with any gated + activation (SwiGLU, GeGLU, ReGLU, etc.). + + When fuse_gate_up=True, gate and up projections share a single + SwitchLinear with output dim 2*intermediate_size. This reduces + gather_mm/gather_qmm calls from 3 to 2 per forward pass (one + fused gate+up gather, one down gather). The output is split via + a cheap tensor slice. + + Args: + hidden_size: Model hidden dimension (input/output of MLP) + intermediate_size: MLP intermediate dimension (per expert) + num_experts: Number of experts + activation: Gating activation function (default: F.silu for SwiGLU) + bias: Whether expert linears use bias + fuse_gate_up: Fuse gate and up projections into a single SwitchLinear + (default: False). When True, uses one [E, 2*inter, D] weight + instead of two [E, inter, D] weights, saving one gather call. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_experts: int, + activation=None, + bias: bool = False, + fuse_gate_up: bool = False, + ): + super().__init__() + if activation is None: + activation = nn.functional.silu + self.activation = activation + self.num_experts = num_experts + self.intermediate_size = intermediate_size + self.fuse_gate_up = fuse_gate_up + + if fuse_gate_up: + self.gate_up_proj = SwitchLinear( + hidden_size, 2 * intermediate_size, num_experts, bias=bias + ) + else: + self.gate_proj = SwitchLinear( + hidden_size, intermediate_size, num_experts, bias=bias + ) + self.up_proj = SwitchLinear( + hidden_size, intermediate_size, num_experts, bias=bias + ) + self.down_proj = SwitchLinear( + intermediate_size, hidden_size, num_experts, bias=bias + ) + + def forward( + self, + x: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + top_k: int, + sort_experts: bool = False, + ) -> torch.Tensor: + """Forward pass through the gated MoE MLP. + + Args: + x: Input activations [N, D] + expert_weights: Routing weights [N, top_k] (already softmaxed) + expert_indices: Expert assignments [N, top_k] + top_k: Number of experts per token + sort_experts: Sort tokens by expert index for coalesced memory + access during prefill. No effect on decode (single token). + + Returns: + Output tensor [N, D] + """ + N = x.shape[0] + + if sort_experts: + flat_indices = expert_indices.flatten() + order = flat_indices.argsort().to(torch.int32) + inv_order = order.argsort().to(torch.int32) + sorted_idx = flat_indices[order].to(torch.int32) + x_sorted = x[(order // top_k).to(torch.int64)] + x_input = x_sorted.unsqueeze(-2) + idx = sorted_idx + else: + x_input = x.unsqueeze(-2).unsqueeze(-2) + idx = expert_indices + + if self.fuse_gate_up: + gate_up = self.gate_up_proj(x_input, idx, sorted_indices=sort_experts) + gate = gate_up[..., : self.intermediate_size] + up = gate_up[..., self.intermediate_size :] + else: + gate = self.gate_proj(x_input, idx, sorted_indices=sort_experts) + up = self.up_proj(x_input, idx, sorted_indices=sort_experts) + h = self.activation(gate) * up + down = self.down_proj(h, idx, sorted_indices=sort_experts) + + if sort_experts: + down = down.squeeze(-2) + down = down[inv_order].reshape(N, top_k, -1) + else: + down = down.squeeze(-2) + + return (down * expert_weights.unsqueeze(-1)).sum(dim=-2) + + +def pack_all_switch_linears(model: nn.Module) -> int: + """Call pack() on all SwitchLinear modules in the model. + + Args: + model: The model to pack + + Returns: + Number of SwitchLinear modules packed + """ + count = 0 + for _name, module in model.named_modules(): + if isinstance(module, SwitchLinear): + module.pack() + count += 1 + if count > 0: + logger.info(f"Packed {count} SwitchLinear modules") + return count diff --git a/backends/mlx/model_ops/__init__.py b/backends/mlx/model_ops/__init__.py new file mode 100644 index 00000000000..56fad3a491e --- /dev/null +++ b/backends/mlx/model_ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Model-specific MLX custom op implementations.""" diff --git a/backends/mlx/model_ops/gated_delta_rule.py b/backends/mlx/model_ops/gated_delta_rule.py new file mode 100644 index 00000000000..ead73b00ff5 --- /dev/null +++ b/backends/mlx/model_ops/gated_delta_rule.py @@ -0,0 +1,565 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Gated delta rule custom op and pattern handler for MLX backend. + +This module defines: +1. mlx::gated_delta_rule custom op with mutates_args=("state",) +2. GatedDeltaRuleHandler pattern — matches the auto_functionalized_v2 + wrapper that edge decomposition inserts for mutating ops + +After edge decomposition the graph looks like: + auto_func = auto_functionalized_v2(mlx.gated_delta_rule, q=..., ...) + getitem = auto_func[0] # output tensor (USER_OUTPUT) + getitem_1 = auto_func[1] # mutated state (BUFFER_MUTATION) + return (getitem_1, getitem) + +The pattern handler uses HEAD = getitem[1] (same as ETKVCacheUpdateHandler) +because the partitioner needs the BUFFER_MUTATION node as a proper subgraph +output. getitem[0] is left for the normal _getitem_handler to process. +""" + +from __future__ import annotations + +from typing import List, Optional + +import torch +from torch import Tensor +from torch.fx.node import Node + + +@torch.library.custom_op("mlx::gated_delta_rule", mutates_args=("state",)) +def gated_delta_rule( + q: Tensor, # [B, T, Hk, Dk] + k: Tensor, # [B, T, Hk, Dk] + v: Tensor, # [B, T, Hv, Dv] + g: Tensor, # [B, T, Hv] — decay gate + beta: Tensor, # [B, T, Hv] — update gate + state: Tensor, # [B, Hv, Dv, Dk] — recurrent state (MUTATED in place) + use_custom_kernel: bool = True, +) -> Tensor: + """ + Gated delta rule recurrence — sequential scan over T. + + Returns: + output: [B, T, Hv, Dv] + """ + B, T_len, Hk, Dk = q.shape + Hv, Dv = v.shape[-2:] + + s = state.clone() + + ys = [] + for t in range(T_len): + q_t = q[:, t] + k_t = k[:, t] + v_t = v[:, t] + g_t = g[:, t] + beta_t = beta[:, t] + + s = s * g_t[:, :, None, None] + kv_mem = (s * k_t[:, :, None, :]).sum(dim=-1) + delta = (v_t - kv_mem) * beta_t[:, :, None] + s = s + k_t[:, :, None, :] * delta[:, :, :, None] + y_t = (s * q_t[:, :, None, :]).sum(dim=-1) + ys.append(y_t) + + state.copy_(s) + + return torch.stack(ys, dim=1) + + +@torch.library.register_fake("mlx::gated_delta_rule") +def gated_delta_rule_fake( + q: Tensor, + k: Tensor, + v: Tensor, + g: Tensor, + beta: Tensor, + state: Tensor, + use_custom_kernel: bool = True, +) -> Tensor: + B, T = q.shape[:2] + Hv, Dv = v.shape[-2:] + return v.new_empty(B, T, Hv, Dv) + + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddNode, + ExpandDimsNode, + IdCopyNode, + IntOrVid, + MetalKernelNode, + MultiplyNode, + ScanNode, + SubtractNode, + SumNode, +) +from torch.export.exported_program import ExportedProgram + + +class GatedDeltaRuleHandler(PatternHandler): + """ + Pattern for gated delta rule state mutation. + + HEAD = getitem[1] (BUFFER_MUTATION — mutated state) + BODY = [auto_func_node, getitem_0] + + Both getitem nodes are handled by this pattern to prevent + _getitem_handler from calling slot_map on auto_func_node + (which would create a slot on the deferred body node and + fail _verify_build). The HEAD handler sets slots for both. + """ + + def __init__( + self, + head: Node, + body: List[Node], + auto_func_node: Node, + getitem_0: Node, + q: Node, + k: Node, + v: Node, + g: Node, + beta: Node, + state: Node, + ): + super().__init__(head, body) + self.auto_func_node = auto_func_node + self.getitem_0 = getitem_0 + self.q_node = q + self.k_node = k + self.v_node = v + self.g_node = g + self.beta_node = beta + self.state_node = state + + @staticmethod + def _is_auto_func_gated_delta_rule(node: Node) -> bool: + if node.op != "call_function": + return False + if "auto_functionalized" not in str(node.target): + return False + if len(node.args) < 1: + return False + func_str = str(node.args[0]) if node.args[0] else "" + return "gated_delta_rule" in func_str and "mlx" in func_str + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["GatedDeltaRuleHandler"]: + """ + Match HEAD = getitem[1] from auto_functionalized_v2(gated_delta_rule). + """ + if head.op != "call_function" or "getitem" not in str(head.target): + return None + if len(head.args) < 2 or head.args[1] != 1: + return None + if not isinstance(head.args[0], Node): + return None + + auto_func_node = head.args[0] + if not cls._is_auto_func_gated_delta_rule(auto_func_node): + return None + + kwargs = auto_func_node.kwargs + q = kwargs.get("q") + k = kwargs.get("k") + v = kwargs.get("v") + g = kwargs.get("g") + beta = kwargs.get("beta") + all_bases = kwargs.get("_all_bases", []) + + if not all([q, k, v, g, beta]) or not all_bases: + return None + + state = all_bases[0] + + # Find getitem[0] (output tensor) among auto_func's users + getitem_0 = None + for user in auto_func_node.users: + if ( + user.op == "call_function" + and "getitem" in str(user.target) + and len(user.args) >= 2 + and user.args[1] == 0 + ): + getitem_0 = user + break + + if getitem_0 is None: + return None + + return cls( + head=head, + body=[auto_func_node, getitem_0], + auto_func_node=auto_func_node, + getitem_0=getitem_0, + q=q, + k=k, + v=v, + g=g, + beta=beta, + state=state, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + + q_meta = self.q_node.meta["val"] + Dk = int(q_meta.shape[-1]) + + # Read use_custom_kernel from the op's kwargs in the graph + use_custom_kernel = self.auto_func_node.kwargs.get("use_custom_kernel", True) + + if use_custom_kernel: + if Dk % 32 != 0: + raise ValueError( + f"MetalKernelNode requires Dk to be a multiple of 32, got Dk={Dk}. " + f"Set use_custom_kernel=False to use the ScanNode fallback." + ) + return self._emit_metal_kernel(P, n) + return self._emit_scan(P, n) + + def _emit_metal_kernel(self, P: MLXProgramBuilder, n: Node) -> Slot: + """Emit a fused MetalKernelNode for the gated delta recurrence.""" + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + FullNode, + MultiplyIntNode, + SymSizeNode, + ) + + q_slot, k_slot, v_slot, g_slot, beta_slot, state_slot = P.slot_map( + [ + self.q_node, + self.k_node, + self.v_node, + self.g_node, + self.beta_node, + self.state_node, + ] + ) + + # Extract shapes from metadata + q_meta = self.q_node.meta["val"] + v_meta = self.v_node.meta["val"] + _, _, Hk, Dk = q_meta.shape + Hv, Dv = v_meta.shape[-2:] + dtype_int = torch_dtype_to_scalar_type(q_meta.dtype) + + # B and T are potentially dynamic — extract as runtime Vids via SymSizeNode + _, b_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(q_slot), + dim=0, + out=P.slot_to_vid(b_val), + ) + ) + _, t_val = P.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(q_slot), + dim=1, + out=P.slot_to_vid(t_val), + ) + ) + + # grid[2] = B * Hv (computed at runtime) + _, b_times_hv = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=P.to_int_or_vid(b_val), + b=IntOrVid.from_literal(int(Hv)), + out=P.slot_to_vid(b_times_hv), + ) + ) + + # T as a 0-D int32 tensor for the kernel input (created at runtime from Vid) + _, t_tensor = P.make_tmp_slot() + P.emit( + FullNode( + out=P.slot_to_tid(t_tensor), + shape=[], + v=P.to_float_or_vid(t_val), + scalar_type=torch_dtype_to_scalar_type(torch.int32), + ) + ) + + # B as IntOrVid for output shapes + b_iov = P.to_int_or_vid(b_val) + t_iov = P.to_int_or_vid(t_val) + + # Output slot for y — use existing IO slot if getitem_0 is a graph output, + # otherwise create a new temp slot. + out = P.make_or_get_slot(self.getitem_0) + + # Output slot for state_out (carry) + _, carry = P.make_tmp_slot() + + # Metal kernel source (non-vectorized, no mask variant from mlx-lm) + source = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + // q, k: [B, T, Hk, Dk] + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + + // v, y: [B, T, Hv, Dv] + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + // state_in, state_out: [B, Hv, Dv, Dk] + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); + } + + // g: [B, T, Hv] + auto g_ = g + b_idx * T * Hv; + auto beta_ = beta + b_idx * T * Hv; + + for (int t = 0; t < T; ++t) { + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * g_[hv_idx]; + kv_mem += state[i] * k_[s_idx]; + } + kv_mem = simd_sum(kv_mem); + + auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx]; + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + k_[s_idx] * delta; + out += state[i] * q_[s_idx]; + } + out = simd_sum(out); + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + } + // Increment data pointers to next time step + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + g_ += Hv; + beta_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); + } + """ + + # Output shapes: y=[B,T,Hv,Dv], state_out=[B,Hv,Dv,Dk] + # B and T are dynamic (Vids), Hv/Dv/Dk are static literals + output_shapes_flat = [ + # y shape + b_iov, + t_iov, + IntOrVid.from_literal(int(Hv)), + IntOrVid.from_literal(int(Dv)), + # state_out shape + b_iov, + IntOrVid.from_literal(int(Hv)), + IntOrVid.from_literal(int(Dv)), + IntOrVid.from_literal(int(Dk)), + ] + output_shape_lengths = [4, 4] + + P.emit( + MetalKernelNode( + name="gated_delta_step", + source=source, + inputs=[ + P.slot_to_tid(q_slot), + P.slot_to_tid(k_slot), + P.slot_to_tid(v_slot), + P.slot_to_tid(g_slot), + P.slot_to_tid(beta_slot), + P.slot_to_tid(state_slot), + P.slot_to_tid(t_tensor), + ], + outputs=[P.slot_to_tid(out), P.slot_to_tid(carry)], + grid=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(int(Dv)), + P.to_int_or_vid(b_times_hv), + ], + threadgroup=[ + IntOrVid.from_literal(32), + IntOrVid.from_literal(4), + IntOrVid.from_literal(1), + ], + input_names=["q", "k", "v", "g", "beta", "state_in", "T"], + output_names=["y", "state_out"], + output_shapes_flat=output_shapes_flat, + output_shape_lengths=output_shape_lengths, + output_dtypes=[dtype_int, dtype_int], + template_arg_names=["InT", "Dk", "Dv", "Hk", "Hv"], + template_arg_kinds=[2, 0, 0, 0, 0], # 2=dtype, 0=int + template_arg_values=[dtype_int, int(Dk), int(Dv), int(Hk), int(Hv)], + ) + ) + + # HEAD is getitem[1] = mutated state → bind to carry + P.set_slot(n, carry) + P.set_slot(self.getitem_0, out) + + return carry + + def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot: + """Emit ScanNode decomposition of the gated delta recurrence.""" + + q_slot, k_slot, v_slot, g_slot, beta_slot, state_slot = P.slot_map( + [ + self.q_node, + self.k_node, + self.v_node, + self.g_node, + self.beta_node, + self.state_node, + ] + ) + + # Carry needs a writable temp slot + _, carry = P.make_tmp_slot() + P.emit(IdCopyNode(x=P.slot_to_tid(state_slot), out=P.slot_to_tid(carry))) + + # Sliced temp slots for per-step inputs + _, q_s = P.make_tmp_slot() + _, k_s = P.make_tmp_slot() + _, v_s = P.make_tmp_slot() + _, g_s = P.make_tmp_slot() + _, beta_s = P.make_tmp_slot() + + # Output slot for the recurrence output. + out = P.make_or_get_slot(self.getitem_0) + + # Body temp slots + _, t0 = P.make_tmp_slot() + _, t1 = P.make_tmp_slot() + _, t2 = P.make_tmp_slot() + + with P.new_chain() as body_idx: + # state = state * g_t[:, :, None, None] + P.emit(ExpandDimsNode(x=P.slot_to_tid(g_s), out=P.slot_to_tid(t0), axis=-1)) + P.emit(ExpandDimsNode(x=P.slot_to_tid(t0), out=P.slot_to_tid(t0), axis=-1)) + P.emit( + MultiplyNode( + a=P.slot_to_tid(carry), + b=P.slot_to_tid(t0), + out=P.slot_to_tid(carry), + ) + ) + + # kv_mem = (state * k_t[:, :, None, :]).sum(-1) + P.emit(ExpandDimsNode(x=P.slot_to_tid(k_s), out=P.slot_to_tid(t0), axis=-2)) + P.emit( + MultiplyNode( + a=P.slot_to_tid(carry), b=P.slot_to_tid(t0), out=P.slot_to_tid(t1) + ) + ) + P.emit(SumNode(x=P.slot_to_tid(t1), out=P.slot_to_tid(t1), axes=[-1])) + + # delta = (v_t - kv_mem) * beta_t[:, :, None] + P.emit( + SubtractNode( + a=P.slot_to_tid(v_s), b=P.slot_to_tid(t1), out=P.slot_to_tid(t1) + ) + ) + P.emit( + ExpandDimsNode(x=P.slot_to_tid(beta_s), out=P.slot_to_tid(t2), axis=-1) + ) + P.emit( + MultiplyNode( + a=P.slot_to_tid(t1), b=P.slot_to_tid(t2), out=P.slot_to_tid(t1) + ) + ) + + # state = state + k[:,:,None,:] * delta[:,:,:,None] + P.emit(ExpandDimsNode(x=P.slot_to_tid(k_s), out=P.slot_to_tid(t2), axis=-2)) + P.emit(ExpandDimsNode(x=P.slot_to_tid(t1), out=P.slot_to_tid(t1), axis=-1)) + P.emit( + MultiplyNode( + a=P.slot_to_tid(t2), b=P.slot_to_tid(t1), out=P.slot_to_tid(t2) + ) + ) + P.emit( + AddNode( + a=P.slot_to_tid(carry), + b=P.slot_to_tid(t2), + out=P.slot_to_tid(carry), + ) + ) + + # y_t = (state * q_t[:,:,None,:]).sum(-1) + P.emit(ExpandDimsNode(x=P.slot_to_tid(q_s), out=P.slot_to_tid(t0), axis=-2)) + P.emit( + MultiplyNode( + a=P.slot_to_tid(carry), b=P.slot_to_tid(t0), out=P.slot_to_tid(t0) + ) + ) + P.emit(SumNode(x=P.slot_to_tid(t0), out=P.slot_to_tid(out), axes=[-1])) + + # Emit the ScanNode + P.emit( + ScanNode( + body_chain_idx=body_idx, + scan_axis=1, + originals=[ + P.slot_to_tid(s) + for s in [q_slot, k_slot, v_slot, g_slot, beta_slot] + ], + sliced=[P.slot_to_tid(s) for s in [q_s, k_s, v_s, g_s, beta_s]], + outputs=[P.slot_to_tid(out)], + carry=[P.slot_to_tid(carry)], + ) + ) + + # HEAD is getitem[1] = mutated state → bind to carry + P.set_slot(n, carry) + + # Set getitem[0] slot → output tensor (for downstream computation) + P.set_slot(self.getitem_0, out) + + return carry + + +_registered = False + + +def register(): + global _registered + if _registered: + return + REGISTRY.register_pattern(name="GATED_DELTA_RULE")(GatedDeltaRuleHandler) + _registered = True + + +register() diff --git a/backends/mlx/model_ops/test_gated_delta_rule.py b/backends/mlx/model_ops/test_gated_delta_rule.py new file mode 100644 index 00000000000..10dceef14b1 --- /dev/null +++ b/backends/mlx/model_ops/test_gated_delta_rule.py @@ -0,0 +1,988 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for mlx::gated_delta_rule custom op + pattern handler. + +Usage: + # Run all configs: + python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run + + # Run with verbose output: + python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v + + # Rebuild C++ runner first: + python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.gated_delta_rule # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class GatedDeltaRuleModel(nn.Module): + """Model using mlx::gated_delta_rule for sequential recurrence.""" + + def __init__( + self, + batch_size: int, + num_heads: int, + head_dim: int, + value_dim: int, + use_custom_kernel: bool = True, + ): + super().__init__() + self.use_custom_kernel = use_custom_kernel + self.register_buffer( + "state", + torch.zeros(batch_size, num_heads, value_dim, head_dim), + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.mlx.gated_delta_rule( + q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel + ) + + +class GatedDeltaRuleGQAModel(nn.Module): + """Model with Hk < Hv (GQA-style), matching real Qwen 3.5 config. + + Q and K have num_k_heads heads, V has num_v_heads heads. + Q and K are repeat_interleaved to match num_v_heads before the custom op call, + matching the pattern in _exportable_gated_delta_net_forward. + """ + + def __init__( + self, + batch_size: int, + num_k_heads: int, + num_v_heads: int, + head_dim: int, + value_dim: int, + use_custom_kernel: bool = True, + ): + super().__init__() + assert num_v_heads % num_k_heads == 0 + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_repeat = num_v_heads // num_k_heads + self.use_custom_kernel = use_custom_kernel + self.register_buffer( + "state", + torch.zeros(batch_size, num_v_heads, value_dim, head_dim), + ) + + def forward( + self, + q: torch.Tensor, # [B, T, Hk, Dk] + k: torch.Tensor, # [B, T, Hk, Dk] + v: torch.Tensor, # [B, T, Hv, Dv] + g: torch.Tensor, # [B, T, Hv] + beta: torch.Tensor, # [B, T, Hv] + ) -> torch.Tensor: + if self.head_repeat > 1: + q = q.repeat_interleave(self.head_repeat, dim=2) + k = k.repeat_interleave(self.head_repeat, dim=2) + return torch.ops.mlx.gated_delta_rule( + q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel + ) + + +class GatedDeltaRuleMultiStepModel(nn.Module): + """Model that calls gated_delta_rule TWICE to test state carry-forward. + + The second call reads the state mutated by the first call. + If state doesn't persist, out2 would be identical to running with + zero state — which is wrong. + """ + + def __init__( + self, + batch_size: int, + num_heads: int, + head_dim: int, + value_dim: int, + use_custom_kernel: bool = False, + ): + super().__init__() + self.use_custom_kernel = use_custom_kernel + self.register_buffer( + "state", + torch.zeros(batch_size, num_heads, value_dim, head_dim), + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # Step 1: process inputs, mutates self.state + out1 = torch.ops.mlx.gated_delta_rule( + q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel + ) + # Step 2: same inputs, but state carries from step 1 + out2 = torch.ops.mlx.gated_delta_rule( + q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel + ) + # Return concatenated so we can verify both outputs + return torch.cat([out1, out2], dim=1) + + +class GatedDeltaRuleTest(OpTestCase): + """Test case for mlx::gated_delta_rule (ScanNode and MetalKernelNode).""" + + name = "gated_delta_rule" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 1, + seq_len: int = 4, + num_heads: int = 2, + head_dim: int = 16, + value_dim: int = 16, + dtype: torch.dtype = torch.float32, + rtol: float = 1e-4, + atol: float = 1e-4, + use_custom_kernel: bool = False, + ): + self.batch_size = batch_size + self.seq_len = seq_len + self.num_heads = num_heads + self.head_dim = head_dim + self.value_dim = value_dim + self.dtype = dtype + self.rtol = rtol + self.atol = atol + self.use_custom_kernel = use_custom_kernel + + parts = [ + "gated_delta_rule", + f"b{batch_size}", + f"t{seq_len}", + f"h{num_heads}", + f"dk{head_dim}", + f"dv{value_dim}", + ] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + parts.append("kernel" if use_custom_kernel else "scan") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["GatedDeltaRuleTest"]: + configs = [] + # Small dims (Dk not multiple of 32) — scan-only + for use_kernel in [False]: + configs.append(cls(use_custom_kernel=use_kernel)) + configs.append(cls(seq_len=1, use_custom_kernel=use_kernel)) + configs.append( + cls( + seq_len=8, + num_heads=4, + head_dim=32, + value_dim=32, + use_custom_kernel=use_kernel, + ) + ) + configs.append( + cls( + dtype=torch.bfloat16, + rtol=0.05, + atol=0.15, + use_custom_kernel=use_kernel, + ) + ) + # Dims with Dk multiple of 32 — both scan and custom kernel + for use_kernel in [False, True]: + configs.append( + cls( + num_heads=2, head_dim=64, value_dim=64, use_custom_kernel=use_kernel + ) + ) + configs.append( + cls( + seq_len=1, + num_heads=2, + head_dim=64, + value_dim=64, + use_custom_kernel=use_kernel, + ) + ) + configs.append( + cls( + seq_len=8, + num_heads=4, + head_dim=64, + value_dim=64, + use_custom_kernel=use_kernel, + ) + ) + configs.append( + cls( + num_heads=2, + head_dim=64, + value_dim=64, + dtype=torch.bfloat16, + rtol=0.05, + atol=0.15, + use_custom_kernel=use_kernel, + ) + ) + configs.append( + cls( + num_heads=2, + head_dim=128, + value_dim=128, + use_custom_kernel=use_kernel, + ) + ) + return configs + + def create_model(self) -> nn.Module: + model = GatedDeltaRuleModel( + self.batch_size, + self.num_heads, + self.head_dim, + self.value_dim, + use_custom_kernel=self.use_custom_kernel, + ) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Scale q and k by 1/√Dk to keep dot products in a reasonable range. + # Without this, bf16 accumulation diverges at larger head dims (dk64+) + # because sum-of-64 products grows to ~O(√Dk) per step, compounding + # exponentially through the recurrence. + scale = self.head_dim**-0.5 + q = ( + torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + ) + * scale + ) + k = ( + torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + ) + * scale + ) + v = torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.value_dim, + dtype=self.dtype, + ) + g = torch.randn( + self.batch_size, self.seq_len, self.num_heads, dtype=self.dtype + ).sigmoid() + beta = torch.randn( + self.batch_size, self.seq_len, self.num_heads, dtype=self.dtype + ).sigmoid() + return (q, k, v, g, beta) + + +class GatedDeltaRuleDynamicSeqTest(OpTestCase): + """Test gated_delta_rule with dynamic seq_len. + + Exports with seq_len=export_seq_len using dynamic shapes, then runs + inference with seq_len=test_seq_len. Verifies the MetalKernelNode/ScanNode + handles runtime sequence lengths different from the trace-time value. + """ + + name = "gated_delta_rule_dynamic" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + export_seq_len: int = 4, + test_seq_len: int = 1, + num_heads: int = 2, + head_dim: int = 64, + value_dim: int = 64, + dtype: torch.dtype = torch.float32, + rtol: float = 1e-4, + atol: float = 1e-4, + use_custom_kernel: bool = False, + ): + self.batch_size = 1 + self.export_seq_len = export_seq_len + self.test_seq_len = test_seq_len + self.seq_len = export_seq_len # used by create_inputs (export tracing) + self.num_heads = num_heads + self.head_dim = head_dim + self.value_dim = value_dim + self.dtype = dtype + self.rtol = rtol + self.atol = atol + self.use_custom_kernel = use_custom_kernel + + parts = [ + "gated_delta_rule_dynamic", + f"export_t{export_seq_len}", + f"test_t{test_seq_len}", + f"h{num_heads}", + f"dk{head_dim}", + ] + parts.append("kernel" if use_custom_kernel else "scan") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["GatedDeltaRuleDynamicSeqTest"]: + configs = [] + for use_kernel in [False, True]: + # Export with T=4, test with T=1 (decode) + configs.append( + cls(export_seq_len=4, test_seq_len=1, use_custom_kernel=use_kernel) + ) + # Export with T=2, test with T=8 (longer prefill) + configs.append( + cls(export_seq_len=2, test_seq_len=8, use_custom_kernel=use_kernel) + ) + # Export with T=4, test with T=4 (same — control) + configs.append( + cls(export_seq_len=4, test_seq_len=4, use_custom_kernel=use_kernel) + ) + return configs + + def get_dynamic_shapes(self): + # All 5 inputs (q, k, v, g, beta) have seq_len at dim 1 + seq_dim = torch.export.Dim("seq_len", min=1, max=32) + return { + "q": {1: seq_dim}, + "k": {1: seq_dim}, + "v": {1: seq_dim}, + "g": {1: seq_dim}, + "beta": {1: seq_dim}, + } + + def create_model(self) -> nn.Module: + model = GatedDeltaRuleModel( + self.batch_size, + self.num_heads, + self.head_dim, + self.value_dim, + use_custom_kernel=self.use_custom_kernel, + ) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Inputs for export tracing (uses export_seq_len).""" + scale = self.head_dim**-0.5 + T = self.export_seq_len + q = ( + torch.randn( + self.batch_size, T, self.num_heads, self.head_dim, dtype=self.dtype + ) + * scale + ) + k = ( + torch.randn( + self.batch_size, T, self.num_heads, self.head_dim, dtype=self.dtype + ) + * scale + ) + v = torch.randn( + self.batch_size, T, self.num_heads, self.value_dim, dtype=self.dtype + ) + g = torch.randn(self.batch_size, T, self.num_heads, dtype=self.dtype).sigmoid() + beta = torch.randn( + self.batch_size, T, self.num_heads, dtype=self.dtype + ).sigmoid() + return (q, k, v, g, beta) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + """Inputs for runtime test (uses test_seq_len — may differ from export).""" + scale = self.head_dim**-0.5 + T = self.test_seq_len + q = ( + torch.randn( + self.batch_size, T, self.num_heads, self.head_dim, dtype=self.dtype + ) + * scale + ) + k = ( + torch.randn( + self.batch_size, T, self.num_heads, self.head_dim, dtype=self.dtype + ) + * scale + ) + v = torch.randn( + self.batch_size, T, self.num_heads, self.value_dim, dtype=self.dtype + ) + g = torch.randn(self.batch_size, T, self.num_heads, dtype=self.dtype).sigmoid() + beta = torch.randn( + self.batch_size, T, self.num_heads, dtype=self.dtype + ).sigmoid() + return (q, k, v, g, beta) + + +class GatedDeltaRuleGQATest(OpTestCase): + """Test gated_delta_rule with Hk < Hv (GQA-style head repeat). + + Matches real Qwen 3.5 config where num_k_heads=1, num_v_heads=2 (tiny) + or num_k_heads=8, num_v_heads=64 (full). Q and K get repeat_interleaved + before the custom op call. + """ + + name = "gated_delta_rule_gqa" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 1, + seq_len: int = 4, + num_k_heads: int = 1, + num_v_heads: int = 2, + head_dim: int = 64, + value_dim: int = 64, + dtype: torch.dtype = torch.float32, + rtol: float = 1e-4, + atol: float = 1e-4, + use_custom_kernel: bool = False, + ): + self.batch_size = batch_size + self.seq_len = seq_len + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_dim = head_dim + self.value_dim = value_dim + self.dtype = dtype + self.rtol = rtol + self.atol = atol + self.use_custom_kernel = use_custom_kernel + + parts = [ + "gated_delta_rule_gqa", + f"b{batch_size}", + f"t{seq_len}", + f"hk{num_k_heads}", + f"hv{num_v_heads}", + f"dk{head_dim}", + f"dv{value_dim}", + ] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + parts.append("kernel" if use_custom_kernel else "scan") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["GatedDeltaRuleGQATest"]: + configs = [] + for use_kernel in [False, True]: + # Tiny config: Hk=1, Hv=2 (head_repeat=2) + configs.append( + cls(num_k_heads=1, num_v_heads=2, use_custom_kernel=use_kernel) + ) + # Decode (T=1) with GQA + configs.append( + cls( + seq_len=1, + num_k_heads=1, + num_v_heads=2, + use_custom_kernel=use_kernel, + ) + ) + # Larger head ratio: Hk=2, Hv=8 (head_repeat=4) + configs.append( + cls(num_k_heads=2, num_v_heads=8, use_custom_kernel=use_kernel) + ) + # bf16 with GQA + configs.append( + cls( + num_k_heads=1, + num_v_heads=2, + dtype=torch.bfloat16, + rtol=0.05, + atol=0.15, + use_custom_kernel=use_kernel, + ) + ) + return configs + + def create_model(self) -> nn.Module: + model = GatedDeltaRuleGQAModel( + self.batch_size, + self.num_k_heads, + self.num_v_heads, + self.head_dim, + self.value_dim, + use_custom_kernel=self.use_custom_kernel, + ) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + scale = self.head_dim**-0.5 + # Q and K have num_k_heads (model does repeat_interleave internally) + q = ( + torch.randn( + self.batch_size, + self.seq_len, + self.num_k_heads, + self.head_dim, + dtype=self.dtype, + ) + * scale + ) + k = ( + torch.randn( + self.batch_size, + self.seq_len, + self.num_k_heads, + self.head_dim, + dtype=self.dtype, + ) + * scale + ) + # V, g, beta have num_v_heads + v = torch.randn( + self.batch_size, + self.seq_len, + self.num_v_heads, + self.value_dim, + dtype=self.dtype, + ) + g = torch.randn( + self.batch_size, self.seq_len, self.num_v_heads, dtype=self.dtype + ).sigmoid() + beta = torch.randn( + self.batch_size, self.seq_len, self.num_v_heads, dtype=self.dtype + ).sigmoid() + return (q, k, v, g, beta) + + +class GatedDeltaRuleFloatCastModel(nn.Module): + """Model that mirrors the export pattern: bf16 inputs, fp32 state buffer. + + The recurrent state must be fp32 for numerical stability. Rather than + casting with .float() (which creates a temporary that breaks mutation + tracking), the state buffer is registered as fp32 from the start. + Inputs are cast to fp32 before the op call. + """ + + def __init__(self, batch_size: int, num_heads: int, head_dim: int, value_dim: int): + super().__init__() + # fp32 state buffer — NOT bf16. Avoids .float() cast that breaks mutation. + self.register_buffer( + "state", + torch.zeros( + batch_size, num_heads, value_dim, head_dim, dtype=torch.float32 + ), + ) + + def forward( + self, + q: torch.Tensor, # bf16 + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + output = torch.ops.mlx.gated_delta_rule( + q.float(), + k.float(), + v.float(), + g.float(), + beta.float(), + self.state, # already fp32, no cast needed + use_custom_kernel=False, + ) + return output.to(q.dtype) + + +class GatedDeltaRuleFloatCastTest(OpTestCase): + """Test gated_delta_rule with bf16 state + fp32 cast (mirrors export model).""" + + name = "gated_delta_rule_float_cast" + rtol = 0.05 + atol = 0.15 + + @classmethod + def get_test_configs(cls) -> List["GatedDeltaRuleFloatCastTest"]: + return [cls()] + + def __init__(self): + self.batch_size = 1 + self.seq_len = 4 + self.num_heads = 2 + self.head_dim = 16 + self.value_dim = 16 + self.dtype = torch.bfloat16 + self.name = "gated_delta_rule_float_cast_b1_t4_h2" + + def create_model(self) -> nn.Module: + return GatedDeltaRuleFloatCastModel( + self.batch_size, + self.num_heads, + self.head_dim, + self.value_dim, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + q = torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + ) + k = torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + ) + v = torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.value_dim, + dtype=self.dtype, + ) + g = torch.randn( + self.batch_size, self.seq_len, self.num_heads, dtype=self.dtype + ).sigmoid() + beta = torch.randn( + self.batch_size, self.seq_len, self.num_heads, dtype=self.dtype + ).sigmoid() + return (q, k, v, g, beta) + + +class GatedDeltaRuleWithProjectionModel(nn.Module): + """Model that derives Q, K, V from a shared projection, mimicking the real export. + + Uses bf16 weights with fp32 state buffer. The .float() casts on Q, K, V, + g, beta create intermediate ASTYPE temp slots that the delete-as-you-go + allocator can free and reuse, potentially causing Q and K to share the + same slot in the ScanNode originals. + """ + + def __init__(self, batch_size: int, num_heads: int, head_dim: int, value_dim: int): + super().__init__() + qkv_dim = 2 * num_heads * head_dim + num_heads * value_dim + gate_dim = num_heads # g and beta each have num_heads dims + self.proj = nn.Linear(num_heads * head_dim, qkv_dim + 2 * gate_dim, bias=False) + self.num_heads = num_heads + self.head_dim = head_dim + self.value_dim = value_dim + # fp32 state buffer (same as export model) + self.register_buffer( + "state", + torch.zeros( + batch_size, num_heads, value_dim, head_dim, dtype=torch.float32 + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, _ = x.shape + nh = self.num_heads + hd = self.head_dim + vd = self.value_dim + + proj = self.proj(x) # bf16 + kd = nh * hd + + # Split into Q, K, V, g, beta — same pattern as the real GDN forward + q = proj[..., :kd].reshape(B, T, nh, hd) + k = proj[..., kd : 2 * kd].reshape(B, T, nh, hd) + v = proj[..., 2 * kd : 2 * kd + nh * vd].reshape(B, T, nh, vd) + g = proj[..., 2 * kd + nh * vd : 2 * kd + nh * vd + nh].sigmoid() + beta = proj[..., 2 * kd + nh * vd + nh :].sigmoid() + + # L2-normalize Q and K (creates intermediate temps) + q = torch.nn.functional.normalize(q, p=2, dim=-1) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + + # .float() casts create ASTYPE temp slots that can alias + output = torch.ops.mlx.gated_delta_rule( + q.float(), + k.float(), + v.float(), + g.float(), + beta.float(), + self.state, # already fp32 + use_custom_kernel=False, + ) + return output.to(x.dtype) + + +class GatedDeltaRuleWithProjectionTest(OpTestCase): + """Test gated_delta_rule with shared projection → normalize → float cast. + + This reproduces the slot aliasing bug where Q and K get the same temp slot + because the delete-as-you-go allocator frees Q's float cast slot before K's + float cast allocates one. + """ + + name = "gated_delta_rule_projection" + rtol = 1e-4 + atol = 1e-4 + + @classmethod + def get_test_configs(cls) -> List["GatedDeltaRuleWithProjectionTest"]: + return [cls()] + + def __init__(self): + self.batch_size = 1 + self.seq_len = 4 + self.num_heads = 2 + self.head_dim = 16 + self.value_dim = 16 + self.dtype = torch.bfloat16 # bf16 so .float() casts produce ASTYPE temp slots + self.rtol = 0.05 + self.atol = 0.15 + self.name = "gated_delta_rule_projection_b1_t4_h2" + + def create_model(self) -> nn.Module: + model = GatedDeltaRuleWithProjectionModel( + self.batch_size, + self.num_heads, + self.head_dim, + self.value_dim, + ) + return model.to(torch.bfloat16) # bf16 weights so .float() casts are real + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.seq_len, + self.num_heads * self.head_dim, + dtype=torch.bfloat16, + ) + return (x,) + + +class GatedDeltaRuleMultiStepTest(OpTestCase): + """Test that state carries forward between two gated_delta_rule calls. + + Uses GatedDeltaRuleMultiStepModel which calls the op twice in a single + forward. The second call reads the mutated state from the first. If state + doesn't persist, out2 would equal running with zero state — and the + concatenated output would mismatch the eager reference. + """ + + name = "gated_delta_rule_multistep" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 1, + seq_len: int = 4, + num_heads: int = 2, + head_dim: int = 16, + value_dim: int = 16, + dtype: torch.dtype = torch.float32, + rtol: float = 1e-4, + atol: float = 1e-4, + use_custom_kernel: bool = False, + ): + self.batch_size = batch_size + self.seq_len = seq_len + self.num_heads = num_heads + self.head_dim = head_dim + self.value_dim = value_dim + self.dtype = dtype + self.rtol = rtol + self.atol = atol + self.use_custom_kernel = use_custom_kernel + + parts = [ + "gated_delta_rule_multistep", + f"b{batch_size}", + f"t{seq_len}", + f"h{num_heads}", + f"dk{head_dim}", + f"dv{value_dim}", + ] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + parts.append("kernel" if use_custom_kernel else "scan") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["GatedDeltaRuleMultiStepTest"]: + return [ + cls(), + cls(seq_len=1), + # Dk=64 with custom kernel — exposes [int8] serialization bug + # if state_out dtype is corrupted to u8 + cls(num_heads=2, head_dim=64, value_dim=64, use_custom_kernel=True), + cls( + seq_len=1, + num_heads=2, + head_dim=64, + value_dim=64, + use_custom_kernel=True, + ), + # bf16 multistep with kernel — tests precision over two calls + cls( + num_heads=2, + head_dim=64, + value_dim=64, + dtype=torch.bfloat16, + rtol=0.05, + atol=0.15, + use_custom_kernel=True, + ), + ] + + def create_model(self) -> nn.Module: + model = GatedDeltaRuleMultiStepModel( + self.batch_size, + self.num_heads, + self.head_dim, + self.value_dim, + use_custom_kernel=self.use_custom_kernel, + ) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + scale = self.head_dim**-0.5 + q = ( + torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + ) + * scale + ) + k = ( + torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.head_dim, + dtype=self.dtype, + ) + * scale + ) + v = torch.randn( + self.batch_size, + self.seq_len, + self.num_heads, + self.value_dim, + dtype=self.dtype, + ) + g = torch.randn( + self.batch_size, self.seq_len, self.num_heads, dtype=self.dtype + ).sigmoid() + beta = torch.randn( + self.batch_size, self.seq_len, self.num_heads, dtype=self.dtype + ).sigmoid() + return (q, k, v, g, beta) + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser( + description="Test mlx::gated_delta_rule op (ScanNode)" + ) + parser.add_argument( + "action", + choices=["generate", "compare", "run", "list"], + help="Action: generate (export), compare (check outputs), run (full test), list (show configs)", + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument( + "--rebuild", action="store_true", help="Rebuild C++ runner first" + ) + parser.add_argument( + "--config", type=str, default=None, help="Run specific config by name" + ) + args = parser.parse_args() + + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = ( + GatedDeltaRuleTest.get_test_configs() + + GatedDeltaRuleDynamicSeqTest.get_test_configs() + + GatedDeltaRuleGQATest.get_test_configs() + + GatedDeltaRuleFloatCastTest.get_test_configs() + + GatedDeltaRuleWithProjectionTest.get_test_configs() + + GatedDeltaRuleMultiStepTest.get_test_configs() + ) + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names = [] + + for test in configs: + if args.action == "generate": + pte_path, input_path, expected_path = test.generate_test_files( + verbose=args.verbose + ) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 439d4569313..4dc891ee984 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -119,6 +119,7 @@ RopeNode, RoundNode, RsqrtNode, + ScatterAddNode, SigmoidNode, SignNode, SiluNode, @@ -1486,6 +1487,105 @@ def _split_with_sizes_handler(P: MLXProgramBuilder, n: Node) -> Slot: return output_slots +@REGISTRY.register(target=[torch.ops.mlx.gather_mm.default]) +def _gather_mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle mlx::gather_mm — fused gather + matmul for MoE experts.""" + from executorch.backends.mlx.serialization.mlx_graph_schema import GatherMmNode + + args = P.args(n) + kwargs = P.kwargs(n) + + a = args[0] + b = args[1] + rhs_indices = args[2] if len(args) > 2 else kwargs.get("rhs_indices") + lhs_indices = args[3] if len(args) > 3 else kwargs.get("lhs_indices") + sorted_indices = args[4] if len(args) > 4 else kwargs.get("sorted_indices", False) + + out = P.make_or_get_slot(n) + P.emit( + GatherMmNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + lhs_indices=P.slot_to_tid(lhs_indices) if lhs_indices is not None else None, + rhs_indices=P.slot_to_tid(rhs_indices) if rhs_indices is not None else None, + sorted_indices=sorted_indices, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.mlx.gather_qmm.default]) +def _gather_qmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle mlx::gather_qmm — fused gather + dequant + matmul for quantized MoE experts. + + Converts TorchAO quantization format to MLX format (unsigned + biases) + and emits a GatherQmmNode. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import GatherQmmNode + + args = P.args(n) + kwargs = P.kwargs(n) + + x = args[0] + w_node = n.args[1] # Need the original node for constant lookup + scales_node = n.args[2] + biases_node = n.args[3] if len(n.args) > 3 else n.kwargs.get("biases") + rhs_indices = args[4] if len(args) > 4 else kwargs.get("rhs_indices") + lhs_indices = args[5] if len(args) > 5 else kwargs.get("lhs_indices") + transpose = args[6] if len(args) > 6 else kwargs.get("transpose", True) + group_size = args[7] if len(args) > 7 else kwargs.get("group_size", 32) + bits = args[8] if len(args) > 8 else kwargs.get("bits", 4) + mode = args[9] if len(args) > 9 else kwargs.get("mode", "affine") + sorted_indices = args[10] if len(args) > 10 else kwargs.get("sorted_indices", False) + + # Convert quantized weights to MLX format + w_target, w_data = P.get_placeholder_target_and_tensor(w_node) + _, scale_data = P.get_placeholder_target_and_tensor(scales_node) + zp_target = None + zp_data = None + if biases_node is not None: + zp_target, zp_data = P.get_placeholder_target_and_tensor(biases_node) + + # Reshape 3D [E, out, in] to 2D for to_mlx_qparams, then reshape back + orig_shape = w_data.shape + E, out_dim = orig_shape[0], orig_shape[1] + w_2d = w_data.reshape(E * out_dim, -1) + s_2d = scale_data.reshape(E * out_dim, -1) + zp_2d = ( + zp_data.reshape(E * out_dim, -1) + if zp_data is not None + else torch.zeros_like(s_2d, dtype=torch.int8) + ) + + Q, B = to_mlx_qparams(w_2d, s_2d, zp_2d, bits) + Q = Q.reshape(E, out_dim, -1) + B = B.reshape(E, out_dim, -1) + + packed_slot = P.make_or_get_constant(f"{w_target}_to_packed", Q) + scale_slot = P.slot_map([scales_node])[0] + biases_slot = P.make_or_get_constant(f"{zp_target or w_target}_to_biases", B) + + out = P.make_or_get_slot(n) + P.emit( + GatherQmmNode( + x=P.slot_to_tid(x), + w=P.slot_to_tid(packed_slot), + scales=P.slot_to_tid(scale_slot), + out=P.slot_to_tid(out), + biases=P.slot_to_tid(biases_slot), + lhs_indices=P.slot_to_tid(lhs_indices) if lhs_indices is not None else None, + rhs_indices=P.slot_to_tid(rhs_indices) if rhs_indices is not None else None, + transpose=transpose, + group_size=group_size, + bits=bits, + mode=mode, + sorted_indices=sorted_indices, + ) + ) + return out + + @REGISTRY.register( target=[torch.ops.aten.split.Tensor, torch.ops.aten.split_copy.Tensor] ) @@ -1725,6 +1825,31 @@ def _slice_scatter_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.scatter_add.default]) +def _scatter_add_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.scatter_add: accumulate src into self at index positions along dim. + + scatter_add(self, dim, index, src) -> Tensor + + Maps to mlx::scatter_add(a, indices, updates, axis). + """ + args = P.args(n) + require_args(args, 4, 4, "aten.scatter_add") + require_kwargs(P.kwargs(n), set(), "aten.scatter_add") + x, dim, indices, src = args + out = P.make_or_get_slot(n) + P.emit( + ScatterAddNode( + x=P.slot_to_tid(x), + indices=P.slot_to_tid(indices), + updates=P.slot_to_tid(src), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + @REGISTRY.register(target=[torch.ops.aten.select.int, torch.ops.aten.select_copy.int]) def _select_handler(P: MLXProgramBuilder, n: Node) -> Slot: """ diff --git a/backends/mlx/pte_inspector.py b/backends/mlx/pte_inspector.py index d9e533b0b1e..df11a21f370 100644 --- a/backends/mlx/pte_inspector.py +++ b/backends/mlx/pte_inspector.py @@ -403,6 +403,11 @@ def _extract_field(node, accessor_name: str, kind: str) -> Any: # noqa: C901 items.append(f"tid {s.Idx()}" if s else None) return items + if kind == "string_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + return [getter(i).decode("utf-8") if getter(i) else None for i in range(length)] + if kind == "int_or_vid_list": length = getattr(node, f"{accessor_name}Length")() getter = getattr(node, accessor_name) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index a351fcfb619..9fa08ab722d 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -787,6 +787,16 @@ exec_gather(const GatherNode& n, ExecutionState& st, StreamOrDevice s) { st.set_tensor(n.out, gather(x, indices, n.axes, slice_sizes, s)); } +inline void exec_scatter_add( + const ScatterAddNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& indices = st.const_tensor_ref(n.indices); + const auto& updates = st.const_tensor_ref(n.updates); + st.set_tensor(n.out, scatter_add_axis(x, indices, updates, n.axis, s)); +} + inline void exec_slice(const SliceNode& n, ExecutionState& st, StreamOrDevice s) { const array& x = st.const_tensor_ref(n.x); @@ -843,6 +853,221 @@ inline void exec_quantized_matmul( st.set_tensor(n.out, std::move(Y)); } +inline void +exec_gather_mm(const GatherMmNode& n, ExecutionState& st, StreamOrDevice s) { + array A = st.const_tensor_ref(n.a); + array B = st.const_tensor_ref(n.b); + + std::optional lhs_idx = std::nullopt; + if (n.lhs_indices.has_value()) { + lhs_idx = st.const_tensor_ref(*n.lhs_indices); + } + std::optional rhs_idx = std::nullopt; + if (n.rhs_indices.has_value()) { + rhs_idx = st.const_tensor_ref(*n.rhs_indices); + } + + array Y = gather_mm(A, B, lhs_idx, rhs_idx, n.sorted_indices, s); + st.set_tensor(n.out, std::move(Y)); +} + +inline void +exec_gather_qmm(const GatherQmmNode& n, ExecutionState& st, StreamOrDevice s) { + array X = st.const_tensor_ref(n.x); + array Wq = st.const_tensor_ref(n.w); + array Sc = st.const_tensor_ref(n.scales); + + std::optional Qb = std::nullopt; + if (n.biases.has_value()) { + Qb = st.const_tensor_ref(*n.biases); + } + std::optional lhs_idx = std::nullopt; + if (n.lhs_indices.has_value()) { + lhs_idx = st.const_tensor_ref(*n.lhs_indices); + } + std::optional rhs_idx = std::nullopt; + if (n.rhs_indices.has_value()) { + rhs_idx = st.const_tensor_ref(*n.rhs_indices); + } + + array Y = gather_qmm( + X, + Wq, + Sc, + Qb, + lhs_idx, + rhs_idx, + n.transpose, + n.group_size, + n.bits, + n.mode, + n.sorted_indices, + s); + st.set_tensor(n.out, std::move(Y)); +} + +inline void exec_metal_kernel( + const MetalKernelNode& n, + ExecutionState& st, + StreamOrDevice s) { +#ifndef ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION + throw std::runtime_error( + "MetalKernelNode: custom kernel execution is disabled. " + "Rebuild with -DET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION to enable. " + "WARNING: enabling this allows .pte files to execute arbitrary GPU code."); +#else + + // Validate parallel array lengths + if (n.input_names.size() != n.inputs.size()) { + throw std::runtime_error( + "MetalKernelNode: input_names length (" + + std::to_string(n.input_names.size()) + ") must match inputs length (" + + std::to_string(n.inputs.size()) + ")"); + } + if (n.output_names.size() != n.outputs.size()) { + throw std::runtime_error( + "MetalKernelNode: output_names length (" + + std::to_string(n.output_names.size()) + + ") must match outputs length (" + std::to_string(n.outputs.size()) + + ")"); + } + if (n.outputs.empty()) { + throw std::runtime_error("MetalKernelNode: outputs must not be empty"); + } + if (n.inputs.empty()) { + throw std::runtime_error("MetalKernelNode: inputs must not be empty"); + } + if (n.output_shape_lengths.size() != n.outputs.size()) { + throw std::runtime_error( + "MetalKernelNode: output_shape_lengths length (" + + std::to_string(n.output_shape_lengths.size()) + + ") must match outputs length (" + std::to_string(n.outputs.size()) + + ")"); + } + if (n.output_dtypes.size() != n.outputs.size()) { + throw std::runtime_error( + "MetalKernelNode: output_dtypes length (" + + std::to_string(n.output_dtypes.size()) + + ") must match outputs length (" + std::to_string(n.outputs.size()) + + ")"); + } + + // Validate output_shapes_flat length matches sum of output_shape_lengths + size_t expected_flat_len = 0; + for (int32_t len : n.output_shape_lengths) { + if (len < 0) { + throw std::runtime_error( + "MetalKernelNode: output_shape_lengths contains negative value " + + std::to_string(len)); + } + expected_flat_len += static_cast(len); + } + if (n.output_shapes_flat.size() != expected_flat_len) { + throw std::runtime_error( + "MetalKernelNode: output_shapes_flat length (" + + std::to_string(n.output_shapes_flat.size()) + + ") must equal sum of output_shape_lengths (" + + std::to_string(expected_flat_len) + ")"); + } + + // Validate template arg parallel arrays + if (n.template_arg_kinds.size() != n.template_arg_names.size() || + n.template_arg_values.size() != n.template_arg_names.size()) { + throw std::runtime_error( + "MetalKernelNode: template_arg_names/kinds/values must have same length (" + + std::to_string(n.template_arg_names.size()) + "/" + + std::to_string(n.template_arg_kinds.size()) + "/" + + std::to_string(n.template_arg_values.size()) + ")"); + } + + // Build the kernel function (cached internally by MLX based on name+source) + auto kernel_fn = ::mlx::core::fast::metal_kernel( + n.name, + n.input_names, + n.output_names, + n.source, + n.header, + n.ensure_row_contiguous, + n.atomic_outputs); + + // Resolve inputs + std::vector inputs; + inputs.reserve(n.inputs.size()); + for (const auto& tid : n.inputs) { + inputs.push_back(st.const_tensor_ref(tid)); + } + + // Resolve grid and threadgroup (IntOrVid → int) + auto grid_ints = resolve_ints(n.grid, st); + auto tg_ints = resolve_ints(n.threadgroup, st); + if (grid_ints.size() != 3) { + throw std::runtime_error( + "MetalKernelNode: grid must have exactly 3 elements, got " + + std::to_string(grid_ints.size())); + } + if (tg_ints.size() != 3) { + throw std::runtime_error( + "MetalKernelNode: threadgroup must have exactly 3 elements, got " + + std::to_string(tg_ints.size())); + } + std::tuple grid{grid_ints[0], grid_ints[1], grid_ints[2]}; + std::tuple threadgroup{tg_ints[0], tg_ints[1], tg_ints[2]}; + + // Resolve output shapes from flattened representation (lengths already + // validated) + std::vector<::mlx::core::Shape> output_shapes; + size_t flat_offset = 0; + for (int32_t len : n.output_shape_lengths) { + ::mlx::core::Shape shape; + for (int32_t j = 0; j < len; ++j) { + shape.push_back(resolve_int(n.output_shapes_flat[flat_offset++], st)); + } + output_shapes.push_back(std::move(shape)); + } + + // Resolve output dtypes + std::vector<::mlx::core::Dtype> output_dtypes; + for (int32_t d : n.output_dtypes) { + output_dtypes.push_back(resolve_dtype(static_cast(d))); + } + + // Resolve template args from parallel arrays (lengths already validated) + std::vector> + template_args; + for (size_t i = 0; i < n.template_arg_names.size(); ++i) { + int32_t kind = n.template_arg_kinds[i]; + int32_t value = n.template_arg_values[i]; + ::mlx::core::fast::TemplateArg arg; + if (kind == 0) { + arg = value; // int + } else if (kind == 1) { + arg = static_cast(value); // bool + } else { + arg = resolve_dtype(static_cast(value)); // Dtype + } + template_args.push_back({n.template_arg_names[i], std::move(arg)}); + } + + // Invoke the kernel + auto results = kernel_fn( + inputs, + output_shapes, + output_dtypes, + grid, + threadgroup, + template_args, + n.init_value, + /*verbose=*/false, + s); + + // Store outputs + for (size_t i = 0; i < results.size() && i < n.outputs.size(); ++i) { + st.set_tensor(n.outputs[i], std::move(results[i])); + } + +#endif // ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION +} + inline void exec_concatenate( const ConcatenateNode& n, ExecutionState& st, @@ -1567,13 +1792,54 @@ class Interpreter { size_t idx = 0; for (const auto& instr : chain) { st.begin_op(idx, op_name(instr.op)); - dispatch(instr, st, stream); + if (instr.op == OpCode::SCAN) { + exec_scan(prog, std::get(instr.node), st, stream); + } else { + dispatch(instr, st, stream); + } st.end_op(); ++idx; } } private: + void exec_scan( + const MLXProgram& prog, + const ScanNode& n, + ExecutionState& st, + StreamOrDevice s) const { + int axis = n.scan_axis; + int T_int = st.const_tensor_ref(n.originals[0]).shape(axis); + size_t T = static_cast(T_int); + size_t num_outputs = n.outputs.size(); + + std::vector> collected(num_outputs); + for (size_t i = 0; i < num_outputs; ++i) { + collected[i].reserve(T); + } + + for (size_t t = 0; t < T; ++t) { + for (size_t i = 0; i < n.originals.size(); ++i) { + st.set_tensor( + n.sliced[i], + ::mlx::core::take( + st.const_tensor_ref(n.originals[i]), + static_cast(t), + axis, + s)); + } + + run_chain(prog, static_cast(n.body_chain_idx), st, s); + + for (size_t i = 0; i < num_outputs; ++i) { + collected[i].push_back(st.const_tensor_ref(n.outputs[i])); + } + } + + for (size_t i = 0; i < num_outputs; ++i) { + st.set_tensor(n.outputs[i], ::mlx::core::stack(collected[i], axis, s)); + } + } void dispatch(const Instruction& instr, ExecutionState& st, StreamOrDevice s) const { switch (instr.op) { @@ -1949,6 +2215,18 @@ class Interpreter { ops::exec_quantized_matmul( std::get(instr.node), st, s); break; + case OpCode::SCATTER_ADD: + ops::exec_scatter_add(std::get(instr.node), st, s); + break; + case OpCode::GATHER_MM: + ops::exec_gather_mm(std::get(instr.node), st, s); + break; + case OpCode::GATHER_QMM: + ops::exec_gather_qmm(std::get(instr.node), st, s); + break; + case OpCode::METAL_KERNEL: + ops::exec_metal_kernel(std::get(instr.node), st, s); + break; default: throw std::runtime_error( "Unknown opcode: " + std::to_string(static_cast(instr.op))); diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index 6f6ee11fe41..db3d4cd2d49 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -604,13 +604,29 @@ def generate_python_serializers(schema: FBSSchema) -> str: "", "", "def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int:", - ' """Build a vector of int32."""', + ' """Pre-build a vector of int32 values (must be called before table Start)."""', " builder.StartVector(4, len(vec), 4)", " for v in reversed(vec):", " builder.PrependInt32(v)", " return builder.EndVector()", "", "", + "def _build_int8_vector(builder: flatbuffers.Builder, vec: List[int]) -> int:", + ' """Pre-build a vector of int8 values (must be called before table Start)."""', + " builder.StartVector(1, len(vec), 1)", + " for v in reversed(vec):", + " builder.PrependInt8(v)", + " return builder.EndVector()", + "", + "", + "def _build_uint8_vector(builder: flatbuffers.Builder, vec: List[int]) -> int:", + ' """Pre-build a vector of uint8 values (must be called before table Start)."""', + " builder.StartVector(1, len(vec), 1)", + " for v in reversed(vec):", + " builder.PrependUint8(v)", + " return builder.EndVector()", + "", + "", "class GeneratedOpBuilders:", ' """Mixin class with auto-generated op builder methods."""', "", @@ -694,6 +710,16 @@ def generate_python_serializers(schema: FBSSchema) -> str: " builder.PrependUint32(tid.idx)", " return builder.EndVector()", "", + " def _build_string_vector(", + " self, builder: flatbuffers.Builder, vec: List[str]", + " ) -> int:", + ' """Pre-build a vector of strings (offsets must be created before table Start)."""', + " offsets = [builder.CreateString(s) for s in vec]", + " builder.StartVector(4, len(offsets), 4)", + " for off in reversed(offsets):", + " builder.PrependUOffsetTRelative(off)", + " return builder.EndVector()", + "", ] ) @@ -766,8 +792,11 @@ def _generate_op_builder_method(table: FBSTable) -> str: _PY_PREBUILD_VECTOR = { "list_int": "_build_int_vector(builder, op.{name})", + "list_int8": "_build_int8_vector(builder, op.{name})", + "list_uint8": "_build_uint8_vector(builder, op.{name})", "list_int_or_vid": "self._build_int_or_vid_vector(builder, op.{name})", "list_tid": "self._build_tid_vector(builder, op.{name})", + "list_str": "self._build_string_vector(builder, op.{name})", } _PY_PREBUILD_OFFSET = { @@ -818,7 +847,14 @@ def _emit_py_add( if kind in ("str", "int_or_vid", "float_or_vid", "vid_or_tid", "int_or_vid_or_tid"): return [f" {add}(builder, {n}_off)"] # Pre-built vectors (required vs optional) - if kind in ("list_int", "list_int_or_vid", "list_tid"): + if kind in ( + "list_int", + "list_int8", + "list_uint8", + "list_int_or_vid", + "list_tid", + "list_str", + ): if fld.required: return [f" {add}(builder, {n}_vec)"] return [ @@ -859,11 +895,17 @@ def _get_field_kind(fld: FBSField, table: FBSTable) -> str: # noqa: C901 if t.startswith("[") and t.endswith("]"): inner = t[1:-1] if inner in FBS_INTEGER_TYPES: + if inner == "int8": + return "list_int8" + if inner == "uint8": + return "list_uint8" return "list_int" if inner == "IntOrVid": return "list_int_or_vid" if inner == "Tid": return "list_tid" + if inner == "string": + return "list_str" raise ValueError( f"Unrecognized array element type '{inner}' for field '{fld.name}' in table '{table.name}'. " f"Add a handler in _get_field_kind()." @@ -1140,7 +1182,7 @@ def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": " }", ] # Integer/bool vector via to_vector - if kind == "list_int": + if kind in ("list_int", "list_int8", "list_uint8"): return [f" node.{name} = to_vector(fb->{fb_name}());"] # Int-or-vid vector (indexed access) if kind == "list_int_or_vid": @@ -1160,6 +1202,15 @@ def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": " }", " }", ] + # String vector + if kind == "list_str": + return [ + f" if (fb->{fb_name}()) {{", + f" for (const auto* s : *fb->{fb_name}()) {{", + f" node.{name}.push_back(s ? s->str() : std::string{{}});", + " }", + " }", + ] return None @@ -1258,6 +1309,9 @@ def _fixup_flatc_imports() -> None: "list_int": "int_list", "list_int_or_vid": "int_or_vid_list", "list_tid": "tid_list", + "list_str": "string_list", + "list_int8": "int_list", + "list_uint8": "int_list", "int": "scalar", "optional_int": "scalar", "float": "scalar", @@ -1290,7 +1344,7 @@ def generate_inspector(schema: "Schema") -> str: # noqa: F821 "# Field kinds and their extractors", "# Each field is a tuple of (display_name, accessor_name, kind)", "# where kind is one of: 'tid', 'vid', 'int_or_vid', 'float_or_vid',", - "# 'int_list', 'int_or_vid_list', 'tid_list', 'scalar', 'string'", + "# 'int_list', 'int_or_vid_list', 'tid_list', 'string_list', 'scalar', 'string'", "", "FieldSpec = Tuple[str, str, str] # (display_name, accessor_name, kind)", "", diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index b101b5756f7..6e8d6f47db8 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -449,6 +449,16 @@ table QuantizedMatmulNode { transpose: bool = true; } +// Scatter-add: accumulate updates into input at index positions along an axis +// Maps to mlx::scatter_add(a, indices, updates, axis) +table ScatterAddNode { + x: Tid (required); // Input tensor to scatter into + indices: Tid (required); // Index tensor + updates: Tid (required); // Values to accumulate + out: Tid (required); + axis: int32; // Dimension to scatter along +} + table ConcatenateNode { tensors: [Tid] (required); // List of tensors to concatenate out: Tid (required); @@ -899,6 +909,78 @@ table MedianNode { keepdims: bool = false; } +table GatherMmNode { + a: Tid (required); // Input activations + b: Tid (required); // Weight matrix [E, out, in] or similar + out: Tid (required); + lhs_indices: Tid; // optional - LHS gather indices + rhs_indices: Tid; // optional - RHS gather indices (expert selection) + sorted_indices: bool = false; +} + +table GatherQmmNode { + x: Tid (required); // Input activations + w: Tid (required); // Quantized weight matrix [E, out, in_packed] + scales: Tid (required); // Quantization scales [E, out, in//gs] + out: Tid (required); + mode: string (required); // "affine", "fp", etc. + biases: Tid; // optional - for affine mode + lhs_indices: Tid; // optional - LHS gather indices + rhs_indices: Tid; // optional - RHS gather indices (expert selection) + transpose: bool = true; + group_size: int32; + bits: int32; + sorted_indices: bool = false; +} + +table ScanNode { + originals: [Tid] (required); // [B, T, ...] — read-only, not modified + sliced: [Tid] (required); // runtime writes per-step slices here (same length as originals) + outputs: [Tid] (required); // body writes [B, ...] per step, runtime stacks to [B, T, ...] + carry: [Tid] (required); // body reads/writes in place, persists across steps + + body_chain_idx: int32; // index into MLXGraph.instruction_chains + scan_axis: int32 = 1; // dimension to iterate over +} + +// Custom Metal kernel execution via mlx::core::fast::metal_kernel(). +// Two-phase API: +// 1. Factory: metal_kernel(name, input_names, output_names, source, header, +// ensure_row_contiguous, atomic_outputs) -> kernel_fn +// 2. Invocation: kernel_fn(inputs, output_shapes, output_dtypes, grid, +// threadgroup, template_args, init_value) +// +// Output shapes are flattened: output_shapes_flat contains all shape dims +// concatenated, output_shape_lengths gives the rank of each output. +// E.g. shapes [[B,T,H,D], [B,H,D,K]] -> flat=[B,T,H,D,B,H,D,K], lengths=[4,4] +// +// Template args are parallel arrays: template_arg_names[i] paired with +// template_arg_kinds[i] (0=int, 1=bool, 2=dtype/ScalarType) and +// template_arg_values[i] (int value, bool as 0/1, or ScalarType enum). +table MetalKernelNode { + // Required fields (no defaults) — must come first for Python dataclass ordering + name: string (required); + source: string (required); + inputs: [Tid] (required); + outputs: [Tid] (required); + grid: [IntOrVid] (required); + threadgroup: [IntOrVid] (required); + + // Optional / defaulted fields + header: string; + input_names: [string]; + output_names: [string]; + ensure_row_contiguous: bool = true; + atomic_outputs: bool = false; + output_shapes_flat: [IntOrVid]; + output_shape_lengths: [int32]; + output_dtypes: [int8]; + template_arg_names: [string]; + template_arg_kinds: [int8]; // 0=int, 1=bool, 2=dtype (ScalarType) + template_arg_values: [int32]; // int value, bool as 0/1, or ScalarType enum + init_value: float = null; +} + // ============================================================================= // Union of all op types // ============================================================================= @@ -1026,7 +1108,12 @@ union OpNode { ArgsortNode, PartitionNode, ArgPartitionNode, - QuantizedMatmulNode + QuantizedMatmulNode, + ScatterAddNode, + GatherMmNode, + GatherQmmNode, + ScanNode, + MetalKernelNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 35514f4df04..e5ece4931b9 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -5452,6 +5452,515 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x,) +@torch.library.custom_op("mlx_test::vadd", mutates_args=()) +def vadd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Element-wise vector add via custom Metal kernel (reference impl).""" + return a + b + + +@torch.library.register_fake("mlx_test::vadd") +def vadd_fake(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +class MetalKernelVaddModel(nn.Module): + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx_test.vadd(a, b) + + +def _register_vadd_handler(): + """Register an MLX op handler that emits MetalKernelNode for vadd.""" + from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type + from executorch.backends.mlx.builder.op_registry import REGISTRY + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.backends.mlx.builder.slot_manager import Slot + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, + ) + + vadd_source = """ + uint idx = thread_position_in_grid.x; + if (idx < a_shape[0]) { + out[idx] = a[idx] + b[idx]; + } + """ + + @REGISTRY.register(target=[torch.ops.mlx_test.vadd.default]) + def _vadd_handler(P: MLXProgramBuilder, n: "torch.fx.node.Node") -> Slot: + args = P.args(n) + a_slot, b_slot = args[0], args[1] + out = P.make_or_get_slot(n) + + a_meta = n.args[0].meta.get("val") + numel = a_meta.numel() + dtype_int = torch_dtype_to_scalar_type(a_meta.dtype) + + P.emit( + MetalKernelNode( + name="vadd", + source=vadd_source, + inputs=[P.slot_to_tid(a_slot), P.slot_to_tid(b_slot)], + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(numel), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(256), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["a", "b"], + output_names=["out"], + output_shapes_flat=[IntOrVid.from_literal(d) for d in a_meta.shape], + output_shape_lengths=[len(a_meta.shape)], + output_dtypes=[dtype_int], + ) + ) + return out + + +_register_vadd_handler() + + +@register_test +class MetalKernelTest(OpTestCase): + name = "metal_kernel" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, size=1024): + self.size = size + self.name = f"metal_kernel_{size}" + + @classmethod + def get_test_configs(cls) -> List["MetalKernelTest"]: + return [ + cls(size=1024), + cls(size=4096), + ] + + def create_model(self) -> nn.Module: + return MetalKernelVaddModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.size), torch.randn(self.size)) + + +class SwitchLinearModel(nn.Module): + """Model using SwitchLinear for expert selection + matmul.""" + + def __init__(self, num_experts: int, in_features: int, out_features: int): + super().__init__() + from executorch.backends.mlx.llm.switch import SwitchLinear + + self.switch = SwitchLinear(in_features, out_features, num_experts) + self.switch.pack() + + def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return self.switch(x, indices) + + +@register_test +class SwitchLinearTest(OpTestCase): + """Test case for SwitchLinear (unquantized).""" + + name = "switch_linear" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + num_experts: int = 4, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + dtype: torch.dtype = torch.float32, + ): + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.dtype = dtype + + parts = [ + "switch_linear", + f"e{num_experts}", + f"i{in_features}", + f"o{out_features}", + ] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + if batch_size != 2: + parts.append(f"b{batch_size}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["SwitchLinearTest"]: + return [ + cls(), + cls(num_experts=8, in_features=128, out_features=256), + cls(dtype=torch.bfloat16), + cls(batch_size=1), + ] + + def create_model(self) -> nn.Module: + model = SwitchLinearModel(self.num_experts, self.in_features, self.out_features) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features, dtype=self.dtype) + indices = torch.randint(0, self.num_experts, (self.batch_size,)) + return (x, indices) + + +class QuantizedSwitchLinearModel(nn.Module): + """Model using quantized SwitchLinear for expert selection + matmul.""" + + def __init__( + self, + num_experts: int, + in_features: int, + out_features: int, + group_size: int = 32, + ): + super().__init__() + from executorch.backends.mlx.llm.quantization import quantize_model_ + from executorch.backends.mlx.llm.switch import SwitchLinear + + self.switch = SwitchLinear(in_features, out_features, num_experts) + quantize_model_( + nn.ModuleDict({"switch": self.switch}), + qlinear_config="4w", + qlinear_group_size=group_size, + ) + self.switch.pack() + + def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return self.switch(x, indices) + + +@register_test +class QuantizedSwitchLinearTest(OpTestCase): + """Test case for SwitchLinear (quantized).""" + + name = "quantized_switch_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + num_experts: int = 4, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + group_size: int = 32, + dtype: torch.dtype = torch.float32, + ): + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.group_size = group_size + self.dtype = dtype + + parts = [ + "quantized_switch_linear", + f"e{num_experts}", + f"i{in_features}", + f"o{out_features}", + ] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + if batch_size != 2: + parts.append(f"b{batch_size}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["QuantizedSwitchLinearTest"]: + return [ + cls(), + cls(num_experts=8, in_features=128, out_features=256), + cls(dtype=torch.bfloat16), + cls(batch_size=1), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = QuantizedSwitchLinearModel( + self.num_experts, + self.in_features, + self.out_features, + self.group_size, + ) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features, dtype=self.dtype) + indices = torch.randint(0, self.num_experts, (self.batch_size,)) + return (x, indices) + + +class GatherMmModel(nn.Module): + """Model using mlx::gather_mm for expert selection + matmul.""" + + def __init__(self, num_experts: int, in_features: int, out_features: int): + super().__init__() + self.register_buffer( + "weight", + torch.randn(num_experts, out_features, in_features), + ) + + def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + import executorch.backends.mlx.custom_ops as _ # noqa + + # Unsqueeze x to [N, 1, K] for gather_mm (MLX expects [..., M, K]) + # Transpose weight from [E, out, in] to [E, in, out] + # gather_mm returns [N, 1, out], squeeze dim -2 + return torch.ops.mlx.gather_mm( + x.unsqueeze(-2), self.weight.transpose(-1, -2), rhs_indices=indices + ).squeeze(-2) + + +@register_test +class GatherMmTest(OpTestCase): + """Test case for mlx::gather_mm.""" + + name = "gather_mm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + num_experts: int = 4, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + dtype: torch.dtype = torch.float32, + ): + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.dtype = dtype + + parts = ["gather_mm", f"e{num_experts}", f"i{in_features}", f"o{out_features}"] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["GatherMmTest"]: + return [ + cls(), + cls(num_experts=8, in_features=128, out_features=256), + cls(dtype=torch.bfloat16), + cls(batch_size=1), + ] + + def create_model(self) -> nn.Module: + model = GatherMmModel(self.num_experts, self.in_features, self.out_features) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features, dtype=self.dtype) + indices = torch.randint(0, self.num_experts, (self.batch_size,)) + return (x, indices) + + +class GatherQmmModel(nn.Module): + """Model using mlx::gather_qmm for quantized expert selection + matmul. + + Uses pack_experts() from UnfusedMoEExperts to create properly quantized + stacked expert weights. + """ + + def __init__( + self, + num_experts: int, + in_features: int, + out_features: int, + group_size: int = 32, + ): + super().__init__() + self.out_features = out_features + self.group_size = group_size + + # Create per-expert nn.Linear, quantize, extract inner tensors + from executorch.backends.mlx.llm.quantization import quantize_model_ + + experts = nn.ModuleList( + [ + nn.Linear(in_features, out_features, bias=False) + for _ in range(num_experts) + ] + ) + # Quantize + wrapper = nn.ModuleDict({"experts": experts}) + quantize_model_(wrapper, qlinear_config="4w", qlinear_group_size=group_size) + + # Extract and stack quantized inner tensors + self.register_buffer( + "qdata", + torch.stack([e.weight.qdata for e in experts]), + ) + self.register_buffer( + "scale", + torch.stack([e.weight.scale for e in experts]), + ) + self.register_buffer( + "zero_point", + torch.stack([e.weight.zero_point for e in experts]), + ) + + def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + import executorch.backends.mlx.custom_ops as _ # noqa + + # Unsqueeze x to [N, 1, K] for gather_qmm (MLX expects [..., M, K]) + # gather_qmm returns [N, 1, out], squeeze dim -2 + return torch.ops.mlx.gather_qmm( + x.unsqueeze(-2), + self.qdata, + self.scale, + biases=self.zero_point, + rhs_indices=indices, + group_size=self.group_size, + ).squeeze(-2) + + +@register_test +class GatherQmmTest(OpTestCase): + """Test case for mlx::gather_qmm.""" + + name = "gather_qmm" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + num_experts: int = 4, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + group_size: int = 32, + dtype: torch.dtype = torch.float32, + ): + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.group_size = group_size + self.dtype = dtype + + parts = [ + "gather_qmm", + f"e{num_experts}", + f"i{in_features}", + f"o{out_features}", + f"g{group_size}", + ] + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["GatherQmmTest"]: + return [ + cls(), + cls(num_experts=8, in_features=128, out_features=256), + cls(dtype=torch.bfloat16), + cls(batch_size=1), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = GatherQmmModel( + self.num_experts, + self.in_features, + self.out_features, + self.group_size, + ) + return model.to(self.dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features, dtype=self.dtype) + indices = torch.randint(0, self.num_experts, (self.batch_size,)) + return (x, indices) + + +class ScatterAddModel(nn.Module): + """Model that performs scatter_add along a dimension.""" + + def __init__(self, dim: int = 0): + super().__init__() + self.dim = dim + + def forward( + self, x: torch.Tensor, index: torch.Tensor, src: torch.Tensor + ) -> torch.Tensor: + return x.scatter_add(self.dim, index, src) + + +@register_test +class ScatterAddTest(OpTestCase): + """Test case for aten.scatter_add op. + + scatter_add(self, dim, index, src) accumulates src values into self + at positions given by index along the specified dimension. + """ + + name = "scatter_add" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 8), + dim: int = 1, + num_indices: int = 3, + ): + self.shape = shape + self.dim = dim + self.num_indices = num_indices + shape_str = "x".join(str(s) for s in shape) + self.name = f"scatter_add_{shape_str}_dim{dim}_idx{num_indices}" + + @classmethod + def get_test_configs(cls) -> List["ScatterAddTest"]: + return [ + # 2D, scatter along dim 1 + cls(shape=(4, 8), dim=1, num_indices=3), + # 2D, scatter along dim 0 + cls(shape=(4, 8), dim=0, num_indices=3), + # 3D, scatter along last dim + cls(shape=(2, 4, 8), dim=2, num_indices=4), + # 3D, scatter along dim 1 + cls(shape=(2, 4, 8), dim=1, num_indices=2), + ] + + def create_model(self) -> nn.Module: + return ScatterAddModel(dim=self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + src_shape = list(self.shape) + src_shape[self.dim] = self.num_indices + src = torch.randn(src_shape) + dim_size = self.shape[self.dim] + index = torch.randint(0, dim_size, src_shape, dtype=torch.long) + return (x, index, src) + + @register_test class QuantizedEmbeddingTest(OpTestCase): """Test case for TorchAO int4 quantized nn.Embedding.""" diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 947f02acee6..83373a804f4 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -4,7 +4,7 @@ Self-contained ExecuTorch implementation of [Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B), a ~35B total / ~3B active parameter Mixture-of-Experts language model. Weights are loaded directly from the HuggingFace safetensors checkpoint. -CUDA backend only. See [model.md](model.md) for architecture and +Supports CUDA and MLX backends. See [model.md](model.md) for architecture and implementation details. ## Overview @@ -30,6 +30,16 @@ Export produces a `model.pte` and `aoti_cuda_blob.ptd` containing the compiled CUDA kernels and quantized weights. Int4 quantization is recommended — the model is too large to fit in VRAM at bf16. +```bash +python export.py \ + --model-id Qwen/Qwen3.5-35B-A3B \ + --output-dir ./qwen35_moe_exports \ + --qlinear 4w \ + --qembedding 8w +``` + +Or with a local directory: + ```bash python export.py \ --model-dir ~/models/Qwen3.5-35B-A3B \ @@ -42,7 +52,8 @@ python export.py \ | Flag | Default | Description | |------|---------|-------------| -| `--model-dir` | (required) | HuggingFace model directory with `config.json` + safetensors | +| `--model-id` | (none) | HuggingFace model ID (e.g. `Qwen/Qwen3.5-35B-A3B`). Downloads automatically. | +| `--model-dir` | (none) | Local HuggingFace model directory with `config.json` + safetensors | | `--output-dir` | `./qwen35_moe_exports` | Output directory | | `--max-seq-len` | `4096` | KV cache length | | `--qlinear` | (none) | Linear layer quantization: `4w`, `8w`, `8da4w`, `8da8w` | @@ -135,3 +146,73 @@ cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \ memory. - **Missing `aoti_cuda_blob.ptd`**: This file is produced during export alongside the `.pte`. Both files are required for inference. + +## MLX Backend (Apple Silicon) + +The MLX backend enables running Qwen 3.5 MoE on Apple Silicon GPUs. +It replaces the Triton-dependent modules (FusedMoEExperts, GatedDeltaNet) +with MLX custom ops (`mlx::gather_qmm`, `mlx::gated_delta_rule`, `mlx::rope`). + +### Export (MLX) + +```bash +python export.py \ + --model-id Qwen/Qwen3.5-35B-A3B \ + --backend mlx \ + --qlinear 4w \ + --qlinear-group-size 64 \ + --output-dir ./qwen35_moe_mlx +``` + +Or with a local directory: + +```bash +python export.py \ + --model-dir ~/models/Qwen3.5-35B-A3B \ + --backend mlx \ + --qlinear 4w \ + --qlinear-group-size 64 \ + --output-dir ./qwen35_moe_mlx +``` + +### MLX Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--backend mlx` | `cuda` | Use MLX backend for Apple Silicon | +| `--model-id` | (none) | HuggingFace model ID (downloads automatically) | +| `--model-dir` | (none) | Local model directory | +| `--qlinear` | (none) | Linear layer quantization: `4w`, `8w` | +| `--qlinear-group-size` | `32` | Group size (64 recommended for MLX) | +| `--qembedding` | (none) | Embedding quantization: `8w` | +| `--tiny-test` | off | Build tiny model with random weights for CI testing | + +### Run (MLX) + +```bash +python -m executorch.examples.models.qwen3_5_moe.run \ + --pte ./qwen35_moe_mlx/model.pte \ + --tokenizer Qwen/Qwen3.5-35B-A3B \ + --prompt "What is the capital of France?" \ + --max-new-tokens 50 +``` + +### Tiny Model Test + +For CI or quick pipeline validation (no model download needed): + +```bash +# Export tiny model (~1 MB, random weights) +python export.py \ + --tiny-test \ + --backend mlx \ + --qlinear 4w \ + --qlinear-group-size 32 \ + --output-dir /tmp/qwen35_moe_mlx_tiny + +# Run inference (random tokens, no tokenizer needed) +python -m executorch.examples.models.qwen3_5_moe.run \ + --pte /tmp/qwen35_moe_mlx_tiny/model.pte \ + --prompt-len 4 \ + --max-new-tokens 5 +``` diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 19a720a2e79..398df1bb086 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -1,10 +1,14 @@ """ -Export Qwen 3.5 MoE to ExecuTorch .pte format (CUDA only). +Export Qwen 3.5 MoE to ExecuTorch .pte format. + +Supports CUDA and MLX backends. Usage: + python export.py --model-id Qwen/Qwen3.5-35B-A3B python export.py --model-dir /path/to/Qwen3.5-MoE-A3B python export.py --model-dir /path/to/model --qlinear 4w python export.py --prequantized /path/to/quantized_bundle/ + python export.py --model-id Qwen/Qwen3.5-35B-A3B --backend mlx --qlinear 4w """ import argparse @@ -25,28 +29,135 @@ # --------------------------------------------------------------------------- +def _prepare_and_quantize_mlx(model, config, args): + """MLX: apply source transforms, quantize via torchao, pack experts.""" + from executorch.backends.mlx.llm.switch import pack_all_switch_linears + from executorch.examples.models.qwen3_5_moe.mlx_source_transformations import ( + mlx_source_transformations, + ) + + model.to(dtype=torch.bfloat16) + + # Materialize meta-device buffers before source transforms + for fqn, buf in list(model.named_buffers()): + if buf.device.type == "meta": + parts = fqn.rsplit(".", 1) + parent = model.get_submodule(parts[0]) if len(parts) > 1 else model + parent.register_buffer( + parts[-1], + torch.zeros(buf.shape, dtype=buf.dtype, device="cpu"), + ) + + mlx_source_transformations( + model, + model_dtype=torch.bfloat16, + config=config, + sort_experts=True, + fuse_gate_up=False, + ) + if args.qlinear or args.qembedding: + from executorch.extension.llm.export.quantize import quantize_model_ + + quantize_model_( + model, + qlinear_config=args.qlinear, + qlinear_group_size=args.qlinear_group_size, + qembedding_config=args.qembedding, + qembedding_group_size=getattr(args, "qembedding_group_size", None), + ) + pack_all_switch_linears(model) + + def load_and_quantize(args): - """Load model from checkpoint, optionally quantize, move to CUDA. + """Load model from checkpoint, optionally quantize. + + For CUDA: quantizes experts with packed INT4, then transformer layers on CUDA. + For MLX: applies source transforms first, then quantizes via torchao, then packs. Returns (model, config) ready for export. """ - if args.prequantized: - return load_prequantized_model(args.prequantized, args.max_seq_len) + backend = getattr(args, "backend", "cuda") + + if not args.prequantized: + if getattr(args, "tiny_test", False): + # Build tiny model with random weights for CI testing. + # Exercises the same architectural features as the real model: + # - GQA in full attention (n_heads=4, n_kv_heads=2 → 2:1 ratio) + # - GDN key/value head ratio (k_heads=2, v_heads=4 → 1:2 ratio) + # - Partial RoPE (25% of head_dim) + # - Mixed attention (full_attention_interval=2 → alternating layers) + # - Top-k MoE routing (top_k=2 from 8 experts) + # - Shared expert with gating + # - Fused gate+up expert weights [E, 2*inter, D] + # - Depthwise conv1d with state (kernel_dim=4) + tiny_config = Qwen35MoEConfig( + vocab_size=256, + hidden_size=128, + num_hidden_layers=4, # 4 layers: 2 linear + 2 full attention + num_attention_heads=4, # GQA: 4 heads with 2 KV heads (2:1 ratio) + num_kv_heads=2, + head_dim=64, + partial_rotary_factor=0.25, + linear_num_key_heads=2, # GDN: 2 key heads, 4 value heads (1:2 ratio) + linear_num_value_heads=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_conv_kernel_dim=4, + num_experts=8, # 8 experts, top-2 routing + num_experts_per_tok=2, + moe_intermediate_size=128, + shared_expert_intermediate_size=128, + full_attention_interval=2, # alternating: linear, full, linear, full + rms_norm_eps=1e-6, + rope_theta=10_000.0, + max_seq_len=64, + ) + print("Building tiny model with random weights...") + torch.manual_seed(42) + model = Qwen35MoE(tiny_config) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + model.eval() + for p in model.parameters(): + p.requires_grad_(False) + config = tiny_config + print( + f"Tiny model: {config.num_hidden_layers} layers, " + f"{config.num_experts} experts top-{config.num_experts_per_tok}, " + f"layer_types={config.layer_types}" + ) + else: + print("Loading model...") + model, config = Qwen35MoE.from_hf_checkpoint( + args.model_dir, max_seq_len=args.max_seq_len + ) + model.eval() + print( + f"Model: {config.num_hidden_layers} layers, {config.hidden_size}d, " + f"{config.num_experts} experts top-{config.num_experts_per_tok}" + ) - print("Loading model...") - model, config = Qwen35MoE.from_hf_checkpoint( - args.model_dir, max_seq_len=args.max_seq_len - ) - model.eval() - print( - f"Model: {config.num_hidden_layers} layers, {config.hidden_size}d, " - f"{config.num_experts} experts top-{config.num_experts_per_tok}" - ) + if backend == "mlx": + if args.prequantized: + raise ValueError( + "MLX backend does not support custom prequantized weights. Use a prequantized torchao checkpoint instead." + ) + _prepare_and_quantize_mlx(model, config, args) + + elif backend == "cuda": + if args.prequantized: + return load_prequantized_model(args.prequantized, args.max_seq_len) + + # CUDA: quantize experts with packed INT4 for Triton kernel + if args.qlinear or args.qembedding: + _quantize(model, config, args) + else: + model.to(dtype=torch.bfloat16) - if args.qlinear or args.qembedding: - _quantize(model, config, args) else: - model.to(dtype=torch.bfloat16) + raise ValueError(f"Unsupported backend: {backend}") return model, config @@ -381,6 +492,96 @@ def _apply_turboquant(model, config): def export_and_lower(model, config, args): + """Export model to .pte via torch.export + backend-specific lowering.""" + backend = getattr(args, "backend", "cuda") + + if backend == "mlx": + _export_mlx(model, config, args) + else: + _export_cuda(model, config, args) + + +def _export_mlx(model, config, args): + """Export model to .pte via torch.export + MLX backend.""" + import gc + + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + example_tokens = torch.tensor([[0, 1]], dtype=torch.long) + example_input_pos = torch.tensor([0, 1], dtype=torch.long) + seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1) + dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) + + print("Exporting with torch.export...") + with torch.no_grad(): + exported = export( + model, + (example_tokens, example_input_pos), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + print("Export successful!") + + del model + gc.collect() + + print("Lowering to ExecuTorch with MLX backend...") + metadata = { + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + } + et_prog = to_edge_transform_and_lower( + exported, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + ) + + del exported + gc.collect() + + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + del et_prog + gc.collect() + + os.makedirs(args.output_dir, exist_ok=True) + pte_path = os.path.join(args.output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + size_mb = os.path.getsize(pte_path) / (1024 * 1024) + print(f"Saved {size_mb:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(args.output_dir) + print(f"Saved tensor data to {args.output_dir}/") + + print("Done!") + + +def _export_cuda(model, config, args): """Export model to .pte via torch.export + CUDA backend. Exports two methods: @@ -509,17 +710,28 @@ def export_and_lower(model, config, args): def main(): parser = argparse.ArgumentParser( - description="Export Qwen3.5 MoE to ExecuTorch (CUDA)" + description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)" ) parser.add_argument( "--model-dir", default=None, help="HuggingFace model directory (not needed with --prequantized)", ) + parser.add_argument( + "--model-id", + default=None, + help="HuggingFace model-id", + ) parser.add_argument( "--output-dir", default="./qwen35_moe_exports", help="Output directory" ) parser.add_argument("--max-seq-len", type=int, default=4096, help="KV cache length") + parser.add_argument( + "--backend", + default="cuda", + choices=["cuda", "mlx"], + help="Backend for export: cuda (default) or mlx.", + ) parser.add_argument( "--qlinear", default=None, @@ -535,6 +747,12 @@ def main(): parser.add_argument( "--qembedding", default=None, choices=["8w"], help="Quantize embedding layers." ) + parser.add_argument( + "--qembedding-group-size", + type=int, + default=None, + help="Group size for embedding quantization.", + ) parser.add_argument( "--hqq", action="store_true", @@ -552,22 +770,47 @@ def main(): action="store_true", help="Enable TurboQuant TQ4 KV cache compression (3.8x cache savings).", ) + parser.add_argument( + "--tiny-test", + action="store_true", + default=False, + help="Build a tiny model with random weights for CI pipeline testing. " + "No checkpoint download needed. Tests all architectural features " + "(GQA, GDN head ratio, mixed attention, MoE routing) at small scale.", + ) args = parser.parse_args() - if not args.prequantized and not args.model_dir: - parser.error("--model-dir is required unless --prequantized is provided.") + if args.model_id: + if args.model_dir is not None: + raise ValueError("Cannot specify model_dir when model_id is provided.") + from huggingface_hub import snapshot_download + + args.model_dir = snapshot_download(repo_id=args.model_id) + + if not args.prequantized and not args.model_dir and not args.tiny_test: + parser.error( + "--model-dir is required unless --prequantized or --tiny-test is provided." + ) if args.hqq and not args.qlinear: parser.error("--hqq requires --qlinear") - # Register FLA Triton kernel - import executorch.backends.cuda.triton.kernels # noqa: F401 + if args.backend == "cuda": + # Register FLA Triton kernel (CUDA only) + import executorch.backends.cuda.triton.kernels # noqa: F401 + + if args.backend == "mlx": + if args.prequantized: + parser.error("--prequantized is not supported with --backend mlx") + if args.turboquant: + parser.error("--turboquant is not supported with --backend mlx") model, config = load_and_quantize(args) - _materialize_buffers(model, config) - if args.turboquant: - _apply_turboquant(model, config) + if args.backend == "cuda": + _materialize_buffers(model, config) + if args.turboquant: + _apply_turboquant(model, config) export_and_lower(model, config, args) diff --git a/examples/models/qwen3_5_moe/mlx_source_transformations.py b/examples/models/qwen3_5_moe/mlx_source_transformations.py new file mode 100644 index 00000000000..25605fb6342 --- /dev/null +++ b/examples/models/qwen3_5_moe/mlx_source_transformations.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MLX source transformations for Qwen 3.5 MoE. + +Replaces Triton-dependent modules (FusedMoEExperts, GatedDeltaNet) with +pure-PyTorch + MLX custom op equivalents that can be exported and lowered +to the MLX delegate. +""" + +import logging +import types + +import torch +import torch.nn as nn + +from executorch.examples.models.qwen3_5_moe.model import ( + FullAttention, + FusedMoEExperts, + GatedDeltaNet, + GemmaRMSNorm, + KVCache, + SparseMoE, +) + +logger = logging.getLogger(__name__) + + +def _rms_norm_gated_forward(self, x, z): + """Export-friendly RMSNormGated: avoids explicit .float() / .type_as() casts. + + Uses F.rms_norm which maps to fast::rms_norm (handles precision internally) + and F.silu which also handles bf16 natively in MLX. + """ + return torch.nn.functional.rms_norm( + x, (x.shape[-1],), self.weight, self.eps + ) * torch.nn.functional.silu(z) + + +def _gemma_rms_norm_forward(self, x): + """Export-friendly GemmaRMSNorm: avoids explicit .float() / .type_as() casts. + + The original does x.float() → normalize → (1+weight).float() → type_as, + producing 2+ AsType nodes per norm. F.rms_norm handles precision internally. + The (1+weight) offset is precomputed by the swap code below. + """ + return torch.nn.functional.rms_norm(x, (x.shape[-1],), self._rms_weight, self.eps) + + +def _sparse_moe_forward(self, x): + """Export-friendly SparseMoE: removes .float() on expert_weights. + + The original passes expert_weights.float() to the experts, causing + bf16→f32 casts. GatherMm/GatherQmm handle bf16 weights natively. + """ + B, T, C = x.size() + x_flat = x.view(-1, C) + + scores = self.gate(x_flat) + expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1) + expert_weights = expert_weights.softmax(dim=-1) + + routed_out = self.experts( + x_flat, + expert_weights, + expert_indices, + self.top_k, + sort_experts=getattr(self, "_sort_experts", False), + ) + + shared_out = self.shared_expert(x_flat) + shared_gate = torch.sigmoid(self.shared_expert_gate(x_flat)) + return (routed_out + shared_gate * shared_out).view(B, T, C) + + +def _full_attention_forward(self, x, input_pos): + """Export-friendly FullAttention: uses mlx::rope custom op. + + Replaces the decomposed RotaryEmbedding (~14 ops: outer, cos, sin, slice, + multiply, subtract, cat, AsType) with 2 RopeNode ops that fuse to + fast::rope. Also removes unnecessary .to(dtype) casts. + """ + B, T, _ = x.size() + + qkv = self.qkv_proj(x) + q_and_gate = qkv[..., : self.q_dim].view(B, T, self.n_heads, self.head_dim * 2) + q = q_and_gate[..., : self.head_dim] + gate = q_and_gate[..., self.head_dim :] + + k = qkv[..., self.q_dim : self.q_dim + self.k_dim].view( + B, T, self.n_kv_heads, self.head_dim + ) + v = qkv[..., self.q_dim + self.k_dim :].view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + # Transpose to BHSD before RoPE (mlx::rope expects B,H,T,D) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Use mlx::rope custom op — fuses to a single RopeNode per tensor, + # replacing ~14 decomposed ops (outer, cos, sin, slice, mul, cat, etc.) + pos = input_pos[0].item() + q = torch.ops.mlx.rope(q, self._rope_dims, pos, False, self._rope_base, 1.0, None) + k = torch.ops.mlx.rope(k, self._rope_dims, pos, False, self._rope_base, 1.0, None) + + k, v = self.kv_cache.update(input_pos, k, v) + + if self.n_kv_groups > 1: + k = k.repeat_interleave(self.n_kv_groups, dim=1) + v = v.repeat_interleave(self.n_kv_groups, dim=1) + + attn_mask = self.mask[input_pos].unsqueeze(0).unsqueeze(0) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + + y = y.transpose(1, 2).contiguous().view(B, T, -1) + + gate = gate.reshape(B, T, -1) + y = y * torch.sigmoid(gate) + + return self.o_proj(y) + + +def _exportable_gated_delta_net_forward(self, x, input_pos): + """Pure PyTorch replacement for GatedDeltaNet.forward(). + + Identical pre/post-processing to the original, but replaces + torch.ops.triton.chunk_gated_delta_rule with a pure PyTorch + recurrent implementation via mlx::gated_delta_rule custom op. + """ + + B, T, _ = x.size() + + # Reset state at position 0 (in-place to preserve buffer identity) + reset = (input_pos[0] == 0).to(self.conv_state.dtype) + keep = 1.0 - reset + self.conv_state.mul_(keep) + self.recurrent_state.mul_(keep) + + # Fused projection: split into qkv, z, b, a + proj = self.in_proj(x) + cd = self.conv_dim + vd = self.value_dim + nh = self.num_v_heads + mixed_qkv = proj[..., :cd] + z = proj[..., cd : cd + vd].reshape(B, T, self.num_v_heads, self.head_v_dim) + b = proj[..., cd + vd : cd + vd + nh] + a = proj[..., cd + vd + nh :] + + # Causal depthwise conv1d with state + qkv_t = mixed_qkv.transpose(1, 2) # [B, C, T] + conv_input = torch.cat([self.conv_state[:B], qkv_t], dim=-1) + conv_len = conv_input.shape[-1] + # Update conv_state in-place to preserve buffer identity + # (attribute reassignment would break mutation tracking) + self.conv_state[:B].copy_(conv_input[:, :, conv_len - self.conv_kernel_size :]) + + conv_out = torch.nn.functional.conv1d( + conv_input, self.conv1d.weight, groups=self.conv_dim + ) + conv_start = conv_out.shape[-1] - T + qkv_conv = torch.nn.functional.silu(conv_out[:, :, conv_start:]).transpose(1, 2) + + # Split into Q, K, V + kd = self.key_dim + q = qkv_conv[..., :kd].reshape(B, T, self.num_k_heads, self.head_k_dim) + k = qkv_conv[..., kd : 2 * kd].reshape(B, T, self.num_k_heads, self.head_k_dim) + v = qkv_conv[..., 2 * kd :].reshape(B, T, self.num_v_heads, self.head_v_dim) + + # RMS-normalize Q and K with asymmetric scaling + # Uses pre-registered _qk_rms_weight (bf16 ones) so rms_norm returns bf16 + inv_scale = torch.tensor(self.head_k_dim**-0.5, dtype=x.dtype) + q = (inv_scale * inv_scale) * torch.nn.functional.rms_norm( + q, (self.head_k_dim,), self._qk_rms_weight, eps=1e-6 + ) + k = inv_scale * torch.nn.functional.rms_norm( + k, (self.head_k_dim,), self._qk_rms_weight, eps=1e-6 + ) + + # head_repeat for k_heads != v_heads + if self.head_repeat > 1: + q = q.repeat_interleave(self.head_repeat, dim=2) + k = k.repeat_interleave(self.head_repeat, dim=2) + + # Mamba-style gating + beta = b.sigmoid() + x = a + self.dt_bias + g = (-self.A_log.exp() * torch.logaddexp(x, torch.zeros_like(x))).exp() + + import executorch.backends.mlx.model_ops.gated_delta_rule as _ # noqa: ensure op registered + + output = torch.ops.mlx.gated_delta_rule( + q, + k, + v, + g, + beta, + self.recurrent_state[:B], + ) + + # RMSNorm(output) * silu(z) → out_proj + # output shape from exportable fn: (B, T, num_v_heads * head_v_dim) + output = output.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + output = self.norm(output, z) + output = output.reshape(B, T, -1) + + return self.out_proj(output) + + +def _swap_moe_experts(model, fuse_gate_up): + """FusedMoEExperts → SwitchMLP.""" + from executorch.backends.mlx.llm.switch import SwitchMLP + + count = 0 + for name, module in model.named_modules(): + if not isinstance(module, FusedMoEExperts): + continue + + switch_mlp = SwitchMLP( + module.hidden_size, + module.intermediate_size, + module.num_experts, + fuse_gate_up=fuse_gate_up, + ) + switch_mlp.to(dtype=module.w1_weight.dtype) + + inter = module.intermediate_size + with torch.no_grad(): + if fuse_gate_up: + for e in range(module.num_experts): + switch_mlp.gate_up_proj.experts[e].weight.copy_(module.w1_weight[e]) + switch_mlp.down_proj.experts[e].weight.copy_(module.w2_weight[e]) + else: + for e in range(module.num_experts): + switch_mlp.gate_proj.experts[e].weight.copy_( + module.w1_weight[e, :inter, :] + ) + switch_mlp.up_proj.experts[e].weight.copy_( + module.w1_weight[e, inter:, :] + ) + switch_mlp.down_proj.experts[e].weight.copy_(module.w2_weight[e]) + + parts = name.rsplit(".", 1) + if len(parts) == 1: + setattr(model, parts[0], switch_mlp) + else: + parent = model.get_submodule(parts[0]) + setattr(parent, parts[1], switch_mlp) + count += 1 + return count + + +def _swap_gated_delta_net(model, model_dtype): + """GatedDeltaNet → mlx::gated_delta_rule custom op.""" + count = 0 + for _name, module in model.named_modules(): + if isinstance(module, GatedDeltaNet): + module.forward = types.MethodType( + _exportable_gated_delta_net_forward, module + ) + if module.recurrent_state.dtype != model_dtype: + module.recurrent_state = module.recurrent_state.to(model_dtype) + module.norm.forward = types.MethodType(_rms_norm_gated_forward, module.norm) + module.register_buffer( + "_qk_rms_weight", + torch.ones(module.head_k_dim, dtype=model_dtype), + ) + count += 1 + return count + + +def _swap_full_attention(model, config): + """FullAttention → mlx::rope custom op + causal mask.""" + rope_theta = config.rope_theta if config else 10000.0 + max_seq_len = config.max_seq_len if config else 4096 + count = 0 + for _name, module in model.named_modules(): + if isinstance(module, FullAttention): + module._rope_dims = module.rotary_emb.rotary_dim + module._rope_base = rope_theta + mask = torch.full((max_seq_len, max_seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + module.register_buffer("mask", mask) + module.forward = types.MethodType(_full_attention_forward, module) + count += 1 + return count + + +def _swap_kv_cache(model, model_dtype): + """KVCache → MLX KVCache (mlx::kv_cache_update).""" + from executorch.backends.mlx.llm.cache import KVCache as MLXKVCache + + count = 0 + for _name, module in model.named_modules(): + if hasattr(module, "kv_cache") and isinstance(module.kv_cache, KVCache): + old = module.kv_cache + module.kv_cache = MLXKVCache( + max_batch_size=1, + max_context_length=old.k_cache.shape[2], + n_heads=old.k_cache.shape[1], + head_dim=old.k_cache.shape[3], + enable_dynamic_shape=True, + dtype=model_dtype, + ) + count += 1 + return count + + +def _swap_rms_norm(model): + """GemmaRMSNorm → F.rms_norm (no .float() casts).""" + count = 0 + for _name, module in model.named_modules(): + if isinstance(module, GemmaRMSNorm): + module._rms_weight = nn.Parameter(1.0 + module.weight.data) + module.forward = types.MethodType(_gemma_rms_norm_forward, module) + count += 1 + return count + + +def _swap_sparse_moe(model, sort_experts): + """SparseMoE → no .float() on expert_weights.""" + count = 0 + for _name, module in model.named_modules(): + if isinstance(module, SparseMoE): + module._sort_experts = sort_experts + module.forward = types.MethodType(_sparse_moe_forward, module) + count += 1 + return count + + +def mlx_source_transformations( + model, + model_dtype=torch.bfloat16, + config=None, + sort_experts=False, + fuse_gate_up=False, +): + """Replace all Triton-dependent modules with MLX-compatible equivalents. + + Performs the following transformations: + 1. FusedMoEExperts → SwitchMLP (uses mlx::gather_mm / mlx::gather_qmm) + 2. GatedDeltaNet → mlx::gated_delta_rule custom op + 3. FullAttention → mlx::rope custom op + 4. KVCache → MLX KVCache (mlx::kv_cache_update) + 5. GemmaRMSNorm → F.rms_norm (no .float() casts) + 6. SparseMoE → no .float() on expert_weights + + Args: + model: The Qwen 3.5 MoE model to transform. + model_dtype: Target dtype for the model (default: bf16). + config: Model config (Qwen35MoEConfig). + sort_experts: Sort tokens by expert index for coalesced memory access. + fuse_gate_up: Fuse gate+up into single SwitchLinear. + """ + count_moe = _swap_moe_experts(model, fuse_gate_up) + count_gdn = _swap_gated_delta_net(model, model_dtype) + count_attn = _swap_full_attention(model, config) + count_kv = _swap_kv_cache(model, model_dtype) + count_norm = _swap_rms_norm(model) + count_moe_fwd = _swap_sparse_moe(model, sort_experts) + + logger.info(f"Replaced {count_moe} FusedMoEExperts → SwitchMLP") + logger.info(f"Replaced {count_gdn} GatedDeltaNet → exportable PyTorch forward") + logger.info(f"Replaced {count_attn} FullAttention → mlx::rope + causal mask") + logger.info(f"Replaced {count_kv} KVCache → MLX KVCache (mlx::kv_cache_update)") + logger.info(f"Replaced {count_norm} GemmaRMSNorm → F.rms_norm (no .float() casts)") + logger.info(f"Replaced {count_moe_fwd} SparseMoE → no .float() on expert_weights") diff --git a/examples/models/qwen3_5_moe/run.py b/examples/models/qwen3_5_moe/run.py new file mode 100644 index 00000000000..94350afe6db --- /dev/null +++ b/examples/models/qwen3_5_moe/run.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run exported Qwen 3.5 MoE model using ExecuTorch pybindings. + +Companion to export.py --backend mlx. Supports both real model inference +(with HuggingFace tokenizer) and fake-weight validation (random tokens). + +Usage: + # Run with real tokenizer: + python -m executorch.examples.models.qwen3_5_moe.run \ + --pte qwen35_moe_mlx.pte \ + --tokenizer Qwen/Qwen3.5-35B-A3B \ + --prompt "Hello, world!" + + # Run with random tokens (fake weights, no tokenizer needed): + python -m executorch.examples.models.qwen3_5_moe.run \ + --pte qwen35_moe_mlx.pte \ + --prompt-len 8 \ + --max-new-tokens 20 +""" + +import argparse +import logging +import time + +import torch +from executorch.runtime import Runtime, Verification + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def run_inference( + pte_path: str, + tokenizer_id: str = None, + prompt: str = None, + prompt_len: int = 4, + max_new_tokens: int = 10, + vocab_size: int = 248320, + temperature: float = 0.0, +) -> None: + """Run inference on the exported Qwen 3.5 MoE model.""" + logger.info(f"Loading model from {pte_path}...") + et_runtime = Runtime.get() + program = et_runtime.load_program(pte_path, verification=Verification.Minimal) + forward = program.load_method("forward") + logger.info("Model loaded successfully") + + # Read vocab size from model metadata if available + try: + meta_method = program.load_method("get_vocab_size") + result = meta_method.execute([]) + model_vocab_size = result[0] if isinstance(result[0], int) else result[0].item() + logger.info(f"Vocab size from model metadata: {model_vocab_size}") + vocab_size = model_vocab_size + except Exception: + logger.info(f"No vocab size in metadata, using default: {vocab_size}") + + # Tokenize or generate random tokens + tokenizer = None + if tokenizer_id and prompt: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) + + # Apply chat template so the model sees proper conversation boundaries + # and knows when to stop generating (at <|im_end|>) + messages = [{"role": "user", "content": prompt}] + templated = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + input_ids = tokenizer.encode(templated, return_tensors="pt").to(torch.long) + prompt_len = input_ids.shape[1] + logger.info(f"Prompt: {prompt!r} ({prompt_len} tokens)") + + # Collect stop token ids (EOS + any end-of-turn markers) + stop_token_ids = set() + if tokenizer.eos_token_id is not None: + stop_token_ids.add(tokenizer.eos_token_id) + # <|im_end|> is the stop token for Qwen chat models + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + if isinstance(im_end_id, int) and im_end_id != tokenizer.unk_token_id: + stop_token_ids.add(im_end_id) + else: + stop_token_ids = set() + torch.manual_seed(42) + input_ids = torch.randint(0, vocab_size, (1, prompt_len), dtype=torch.long) + logger.info(f"Random prompt ({prompt_len} tokens): {input_ids[0].tolist()}") + + # --- Warmup run (JIT compile Metal kernels, warm GPU caches) --- + logger.info("Running warmup...") + warmup_tokens = torch.zeros((1, 1), dtype=torch.long) + warmup_pos = torch.tensor([0], dtype=torch.long) + forward.execute([warmup_tokens, warmup_pos]) + + # --- Prefill --- + logger.info(f"Running prefill ({prompt_len} tokens)...") + start_time = time.time() + + input_pos = torch.arange(prompt_len, dtype=torch.long) + outputs = forward.execute([input_ids, input_pos]) + logits = outputs[0] + + prefill_time = time.time() - start_time + logger.info( + f"Prefill: {prefill_time:.3f}s " f"({prompt_len / prefill_time:.1f} tokens/sec)" + ) + + # First generated token + next_token_logits = logits[0, -1, :] + if temperature > 0: + probs = torch.softmax(next_token_logits / temperature, dim=-1) + next_token = torch.multinomial(probs, 1).item() + else: + next_token = torch.argmax(next_token_logits).item() + generated_tokens = [next_token] + + # --- Decode --- + logger.info(f"Generating up to {max_new_tokens} tokens...") + decode_start = time.time() + t_execute = 0 + t_prep = 0 + t_post = 0 + + for _i in range(max_new_tokens - 1): + t0 = time.time() + pos = prompt_len + len(generated_tokens) - 1 + input_pos = torch.tensor([pos], dtype=torch.long) + token_input = torch.tensor([[next_token]], dtype=torch.long) + t1 = time.time() + t_prep += t1 - t0 + + outputs = forward.execute([token_input, input_pos]) + logits = outputs[0] + t2 = time.time() + t_execute += t2 - t1 + + next_token_logits = logits[0, -1, :] + if temperature > 0: + probs = torch.softmax(next_token_logits / temperature, dim=-1) + next_token = torch.multinomial(probs, 1).item() + else: + next_token = torch.argmax(next_token_logits).item() + generated_tokens.append(next_token) + t3 = time.time() + t_post += t3 - t2 + + # Stop on EOS / end-of-turn token + if next_token in stop_token_ids: + break + + decode_time = time.time() - decode_start + num_generated = len(generated_tokens) + tokens_per_sec = num_generated / decode_time if decode_time > 0 else 0 + + # Print decode timing breakdown + n_decode = num_generated - 1 # exclude first token (from prefill) + if n_decode > 0: + print(f"\nDecode timing breakdown ({n_decode} steps):") + print( + f" Prep (tensor creation): {t_prep*1000:.1f}ms total, {t_prep/n_decode*1000:.2f}ms/step" + ) + print( + f" Execute (forward.execute): {t_execute*1000:.1f}ms total, {t_execute/n_decode*1000:.2f}ms/step" + ) + print( + f" Post (argmax/sample): {t_post*1000:.1f}ms total, {t_post/n_decode*1000:.2f}ms/step" + ) + + # Print results + print(f"\nPrefill: {prefill_time:.3f}s ({prompt_len / prefill_time:.1f} tok/s)") + print( + f"Decode: {decode_time:.3f}s " + f"({num_generated} tokens, {tokens_per_sec:.1f} tok/s)" + ) + + if tokenizer: + generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"Generated: {generated_text}") + else: + print(f"\nGenerated token ids: {generated_tokens}") + + +def main(): + parser = argparse.ArgumentParser(description="Run exported Qwen 3.5 MoE model") + parser.add_argument( + "--pte", + type=str, + required=True, + help="Path to the .pte file from export.py --backend mlx", + ) + parser.add_argument( + "--tokenizer", + type=str, + default=None, + help="HuggingFace tokenizer ID (e.g. Qwen/Qwen3.5-35B-A3B)", + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Text prompt for generation (requires --tokenizer)", + ) + parser.add_argument( + "--prompt-len", + type=int, + default=4, + help="Number of random tokens for the prompt (when no --prompt given)", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=10, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--vocab-size", + type=int, + default=248320, + help="Vocab size for random token generation", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature (0.0 = greedy)", + ) + + args = parser.parse_args() + + if args.prompt and not args.tokenizer: + parser.error("--prompt requires --tokenizer") + + run_inference( + pte_path=args.pte, + tokenizer_id=args.tokenizer, + prompt=args.prompt, + prompt_len=args.prompt_len, + max_new_tokens=args.max_new_tokens, + vocab_size=args.vocab_size, + temperature=args.temperature, + ) + + +if __name__ == "__main__": + main()