From af6810bfa4d8c2fc3ceb84e33ebff31bc59cbd7d Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 17:46:59 -0800 Subject: [PATCH 01/34] up --- .github/workflows/mlx.yml | 99 ++ .gitignore | 2 + .gitmodules | 4 + CMakeLists.txt | 15 + CMakePresets.json | 85 +- backends/mlx/CMakeLists.txt | 330 ++++ backends/mlx/README.md | 499 ++++++ backends/mlx/__init__.py | 17 + backends/mlx/_logging.py | 40 + backends/mlx/builder/__init__.py | 16 + backends/mlx/builder/op_helpers.py | 275 ++++ backends/mlx/builder/op_registry.py | 151 ++ backends/mlx/builder/pattern_matcher.py | 64 + backends/mlx/builder/program_builder.py | 1018 ++++++++++++ backends/mlx/builder/slot_manager.py | 187 +++ backends/mlx/custom_ops.py | 15 + backends/mlx/ops.py | 294 ++++ backends/mlx/partitioner.py | 298 ++++ backends/mlx/passes.py | 20 + backends/mlx/patches/mlx_json.patch | 29 + backends/mlx/pattern_utils.py | 360 +++++ backends/mlx/patterns.py | 14 + backends/mlx/preprocess.py | 168 ++ backends/mlx/pte_inspector.py | 897 ++++++++++ backends/mlx/runtime/MLXBackend.cpp | 419 +++++ backends/mlx/runtime/MLXExecutor.h | 878 ++++++++++ backends/mlx/runtime/MLXInterpreter.h | 169 ++ backends/mlx/serialization/MLXLoader.cpp.tmpl | 324 ++++ backends/mlx/serialization/MLXLoader.h.tmpl | 343 ++++ backends/mlx/serialization/README.md | 130 ++ backends/mlx/serialization/__init__.py | 32 + backends/mlx/serialization/generate.py | 1437 +++++++++++++++++ .../mlx/serialization/mlx_graph_serialize.py | 416 +++++ backends/mlx/serialization/schema.fbs | 192 +++ backends/mlx/test/CMakeLists.txt | 51 + backends/mlx/test/README.md | 164 ++ backends/mlx/test/__init__.py | 5 + backends/mlx/test/op_test_runner.cpp | 395 +++++ backends/mlx/test/run_all_tests.py | 496 ++++++ backends/mlx/test/strict_compile_test.cpp | 45 + backends/mlx/test/test_ops.py | 176 ++ backends/mlx/test/test_partitioner.py | 45 + backends/mlx/test/test_passes.py | 6 + backends/mlx/test/test_pattern_utils.py | 592 +++++++ backends/mlx/test/test_utils.py | 1122 +++++++++++++ backends/mlx/test/tester.py | 78 + backends/mlx/third-party/mlx | 1 + backends/test/suite/flow.py | 11 +- backends/test/suite/flows/mlx.py | 14 + exir/_serialize/_program.py | 67 + setup.py | 33 + tools/cmake/Utils.cmake | 33 + tools/cmake/executorch-config.cmake | 45 + tools/cmake/preset/default.cmake | 1 + tools/cmake/preset/pybind.cmake | 18 + 55 files changed, 12633 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/mlx.yml create mode 100644 backends/mlx/CMakeLists.txt create mode 100644 backends/mlx/README.md create mode 100644 backends/mlx/__init__.py create mode 100644 backends/mlx/_logging.py create mode 100644 backends/mlx/builder/__init__.py create mode 100644 backends/mlx/builder/op_helpers.py create mode 100644 backends/mlx/builder/op_registry.py create mode 100644 backends/mlx/builder/pattern_matcher.py create mode 100644 backends/mlx/builder/program_builder.py create mode 100644 backends/mlx/builder/slot_manager.py create mode 100644 backends/mlx/custom_ops.py create mode 100644 backends/mlx/ops.py create mode 100644 backends/mlx/partitioner.py create mode 100644 backends/mlx/passes.py create mode 100644 backends/mlx/patches/mlx_json.patch create mode 100644 backends/mlx/pattern_utils.py create mode 100644 backends/mlx/patterns.py create mode 100644 backends/mlx/preprocess.py create mode 100644 backends/mlx/pte_inspector.py create mode 100644 backends/mlx/runtime/MLXBackend.cpp create mode 100644 backends/mlx/runtime/MLXExecutor.h create mode 100644 backends/mlx/runtime/MLXInterpreter.h create mode 100644 backends/mlx/serialization/MLXLoader.cpp.tmpl create mode 100644 backends/mlx/serialization/MLXLoader.h.tmpl create mode 100644 backends/mlx/serialization/README.md create mode 100644 backends/mlx/serialization/__init__.py create mode 100755 backends/mlx/serialization/generate.py create mode 100644 backends/mlx/serialization/mlx_graph_serialize.py create mode 100644 backends/mlx/serialization/schema.fbs create mode 100644 backends/mlx/test/CMakeLists.txt create mode 100644 backends/mlx/test/README.md create mode 100644 backends/mlx/test/__init__.py create mode 100644 backends/mlx/test/op_test_runner.cpp create mode 100644 backends/mlx/test/run_all_tests.py create mode 100644 backends/mlx/test/strict_compile_test.cpp create mode 100644 backends/mlx/test/test_ops.py create mode 100644 backends/mlx/test/test_partitioner.py create mode 100644 backends/mlx/test/test_passes.py create mode 100644 backends/mlx/test/test_pattern_utils.py create mode 100644 backends/mlx/test/test_utils.py create mode 100644 backends/mlx/test/tester.py create mode 160000 backends/mlx/third-party/mlx create mode 100644 backends/test/suite/flows/mlx.py diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml new file mode 100644 index 00000000000..2e8ca7aa3b7 --- /dev/null +++ b/.github/workflows/mlx.yml @@ -0,0 +1,99 @@ +name: MLX + +on: + push: + branches: + - main + - release/* + pull_request: + paths: + - .github/workflows/mlx.yml + - backends/mlx/** + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-mlx: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch and configure build" + ${CONDA_RUN} python install_executorch.py > /dev/null + # The sanitizers fail on github VM runner, but pass on real device + # TODO: figure out why + ${CONDA_RUN} cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON -DEXECUTORCH_MLX_ENABLE_SANITIZERS=OFF + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Build test runners" + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + echo "::endgroup::" + + echo "::group::Run op unit tests" + ${CONDA_RUN} python -m executorch.backends.mlx.test.run_all_tests -j4 --max-tasks-per-worker 10 --clean-after + echo "::endgroup::" + + echo "::group::Run Python unit tests" + ${CONDA_RUN} python -m pytest \ + backends/mlx/test/test_passes.py \ + backends/mlx/test/test_pattern_utils.py \ + backends/mlx/test/test_partitioner.py \ + -v + echo "::endgroup::" + + backend-tester: + strategy: + fail-fast: false + matrix: + suite: [models, operators] + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-backend-${{ matrix.suite }} + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Run backend test suite (${{ matrix.suite }})" + ${CONDA_RUN} pytest -c /dev/null backends/test/suite/${{ matrix.suite }}/ -m flow_mlx -n auto 2>&1 | tee pytest_output.txt || true + echo "::endgroup::" + + # Parse pytest summary and check failure threshold + if grep -E "^=+ .* =+$" pytest_output.txt | tail -1 | grep -q "failed"; then + FAILED=$(grep -E "^=+ .* =+$" pytest_output.txt | tail -1 | grep -oE "[0-9]+ failed" | grep -oE "[0-9]+") + else + FAILED=0 + fi + + if [ "${{ matrix.suite }}" = "operators" ]; then + MAX_FAILURES=0 + else + MAX_FAILURES=3 + fi + + echo "Failed tests: $FAILED (max allowed: $MAX_FAILURES)" + if [ "$FAILED" -gt "$MAX_FAILURES" ]; then + echo "::error::Too many test failures: $FAILED > $MAX_FAILURES" + exit 1 + fi diff --git a/.gitignore b/.gitignore index 4ddbb7c49ad..3453b7e9676 100644 --- a/.gitignore +++ b/.gitignore @@ -74,5 +74,7 @@ xcuserdata/ *.dll *.pyd + # Agents .claude/*.local.* +extension/pybindings/mlx.metallib diff --git a/.gitmodules b/.gitmodules index 1f202d4fdec..917e755da27 100644 --- a/.gitmodules +++ b/.gitmodules @@ -67,3 +67,7 @@ [submodule "third-party/json"] path = third-party/json url = https://github.com/nlohmann/json.git +[submodule "backends/mlx/third-party/mlx"] + path = backends/mlx/third-party/mlx + url = https://github.com/ml-explore/mlx.git + shallow = true diff --git a/CMakeLists.txt b/CMakeLists.txt index 995a75c342b..2297a8142f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -659,6 +659,11 @@ if(EXECUTORCH_BUILD_MPS) list(APPEND _executorch_backends mpsdelegate) endif() +if(EXECUTORCH_BUILD_MLX) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mlx) + list(APPEND _executorch_backends mlxdelegate) +endif() + if(EXECUTORCH_BUILD_NEURON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek) list(APPEND _executorch_backends neuron_backend) @@ -956,6 +961,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs mpsdelegate) endif() + if(EXECUTORCH_BUILD_MLX) + list(APPEND _dep_libs mlxdelegate) + endif() + if(EXECUTORCH_BUILD_OPENVINO) list(APPEND _dep_libs openvino_backend) endif() @@ -1056,6 +1065,12 @@ if(EXECUTORCH_BUILD_PYBIND) install(TARGETS data_loader LIBRARY DESTINATION executorch/extension/pybindings ) + + # Copy MLX metallib next to _portable_lib.so for editable installs. MLX uses + # dladdr() to find the directory containing the library with MLX code, then + # looks for mlx.metallib in that directory. When MLX is statically linked into + # _portable_lib.so, we need the metallib colocated with it. + executorch_target_copy_mlx_metallib(portable_lib) endif() if(EXECUTORCH_BUILD_WASM) diff --git a/CMakePresets.json b/CMakePresets.json index ca4da226ba1..fa1d77623d9 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -110,7 +110,7 @@ "inherits": ["common"], "cacheVariables": { "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/pybind.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0" }, "condition": { "type": "inList", @@ -294,6 +294,43 @@ "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_ethosu_linux.cmake", "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/aarch64-linux-musl-toolchain.cmake" } + }, + { + "name": "mlx", + "displayName": "Build MLX delegate", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/mlx.cmake", + "EXECUTORCH_ENABLE_LOGGING": "ON", + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "mlx-release", + "displayName": "MLX delegate release build", + "inherits": ["mlx"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "ET_MLX_ENABLE_OP_LOGGING": "OFF", + "ET_MIN_LOG_LEVEL": "Error" + } + }, + { + "name": "mlx-debug", + "displayName": "MLX delegate debug build with op logging", + "inherits": ["mlx"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "ET_MLX_ENABLE_OP_LOGGING": "ON", + "ET_MIN_LOG_LEVEL": "Debug" + } } ], "buildPresets": [ @@ -362,6 +399,24 @@ "install" ], "jobs": 0 + }, + { + "name": "mlx-release-install", + "displayName": "Build and install MLX delegate release artifacts", + "configurePreset": "mlx-release", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "mlx-debug-install", + "displayName": "Build and install MLX delegate debug artifacts", + "configurePreset": "mlx-debug", + "targets": [ + "install" + ], + "jobs": 0 } ], "workflowPresets": [ @@ -462,6 +517,34 @@ "name": "llm-metal-stats-install" } ] + }, + { + "name": "mlx-release", + "displayName": "Configure, build and install ExecuTorch MLX delegate", + "steps": [ + { + "type": "configure", + "name": "mlx-release" + }, + { + "type": "build", + "name": "mlx-release-install" + } + ] + }, + { + "name": "mlx-debug", + "displayName": "Configure, build and install ExecuTorch MLX delegate with op logging (Debug)", + "steps": [ + { + "type": "configure", + "name": "mlx-debug" + }, + { + "type": "build", + "name": "mlx-debug-install" + } + ] } ] } diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt new file mode 100644 index 00000000000..00e7c497b1c --- /dev/null +++ b/backends/mlx/CMakeLists.txt @@ -0,0 +1,330 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_compile_options -Wall -Werror -Wno-deprecated-declarations) + +# Sanitizer flags (asan + ubsan) for security hardening — CI only. Enable via: +# cmake --preset mlx-release -DEXECUTORCH_MLX_ENABLE_SANITIZERS=ON +option(EXECUTORCH_MLX_ENABLE_SANITIZERS + "Enable ASan + UBSan for MLX delegate and tests" OFF +) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + list(APPEND _common_compile_options -fsanitize=address,undefined + -fno-omit-frame-pointer + ) + set(_mlx_sanitizer_link_options -fsanitize=address,undefined) +endif() + +# ----------------------------------------------------------------------------- +# Code generation from schema.fbs +# ----------------------------------------------------------------------------- +# +# The generate.py script generates all code from schema.fbs: Python: +# mlx_graph_schema.py, _generated_serializers.py, _generated/ C++: MLXLoader.h, +# MLXLoader.cpp, schema_generated.h +# +# We run generate.py at build time so these files don't need to be checked in. +# ----------------------------------------------------------------------------- + +set(_mlx_generate_script + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/generate.py" +) +set(_mlx_schema_fbs "${CMAKE_CURRENT_SOURCE_DIR}/serialization/schema.fbs") + +# Generated C++ files that we need for compilation +set(_mlx_generated_cpp_files + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/schema_generated.h" + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.h" + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp" +) + +# Generated Python files (tracked for dependency purposes) +set(_mlx_generated_python_files + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/mlx_graph_schema.py" + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/_generated_serializers.py" +) + +# Run generate.py to create all generated files from schema.fbs Find Python - +# prefer Python3_EXECUTABLE if set (from FindPython3), otherwise use +# PYTHON_EXECUTABLE +if(Python3_EXECUTABLE) + set(_python_executable ${Python3_EXECUTABLE}) +elseif(PYTHON_EXECUTABLE) + set(_python_executable ${PYTHON_EXECUTABLE}) +else() + find_package( + Python3 + COMPONENTS Interpreter + REQUIRED + ) + set(_python_executable ${Python3_EXECUTABLE}) +endif() + +add_custom_command( + OUTPUT ${_mlx_generated_cpp_files} ${_mlx_generated_python_files} + COMMAND ${_python_executable} ${_mlx_generate_script} --flatc + $ + WORKING_DIRECTORY ${EXECUTORCH_ROOT} + DEPENDS ${_mlx_schema_fbs} ${_mlx_generate_script} flatc + COMMENT "Generating MLX delegate code from schema.fbs" + VERBATIM +) + +# Custom target to trigger generation +add_custom_target( + mlx_generate_code DEPENDS ${_mlx_generated_cpp_files} + ${_mlx_generated_python_files} +) + +# Interface library for schema includes +add_library(mlx_schema INTERFACE) +add_dependencies(mlx_schema mlx_generate_code) +target_include_directories( + mlx_schema + INTERFACE + $ + $ +) + +# ----------------------------------------------------------------------------- +# MLX dependency (from submodule) +# ----------------------------------------------------------------------------- + +set(MLX_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third-party/mlx) + +# Check that submodule is initialized +if(NOT EXISTS "${MLX_SOURCE_DIR}/CMakeLists.txt") + message( + FATAL_ERROR "MLX submodule not initialized.\n" + "Run: git submodule update --init backends/mlx/third-party/mlx" + ) +endif() + +# Validate deployment target - MLX requires macOS 14.0+ / iOS 17.0+ +# +# The macOS preset uses ios.toolchain.cmake (with PLATFORM=MAC_ARM64), so +# DEPLOYMENT_TARGET is set for both macOS and iOS builds. We check PLATFORM to +# distinguish them rather than relying on which variable is set. +set(_mlx_deployment_target_ok TRUE) +if(PLATFORM AND PLATFORM MATCHES "^MAC") + # macOS build via ios.toolchain.cmake (e.g., MAC_ARM64, MAC_UNIVERSAL) + if(DEPLOYMENT_TARGET) + set(_mlx_dt_value ${DEPLOYMENT_TARGET}) + elseif(CMAKE_OSX_DEPLOYMENT_TARGET) + set(_mlx_dt_value ${CMAKE_OSX_DEPLOYMENT_TARGET}) + endif() + if(_mlx_dt_value AND _mlx_dt_value VERSION_LESS "14.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${_mlx_dt_value}) + set(_mlx_deployment_target_min "14.0") + endif() +elseif(DEPLOYMENT_TARGET) + # iOS/tvOS/watchOS/visionOS builds via ios.toolchain.cmake + if(DEPLOYMENT_TARGET VERSION_LESS "17.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${DEPLOYMENT_TARGET}) + set(_mlx_deployment_target_min "17.0") + endif() +elseif(CMAKE_OSX_DEPLOYMENT_TARGET) + # Plain macOS build (no ios.toolchain.cmake) + if(CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS "14.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${CMAKE_OSX_DEPLOYMENT_TARGET}) + set(_mlx_deployment_target_min "14.0") + endif() +endif() + +if(NOT _mlx_deployment_target_ok) + message( + FATAL_ERROR + "MLX requires deployment target >= ${_mlx_deployment_target_min}, got ${_mlx_deployment_target_value}.\n" + "Either increase the deployment target or disable MLX with -DEXECUTORCH_BUILD_MLX=OFF" + ) +endif() + +# MLX build options - we only need the C++ library with Metal +set(MLX_BUILD_PYTHON_BINDINGS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_TESTS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_EXAMPLES + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_BENCHMARKS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_PYTHON_STUBS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_CUDA + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_CPU + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_METAL + ON + CACHE BOOL "" FORCE +) +set(MLX_BUILD_SHARED_LIBS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_GGUF + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_SAFETENSORS + OFF + CACHE BOOL "" FORCE +) +set(MLX_METAL_JIT + ON + CACHE BOOL "Use JIT compiled Metal kernels" +) + +# Auto-apply patches to MLX submodule. Each patch is applied idempotently: `git +# apply --check` tests whether the patch is still applicable (i.e. not yet +# applied), and only then applies it. +set(_mlx_patches "${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch") +foreach(_patch IN LISTS _mlx_patches) + if(EXISTS "${_patch}" AND EXISTS "${MLX_SOURCE_DIR}") + get_filename_component(_patch_name "${_patch}" NAME) + execute_process( + COMMAND git apply --check "${_patch}" + WORKING_DIRECTORY ${MLX_SOURCE_DIR} + RESULT_VARIABLE _patch_check + OUTPUT_QUIET ERROR_QUIET + ) + if(_patch_check EQUAL 0) + execute_process( + COMMAND git apply "${_patch}" WORKING_DIRECTORY ${MLX_SOURCE_DIR} + ) + message(STATUS "Applied ${_patch_name} to MLX submodule") + else() + message(STATUS "${_patch_name} already applied or not applicable") + endif() + endif() +endforeach() + +# Add MLX subdirectory +message(STATUS "Adding MLX from submodule: ${MLX_SOURCE_DIR}") +add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx) + +# ----------------------------------------------------------------------------- +# MLX Backend library +# ----------------------------------------------------------------------------- + +# Op logging option (for debugging) - OFF by default for performance +option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF) + +set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp +) + +add_library(mlxdelegate ${_mlx_backend__srcs}) + +# Ensure schema is generated before compiling +add_dependencies(mlxdelegate mlx_schema) + +# Add logging flag if enabled +if(ET_MLX_ENABLE_OP_LOGGING) + target_compile_definitions(mlxdelegate PRIVATE ET_MLX_ENABLE_OP_LOGGING=1) + message(STATUS "MLX delegate op logging ENABLED") +endif() + +target_include_directories( + mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime +) + +# Link against MLX and executorch mlx is only available at BUILD_INTERFACE - +# consumers must link to mlx separately +target_link_libraries( + mlxdelegate PRIVATE mlx_schema executorch_core $ +) + +executorch_target_link_options_shared_lib(mlxdelegate) +target_compile_options(mlxdelegate PRIVATE ${_common_compile_options}) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options(mlxdelegate PRIVATE ${_mlx_sanitizer_link_options}) +endif() + +install( + TARGETS mlxdelegate mlx_schema + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# Install mlx library for downstream consumers +install(TARGETS mlx DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +# Install mlx headers for downstream consumers that may need mlx types +install( + DIRECTORY ${MLX_SOURCE_DIR}/mlx/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/mlx + FILES_MATCHING + PATTERN "*.h" +) + +# Install mlx.metallib (Metal GPU kernels) for runtime execution +# +# MLX searches for metallib in this order (see mlx/backend/metal/device.cpp): 1. +# {binary_dir}/mlx.metallib - colocated with the .so/.dylib 2. +# {binary_dir}/Resources/mlx/ - Resources subdirectory 3. SwiftPM bundle - +# not applicable for us 4. {binary_dir}/Resources/default/ - Resources +# subdirectory 5. METAL_PATH (compile-time) - hardcoded build path (won't +# exist) +# +# where {binary_dir} is determined at runtime via dladdr() on the library +# containing MLX code. When MLX is statically linked into _portable_lib.so, this +# is the directory containing _portable_lib.so. +# +# For the installed library, we put metallib in lib/ alongside libmlx.a +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# Cache the metallib path for pybindings to copy it next to _portable_lib.so +# This enables editable installs to work correctly +set(MLX_METALLIB_PATH + "${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib" + CACHE INTERNAL "Path to mlx.metallib for pybindings" +) + +# ----------------------------------------------------------------------------- +# Tests (off by default; CI passes -DEXECUTORCH_BUILD_TESTS=ON) +# ----------------------------------------------------------------------------- + +if(EXECUTORCH_BUILD_TESTS) + add_subdirectory(test) +endif() diff --git a/backends/mlx/README.md b/backends/mlx/README.md new file mode 100644 index 00000000000..ebab893385a --- /dev/null +++ b/backends/mlx/README.md @@ -0,0 +1,499 @@ +# MLX Delegate for ExecuTorch + +The MLX delegate compiles PyTorch models to run on Apple Silicon GPUs via the +[MLX](https://github.com/ml-explore/mlx) framework. It consists of: + +- A **Python compilation pipeline** that converts ExportedPrograms (Edge IR) into + a custom FlatBuffer bytecode format. +- A **C++ runtime** that loads the bytecode and executes it using MLX GPU + primitives. + +> **Adding a new op?** Jump to [How to Add a New Op](#how-to-add-a-new-op). + +## Getting Started + +The MLX delegate requires **Apple Silicon** (M1 or later) and the **Metal +compiler**, which ships with Xcode (not the standalone Command Line Tools). + +**Check if Metal is available:** + +```bash +xcrun -sdk macosx --find metal +``` + +If this prints a path (e.g. `/Applications/Xcode.app/.../metal`), you're set. +If it errors, you either need to install Xcode from the +[App Store](https://apps.apple.com/us/app/xcode/id497799835) or +, or — if Xcode is already installed but the +command line developer directory points at Command Line Tools — switch it: + +```bash +sudo xcode-select -s /Applications/Xcode.app/Contents/Developer +``` + +### Python (pybindings) + +The simplest way to get started is to install ExecuTorch with Python bindings. +From the repo root: + +```bash +python install_executorch.py +``` + +This builds and installs the `executorch` pip package with pybindings. On Apple +Silicon, when the Metal compiler is available, the MLX backend is automatically +included. You can then export models in Python using the MLX partitioner and run +them via the ExecuTorch Python API. + +### C++ (CMake preset) + +To build the C++ runtime with the MLX delegate, use the `mlx-release` CMake +workflow preset from the repo root: + +```bash +cmake --workflow --preset mlx-release +``` + +This configures and builds a Release build of the ExecuTorch runtime with the +MLX delegate and installs artifacts into `cmake-out/`. The preset enables the +MLX delegate along with commonly needed extensions (module, data loader, flat +tensor, LLM runner, etc.). + +Downstream C++ apps can then `find_package(executorch)` and link against +`mlxdelegate` and `mlx`. See +[`examples/models/llama/CMakeLists.txt`](../../examples/models/llama/CMakeLists.txt) +for a working example. + +There is also an `mlx-debug` preset that enables debug symbols and compiles in +per-op logging support, which is useful during development: + +```bash +cmake --workflow --preset mlx-debug +``` + +The debug build compiles in the logging code, but to actually see per-op output +you must also set the environment variable when running the binary: + +```bash +ET_MLX_ENABLE_OP_LOGGING=1 ./cmake-out/my_app +``` + +### Debugging + +Set `ET_MLX_DEBUG=1` during AOT (export/compilation) to see detailed debug +logging from the partitioner and preprocessor — including ops-to-not-decompose +lists, graph dumps, per-node support decisions, and serialization details: + +```bash +ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llm_hf ... +``` + +--- + +## Directory Layout + +``` +backends/mlx/ +├── serialization/ # Schema + code generation +│ ├── schema.fbs # ← Source of truth (FlatBuffer schema) +│ ├── generate.py # Code generator (schema.fbs → everything else) +│ ├── mlx_graph_schema.py # [GENERATED] Python dataclasses for IR nodes +│ ├── mlx_graph_serialize.py # Serialization to FlatBuffer binary +│ ├── _generated_serializers.py # [GENERATED] Per-op FlatBuffer builders +│ └── _generated/ # [GENERATED] FlatBuffer Python bindings (flatc) +├── runtime/ # C++ runtime (loaded at inference time) +│ ├── MLXBackend.cpp # BackendInterface (init / execute / destroy) +│ ├── MLXLoader.h/.cpp # [GENERATED] FlatBuffer → C++ structs +│ ├── MLXExecutor.h # ExecutionState, constant loading, helpers +│ ├── MLXInterpreter.h # Op dispatch loop + per-op exec_* functions +│ └── schema_generated.h # [GENERATED] FlatBuffer C++ bindings (flatc) +├── llm/ # LLM infrastructure (KV cache, attention, etc.) +│ ├── cache.py # KV cache implementations (ET + HF static cache) +│ ├── et_attention.py # ExecuTorch custom SDPA attention +│ ├── hf_attention.py # HuggingFace custom SDPA attention +│ ├── quantization.py # TorchAO quantization helpers +│ └── source_transformation.py # Source transforms for MLX export +├── _generated_inspector.py # [GENERATED] Inspector utilities for .pte debugging +├── _logging.py # Debug logging utilities (ET_MLX_DEBUG) +├── builder/ # Core build infrastructure +│ ├── op_registry.py # REGISTRY (op handler registration) +│ ├── op_helpers.py # Helper utilities for op handlers +│ ├── pattern_matcher.py # Pattern matching for multi-node fusions +│ ├── program_builder.py # MLXProgramBuilder +│ └── slot_manager.py # Tensor/value slot allocation +├── ops.py # Op handlers (ATen target → MLX IR node) +├── patterns.py # Pattern handlers (multi-node fusions) +├── passes.py # Graph passes (RMSNorm fusion, CSE, etc.) +├── pattern_utils.py # Pattern matching utilities for passes +├── partitioner.py # Decides which ops to delegate to MLX +├── preprocess.py # BackendDetails.preprocess() entry point +├── custom_ops.py # Custom torch ops (kv_cache_update, custom_sdpa, rope) +├── pte_inspector.py # .pte file inspection/debugging tool +├── test/ +│ ├── test_ops.py # Op test definitions (models + configs) +│ ├── test_utils.py # OpTestCase base class + helpers +│ ├── op_test_runner.cpp # C++ test runner (loads .pte, runs, compares) +│ └── run_all_tests.py # End-to-end: export → C++ run → compare +└── examples/ + ├── llm/ # LLM export + run via HuggingFace + └── whisper/ # Whisper export + run +``` + +Files marked **[GENERATED]** are NOT CHECKED IN CODE and are produced by running: + +```bash +python backends/mlx/serialization/generate.py +``` + +--- + +## Compilation Pipeline + +The compilation pipeline converts a PyTorch model into a `.pte` file containing +the MLX delegate payload. The high-level flow: + +``` +torch.export() → ExportedProgram (ATen IR) +to_edge_transform_and_lower() → Edge IR + partitioning + lowering +``` + +Within that flow, the MLX-specific steps are: + +1. **Partitioning** (`partitioner.py`) — `MLXPartitioner` walks the Edge IR + graph and tags nodes that MLX can handle. It uses `MLXProgramBuilder` in a + dry-run mode to determine support — so partitioning and compilation use the + exact same logic. Unsupported ops fall back to ExecuTorch's portable + runtime. + +2. **Preprocessing** (`preprocess.py`) — For each partitioned subgraph, + `MLXBackend.preprocess()` is called. It builds an `MLXGraph` via + `MLXProgramBuilder`, serializes it to FlatBuffer, and returns a + `PreprocessResult` with the binary payload and constant data. + +3. **Op handling** (`ops.py`, `patterns.py`) — During the build, + `MLXProgramBuilder` walks the FX graph node-by-node and dispatches to + registered handlers. Single-op handlers live in `ops.py`; multi-node fused + patterns (e.g., quantized linear, SDPA, KV cache update) live in + `patterns.py`. + +4. **Serialization** (`serialization/`) — The `MLXGraph` dataclass tree is + serialized to a FlatBuffer binary. See [Serialization](#serialization) below. + +The complete preprocessing flow: + +``` +ExportedProgram (subgraph) + → MLXProgramBuilder.build() # walks FX graph, calls op handlers + → MLXGraph # Python IR (dataclasses from mlx_graph_schema.py) + → MLXGraphSerializer.serialize() # FlatBuffer binary + → PreprocessResult # returned to ExecuTorch +``` + +--- + +## How to Add a New Op + +This section walks through adding a new op end-to-end, using **`aten.linear`** +as an example. + +### Step 1: Add the Node to `schema.fbs` + +Add a new table in the "Op nodes" section and add it to the `OpNode` union: + +```fbs +table LinearNode { + x: Tid (required); + weight: Tid (required); + out: Tid (required); + bias: Tid; // optional +} +``` + +Then add `LinearNode` to the `union OpNode { ... }` list. + +### Step 2: Run the Code Generator + +```bash +python backends/mlx/serialization/generate.py +``` + +This regenerates: + +- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass +- `_generated_serializers.py` — adds `_build_LinearNode` serializer +- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader +- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization +- `runtime/schema_generated.h` — FlatBuffer C++ bindings + +### Step 3: Add the Python Op Handler (`ops.py`) + +Register a handler that converts the ATen op to your new node. Make sure to +import `LinearNode` from `mlx_graph_schema`: + +```python +from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode + +@REGISTRY.register(target=[torch.ops.aten.linear.default]) +def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.linear") + require_kwargs(P.kwargs(n), set(), "aten.linear") + x, w = args[0], args[1] + b = args[2] if len(args) > 2 else None + out = P.make_or_get_slot(n) + P.emit( + LinearNode( + x=P.slot_to_tid(x), + weight=P.slot_to_tid(w), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(b) if b else None, + ) + ) + return out +``` + +Key APIs: +- **`P.args(n)`** — resolves FX node args to `Slot` objects (tensor/value references) +- **`P.make_or_get_slot(n)`** — allocates the output tensor slot +- **`P.slot_to_tid(slot)`** — converts a `Slot` to a `Tid` for the IR node +- **`P.emit(node)`** — appends the instruction to the graph + +### Step 4: Add the C++ Op Handler (`MLXInterpreter.h`) + +Add an `exec_*` function in the `ops` namespace: + +```cpp +inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& X = st.const_tensor_ref(n.x); + auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s); + array Y = n.bias + ? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s) + : matmul(X, W, s); + st.set_tensor(n.out, std::move(Y)); +} +``` + +Then add the dispatch case in `Interpreter::execute_instruction()`: + +```cpp +case OpCode::LINEAR: + ops::exec_linear(std::get(instr.node), st, s); + break; +``` + +### Step 5: Write a Test (`test/test_ops.py`) + +Each test follows a standard pattern: + +1. **Define a `nn.Module`** that uses the op. +2. **Define an `OpTestCase` subclass** that specifies test configurations. +3. **Decorate with `@register_test`** to register it with the test runner. + +```python +class LinearModel(nn.Module): + def __init__(self, in_features=64, out_features=128, bias=True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +@register_test +class LinearTest(OpTestCase): + name = "linear" + rtol = 1e-4 + atol = 1e-4 + + def __init__(self, in_features=64, out_features=128, bias=True): + self.in_features = in_features + self.out_features = out_features + self.bias = bias + + @classmethod + def get_test_configs(cls): + return [cls(), cls(bias=False)] + + def create_model(self): + return LinearModel(self.in_features, self.out_features, bias=self.bias) + + def create_inputs(self): + return (torch.randn(2, 16, self.in_features),) +``` + +### Step 6: Run Tests + +Tests are end-to-end: export `.pte` → run via C++ `op_test_runner` → compare +outputs against PyTorch reference. Since adding a new op always involves C++ +changes, use `--rebuild` to recompile the runtime: + +```bash +python -m executorch.backends.mlx.test.run_all_tests --rebuild linear +``` + +Run all tests in parallel: + +```bash +python -m executorch.backends.mlx.test.run_all_tests --rebuild -j4 --clean-after +``` + +Other useful flags: + +| Flag | Purpose | +|---|---| +| `--rebuild` | Rebuild the C++ `op_test_runner` before running | +| `-j N` / `--parallel N` | Run N tests in parallel | +| `--clean-after` | Remove generated test artifacts after running | +| `--list` | List all available test names and exit | +| `-v` / `--verbose` | Verbose output | + +Test artifacts are saved to `test/op_tests//` (`.pte`, input/output +`.bin` files). See [`test/README.md`](test/README.md) for full details on test +architecture, prerequisites, and the `OpTestCase` API. + +### Checklist + +- [ ] Add `*Node` table to `schema.fbs` + add to `OpNode` union +- [ ] Run `python backends/mlx/serialization/generate.py` +- [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class) +- [ ] Add `exec_*` function in `runtime/MLXInterpreter.h` +- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()` +- [ ] Add test model + `OpTestCase` in `test/test_ops.py` +- [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild ` + +--- + +## Serialization + +### Overview + +The serialization system converts a Python `MLXGraph` dataclass tree into a +FlatBuffer binary that the C++ runtime can load. The source of truth is +**`schema.fbs`** — a single FlatBuffer schema file from which all code on both +sides is generated. + +### Schema (`schema.fbs`) + +The schema defines: + +| Concept | FlatBuffer type | Purpose | +|---|---|---| +| **`Tid`** | struct | Tensor slot index (indexes into the runtime tensor array) | +| **`Vid`** | struct | Value slot index (for scalar `int32`/`float`/`bool` values) | +| **`IntOrVid`** | table | A field that is either a literal `int64` or a runtime `Vid` reference (for dynamic shapes) | +| **`FloatOrVid`** | table | Same idea for floats | +| **`TidOrVid`** | table | Either a tensor or a scalar value | +| **Op node tables** | table | One per op (e.g. `AddNode`, `SiluNode`, `ReshapeNode`). Each declares its inputs/outputs as `Tid`/`Vid` references and any scalar parameters. | +| **`OpNode`** | union | Union of all op node tables | +| **`Instruction`** | table | Wraps an `OpNode` union | +| **`MLXGraph`** | table (root) | The complete program: slot counts, instruction list, I/O maps, named slots, tensor metadata | + +Key design points: + +- **No embedded weights.** Constants are stored in ExecuTorch's `named_data_map` + and loaded by name at runtime. This enables zero-copy on unified memory. +- **Tensor IDs (`Tid`) are globally ordered:** Constants → Inputs → Outputs → + Mutable Buffers → Temps. The runtime uses this ordering for O(1) type lookup. +- **Dynamic shapes** are supported via `IntOrVid` — a shape dimension can be + either a literal integer or a reference to a runtime value produced by + `sym_size` / `item()` ops. + +### Code Generation (`generate.py`) + +`generate.py` parses `schema.fbs` and generates **all** boilerplate on both the +Python and C++ sides: + +| Generated file | What it contains | +|---|---| +| `mlx_graph_schema.py` | Python `@dataclass` for every op node, `Tid`, `Vid`, `IntOrVid`, etc. | +| `_generated_serializers.py` | `GeneratedOpBuilders` mixin class with `_build_*Node` methods for every op | +| `_generated_inspector.py` | Inspector utilities for debugging `.pte` files | +| `runtime/MLXLoader.h` | C++ structs for every op node, `OpCode` enum, `NodeVariant`, `Instruction`, `MLXProgram` | +| `runtime/MLXLoader.cpp` | `load_instruction()` and `load_program()` — FlatBuffer → C++ struct conversion | +| `runtime/schema_generated.h` | Standard FlatBuffer C++ bindings (via `flatc`) | +| `_generated/` directory | Standard FlatBuffer Python bindings (via `flatc`) | + +Running the generator: + +```bash +python backends/mlx/serialization/generate.py +``` + +Use `--skip-flatc` if you only changed op node definitions (not core types) and +want to skip the `flatc` invocation. + +### Serialization Format + +The binary payload embedded in the `.pte` file has this layout: + +``` +[Header: 24 bytes] + 4 bytes padding (zeros) + 4 bytes magic ("MLX0") + 8 bytes data_segment_offset (uint64 LE) + 8 bytes data_segment_size (uint64 LE) +[FlatBuffer payload] +[Padding to 16-byte alignment] +[Data segment (currently unused — constants go via named_data_map)] +``` + +The `MLXGraphSerializer` class (in `mlx_graph_serialize.py`) drives +serialization. It inherits `GeneratedOpBuilders` for the per-op builders and +adds the root-table construction, I/O maps, tensor metadata, and header. + +--- + +## Runtime + +### Initialization (`init`) + +When ExecuTorch loads a `.pte` with an MLX delegate blob, `MLXBackend::init()` +is called: + +1. **Parse FlatBuffer** — `loader::load_program()` deserializes the binary into + an `MLXProgram` struct (C++ mirrors of the schema). +2. **Load constants** — Iterates `named_slots`, calls + `named_data_map->get_data(name)` for each constant tensor, wraps the buffer + as an `mlx::core::array` (zero-copy when possible on unified memory). +3. **Initialize mutable buffers** — Creates zero-filled MLX arrays for + persistent state (e.g., KV cache). These live across `execute()` calls. +4. **Bind execution state** — `ExecutionState::bind()` pre-computes tensor ID + ranges for O(1) routing. + +### Execution (`execute`) + +Each `execute()` call: + +1. **Reset** per-execution state (inputs/outputs/temps cleared; mutable buffers + and constants are retained). +2. **Bind inputs** — Walk `input_map`, convert each ExecuTorch tensor to an + `mlx::core::array` (zero-copy pointer wrap). +3. **Run instructions** — `Interpreter::run()` dispatches each `Instruction` + through a `switch` on `OpCode`, calling the corresponding `exec_*` function. +4. **Evaluate** — Call `mlx::core::eval()` on output tensors to trigger + lazy GPU computation. +5. **Copy outputs** — Convert MLX arrays back to ExecuTorch tensors via + `memcpy`. + +### Tensor ID Layout + +Tensor slot IDs are assigned in a fixed order during compilation: + +``` + ┌──────────┬──────────┬──────────┬────────────────┬──────────┐ + │ Constants│ Inputs │ Outputs │ Mutable Buffers│ Temps │ + │ 0..C-1 │ C..I-1 │ I..O-1 │ O..M-1 │ M..T-1 │ + └──────────┴──────────┴──────────┴────────────────┴──────────┘ +``` + +The runtime stores constants and mutable buffers in separate containers +(`ConstantData`, `MutableBufferData`). Inputs, outputs, and temps share a flat +`vector>` in `ExecutionState`. + +### Key Runtime Files + +| File | Role | +|---|---| +| `MLXBackend.cpp` | `init()` / `execute()` / `destroy()` — the ExecuTorch `BackendInterface` | +| `MLXLoader.h/.cpp` | [GENERATED] Deserializes FlatBuffer into `MLXProgram` (C++ structs) | +| `MLXExecutor.h` | `ExecutionState`, `ConstantData`, `MutableBufferData`, constant loading, dtype conversion | +| `MLXInterpreter.h` | The op dispatch switch + all `exec_*` implementations | diff --git a/backends/mlx/__init__.py b/backends/mlx/__init__.py new file mode 100644 index 00000000000..48f4c28f5ca --- /dev/null +++ b/backends/mlx/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""MLX backend for ExecuTorch - executes models on Apple Silicon using MLX.""" + +# Import custom_ops module to register custom ATen ops (rope, etc.) +from executorch.backends.mlx import custom_ops as _custom_ops # noqa: F401 +from executorch.backends.mlx.partitioner import MLXPartitioner + +from executorch.backends.mlx.preprocess import MLXBackend + +__all__ = ["MLXBackend", "MLXPartitioner"] diff --git a/backends/mlx/_logging.py b/backends/mlx/_logging.py new file mode 100644 index 00000000000..eff472550f9 --- /dev/null +++ b/backends/mlx/_logging.py @@ -0,0 +1,40 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Centralized logging for the MLX backend. + +Usage: + from executorch.backends.mlx._logging import logger + + logger.info("Always visible (e.g., unsupported ops summary)") + logger.debug("Only visible when ET_MLX_DEBUG=1") + logger.warning("Always visible") + +The logger is set to INFO by default, so logger.info() always prints. +Set ET_MLX_DEBUG=1 to lower the threshold to DEBUG for verbose output +(graph dumps, per-node traces, ops_to_not_decompose lists, etc.). +""" + +import logging +import os + +_MLX_DEBUG = os.environ.get("ET_MLX_DEBUG", "") not in ("", "0") + +logger = logging.getLogger("executorch.backends.mlx") +logger.setLevel(logging.DEBUG if _MLX_DEBUG else logging.INFO) +logger.propagate = False + +if not logger.handlers: + _handler = logging.StreamHandler() + _handler.setFormatter( + logging.Formatter( + "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" + ) + ) + logger.addHandler(_handler) diff --git a/backends/mlx/builder/__init__.py b/backends/mlx/builder/__init__.py new file mode 100644 index 00000000000..ce793ed9a15 --- /dev/null +++ b/backends/mlx/builder/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +# Trigger op/pattern handler registration. +# ops.py and patterns.py use @REGISTRY.register() decorators at import time. +# This must happen after REGISTRY is defined (in op_registry.py). +from executorch.backends.mlx import ops, patterns # noqa: F401 +from executorch.backends.mlx.builder.op_registry import REGISTRY # noqa: F401 +from executorch.backends.mlx.builder.program_builder import ( # noqa: F401 + MLXProgramBuilder, +) diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py new file mode 100644 index 00000000000..5e082cdf386 --- /dev/null +++ b/backends/mlx/builder/op_helpers.py @@ -0,0 +1,275 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.exir.scalar_type import ScalarType +from torch.fx.node import Node + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + +def get_aten_target(target): + """ + Unwrap EdgeOpOverload to get the underlying ATen op. + + In Edge IR, ops are wrapped in EdgeOpOverload. This extracts the + underlying ATen op for consistent comparison. + """ + if hasattr(target, "_op") and "EdgeOpOverload" in type(target).__name__: + return target._op + return target + + +# Mapping from _copy variants to their non-copy equivalents. +# Edge IR uses _copy variants for certain ops, but for pattern matching +# we want to compare against the semantic operation. +_COPY_TO_NON_COPY = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + torch.ops.aten.transpose_copy.int: torch.ops.aten.transpose.int, + torch.ops.aten.view_copy.default: torch.ops.aten.view.default, + torch.ops.aten.permute_copy.default: torch.ops.aten.permute.default, + torch.ops.aten.unsqueeze_copy.default: torch.ops.aten.unsqueeze.default, + torch.ops.aten.squeeze_copy.dim: torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dims: torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze_copy.default: torch.ops.aten.squeeze.default, + torch.ops.aten.expand_copy.default: torch.ops.aten.expand.default, + torch.ops.aten.alias_copy.default: torch.ops.aten.alias.default, +} + + +def get_aten_target_normalized(target): + """ + Get ATen target, mapping _copy variants to their non-copy equivalents. + + Use this for pattern matching where Edge IR uses _copy variants but + we want to match the semantic operation. + + E.g., aten.transpose_copy.int -> aten.transpose.int + """ + target = get_aten_target(target) + return _COPY_TO_NON_COPY.get(target, target) + + +def emit_stop_position( + P: "MLXProgramBuilder", + start: "Union[int, Slot]", + length_tensor: "Slot", + length_dim: int, + length_meta: "Optional[torch.Tensor]" = None, +) -> "Union[int, Slot]": + """ + Emit nodes to compute stop = start + length for slice operations. + + May emit SymSizeNode and/or AddIntNode depending on whether + start and length are static or dynamic. + + Args: + P: The program builder + start: Start position (int or Slot) + length_tensor: The tensor slot whose dimension gives the length + length_dim: Which dimension of length_tensor contains the length + length_meta: Optional tensor metadata for static length extraction + + Returns: + stop position as int (if fully static) or Slot (if any dynamic) + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + IntOrVid, + SymSizeNode, + ) + + # Check if seq_len is symbolic (dynamic) + seq_len_is_symbolic = False + seq_len_concrete = None + + if length_meta is not None: + seq_len_dim = length_meta.shape[length_dim] + if hasattr(seq_len_dim, "node"): + seq_len_is_symbolic = True + else: + seq_len_concrete = int(seq_len_dim) + + if seq_len_is_symbolic or length_meta is None: + # Dynamic seq_len: emit SymSizeNode to get length at runtime + _, seq_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(length_tensor), + dim=length_dim, + out=P.slot_to_vid(seq_len_slot), + ) + ) + _, stop_slot = P.slot_manager.make_tmp_value_slot() + if isinstance(start, Slot): + start_iov = P.to_int_or_vid(start) + else: + start_iov = IntOrVid.from_literal(int(start)) + P.emit( + AddIntNode( + a=start_iov, + b=IntOrVid.from_vid(P.slot_to_vid(seq_len_slot)), + out=P.slot_to_vid(stop_slot), + ) + ) + return stop_slot + else: + # Static seq_len + if isinstance(start, Slot): + # Dynamic start + static length + _, stop_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(start), + b=IntOrVid.from_literal(seq_len_concrete), + out=P.slot_to_vid(stop_slot), + ) + ) + return stop_slot + else: + # Both static - just return the sum + return start + seq_len_concrete + + +def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> Slot: + """Lift a scalar to a 0-D tensor. + + Concrete scalars (int/float/bool) become deduplicated constants. + Dynamic values (SymInt Slots) emit a FullNode at runtime. + """ + + if isinstance(value, (int, float, bool)): + return P.make_or_get_constant( + f"_scalar_{value}", torch.tensor(value, dtype=dtype) # 0-D + ) + + from executorch.backends.mlx.serialization.mlx_graph_schema import FullNode + + _, slot = P.make_tmp_slot() + P.emit( + FullNode( + shape=[], + v=P.to_float_or_vid(value), + scalar_type=torch_dtype_to_scalar_type(dtype), + out=P.slot_to_tid(slot), + ) + ) + return slot + + +def to_mlx_qparams( + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + bits: int, + compute_biases: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Convert TorchAO quantization params to MLX format. + + TorchAO uses: s * (q - z), with q signed + MLX uses: S * Q + B, with Q unsigned + + s * (q - z) + = s ((q + offset) - (z + offset)) + = s Q + B, + where Q = q + offset, B = -s * (z + offset) + + Args: + compute_biases: If False, skip bias computation (for scale_only mode). + Returns (Q, None) in this case. This is valid when + zero_point is all zeros, as the C++ runtime will compute + biases = -scales * 2^(bits-1). + """ + assert qdata.dtype == torch.int8 + offset = 2 ** (bits - 1) + Q = qdata.to(torch.int32) + offset + + # Pack data tightly into uint32 + assert 32 % bits == 0 + vals_per_uint32 = 32 // bits + assert qdata.shape[1] % vals_per_uint32 == 0 + + Q = Q.reshape(-1, vals_per_uint32) + shifts = torch.arange(0, 32, bits, dtype=torch.int64) + + # Convert to int64 for shift/packing + Q = Q.to(torch.int64) + Q = (Q << shifts).sum(dim=-1) + Q = Q.to(torch.uint32) + Q = Q.reshape(qdata.shape[0], -1) + + if compute_biases: + B = -scale * (zero_point.to(scale.dtype) + offset) + return Q, B + else: + return Q, None + + +def parse_dequant_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: + """Parse a torchao.dequantize_affine node. + + Accepts N-dimensional block_size with a single non-1 element identifying + the quantized dimension and group_size. For example: + - Linear weights (2D): block_size=[1, 32] → quantized_dim=1 + - Conv2d weights (4D): block_size=[1, 32, 1, 1] → quantized_dim=1 + + Returns (qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim) + or None if unsupported. + """ + qdata, block_size, scale, zero_point, dtype, qmin, qmax = node.args[0:7] + out_dtype = ( + node.args[7] if len(node.args) > 7 else node.kwargs.get("output_dtype", None) + ) + if dtype != torch.int8: + return None + if len(block_size) < 2: + return None + non_one = [(i, d) for i, d in enumerate(block_size) if d != 1] + if len(non_one) != 1: + return None + quantized_dim, group_size = non_one[0] + if group_size not in [32, 64, 128]: + return None + if qmin == -8 and qmax == 7: + bits = 4 + elif qmin == -128 and qmax == 127: + bits = 8 + else: + return None + return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim + + +# Mapping from torch dtype to ET ScalarType int value +# See executorch/exir/scalar_type.py for ScalarType enum +_TORCH_DTYPE_TO_SCALAR_TYPE: Dict[torch.dtype, int] = { + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.bfloat16: ScalarType.BFLOAT16, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.uint32: ScalarType.UINT32, + torch.uint8: ScalarType.BYTE, + torch.bool: ScalarType.BOOL, + torch.int8: ScalarType.CHAR, +} + + +def torch_dtype_to_scalar_type(dtype: torch.dtype) -> int: + """Convert torch dtype to ET ScalarType int value.""" + if dtype not in _TORCH_DTYPE_TO_SCALAR_TYPE: + raise ValueError(f"Unsupported dtype: {dtype}") + return int(_TORCH_DTYPE_TO_SCALAR_TYPE[dtype]) diff --git a/backends/mlx/builder/op_registry.py b/backends/mlx/builder/op_registry.py new file mode 100644 index 00000000000..19668ca2c1b --- /dev/null +++ b/backends/mlx/builder/op_registry.py @@ -0,0 +1,151 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union + +from executorch.backends.mlx._logging import logger +from torch.fx.node import Node + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.backends.mlx.builder.slot_manager import Slot + from torch.export import ExportedProgram + +# Handler type: takes (builder, node) and returns optional slot(s) +Handler = Callable[ + ["MLXProgramBuilder", Node], Optional[Union["Slot", Tuple["Slot", ...]]] +] + + +class PatternHandler: + def __init__(self, head: Node, body: List[Node]) -> None: + self.head: Node = head + self.body: List[Node] = body + + @classmethod + def deferred_handler(cls, P: MLXProgramBuilder, n: Node) -> None: + pass + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node) -> Optional[PatternHandler]: + raise NotImplementedError + + def __call__(self, P: MLXProgramBuilder, n: Node) -> None: + raise NotImplementedError + + def set_handlers(self, P: MLXProgramBuilder): + if P.node_info[self.head].handler is not None: + raise AssertionError( + f"Head node {self.head.name} already has handler {P.node_info[self.head].handler}, " + f"cannot set pattern {self.__class__.__name__}" + ) + for n in self.body: + if P.node_info[n].handler is not None: + raise AssertionError( + f"Body node {n.name} already has handler {P.node_info[n].handler}, " + f"cannot set pattern {self.__class__.__name__}" + ) + + logger.debug( + f"Pattern {self.__class__.__name__}: " + f"HEAD={self.head.name}, BODY={[n.name for n in self.body]}" + ) + P.node_info[self.head].handler = self + for n in self.body: + P.node_info[n].handler = PatternHandler.deferred_handler + + +class MLXOpRegistry: + """Registry for op handlers and pattern handlers.""" + + def __init__(self): + self._handlers: Dict[Union[str, Callable], Handler] = {} + self._patterns: Dict[str, Type[PatternHandler]] = {} + + def reset(self) -> None: + """Reset the registry to empty state. Useful for testing.""" + self._handlers.clear() + self._patterns.clear() + + def register(self, target: Union[str, Callable, list, tuple]): + """Decorator to register a handler for one or more op targets.""" + + def deco(fn: Handler): + targets = target if isinstance(target, (list, tuple)) else [target] + for t in targets: + if t in self._handlers: + raise ValueError(f"Target {t} already registered") + self._handlers[t] = fn + return fn + + return deco + + def get_handler(self, node: Node) -> Optional[Handler]: + """Get the handler for a node, or None if not registered.""" + t = node.target + if t in self._handlers: + return self._handlers[t] + # Handle EdgeOpOverload by extracting the underlying ATen op + if hasattr(t, "_op") and t._op in self._handlers: + return self._handlers[t._op] + # Check for string-based targets (e.g., higher_order ops) + target_str = str(t) + if target_str in self._handlers: + return self._handlers[target_str] + return None + + def registered_ops(self) -> set: + """Return all registered op targets.""" + return set(self._handlers.keys()) + + def unregister(self, target: Union[str, Callable, list, tuple]) -> None: + """Remove a handler for one or more op targets. + + This is useful for debugging - allows temporarily disabling specific + handlers to test if they are causing issues. + + Args: + target: Single target or list of targets to unregister + """ + targets = target if isinstance(target, (list, tuple)) else [target] + for t in targets: + if t in self._handlers: + del self._handlers[t] + + def register_pattern(self, name: str): + """Decorator to register a pattern handler class.""" + + def deco(cls: Type[PatternHandler]): + if not issubclass(cls, PatternHandler): + raise TypeError( + "register_pattern must decorate a PatternHandler subclass" + ) + if name in self._patterns: + raise ValueError(f"Pattern '{name}' already registered") + self._patterns[name] = cls + return cls + + return deco + + def get_pattern_cls(self, name: str) -> Optional[Type[PatternHandler]]: + """Get a pattern handler class by name.""" + return self._patterns.get(name) + + def get_noop_handler(self) -> Optional[Handler]: + """Get the NOOP handler, if registered.""" + return self._handlers.get("NOOP") + + def patterns(self): + """Return all registered pattern names.""" + return self._patterns.keys() + + +# Global registry +REGISTRY = MLXOpRegistry() diff --git a/backends/mlx/builder/pattern_matcher.py b/backends/mlx/builder/pattern_matcher.py new file mode 100644 index 00000000000..2db422e3f68 --- /dev/null +++ b/backends/mlx/builder/pattern_matcher.py @@ -0,0 +1,64 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import List, TYPE_CHECKING + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.op_registry import PatternHandler + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.op_registry import MLXOpRegistry + from torch.export import ExportedProgram + + +class PatternMatcher: + """ + Discovers and applies pattern handlers to an FX graph. + + Pattern handlers match multi-node subgraphs and lower them to optimized + MLX operations. This class orchestrates the pattern discovery process: + + 1. Iterates through all registered pattern types + 2. For each pattern, tries to match it against every node in the graph + 3. When a match is found, assigns handlers to the head and body nodes + + The ordering matters: patterns are matched before dead code elimination + because some pattern body nodes (e.g., update_cache) have no users + since they mutate in-place, but they're not dead. + """ + + def __init__(self, ep: ExportedProgram, registry: "MLXOpRegistry"): + self.ep = ep + self.registry = registry + self._matches: List[PatternHandler] = [] + + def find_patterns(self) -> List[PatternHandler]: + """ + Find all pattern matches in the graph. + + Returns a list of PatternHandler instances, one for each match found. + Patterns are tried in registration order. + """ + self._matches = [] + for name in self.registry.patterns(): + self._find_pattern(name) + return self._matches + + def _find_pattern(self, name: str) -> None: + """Try to match a single pattern type against all nodes.""" + pattern_cls = self.registry.get_pattern_cls(name) + if pattern_cls is None: + return + + for n in self.ep.graph.nodes: + handler = pattern_cls.maybe_create(self.ep, n) + if handler is not None: + logger.debug(f"Pattern {name} matched at node {n.name}") + self._matches.append(handler) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py new file mode 100644 index 00000000000..60d5ebbdbfe --- /dev/null +++ b/backends/mlx/builder/program_builder.py @@ -0,0 +1,1018 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Program Builder - converts an ExportedProgram to an MLXGraph. + +This module is responsible for: +1. Walking the FX graph from an ExportedProgram +2. Converting each node to the corresponding MLX op +3. Managing tensor and value slots +4. Building the final MLXGraph dataclass for serialization + +Op handlers are registered in ops.py. +Pattern handlers are registered in patterns.py. +""" + +from __future__ import annotations + +import traceback +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union + +import torch + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import ( + Handler, + PatternHandler, + REGISTRY, +) +from executorch.backends.mlx.builder.pattern_matcher import PatternMatcher +from executorch.backends.mlx.builder.slot_manager import ( + IdSpace, + IdType, + Slot, + SlotManager, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + FloatOrVid, + IdCopyNode, + Instruction, + InstructionChain, + IntOrVid, + IntOrVidOrTid, + MLXGraph, + NamedSlot, + OpNodeUnion, + ShapeDim, + SlotType, + SlotVariant, + TensorMeta, + Tid, + Vid, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node +from torch.utils import _pytree as pytree + + +def _check_dtype(node: Node) -> Optional[str]: + """ + Check if a node has a supported dtype. + + Args: + node: The FX node to check + + Returns: + None if the node's dtype is supported, otherwise an error message string + """ + fake_val = node.meta.get("val", None) + if fake_val is not None and hasattr(fake_val, "dtype"): + try: + torch_dtype_to_scalar_type(fake_val.dtype) + except ValueError: + return f"has unsupported dtype: {fake_val.dtype}" + return None + + +def _check_input_dtypes(node: Node) -> Optional[str]: + """ + Check if all input tensors to a node have supported dtypes. + + Args: + node: The FX node to check + + Returns: + None if all input dtypes are supported, otherwise an error message string + describing which input (arg position or kwarg name) has an unsupported dtype + """ + # Check positional args + for i, arg in enumerate(node.args): + if isinstance(arg, Node): + dtype_error = _check_dtype(arg) + if dtype_error is not None: + return f"arg[{i}] ({arg.name}) {dtype_error}" + + # Check kwargs + for kwarg_name, kwarg_val in node.kwargs.items(): + if isinstance(kwarg_val, Node): + dtype_error = _check_dtype(kwarg_val) + if dtype_error is not None: + return f"kwarg '{kwarg_name}' ({kwarg_val.name}) {dtype_error}" + + return None + + +@dataclass +class NodeInfo: + handled: bool = False + handler: Optional[Union[Handler, PatternHandler]] = None + supported: bool = False + unsupported_reason: Optional[str] = None + name: Optional[str] = None + remaining_reads: int = 0 + + +class MLXProgramBuilder: + """ + Builds an MLXGraph from an ExportedProgram. + + Args: + ep: The ExportedProgram to build from + """ + + def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): + self.ep: ExportedProgram = ep + self._instrs: List[Instruction] = [] + self.extra_constants: Dict[str, torch.Tensor] = {} + self.slot_manager = SlotManager() + self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) + self._mlx_graph: Optional[MLXGraph] = None + # Map from SymInt symbol names (e.g., "s77") to the FX Node that produces them. + # This is used to resolve symbolic tensor dimensions to Vid references. + self._symint_to_node: Dict[str, Node] = {} + # Maps for remapping local slot indices to global Tid/Vid indices during build + self._tid_slot_map: List[Tuple[Tid, Slot]] = [] + self._vid_slot_map: List[Tuple[Vid, Slot]] = [] + # Prefix for named_data_store keys and named_slots to avoid collisions + # in multi-method programs where different methods may have lifted tensor + # constants with the same auto-generated name. + self._named_data_key_prefix: str = named_data_key_prefix + # Unprefixed canonical-name → Slot for constants, populated by _build_io_maps(). + # Used by get_named_data_store() to look up tensors without prefix interference. + self._constant_name_to_slot: Dict[str, Slot] = {} + + def _prefix_key(self, name: str) -> str: + """Apply the named-data key prefix for the .pte namespace. + + This is the single point where canonical (unprefixed) names are + transformed into the external keys used in the .pte's ``named_data`` + section and the FlatBuffer ``named_slots`` field. + """ + if self._named_data_key_prefix: + return f"{self._named_data_key_prefix}/{name}" + return name + + def emit(self, op: OpNodeUnion) -> None: + self._instrs.append(Instruction(op=op)) + + def args(self, node: Node) -> Tuple[Any, ...]: + return self.slot_map(node.args) + + def kwargs(self, node: Node) -> Dict[str, Any]: + return self.slot_map(node.kwargs) + + def slot_map(self, tree): + leaves, spec = pytree.tree_flatten(tree) + new_leaves = [] + for a in leaves: + if isinstance(a, Node): + # Use make_or_get_slots which handles both single and multi-output nodes. + # For single-output nodes, returns a 1-tuple; for multi-output, returns n-tuple. + # We unwrap single-element tuples for convenience. + slots = self.make_or_get_slots(a) + if len(slots) == 1: + new_leaves.append(slots[0]) + else: + new_leaves.append(slots) + else: + new_leaves.append(a) + + for a in new_leaves: + if isinstance(a, Slot): + assert self.slot_manager.is_alive( + a + ), f"Slot {a} is not alive; it was either already freed or never created" + + return pytree.tree_unflatten(new_leaves, spec) + + def make_or_get_slots( + self, node: Node, id_space: IdSpace = IdSpace.Temp + ) -> Tuple[Slot, ...]: + """Get or create slots for a multi-output node. Always returns a tuple.""" + return self.slot_manager.make_or_get_slots(node, id_space) + + def make_or_get_slot(self, node: Node, id_space: IdSpace = IdSpace.Temp) -> Slot: + """Get or create a slot for a single-output node. Returns a single Slot.""" + return self.slot_manager.make_or_get_slot(node, id_space) + + def set_slot(self, node: Node, slot: Slot): + self.slot_manager.set_slot(node, slot) + + def make_tmp_slot(self) -> Tuple[str, Slot]: + """Create a temporary tensor slot.""" + return self.slot_manager.make_tmp_slot() + + def make_tmp_value_slot(self) -> Tuple[str, Slot]: + """Create a temporary value (SymInt) slot.""" + return self.slot_manager.make_tmp_value_slot() + + def make_or_get_constant(self, name: str, tensor: torch.Tensor) -> Slot: + """ + Creates an extra constant outside of the ExportedProgram state_dict. + Ops can use this to add constants during build that do not exist in the + ExportedProgram state_dict, e.g., doing naive packing of quantized ops. + """ + assert name not in self.ep.state_dict + assert name not in self.ep.constants + + if name in self.extra_constants: + # During fake tensor tracing, we can't use torch.equal + # Just assume tensors with same name are the same + slot = self.slot_manager.get_slot(name) + assert slot is not None + return slot + + slot = self.slot_manager.make_constant_slot(name) + self.extra_constants[name] = tensor + return slot + + def get_placeholder_target_and_tensor(self, node: Node) -> Tuple[str, torch.Tensor]: + assert node.op == "placeholder" + placeholder_name = node.name + + sig = self.ep.graph_signature + sd = self.ep.state_dict + consts = self.ep.constants + + for ispec in sig.input_specs: + if ispec.arg.name != placeholder_name: + continue + target = ispec.target + if target is None: + continue + if target in sd: + return (target, sd[target]) + if target in consts: + return (target, consts[target]) + + raise KeyError(f"Unable to resolve placeholder {placeholder_name}") + + def slot_to_tid(self, slot: Slot) -> Tid: + """Convert a tensor Slot to a Tid, recording it for later remapping.""" + assert slot.id_type == IdType.Tensor + # Use local slot.idx as placeholder - will be remapped to global idx in build() + tid = Tid(idx=slot.idx) + self._tid_slot_map.append((tid, slot)) + return tid + + def slot_to_vid(self, slot: Slot) -> Vid: + """Convert a value Slot to a Vid, recording it for later remapping.""" + assert slot.id_type != IdType.Tensor + vid = Vid(idx=slot.idx) + self._vid_slot_map.append((vid, slot)) + return vid + + def to_int_or_vid(self, v: Union[int, Slot]) -> IntOrVid: + if isinstance(v, Slot): + return IntOrVid.from_vid(self.slot_to_vid(v)) + return IntOrVid.from_literal(int(v)) + + def to_float_or_vid(self, v: Union[float, int, Slot]) -> FloatOrVid: + if isinstance(v, Slot): + return FloatOrVid.from_vid(self.slot_to_vid(v)) + return FloatOrVid.from_literal(float(v)) + + def to_int_or_vid_or_tid(self, v: Union[int, Slot]) -> IntOrVidOrTid: + if isinstance(v, Slot): + if v.id_type == IdType.Tensor: + return IntOrVidOrTid.from_tid(self.slot_to_tid(v)) + return IntOrVidOrTid.from_vid(self.slot_to_vid(v)) + return IntOrVidOrTid.from_literal(int(v)) + + def _mark_read(self, node: Node): + assert self.node_info[node].handled, f"Node {node} is not handled" + assert ( + self.node_info[node].remaining_reads > 0 + ), f"Reading node {node}, but it has no remaining reads" + self.node_info[node].remaining_reads -= 1 + + if self.node_info[node].remaining_reads == 0: + slot = self.slot_manager.get_slot(node) + if slot is None: + return + if not isinstance(slot, tuple): + slot = (slot,) + for s in slot: + if s.id_space != IdSpace.Temp: + continue + if s.id_type == IdType.Tensor: + self.slot_manager.tid_managers[IdSpace.Temp].return_id(s.idx) + else: + self.slot_manager.vid_managers[IdSpace.Temp].return_id(s.idx) + + def _mark_node_handled(self, node: Node, *, handler: Optional[Handler] = None): + if self.node_info[node].handled: + return + self.node_info[node].handled = True + self.node_info[node].remaining_reads = len(node.users) + self.node_info[node].handler = handler + + if handler == PatternHandler.deferred_handler: + return + + def mark_read(n: Node): + flat_args, spec = pytree.tree_flatten((n.args, n.kwargs)) + seen = set() + for a in flat_args: + if isinstance(a, Node): + if a not in seen: + self._mark_read(a) + seen.add(a) + + if isinstance(handler, PatternHandler): + for n in handler.body: + mark_read(n) + mark_read(node) + + def _mark_node_supported(self, node: Node, *, handler: Optional[Handler] = None): + self.node_info[node].supported = True + self._mark_node_handled(node, handler=handler) + + def _mark_node_unsupported(self, node: Node, reason: str): + self.node_info[node].supported = False + self.node_info[node].unsupported_reason = reason + self._mark_node_handled(node) + + def _is_handled(self, node: Node) -> bool: + return self.node_info[node].handled + + def _mark_supported( + self, nodes: Union[List[Node], Node], *, handler: Optional[Handler] = None + ) -> None: + if isinstance(nodes, Node): + nodes = [nodes] + for node in nodes: + self._mark_node_supported(node, handler=handler) + + def _mark_unsupported(self, nodes: Union[List[Node], Node], reason: str) -> None: + if isinstance(nodes, Node): + nodes = [nodes] + for node in nodes: + self._mark_node_unsupported(node, reason) + + def _make_io_slots(self): # noqa: C901 + from torch.export.graph_signature import ( + InputKind, + OutputKind, + SymIntArgument, + TensorArgument, + ) + + output_kind_targets = defaultdict(set) + constant_tensors = [] + user_inputs = [] + user_outputs = [] + mutable_buffers = [] + + for ospec in self.ep.graph_signature.output_specs: + kind = ospec.kind + arg = ospec.arg + name = arg.name + target = ospec.target + if target is not None: + output_kind_targets[kind].add(target) + if kind in (OutputKind.USER_OUTPUT, OutputKind.USER_INPUT_MUTATION): + user_outputs.append(name) + + for ispec in self.ep.graph_signature.input_specs: + kind = ispec.kind + arg = ispec.arg + name = arg.name + target = ispec.target + + if isinstance(arg, TensorArgument): + if kind == InputKind.PARAMETER: + # Parameters are treated as constants (not mutated) + constant_tensors.append(name) + elif kind == InputKind.BUFFER: + if target in output_kind_targets[OutputKind.BUFFER_MUTATION]: + mutable_buffers.append(name) + else: + # Non-mutated buffers (like lifted tensor constants) are constants + constant_tensors.append(name) + elif kind == InputKind.USER_INPUT: + user_inputs.append(name) + elif kind == InputKind.CONSTANT_TENSOR: + constant_tensors.append(name) + else: + raise NotImplementedError( + f"Support for input {arg} is not implemented" + ) + elif isinstance(arg, SymIntArgument): + if kind == InputKind.USER_INPUT: + user_inputs.append(name) + else: + raise NotImplementedError( + f"Support for input {arg} is not implemented" + ) + else: + raise NotImplementedError(f"Support for input {arg} is not implemented") + + for node in self.ep.graph.nodes: + if node.op == "placeholder": + if node.users == {}: + continue + if node.name in constant_tensors: + self.make_or_get_slot(node, id_space=IdSpace.Constant) + elif node.name in user_inputs: + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and not val.is_contiguous(): + raise ValueError( + f"MLX backend requires contiguous input tensors, " + f"but input '{node.name}' has non-contiguous strides. " + f"shape={list(val.shape)}, stride={list(val.stride())}. " + f"Ensure example inputs passed to torch.export.export() " + f"are contiguous (call .contiguous() on them)." + ) + self.make_or_get_slot(node, id_space=IdSpace.Input) + elif node.name in mutable_buffers: + self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) + else: + raise NotImplementedError( + f"Support for placeholder {node.name} is not implemented" + ) + elif node.op == "output": + outs, _ = pytree.tree_flatten(node.args) + for o in outs: + if isinstance(o, Node) and o.name in user_outputs: + self.make_or_get_slot(o, id_space=IdSpace.Output) + + def _mark_noop(self): + """Mark noops and dead nodes.""" + dead = set() + noop_handler = REGISTRY.get_noop_handler() + if noop_handler is None: + return + + for n in reversed(self.ep.graph.nodes): + handler = REGISTRY.get_handler(n) + if handler == noop_handler: + dead.add(n) + + if n.op != "output" and all(user in dead for user in n.users): + self.node_info[n].handler = noop_handler + dead.add(n) + + def _apply_patterns(self) -> None: + """ + Find and apply pattern handlers to the graph. + + Uses PatternMatcher to discover multi-node patterns and assigns + handlers to matched nodes. This must run BEFORE _mark_noop so + pattern body nodes don't get incorrectly marked as dead. + """ + matcher = PatternMatcher(self.ep, REGISTRY) + for handler in matcher.find_patterns(): + handler.set_handlers(self) + + def _process_nodes(self) -> None: # noqa C901 + """ + Common logic for processing all nodes: create slots, match patterns, run handlers. + + This method: + 1. Creates I/O slots for placeholders and outputs + 2. Matches patterns FIRST (so body nodes get handlers and aren't marked dead) + 3. Marks dead/noop nodes + 4. Runs handlers for remaining nodes, marking them supported/unsupported + + The ordering is important: patterns must be matched before noops because + some pattern body nodes (e.g., update_cache) have no users since they + mutate in-place, but they're not dead - they're handled by the pattern. + """ + self._make_io_slots() + + # Apply patterns BEFORE _mark_noop so pattern body nodes don't get + # incorrectly marked as dead (e.g., update_cache nodes have no users + # since they mutate in-place, but they're not dead) + self._apply_patterns() + self._mark_noop() + + for n in self.ep.graph.nodes: + if self._is_handled(n): + continue + + if self.node_info[n].handler is not None: + handler = self.node_info[n].handler + handler(self, n) + self._mark_supported(n, handler=handler) + continue + + # Check input dtypes before processing node + unsupported_dtype_msg = _check_input_dtypes(n) + if unsupported_dtype_msg is not None: + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, unsupported_dtype_msg) + continue + + if n.op in ("placeholder", "output"): + dtype_error = _check_dtype(n) + if dtype_error is not None: + self._mark_unsupported(n, f"{n.op} {dtype_error}") + continue + self._mark_supported(n) + continue + + handler = REGISTRY.get_handler(n) + if handler is None: + msg = f"no handler for target={n.target}" + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, msg) + continue + + try: + handler(self, n) + self._mark_supported(n, handler=handler) + except Exception as e: + trace_str = traceback.format_exc() + msg = f"{handler} failed for {n.target}: {e}.\n{trace_str}" + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, msg) + + def check_support_only(self) -> None: + """ + Check which nodes are supported without building the full MLXGraph. + + This method populates node_info with supported/unsupported status for each + node, but avoids calling _build_mlx_graph() which can corrupt the shape_env + by evaluating symbolic shapes. + + Use this method for ops_to_not_decompose() and similar queries where you + only need to know support status, not the full compiled graph. + """ + self._process_nodes() + # NOTE: We intentionally skip _verify_build() and _build_mlx_graph() here + # because _build_mlx_graph() calls int() on tensor shapes which evaluates + # SymInts and corrupts the shape_env. This method is used for + # ops_to_not_decompose() where we only need support status. + + def _emit_buffer_mutation_writebacks(self): + """Emit copy-back instructions for BUFFER_MUTATION outputs. + + When a model mutates a buffer (e.g., via .copy_() or .mul_()), + torch.export functionalizes it: the new value is a computation result, + and the output spec marks it as BUFFER_MUTATION with a target buffer. + + This method emits an IdCopyNode for each BUFFER_MUTATION output, + copying the computation result back to the mutable buffer slot so + the updated value persists across execution calls. + """ + from torch.export.graph_signature import InputKind, OutputKind + + # Map buffer target name -> input placeholder name + target_to_placeholder = {} + for ispec in self.ep.graph_signature.input_specs: + if ispec.kind == InputKind.BUFFER and ispec.target is not None: + target_to_placeholder[ispec.target] = ispec.arg.name + + for ospec in self.ep.graph_signature.output_specs: + if ospec.kind != OutputKind.BUFFER_MUTATION: + continue + + result_slot = self.slot_manager.get_slot(ospec.arg.name) + placeholder_name = target_to_placeholder.get(ospec.target) + if result_slot is None or placeholder_name is None: + continue + + buffer_slot = self.slot_manager.get_slot(placeholder_name) + if buffer_slot is None or buffer_slot.id_space != IdSpace.MutableBuffer: + continue + + self.emit( + IdCopyNode( + x=self.slot_to_tid(result_slot), + out=self.slot_to_tid(buffer_slot), + ) + ) + + def build(self) -> MLXGraph: + if self._mlx_graph is not None: + return self._mlx_graph + + self._process_nodes() + self._emit_buffer_mutation_writebacks() + self._verify_build() + self._mlx_graph = self._build_mlx_graph() + return self._mlx_graph + + def _verify_build(self): + noop_handler = REGISTRY.get_noop_handler() + + for n, info in self.node_info.items(): + assert info.handled + assert ( + info.remaining_reads == 0 + ), f"Expected {n} to have no remaining reads, but it has {info.remaining_reads}" + if n.op == "output": + assert self.slot_manager.get_slot(n) is None + continue + if ( + info.handler in (noop_handler, PatternHandler.deferred_handler) + or n.users == {} + ): + assert ( + self.slot_manager.get_slot(n) is None + ), f"Did not expect node {n} handled by {info.handler} to have a slot" + else: + assert ( + self.slot_manager.get_slot(n) is not None + ), f"Expected slot for node {n}" + + def _collect_used_slots( + self, + ) -> Tuple[Set[Slot], Dict[IdSpace, int], Dict[IdSpace, int]]: + """ + Collect all used slots and count tensors/values per IdSpace. + + For constants and temps, only includes those actually referenced by + instructions. This ensures unused slots are not serialized or counted. + + Returns: + (used_slots, num_tensors, num_values) + """ + # Get slots actually referenced by instructions + instruction_referenced: Set[Slot] = {slot for _, slot in self._tid_slot_map} + instruction_referenced.update({slot for _, slot in self._vid_slot_map}) + + used_slots: Set[Slot] = set() + for _n, slot in self.slot_manager.name_to_slot.items(): + if not isinstance(slot, tuple): + slot = (slot,) + for s in slot: + # For constants and temps, only include if referenced by instructions + if s.id_space in (IdSpace.Constant, IdSpace.Temp): + if s in instruction_referenced: + used_slots.add(s) + else: + # Inputs, outputs, mutable buffers - always include + used_slots.add(s) + + num_tensors: Dict[IdSpace, int] = defaultdict(int) + num_values: Dict[IdSpace, int] = defaultdict(int) + seen: Set[Slot] = set() + for s in used_slots: + if s in seen: + continue + seen.add(s) + if s.id_type == IdType.Tensor: + num_tensors[s.id_space] += 1 + else: + num_values[s.id_space] += 1 + + return used_slots, num_tensors, num_values + + def _create_slot_mappings( + self, used_slots: Set[Slot] + ) -> Tuple[Dict[Slot, int], Dict[Slot, int]]: + """ + Create slot→Tid and slot→Vid mappings, and remap existing references. + + Returns: + (slot_to_tid, slot_to_vid) + """ + id_space_order = { + IdSpace.Constant: 0, + IdSpace.Input: 1, + IdSpace.Output: 2, + IdSpace.MutableBuffer: 3, + IdSpace.Temp: 4, + } + + # Create Tid mapping + slot_to_tid = sorted( + [s for s in used_slots if s.id_type == IdType.Tensor], + key=lambda s: (id_space_order[s.id_space], s.idx), + ) + slot_to_tid = {s: idx for idx, s in enumerate(slot_to_tid)} + + # Create Vid mapping + slot_to_vid = sorted( + [s for s in used_slots if s.id_type != IdType.Tensor], + key=lambda s: (id_space_order[s.id_space], s.idx), + ) + slot_to_vid = {s: idx for idx, s in enumerate(slot_to_vid)} + + # Remap all Tid/Vid values in instructions to use global indices + if hasattr(self, "_tid_slot_map"): + for tid, slot in self._tid_slot_map: + if slot in slot_to_tid: + tid.idx = slot_to_tid[slot] + else: + logger.warning(f"Slot {slot} not found in slot_to_tid mapping") + + if hasattr(self, "_vid_slot_map"): + for vid, slot in self._vid_slot_map: + if slot in slot_to_vid: + vid.idx = slot_to_vid[slot] + else: + logger.warning(f"Slot {slot} not found in slot_to_vid mapping") + + return slot_to_tid, slot_to_vid + + def _to_slot_variant( + self, + slot: Slot, + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + ) -> SlotVariant: + """Convert a Slot to a SlotVariant using the provided mappings.""" + if slot.id_type == IdType.Tensor: + idx = slot_to_tid[slot] + slot_type = SlotType.TensorSlot + elif slot.id_type == IdType.SymInt: + idx = slot_to_vid[slot] + slot_type = SlotType.IntValueSlot + elif slot.id_type == IdType.SymBool: + idx = slot_to_vid[slot] + slot_type = SlotType.BoolValueSlot + else: + raise NotImplementedError(f"Unsupported slot type {slot.id_type}") + return SlotVariant(idx=idx, slot_type=slot_type) + + def _build_io_maps( + self, + used_slots: Set[Slot], + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + ) -> Tuple[ + List[SlotVariant], List[SlotVariant], List[SlotVariant], List[NamedSlot] + ]: + """ + Build input/output/mutable_buffer maps and named slots. + + Returns: + (input_map, output_map, mutable_buffer_map, named_slots) + """ + input_map: List[SlotVariant] = [] + output_map: List[SlotVariant] = [] + mutable_buffer_map: List[SlotVariant] = [] + # Canonical (unprefixed) name → Slot. The prefix is applied only at + # the exit boundaries: NamedSlot construction and NamedDataStore keys. + name_to_slot: Dict[str, Slot] = {} + + for ispec in self.ep.graph_signature.input_specs: + slot = self.slot_manager.get_slot(ispec.arg.name) + if slot is None: + continue + assert isinstance(slot, Slot) + name = ispec.target if ispec.target is not None else ispec.arg.name + if slot.id_space == IdSpace.Input: + input_map.append(self._to_slot_variant(slot, slot_to_tid, slot_to_vid)) + name_to_slot[name] = slot + elif slot.id_space == IdSpace.MutableBuffer: + mutable_buffer_map.append( + self._to_slot_variant(slot, slot_to_tid, slot_to_vid) + ) + name_to_slot[name] = slot + else: + if slot in used_slots: + name_to_slot[name] = slot + + for ospec in self.ep.graph_signature.output_specs: + name = ospec.arg.name + slot = self.slot_manager.get_slot(name) + if slot is None: + continue + assert isinstance(slot, Slot) + if slot.id_space == IdSpace.Output: + output_map.append(self._to_slot_variant(slot, slot_to_tid, slot_to_vid)) + name = ospec.target if ospec.target is not None else ospec.arg.name + name_to_slot[name] = slot + elif slot.id_space == IdSpace.MutableBuffer: + name = ospec.target if ospec.target is not None else ospec.arg.name + name_to_slot[name] = slot + + for name in self.extra_constants: + slot = self.slot_manager.get_slot(name) + assert slot is not None and isinstance(slot, Slot) + if slot in used_slots: + name_to_slot[name] = slot + + # Store unprefixed constant mapping for get_named_data_store() + self._constant_name_to_slot = { + n: s for n, s in name_to_slot.items() if s.id_space == IdSpace.Constant + } + + # Apply prefix at the exit boundary — the FlatBuffer named_slots + named_slots = [ + NamedSlot( + name=self._prefix_key(n), + slot=self._to_slot_variant(s, slot_to_tid, slot_to_vid), + ) + for n, s in name_to_slot.items() + ] + + return input_map, output_map, mutable_buffer_map, named_slots + + def _build_tensor_meta( # noqa: C901 + self, + used_slots: Set[Slot], + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + num_tensors: Dict[IdSpace, int], + ) -> List[TensorMeta]: + """ + Build tensor metadata list with shape/dtype information. + + Static dimensions are stored as ShapeDim(value=N). + Dynamic dimensions (SymInt) are stored as ShapeDim(value=-1) + with min/max bounds from the shape_env. + + Note: tensor_meta shapes are only consumed by the runtime for + constant and mutable buffer allocation (which are always static). + Dynamic dim metadata is informational — the runtime resolves + dynamic shapes via SymSizeNode at execution time. + """ + + def _get_dim_bounds(dim: torch.SymInt) -> tuple: + """Get (min, max) bounds for a symbolic dimension.""" + try: + node = dim.node + shape_env = node.shape_env + if shape_env is not None: + expr = node.expr + lower = int(shape_env.bound_sympy(expr).lower) + upper = int(shape_env.bound_sympy(expr).upper) + if upper > 2**30: + return (lower, -1) # treat as unbounded + return (lower, upper) + except Exception: + pass + return (0, -1) # unbounded fallback + + def to_tensor_meta(t: torch.Tensor) -> TensorMeta: + shape: List[ShapeDim] = [] + for dim in t.shape: + if isinstance(dim, torch.SymInt): + lo, hi = _get_dim_bounds(dim) + shape.append(ShapeDim(value=-1, min_value=lo, max_value=hi)) + else: + shape.append(ShapeDim(value=int(dim))) + + dim_order = list(range(len(t.shape))) if len(t.shape) > 0 else None + + return TensorMeta( + shape=shape, + scalar_type=torch_dtype_to_scalar_type(t.dtype), + dim_order=dim_order, + ) + + tensor_meta: Dict[int, TensorMeta] = {} + for n in self.node_info: + slot = self.slot_manager.get_slot(n) + if not isinstance(slot, tuple): + slot = (slot,) + fake_val = n.meta.get("val", None) + if not isinstance(fake_val, tuple): + fake_val = (fake_val,) + for s, fv in zip(slot, fake_val): + if s not in used_slots: + continue + if s.id_type != IdType.Tensor: + continue + if s.id_space == IdSpace.Temp: + continue + idx = slot_to_tid[s] + if fv is not None and hasattr(fv, "shape"): + tensor_meta[idx] = to_tensor_meta(fv) + + for name, t in self.extra_constants.items(): + slot = self.slot_manager.get_slot(name) + assert slot is not None and isinstance(slot, Slot) + if slot in used_slots: + idx = slot_to_tid[slot] + tensor_meta[idx] = to_tensor_meta(t) + + num_non_temp_tensors = sum(num_tensors.values()) - num_tensors[IdSpace.Temp] + return [tensor_meta.get(i) for i in range(num_non_temp_tensors)] + + def _build_mlx_graph(self) -> MLXGraph: + # Check support + for node, info in self.node_info.items(): + if not info.supported: + raise ValueError( + f"Found unsupported node: {node}\nReason: {info.unsupported_reason}" + ) + + # Collect slots and create mappings + used_slots, num_tensors, num_values = self._collect_used_slots() + slot_to_tid, slot_to_vid = self._create_slot_mappings(used_slots) + + # Store for use in get_constant_data() - needed to serialize in Tid order + self._slot_to_final_tid = slot_to_tid + + # Build I/O maps and metadata + input_map, output_map, mutable_buffer_map, named_slots = self._build_io_maps( + used_slots, slot_to_tid, slot_to_vid + ) + tensor_meta_list = self._build_tensor_meta( + used_slots, slot_to_tid, slot_to_vid, num_tensors + ) + + # Compute final counts + num_constant_tensors = num_tensors[IdSpace.Constant] + num_temp_tensors = num_tensors[IdSpace.Temp] + num_values_count = sum(num_values.values()) + + return MLXGraph( + version="1", + num_constant_tensors=num_constant_tensors, + num_input_tensors=num_tensors[IdSpace.Input], + num_output_tensors=num_tensors[IdSpace.Output], + num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], + num_temp_tensors=num_temp_tensors, + num_values=num_values_count, + instruction_chains=[InstructionChain(instructions=self._instrs)], + main_chain_idx=0, + init_chain_idx=-1, + input_map=input_map, + output_map=output_map, + mutable_buffer_map=mutable_buffer_map, + named_slots=named_slots, + tensor_meta=tensor_meta_list, + ) + + def get_named_data_store(self) -> NamedDataStore: + """ + Get a NamedDataStore containing all constant tensors. + + Uses the unprefixed canonical-name → Slot mapping built by + ``_build_io_maps()`` so that tensor lookups hit ``ep.state_dict`` / + ``ep.constants`` / ``extra_constants`` (which all use unprefixed + keys). The prefix is applied at the exit boundary — the + ``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``. + """ + named_data_store = NamedDataStore() + + # Sort by final TID for deterministic ordering + entries = sorted( + self._constant_name_to_slot.items(), + key=lambda x: self._slot_to_final_tid.get(x[1], 0), + ) + + logger.debug(f"Adding {len(entries)} constants to NamedDataStore...") + for canonical_name, _slot in entries: + tensor = self._find_constant_tensor(canonical_name) + if tensor is None: + continue + + t = tensor.detach().cpu().contiguous() + named_data_store.add_named_data( + key=self._prefix_key(canonical_name), + data=t, + alignment=16, + ) + logger.debug("Done adding constants to NamedDataStore") + + return named_data_store + + def get_mutable_buffer_names(self) -> List[str]: + """ + Get the names of all mutable buffers in Tid order. + + Returns: + List of mutable buffer names in the order they appear in mutable_buffer_map. + """ + assert self._mlx_graph is not None, "Must call build() first" + + names = [] + for name, slot in self.slot_manager.name_to_slot.items(): + if isinstance(slot, tuple): + continue + if slot.id_space != IdSpace.MutableBuffer: + continue + if slot in self._slot_to_final_tid: + names.append((name, self._slot_to_final_tid[slot])) + + # Sort by Tid and return just the names + names.sort(key=lambda x: x[1]) + return [n for n, _ in names] + + def _find_constant_tensor(self, name: str) -> Optional[torch.Tensor]: + """Find a constant tensor by name from various sources.""" + if name in self.ep.state_dict: + return self.ep.state_dict[name] + if name in self.ep.constants: + return self.ep.constants[name] + if name in self.extra_constants: + return self.extra_constants[name] + # Look up by target + for ispec in self.ep.graph_signature.input_specs: + if ispec.arg.name == name and ispec.target is not None: + if ispec.target in self.ep.state_dict: + return self.ep.state_dict[ispec.target] + if ispec.target in self.ep.constants: + return self.ep.constants[ispec.target] + return None diff --git a/backends/mlx/builder/slot_manager.py b/backends/mlx/builder/slot_manager.py new file mode 100644 index 00000000000..b1884a76a68 --- /dev/null +++ b/backends/mlx/builder/slot_manager.py @@ -0,0 +1,187 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import uuid +from collections import defaultdict +from dataclasses import dataclass +from enum import auto, Enum +from typing import Dict, Optional, Tuple, Union + +import torch +from torch.fx.node import Node + + +class IdType(Enum): + Tensor = auto() + SymInt = auto() + SymBool = auto() + + +class IdSpace(Enum): + Constant = auto() + Input = auto() + Output = auto() + MutableBuffer = auto() + Temp = auto() + + +@dataclass(frozen=True) +class Slot: + id_type: IdType + id_space: IdSpace + idx: Optional[int] = None + + +class IdManager: + def __init__(self): + self.free: set[int] = set() + self.next_new_id = 0 + + def get_id(self): + return self.free.pop() if self.free else self._bump() + + def _bump(self): + idx = self.next_new_id + self.next_new_id += 1 + return idx + + def return_id(self, idx): + self.free.add(idx) + + +class SlotManager: + def __init__(self): + self.tid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager) + self.vid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager) + self.name_to_slot: Dict[str, Slot] = {} + + def set_slot(self, node_or_name: Union[Node, str], slot: Slot): + if isinstance(node_or_name, Node): + node_or_name = node_or_name.name + # Allow setting a slot to the same value (e.g., for in-place ops like SLICE_UPDATE) + existing = self.name_to_slot.get(node_or_name) + if existing is not None: + # If already set to the same slot, it's fine + if existing == slot: + return + raise AssertionError( + f"Slot for {node_or_name} already set to {existing}, trying to set to {slot}" + ) + self.name_to_slot[node_or_name] = slot + + def get_slot( + self, node_or_name: Union[Node, str] + ) -> Optional[Union[Tuple[Slot], Slot]]: + if isinstance(node_or_name, Node): + node_or_name = node_or_name.name + return self.name_to_slot.get(node_or_name, None) + + def _val_to_idtype(self, v) -> IdType: + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(v, FakeTensor): + return IdType.Tensor + elif isinstance(v, torch.SymInt): + return IdType.SymInt + elif isinstance(v, torch.SymBool): + return IdType.SymBool + else: + raise NotImplementedError(f"val_to_idtype: {v}") + + def is_alive(self, slot: Slot) -> bool: + if slot.id_type == IdType.Tensor: + manager = self.tid_managers[slot.id_space] + else: + manager = self.vid_managers[slot.id_space] + idx = slot.idx + if idx >= manager.next_new_id: + return False + if idx in manager.free: + return False + return True + + def make_constant_slot(self, name: str) -> Slot: + assert name not in self.name_to_slot + id_space = IdSpace.Constant + manager = self.tid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return slot + + def make_tmp_slot(self) -> Tuple[str, Slot]: + name = f"tmp_{uuid.uuid4().hex}" + id_space = IdSpace.Temp + manager = self.tid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return name, slot + + def make_tmp_value_slot(self) -> Tuple[str, Slot]: + """Create a temporary SymInt slot and register it.""" + name = f"tmp_val_{uuid.uuid4().hex}" + id_space = IdSpace.Temp + manager = self.vid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.SymInt, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return name, slot + + def make_or_get_slots( + self, node: Node, id_space: IdSpace = IdSpace.Temp + ) -> Tuple[Slot, ...]: + """ + Get or create slots for a node. Always returns a tuple of slots. + + Use this for multi-output ops (e.g., topk returns (values, indices)). + For single-output ops, prefer make_or_get_slot() which returns a single Slot. + """ + if node.name in self.name_to_slot: + slot = self.name_to_slot[node.name] + # Normalize to tuple for consistent return type + if not isinstance(slot, tuple): + return (slot,) + return slot + + val = node.meta.get("val", None) + assert val is not None, f"Node {node} has no val" + if not isinstance(val, (list, tuple)): + val = (val,) + + slots = [] + for v in val: + id_type = self._val_to_idtype(v) + if id_type == IdType.Tensor: + manager = self.tid_managers[id_space] + else: + manager = self.vid_managers[id_space] + idx = manager.get_id() + slots.append(Slot(id_type=id_type, id_space=id_space, idx=idx)) + slots = tuple(slots) + + # Store in the format that matches the node's output structure + if len(slots) == 1: + self.set_slot(node, slots[0]) + else: + self.set_slot(node, slots) + return slots + + def make_or_get_slot(self, node: Node, id_space: IdSpace = IdSpace.Temp) -> Slot: + """ + Get or create a slot for a single-output node. Returns a single Slot. + + Use this for single-output ops (the common case). + For multi-output ops, use make_or_get_slots() instead. + """ + slots = self.make_or_get_slots(node, id_space) + assert len(slots) == 1, ( + f"Expected single output for node {node.name}, got {len(slots)}. " + f"Use make_or_get_slots() for multi-output ops." + ) + return slots[0] diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py new file mode 100644 index 00000000000..81853adbd6d --- /dev/null +++ b/backends/mlx/custom_ops.py @@ -0,0 +1,15 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Custom MLX operator definitions. + +This module defines custom operators that are supported by the MLX backend. +These ops are used during model export to represent operations that MLX +can execute efficiently but may not have direct PyTorch equivalents. +""" diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py new file mode 100644 index 00000000000..6e8516e86b1 --- /dev/null +++ b/backends/mlx/ops.py @@ -0,0 +1,294 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Op Handlers - registered handlers for converting ATen/custom ops to MLX. + +This module contains all the op handler functions registered with the MLXOpRegistry. +Each handler converts a specific PyTorch operation to the corresponding MLX graph node. +""" + +from __future__ import annotations + +import operator +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode +from torch.fx.node import Node + + +def require_static_int(value: Any, param_name: str, op_name: str) -> None: + """ + Validate that a parameter is a static integer (not a Slot/SymInt). + + Raises NotImplementedError if the value is dynamic. + + Args: + value: The parameter value to check + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if isinstance(value, Slot) or not isinstance(value, int): + raise NotImplementedError( + f"{op_name} with dynamic {param_name} is not supported. " + f"{param_name} requires a static int32 value, but got {value} (type={type(value).__name__})." + ) + + +def require_static_float(value: Any, param_name: str, op_name: str) -> None: + """ + Validate that a parameter is a static float (not a Slot/SymFloat). + + Raises NotImplementedError if the value is dynamic. + + Args: + value: The parameter value to check + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if isinstance(value, Slot) or not isinstance(value, (int, float)): + raise NotImplementedError( + f"{op_name} with dynamic {param_name} is not supported. " + f"{param_name} requires a static float value, but got {value} (type={type(value).__name__})." + ) + + +def require_static_ints( + values: Union[List[Any], Any], param_name: str, op_name: str +) -> None: + """ + Validate that all values in a list are static integers (not Slots/SymInts). + + Raises NotImplementedError if any value is dynamic. + + Args: + values: List of values to check, or a single value + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if not isinstance(values, list): + values = [values] + + for v in values: + require_static_int(v, param_name, op_name) + + +def require_args( + args: List[Any], + min_count: int, + max_count: int, + op_name: str, +) -> None: + """ + Validate that args count is within expected range. + + Raises ValueError if the count is outside the expected range. + + Args: + args: The handler args list + min_count: Minimum number of args expected + max_count: Maximum number of args expected + op_name: Name of the operation (for error message) + """ + if not (min_count <= len(args) <= max_count): + if min_count == max_count: + raise ValueError(f"{op_name}: expected {min_count} args, got {len(args)}") + raise ValueError( + f"{op_name}: expected {min_count}-{max_count} args, got {len(args)}" + ) + + +def require_kwargs( + kwargs: Dict[str, Any], + allowed: Set[str], + op_name: str, +) -> None: + """ + Validate that only allowed kwargs are present. + + Raises ValueError if unexpected kwargs are found. + + Args: + kwargs: The handler kwargs dict + allowed: Set of allowed kwarg names + op_name: Name of the operation (for error message) + """ + unexpected = set(kwargs.keys()) - allowed + if unexpected: + raise ValueError(f"{op_name}: unexpected kwargs: {unexpected}") + + +def require_contiguous_format( + *, + layout=None, + memory_format=None, + dim_order=None, + op_name: str, +) -> None: + """ + Validate that layout/memory_format/dim_order specify contiguous format. + + MLX only supports contiguous (strided) tensors. Raises ValueError if + sparse layouts or non-contiguous memory formats are requested. + + Args: + layout: The torch layout (e.g., torch.strided, torch.sparse_coo) + memory_format: The torch memory format (e.g., torch.contiguous_format, + torch.channels_last) + dim_order: The dimension order (list of ints, identity = contiguous) + op_name: Name of the operation (for error message) + """ + if layout is not None and layout != torch.strided: + raise ValueError(f"{op_name}: only strided layout supported, got {layout}") + + if memory_format is not None and memory_format not in ( + torch.contiguous_format, + torch.preserve_format, + ): + raise ValueError( + f"{op_name}: only contiguous memory format supported, got {memory_format}" + ) + + if dim_order is not None: + if list(dim_order) != list(range(len(dim_order))): + raise ValueError( + f"{op_name}: only contiguous dim_order supported, got {dim_order}" + ) + + +def is_static_value(value: Any) -> bool: + """ + Check if a value is static (not a Slot/SymInt). + + Returns: + True if the value is a static scalar (int, float, bool), False otherwise + """ + return not isinstance(value, Slot) + + +def used_getitem_indices(n: Node) -> Set[int]: + """Return the set of getitem indices actually consumed downstream. + + Only includes indices where the getitem node has at least one user. + """ + return { + user.args[1] + for user in n.users + if user.target == operator.getitem and len(user.users) > 0 + } + + +def normalize_reduction_dim( + args: List[Any], start_idx: int = 1 +) -> Tuple[Optional[List[int]], bool]: + """ + Normalize dim argument for reduction operations. + + Extracts and normalizes the dim argument from handler args, returning a list of axes + and the keepdim flag. Handles both list-based dims (e.g., sum.dim_IntList) and + single int dims (e.g., prod.dim_int). + + Args: + args: The handler args list + start_idx: Index where the dim argument starts (default 1, after self) + + Returns: + Tuple of (axes, keepdim) where: + - axes: List of dimension indices, or empty list for reduce-all + - keepdim: Boolean keepdim flag (default False) + """ + if len(args) > start_idx and isinstance(args[start_idx], (list, tuple)): + dim = list(args[start_idx]) + keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False + elif len(args) > start_idx and isinstance(args[start_idx], int): + dim = [args[start_idx]] + keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False + else: + dim = [] + keepdim = False + + return dim, keepdim + + +@REGISTRY.register(target=[torch.ops.aten.addmm.default]) +def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle addmm: self + (mat1 @ mat2). + + addmm(self, mat1, mat2, *, beta=1, alpha=1) computes: + beta * self + alpha * (mat1 @ mat2) + + This is typically the result of decomposing linear(x, w, b) in Edge IR: + permute(w) -> addmm(b, x, permuted_w) + + For the common case where beta=1 and alpha=1, this is equivalent to: + mat1 @ mat2 + self + + We use AddmmNode which calls matmul directly (no transposition needed). + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 3, 3, "aten.addmm") + require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm") + bias, mat1, mat2 = args[0], args[1], args[2] + + beta = kwargs.get("beta", 1) + alpha = kwargs.get("alpha", 1) + + out = P.make_or_get_slot(n) + + # Emit AddmmNode with alpha and beta parameters + P.emit( + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(bias), + alpha=float(alpha), + beta=float(beta), + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.mm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, + ] +) +def _mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle mm/bmm/matmul: matrix multiplication without bias. + + All three ops compute matrix products with different dimension expectations: + - mm: 2D x 2D + - bmm: 3D x 3D (batched) + - matmul: arbitrary dimensions (NumPy semantics) + + MLX's matmul handles all cases, so we emit AddmmNode with bias=None. + """ + args = P.args(n) + require_args(args, 2, 2, "aten.mm/bmm/matmul") + require_kwargs(P.kwargs(n), set(), "aten.mm/bmm/matmul") + mat1, mat2 = args[0], args[1] + + out = P.make_or_get_slot(n) + + P.emit( + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), + out=P.slot_to_tid(out), + bias=None, + ) + ) + return out diff --git a/backends/mlx/partitioner.py b/backends/mlx/partitioner.py new file mode 100644 index 00000000000..0896cafc301 --- /dev/null +++ b/backends/mlx/partitioner.py @@ -0,0 +1,298 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Partitioner - decides which ops should run on the MLX delegate. + +This module provides a Partitioner implementation that analyzes an EdgeIR +graph and marks supported operations for delegation to MLX. +""" + +from __future__ import annotations + +import inspect +from typing import Any, Callable, Dict, List, Tuple, Union + +import torch +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.preprocess import MLXBackend +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_partitions_from_list_of_nodes, +) +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import Partition +from torch.fx.passes.operator_support import OperatorSupportBase + + +class MLXOperatorSupport(OperatorSupportBase): + """ + Determines which operators are supported by the MLX delegate. + + Uses MLXProgramBuilder to determine support - this ensures the partitioner + uses the exact same logic as the actual compilation. A node is supported + if the builder can handle it (either via direct handler or pattern match). + """ + + def __init__( + self, + edge_program: torch.export.ExportedProgram, + compile_specs: List[CompileSpec], + ): + self.edge_program = edge_program + self.compile_specs = compile_specs + + # Run the builder to determine which nodes are supported + # The builder populates node_info with supported/unsupported status + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + self._builder = MLXProgramBuilder(edge_program) + self._builder.check_support_only() + + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + # Check if builder determined this node is supported + info = self._builder.node_info.get(node) + if info is not None and info.supported: + logger.debug(f"[SUPPORTED] Node {node.target}") + return True + + logger.debug(f"[UNSUPPORTED] Node {node.target}") + return False + + +class MLXPartitioner(Partitioner): + """ + Partitioner for the MLX delegate. + + Analyzes an EdgeIR graph and partitions supported operations + for delegation to MLX. + """ + + def __init__(self, compile_specs: List[CompileSpec] | None = None) -> None: + self.compile_specs = compile_specs or [] + self.delegation_spec = DelegationSpec(MLXBackend.__name__, self.compile_specs) + self.partition_tags: Dict[str, DelegationSpec] = {} + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> tuple[list[torch._ops.OpOverload], Callable[[torch.fx.Node], bool] | None]: + """ + Return ops that should NOT be decomposed during edge lowering. + + This runs the MLXProgramBuilder to trace through the graph and determine + which nodes are supported (either via direct handlers or patterns). + Only ops for nodes that are actually supported should be preserved. + + This is called by to_edge_transform_and_lower to determine which + ops to preserve before partitioning. + + NOTE: We use check_support_only() instead of build() to avoid corrupting + the shape_env. build() calls _build_mlx_graph() which evaluates SymInts + to concrete values when converting tensor shapes, which corrupts the + shape_env and causes dynamic shapes to be lost during decomposition. + """ + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + # Check if the graph already contains lowered modules (post-partitioning pass) + # In this case, we should return empty since partitioning is already done + for node in ep.graph.nodes: + if node.op == "get_attr" and "lowered_module" in node.name: + logger.debug( + "MLX ops_to_not_decompose: Graph already partitioned, returning empty" + ) + return ([], None) + + # Run the builder to determine which nodes are supported + builder = MLXProgramBuilder(ep) + builder.check_support_only() + + # Collect ops for nodes that are actually supported + do_not_decompose: list[torch._ops.OpOverload] = [] + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + info = builder.node_info.get(node) + if info is not None and info.supported: + if node.target not in do_not_decompose: + do_not_decompose.append(node.target) + + logger.debug( + f"MLX ops_to_not_decompose: {[str(op) for op in do_not_decompose]}" + ) + return (do_not_decompose, None) + + def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: + """Generate partitions of supported nodes.""" + self.supported_ops = MLXOperatorSupport( + edge_program=edge_program, + compile_specs=self.delegation_spec.compile_specs, + ) + + # Collect unsupported ops, aggregated by target + unsupported_by_target: Dict[str, Tuple[int, str]] = ( + {} + ) # target -> (count, reason) + for node in edge_program.graph.nodes: + is_supported = self.supported_ops.is_node_supported({}, node) + if not is_supported and node.op == "call_function": + target_str = str(node.target) + info = self.supported_ops._builder.node_info.get(node) + reason = info.unsupported_reason if info else "No handler registered" + if target_str in unsupported_by_target: + count, _ = unsupported_by_target[target_str] + unsupported_by_target[target_str] = (count + 1, reason) + else: + unsupported_by_target[target_str] = (1, reason) + + logger.info("=" * 80) + logger.info("MLX Partitioner: UNSUPPORTED OPS SUMMARY") + logger.info("=" * 80) + if unsupported_by_target: + for target, (count, reason) in unsupported_by_target.items(): + logger.info(f" [UNSUPPORTED x{count}] {target}") + logger.info(f" Reason: {reason}") + else: + logger.info(" (All call_function nodes are supported!)") + logger.info("=" * 80) + + partitions = generate_partitions_from_list_of_nodes( + edge_program.graph_module, + op_support=self.supported_ops, + ) + + # WORKAROUND: Include sym_size nodes in partitions when any of their + # users are in the partition. Without this, sym_size nodes stay outside + # the partition and their results cross the partition boundary as concrete + # inputs, losing dynamic shape information during delegate lowering. + # By pulling them inside, the MLX runtime can execute SYM_SIZE at runtime, + # keeping shapes dynamic. + partitions = self._include_sym_size_nodes_in_partitions( + edge_program.graph_module, partitions + ) + + return partitions + + def _include_sym_size_nodes_in_partitions( + self, gm: torch.fx.GraphModule, partitions: List[Partition] + ) -> List[Partition]: + """ + Include sym_size nodes in partitions when any of their users are in the partition. + + This is a workaround for the dynamic shapes bug where symbolic shapes are lost + during delegate lowering if the sym_size node is not included in the partition. + """ + from executorch.exir.dialects.edge._ops import EdgeOpOverload + + for partition in partitions: + partition_nodes = set(partition.nodes) + nodes_to_add = [] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + + # Check if this is a sym_size node + target = node.target + if isinstance(target, EdgeOpOverload): + target = target._op + + if target != torch.ops.aten.sym_size.int: + continue + + # Check if any user of this sym_size node is in the partition + for user in node.users: + if user in partition_nodes: + # Add sym_size to partition if not already there + if node not in partition_nodes: + nodes_to_add.append(node) + logger.debug( + f"Adding sym_size node {node.name} to partition " + f"(used by {user.name})" + ) + break + + # Add the sym_size nodes to the partition + for node in nodes_to_add: + partition.add_node(node) + + return partitions + + def tag_nodes(self, partitions: List[Partition]) -> None: + """Tag nodes in each partition for delegation.""" + for partition in partitions: + delegation_tag = f"mlx_{partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = delegation_tag + self.partition_tags[delegation_tag] = self.delegation_spec + + @staticmethod + def check_partitions(partitions: Union[dict, list]) -> bool: + """Check if any partitions were found.""" + pl = len(partitions) + if pl == 0: + logger.warning("MLX: Nothing can be partitioned!") + else: + logger.info(f"MLX: Found {pl} subgraphs to be partitioned.") + return pl != 0 + + @staticmethod + def _is_to_edge_transform_and_lower() -> bool: + """Check whether we are being called from to_edge_transform_and_lower.""" + for frame_info in inspect.stack(): + if frame_info.function == "to_edge_transform_and_lower": + return True + return False + + def partition(self, edge_program: ExportedProgram) -> PartitionResult: + """ + Partition the edge program for MLX delegation. + + Args: + edge_program: The ExportedProgram to partition. + + Returns: + PartitionResult with tagged nodes and partition specs. + + Raises: + RuntimeError: If called from the deprecated ``to_edge`` workflow. + """ + if not self._is_to_edge_transform_and_lower(): + raise RuntimeError( + "MLXPartitioner must be used with to_edge_transform_and_lower(). " + "The to_edge() + to_backend() workflow is not supported because " + "it decomposes ops that MLX has optimized implementations for. " + "Please use:\n" + " exir.to_edge_transform_and_lower(\n" + ' {"forward": exported_program},\n' + " partitioner=[MLXPartitioner()],\n" + " )" + ) + partitions = self.generate_partitions(edge_program=edge_program) + if self.check_partitions(partitions): + self.tag_nodes(partitions) + # Tag constant data that are used by the supported ops + tag_constant_data(edge_program) + # Tag mutated buffers so they are included in the partition + # This ensures the partitioned subgraph has proper mutation tracking + tag_mutated_buffer(edge_program) + + return PartitionResult( + tagged_exported_program=edge_program, + partition_tags=self.partition_tags, + ) diff --git a/backends/mlx/passes.py b/backends/mlx/passes.py new file mode 100644 index 00000000000..c7efdf561de --- /dev/null +++ b/backends/mlx/passes.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph transformation passes for the MLX backend. +""" + +from typing import List + +from executorch.exir.pass_base import ExportPass + + +def get_default_passes() -> List[ExportPass]: + """ + Returns a list of passes that are enabled by default for the MLX backend. + """ + return [] diff --git a/backends/mlx/patches/mlx_json.patch b/backends/mlx/patches/mlx_json.patch new file mode 100644 index 00000000000..4760403c8e6 --- /dev/null +++ b/backends/mlx/patches/mlx_json.patch @@ -0,0 +1,29 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -304,12 +304,18 @@ else() + set(MLX_BUILD_ACCELERATE OFF) + endif() + +-message(STATUS "Downloading json") +-FetchContent_Declare( +- json +- URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +-FetchContent_MakeAvailable(json) +-target_include_directories( +- mlx PRIVATE $) ++# Only fetch json if nlohmann_json target doesn't already exist ++# (ExecuTorch provides its own copy) ++if(NOT TARGET nlohmann_json) ++ message(STATUS "Downloading json") ++ FetchContent_Declare( ++ json ++ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) ++ FetchContent_MakeAvailable(json) ++ target_include_directories( ++ mlx PRIVATE $) ++else() ++ message(STATUS "Using existing nlohmann_json target") ++endif() + + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) diff --git a/backends/mlx/pattern_utils.py b/backends/mlx/pattern_utils.py new file mode 100644 index 00000000000..0d3d86430eb --- /dev/null +++ b/backends/mlx/pattern_utils.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared pattern matching utilities for MLX backend. + +This module provides common utilities used by both: +- passes.py: Graph transformation passes (ExportPass) +- patterns.py: MLX lowering pattern handlers (PatternHandler) + +The core abstraction is the `PatternMatch` base class which provides: +- `maybe_create(head)` - Class method to match a pattern from a head node +- Captured values as typed fields +- `body` list of intermediate nodes to remove + +Usage in passes.py: + class FuseRMSNormPass(ExportPass): + def call(self, graph_module): + for node in graph.nodes: + if match := RMSNormMatch.maybe_create(node): + replacement = self._emit_fused_op(graph, match) + node.replace_all_uses_with(replacement) + match.remove_body_nodes(graph) + +Usage in patterns.py: + class RMSNormHandler(PatternHandler): + @classmethod + def maybe_create(cls, ep, head): + if match := RMSNormMatch.maybe_create(head): + return cls(head, match.body, match.input, match.weight, match.eps) + return None +""" + +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Set, Tuple, Union + +from executorch.backends.mlx.builder.op_helpers import get_aten_target_normalized +from torch.fx import Graph +from torch.fx.node import Node + + +# Type alias for walk_back result entries +# Each entry corresponds to an OpStep: +# - Node: matched node (for regular steps) +# - None: optional step that didn't match +# - List[Node]: repeat step (0 or more matches) +WalkBackEntry = Union[Node, None, List[Node]] + + +def match_target(node: Node, op: Any) -> bool: + """ + Check if a node's normalized aten target matches the given op. + + Uses get_aten_target_normalized to handle edge dialect ops. + This means slice_copy matches slice, etc. + + Args: + node: The node to check + op: The op to match (e.g., torch.ops.aten.mul.Tensor) + """ + return node.op == "call_function" and get_aten_target_normalized(node.target) == op + + +def has_single_user(node: Node) -> bool: + return len(node.users) == 1 + + +def has_no_users(node: Node) -> bool: + return len(node.users) == 0 + + +def extract_lifted_tensor_constant(node: Node) -> Optional[float]: + """ + Extract scalar value from a lifted tensor constant node. + + Lifted constants are created during torch.export and contain small + constant tensors (like epsilon values). The actual value is stored + in node.meta["val"]. + + Args: + node: A node that may be a lifted tensor constant + + Returns: + The scalar float value, or None if not a lifted constant or not scalar + """ + if not isinstance(node, Node): + return None + if "lifted_tensor_constant" not in node.name: + return None + val = node.meta.get("val") + if val is None: + return None + if not hasattr(val, "item"): + return None + try: + return float(val.item()) + except (RuntimeError, ValueError): + return None + + +@dataclass +class OpStep: + """ + One step in a backward walk through the graph. + + Used with walk_back() to define pattern chains. Supports both exact op + matching and predicate-based matching. + + Attributes: + op: Specific op to match (e.g., torch.ops.aten.rsqrt.default) + predicate: Alternative to op - a function that returns True for matching nodes + optional: If True, skip this step if it doesn't match + repeat: If True, match this step 0 or more times (like regex *) + require_single_user: If True (default), only match nodes with exactly one user + nargs: Number of args required. Can be: + - int: minimum number of args (default 1, since we advance via args[0]) + - tuple (min, max): range of args required (inclusive) + kwargs: Set of kwargs we handle (node's kwargs must be subset of this) + arg_index: Which arg to follow when advancing (default 0) + + Examples: + # Match specific op + OpStep(op=torch.ops.aten.rsqrt.default) + + # Match with predicate (for matching families of ops) + OpStep(predicate=lambda n: match_target(n, torch.ops.aten.select.int)) + + # Match chain of same op type (0 or more) + OpStep(op=torch.ops.aten.select.int, repeat=True) + + # Optional dtype conversion + OpStep(op=torch.ops.aten._to_copy.default, optional=True) + + # Require between 2 and 4 args + OpStep(op=torch.ops.aten.some_op.default, nargs=(2, 4)) + + # Declare that we handle 'dtype' kwarg + OpStep(op=torch.ops.aten._to_copy.default, kwargs={"dtype"}) + + # Follow second arg (e.g., mul(x, rsqrt(y)) -> follow rsqrt in args[1]) + OpStep(op=torch.ops.aten.mul.Tensor, arg_index=1) + """ + + op: Any = None + predicate: Optional[Callable[[Node], bool]] = None + optional: bool = False + repeat: bool = False + require_single_user: bool = True + nargs: Union[int, Tuple[int, int]] = 1 + kwargs: Set[str] = field(default_factory=set) # Empty = no kwargs allowed + arg_index: int = 0 + + def matches(self, node: Node) -> bool: + """Check if this step fully matches the given node.""" + # Check op or predicate + if self.op is not None: + if not match_target(node, self.op): + return False + elif self.predicate is not None: + if not self.predicate(node): + return False + else: + return False + + # Check single user requirement + if self.require_single_user and not has_single_user(node): + return False + + # Check nargs and kwargs + if not self._check_nargs(node): + return False + if not self._check_kwargs(node): + return False + + return True + + def _check_nargs(self, node: Node) -> bool: + """Check if node has the required number of args.""" + n = len(node.args) + if isinstance(self.nargs, tuple): + min_args, max_args = self.nargs + # Must be in range AND enough to access arg_index + return min_args <= n <= max_args and n > self.arg_index + else: + # Must have at least nargs, AND enough to access arg_index + return n >= self.nargs and n > self.arg_index + + def _check_kwargs(self, node: Node) -> bool: + """Check that node's kwargs are all declared in self.kwargs (no unhandled kwargs).""" + return set(node.kwargs.keys()).issubset(self.kwargs) + + +def walk_back( # noqa: C901 + node: Node, + steps: List[OpStep], + debug: bool = False, +) -> Optional[Tuple[Node, List[WalkBackEntry]]]: + """ + Walk backwards through a chain of ops, matching against a pattern. + + Starting from *node*, try to match each step against the current node. + At every matched step the walk advances to ``cur.args[step.arg_index]``. + Optional steps are silently skipped when they don't match. Repeat steps + match 0 or more times. + + Args: + node: Starting node + steps: List of OpStep to match in order + + Returns: + ``(base_node, entries)`` if the full chain matches, else ``None``. + *base_node* is the input to the first (deepest) op in the chain. + *entries* is a list with one entry per OpStep: + - Node: matched node (for regular steps) + - None: optional step that didn't match + - List[Node]: repeat step (0 or more matches) + + Examples: + # Match: rsqrt(add(mean(pow(x, 2)), eps)) + result = walk_back(rsqrt_node, [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.add.Tensor), + OpStep(op=torch.ops.aten.mean.dim), + OpStep(op=torch.ops.aten.pow.Tensor_Scalar), + ]) + if result: + base, entries = result + rsqrt, add, mean, pow = entries # Each is a Node + + # Match chain of select ops (like tensor[0][0]) + result = walk_back(node, [ + OpStep(op=torch.ops.aten.select.int, repeat=True), + ]) + if result: + base, entries = result + select_nodes = entries[0] # List[Node], may be empty + + # Skip optional _to_copy, then match rsqrt + result = walk_back(node, [ + OpStep(op=torch.ops.aten._to_copy.default, optional=True), + OpStep(op=torch.ops.aten.rsqrt.default), + ]) + if result: + base, entries = result + to_copy, rsqrt = entries # to_copy may be None + """ + entries: List[WalkBackEntry] = [] + cur = node + + for i, step in enumerate(steps): + if not isinstance(cur, Node): + if debug: + print( + f" [walk_back] step {i}: cur is not a Node ({type(cur).__name__})" + ) + return None + + if step.repeat: + # Match 0 or more times, return as list + matched_nodes: List[Node] = [] + while isinstance(cur, Node) and step.matches(cur): + matched_nodes.append(cur) + cur = cur.args[step.arg_index] + entries.append(matched_nodes) + if debug: + print( + f" [walk_back] step {i} (repeat): matched {len(matched_nodes)} nodes" + ) + # repeat always succeeds (matches 0 or more) + continue + + if step.matches(cur): + entries.append(cur) + if debug: + print(f" [walk_back] step {i}: matched {cur.name}") + cur = cur.args[step.arg_index] + elif step.optional: + entries.append(None) + if debug: + print(f" [walk_back] step {i} (optional): skipped, cur={cur.name}") + continue + else: + if debug: + print( + f" [walk_back] step {i}: FAILED at cur={cur.name}, target={cur.target}, step.op={step.op}" + ) + return None + + if not isinstance(cur, Node): + return None + + return cur, entries + + +@dataclass +class PatternMatch: + """ + Base class for pattern match results. + + Subclasses should: + 1. Add fields for captured values (input nodes, constants, etc.) + 2. Implement maybe_create() classmethod for pattern matching + 3. Optionally implement emit_* methods for specific backends + + Example: + @dataclass + class RMSNormMatch(PatternMatch): + input_node: Node + weight_node: Node + eps: float + + @classmethod + def maybe_create(cls, head: Node) -> Optional["RMSNormMatch"]: + # Pattern matching logic... + if not matched: + return None + return cls( + head=head, + body=body_nodes, + input_node=input_node, + weight_node=weight_node, + eps=eps_value, + ) + """ + + head: Node # The output node of the matched pattern + body: List[Node] = field(default_factory=list) # Intermediate nodes + + @classmethod + def maybe_create(cls, head: Node, **context) -> Optional["PatternMatch"]: + """ + Try to match the pattern starting from head node. + + Override in subclasses to implement pattern-specific matching. + + Args: + head: Candidate head node to match from + **context: Additional context (e.g., ExportedProgram for patterns.py) + + Returns: + PatternMatch instance with captured values, or None if no match + """ + return None + + def remove_body_nodes(self, graph: Graph) -> None: + """ + Remove body nodes from the graph (in reverse order for safety). + + Call after replacing head with fused op. + """ + for node in reversed(self.body): + if has_no_users(node): + graph.erase_node(node) + + def all_nodes(self) -> List[Node]: + """Return all nodes in the pattern (head + body).""" + return [self.head] + self.body diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py new file mode 100644 index 00000000000..c8bef1f91ca --- /dev/null +++ b/backends/mlx/patterns.py @@ -0,0 +1,14 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Pattern Handlers - pattern-based op lowering for fused operations. + +This module contains pattern handlers that match multi-node subgraphs and lower +them to optimized MLX operations. +""" diff --git a/backends/mlx/preprocess.py b/backends/mlx/preprocess.py new file mode 100644 index 00000000000..315835f1689 --- /dev/null +++ b/backends/mlx/preprocess.py @@ -0,0 +1,168 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Backend preprocessing - converts EdgeIR to MLX delegate payload. + +This module implements the BackendDetails.preprocess() method which: +1. Takes an ExportedProgram (edge dialect) +2. Builds an MLXGraph using MLXProgramBuilder +3. Serializes to FlatBuffer (no embedded constants - those come via named_data_map) +4. Returns PreprocessResult with the binary and data_store_output for constants +""" + +from __future__ import annotations + +import hashlib +from typing import ClassVar, final, List + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( + HEADER_LENGTH, + MAGIC, + serialize_mlx_graph, +) +from executorch.exir.backend.backend_details import ( + BackendDetails, + CompileSpec, + PreprocessResult, +) +from torch.export.exported_program import ExportedProgram + + +@final +class MLXBackend(BackendDetails): + """ + ExecuTorch backend for MLX (Apple Silicon GPU compute framework). + + This backend compiles EdgeIR programs to a custom bytecode format + that can be executed by the MLX C++ runtime. + + Constants (weights) are stored in ExecuTorch's named_data_map rather than + embedded in the delegate payload. This allows ExecuTorch to own the constant + data and provide it to the backend at runtime. + """ + + MAGIC_IX: ClassVar[slice] = slice(4, 8) + DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16) + DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24) + + EXPECTED_MAGIC: ClassVar[bytes] = MAGIC + EXPECTED_LENGTH: ClassVar[int] = HEADER_LENGTH + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Convert an ExportedProgram to MLX delegate payload. + + Args: + edge_program: The ExportedProgram in edge dialect to compile. + compile_specs: List of compilation options. + + Returns: + PreprocessResult containing the serialized MLX program and + data_store_output with constant tensor data. + """ + logger.debug("MLXBackend.preprocess() called") + logger.debug(f"Edge program:\n{edge_program}") + + # Build MLXGraph from ExportedProgram + # Use a deterministic 4-hex prefix derived from the edge program to + # namespace named_data keys, avoiding collisions in multi-method + # programs where different methods may have lifted tensor constants + # with the same auto-generated name. + prefix = hashlib.sha256(str(edge_program).encode()).hexdigest()[:4] + builder = MLXProgramBuilder(edge_program, named_data_key_prefix=prefix) + mlx_graph = builder.build() + + # Get constant data as NamedDataStore (ET will own this data) + named_data_store = builder.get_named_data_store() + + logger.debug(f" named_data_store entries: {len(named_data_store.pte_data)}") + _log_mlx_graph(mlx_graph) + + # Serialize to bytes (no constant data embedded) + serialized = serialize_mlx_graph(mlx_graph) + + logger.debug(f"MLXBackend.preprocess() complete: {len(serialized)} bytes") + + return PreprocessResult( + processed_bytes=serialized, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + +def _format_tensor_meta(meta) -> str: + """Format a TensorMeta for display.""" + shape_parts = [] + for dim in meta.shape: + if dim.value == -1: + # Dynamic dim + if dim.max_value == -1: + shape_parts.append(f"dyn(min={dim.min_value})") + else: + shape_parts.append(f"dyn({dim.min_value}..{dim.max_value})") + else: + shape_parts.append(str(dim.value)) + shape_str = f"[{', '.join(shape_parts)}]" + dtype_str = f"dtype={meta.scalar_type}" if meta.scalar_type is not None else "" + dim_order_str = f"dim_order={meta.dim_order}" if meta.dim_order is not None else "" + parts = [shape_str] + if dtype_str: + parts.append(dtype_str) + if dim_order_str: + parts.append(dim_order_str) + return ", ".join(parts) + + +def _log_mlx_graph(mlx_graph) -> None: # noqa: C901 + """Log MLXGraph contents at DEBUG level for debugging.""" + logger.debug("MLXGraph:") + logger.debug(f" version: {mlx_graph.version}") + logger.debug(f" num_constant_tensors: {mlx_graph.num_constant_tensors}") + logger.debug(f" num_input_tensors: {mlx_graph.num_input_tensors}") + logger.debug(f" num_output_tensors: {mlx_graph.num_output_tensors}") + logger.debug( + f" num_mutable_buffer_tensors: {mlx_graph.num_mutable_buffer_tensors}" + ) + logger.debug(f" num_temp_tensors: {mlx_graph.num_temp_tensors}") + logger.debug(f" num_values: {mlx_graph.num_values}") + logger.debug(f" instruction_chains ({len(mlx_graph.instruction_chains)}):") + for c, chain in enumerate(mlx_graph.instruction_chains): + label = "" + if c == mlx_graph.main_chain_idx: + label = " (main)" + elif c == mlx_graph.init_chain_idx: + label = " (init)" + logger.debug(f" chain {c}{label} ({len(chain.instructions)} instructions):") + for i, instr in enumerate(chain.instructions): + logger.debug(f" [{i}]: {type(instr.op).__name__}") + if mlx_graph.input_map: + logger.debug(f" input_map ({len(mlx_graph.input_map)}):") + for i, slot in enumerate(mlx_graph.input_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.output_map: + logger.debug(f" output_map ({len(mlx_graph.output_map)}):") + for i, slot in enumerate(mlx_graph.output_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.mutable_buffer_map: + logger.debug(f" mutable_buffer_map ({len(mlx_graph.mutable_buffer_map)}):") + for i, slot in enumerate(mlx_graph.mutable_buffer_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.named_slots: + logger.debug(f" named_slots ({len(mlx_graph.named_slots)}):") + for ns in mlx_graph.named_slots: + logger.debug(f" {ns.name}: {ns.slot}") + if mlx_graph.tensor_meta: + logger.debug(f" tensor_meta ({len(mlx_graph.tensor_meta)}):") + for i, meta in enumerate(mlx_graph.tensor_meta): + logger.debug(f" t{i}: {_format_tensor_meta(meta)}") diff --git a/backends/mlx/pte_inspector.py b/backends/mlx/pte_inspector.py new file mode 100644 index 00000000000..d9e533b0b1e --- /dev/null +++ b/backends/mlx/pte_inspector.py @@ -0,0 +1,897 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PTE Inspector - Extract and dump data from ExecuTorch .pte files. + +This utility can: +1. Parse the PTE file structure (header, flatbuffer, segments) +2. Extract delegate payloads (e.g., MLX backend data) +3. Convert FlatBuffer data to JSON for inspection + +Usage: + python pte_inspector.py mlx_mlp.pte + python pte_inspector.py mlx_mlp.pte --output output.json + python pte_inspector.py mlx_mlp.pte --extract-delegate mlx --output mlx_payload.bin +""" + +from __future__ import annotations + +import argparse +import json +import sys +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from executorch.backends.mlx._generated_inspector import OP_NODE_FIELDS +from executorch.backends.mlx.serialization._generated_serializers import ( + MLX_OP_TYPE_NAMES, +) +from executorch.exir._serialize._program import ( + _ExtendedHeader, + _extract_delegate_payload as extract_delegate_payload, +) + +MLX_MAGIC = b"MLX0" +MLX_HEADER_LENGTH = 24 + +_SLOT_TYPE_NAMES = {0: "Tensor", 1: "Int", 2: "Float", 3: "Bool"} + + +@dataclass +class MLXHeader: + + magic: bytes + data_segment_offset: int + data_segment_size: int + + @classmethod + def from_bytes(cls, data: bytes) -> "MLXHeader": + if len(data) < MLX_HEADER_LENGTH: + raise ValueError( + f"Not enough data for MLX header: {len(data)} < {MLX_HEADER_LENGTH}" + ) + + # Layout: [4 bytes padding][4 bytes magic][8 bytes offset][8 bytes size] + magic = data[4:8] + data_segment_offset = int.from_bytes(data[8:16], byteorder="little") + data_segment_size = int.from_bytes(data[16:24], byteorder="little") + + return cls( + magic=magic, + data_segment_offset=data_segment_offset, + data_segment_size=data_segment_size, + ) + + def is_valid(self) -> bool: + return self.magic == MLX_MAGIC + + def to_dict(self) -> Dict[str, Any]: + return { + "magic": self.magic.decode("utf-8", errors="replace"), + "data_segment_offset": self.data_segment_offset, + "data_segment_size": self.data_segment_size, + } + + +@dataclass +class MLXPayload: + """Parsed MLX delegate payload: header + flatbuffer bytes.""" + + header: MLXHeader + fb_data: bytes + raw: bytes + + +def _load_mlx_payload(pte_data: bytes, delegate_index: int = 0) -> MLXPayload: + """Extract MLX delegate payload from PTE data and parse its header. + + Raises ``ValueError`` if the delegate cannot be found or the MLX header is + invalid. + """ + payload = extract_delegate_payload(pte_data, "mlx", delegate_index=delegate_index) + if payload is None: + raise ValueError(f"Could not extract MLX delegate {delegate_index}") + + header = MLXHeader.from_bytes(payload) + if not header.is_valid(): + raise ValueError(f"Invalid MLX magic: {header.magic!r}") + + fb_data = payload[MLX_HEADER_LENGTH : header.data_segment_offset] + return MLXPayload(header=header, fb_data=fb_data, raw=payload) + + +def _find_mlx_delegates(pte_data: bytes) -> List[Tuple[int, Dict]]: + """Return list of ``(plan_index, delegate_dict)`` for every MLX delegate.""" + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + program_data = json.loads(_program_flatbuffer_to_json(pte_data)) + delegates: List[Tuple[int, Dict]] = [] + for plan in program_data.get("execution_plan", []): + for i, delegate in enumerate(plan.get("delegates", [])): + if "mlx" in delegate.get("id", "").lower(): + delegates.append((i, delegate)) + return delegates + + +def _get_fb_graph(fb_data: bytes): + """Return the FlatBuffer MLXGraph root object.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + MLXGraph as FBMLXGraph, + ) + + return FBMLXGraph.MLXGraph.GetRootAs(fb_data, 0) + + +def _parse_graph_info(graph) -> Dict[str, Any]: + """Extract top-level graph scalars (tensor counts, chain counts, etc.).""" + return { + "version": graph.Version().decode("utf-8") if graph.Version() else None, + "num_constant_tensors": graph.NumConstantTensors(), + "num_input_tensors": graph.NumInputTensors(), + "num_output_tensors": graph.NumOutputTensors(), + "num_mutable_buffer_tensors": graph.NumMutableBufferTensors(), + "num_temp_tensors": graph.NumTempTensors(), + "num_values": graph.NumValues(), + "num_instruction_chains": graph.InstructionChainsLength(), + "main_chain_idx": graph.MainChainIdx(), + "init_chain_idx": graph.InitChainIdx(), + "input_map_length": graph.InputMapLength(), + "output_map_length": graph.OutputMapLength(), + "mutable_buffer_map_length": graph.MutableBufferMapLength(), + "named_slots_length": graph.NamedSlotsLength(), + "tensor_meta_length": graph.TensorMetaLength(), + } + + +def _parse_instructions(graph) -> List[Dict[str, Any]]: + """Parse all instruction chains and their op nodes.""" + chains: List[Dict[str, Any]] = [] + for c in range(graph.InstructionChainsLength()): + chain = graph.InstructionChains(c) + chain_info: Dict[str, Any] = {"chain_index": c, "instructions": []} + if chain: + for i in range(chain.InstructionsLength()): + try: + instr = chain.Instructions(i) + if instr: + op_type = instr.OpType() + op_name = MLX_OP_TYPE_NAMES.get(op_type, f"Unknown({op_type})") + instr_info: Dict[str, Any] = { + "instr_idx": i, + "op_type": op_type, + "op_name": op_name, + } + op_data = _parse_op_node(instr, op_name) + if op_data: + instr_info.update(op_data) + chain_info["instructions"].append(instr_info) + except Exception as e: + chain_info["instructions"].append( + {"instr_idx": i, "error": f"parse_failed: {e}"} + ) + chains.append(chain_info) + return chains + + +def _parse_named_slots(graph) -> List[Dict[str, Any]]: + slots: List[Dict[str, Any]] = [] + for i in range(graph.NamedSlotsLength()): + try: + ns = graph.NamedSlots(i) + if ns: + info: Dict[str, Any] = { + "name": ns.Name().decode("utf-8") if ns.Name() else None, + } + slot = ns.Slot() + if slot: + info["slot_idx"] = slot.Idx() + info["slot_type"] = slot.SlotType() + slots.append(info) + except Exception as e: + slots.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return slots + + +def _parse_tensor_meta(graph) -> List[Dict[str, Any]]: + metas: List[Dict[str, Any]] = [] + for i in range(graph.TensorMetaLength()): + try: + tm = graph.TensorMeta(i) + if tm: + shape: List[Any] = [] + for j in range(tm.ShapeLength()): + sd = tm.Shape(j) + if sd.Value() == -1: + lo = sd.MinValue() + hi = sd.MaxValue() + if hi == -1: + shape.append(f"dyn(min={lo})") + else: + shape.append(f"dyn({lo}..{hi})") + else: + shape.append(sd.Value()) + meta: Dict[str, Any] = { + "index": i, + "dtype": tm.Dtype(), + "shape": shape, + } + if tm.StridesLength() > 0: + meta["strides"] = [tm.Strides(j) for j in range(tm.StridesLength())] + metas.append(meta) + except Exception as e: + metas.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return metas + + +def _parse_io_maps( + graph, +) -> Tuple[List[Dict], List[Dict], List[Dict]]: + """Return (input_map, output_map, mutable_buffer_map) as slot-variant dicts.""" + + def _extract( + length_fn: Callable[[], int], getter_fn: Callable[[int], Any] + ) -> List[Dict]: + result = [] + for i in range(length_fn()): + try: + sv = getter_fn(i) + if sv: + result.append({"idx": sv.Idx(), "slot_type": sv.SlotType()}) + except Exception as e: + result.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return result + + return ( + _extract(graph.InputMapLength, graph.InputMap), + _extract(graph.OutputMapLength, graph.OutputMap), + _extract(graph.MutableBufferMapLength, graph.MutableBufferMap), + ) + + +def parse_mlx_flatbuffer(fb_data: bytes) -> Dict[str, Any]: + """Parse MLX FlatBuffer data into a dict using the generated FlatBuffer bindings.""" + result: Dict[str, Any] = {} + try: + graph = _get_fb_graph(fb_data) + + result = _parse_graph_info(graph) + result["instruction_chains"] = _parse_instructions(graph) + result["named_slots"] = _parse_named_slots(graph) + result["tensor_meta"] = _parse_tensor_meta(graph) + + input_map, output_map, mutable_buffer_map = _parse_io_maps(graph) + result["input_map"] = input_map + result["output_map"] = output_map + result["mutable_buffer_map"] = mutable_buffer_map + + try: + cs = graph.ConstantSegment() + if cs: + result["constant_segment"] = { + "offset": cs.Offset(), + "size": cs.Size(), + } + except Exception as e: + result["constant_segment_error"] = f"parse_failed: {e}" + + except ImportError as e: + result["error"] = f"FlatBuffer bindings not available: {e}" + result["_fallback"] = "Using basic header parsing only" + except Exception as e: + result["error"] = f"FlatBuffer parse error: {e}" + result["traceback"] = traceback.format_exc() + + return result + + +def _parse_op_node(instr, op_name: str) -> Optional[Dict[str, Any]]: + """Parse the specific op node fields from an instruction. + + Uses the generated field mappings in ``OP_NODE_FIELDS`` to extract + op-specific fields without manually maintaining per-op logic. + """ + try: + op = instr.Op() + if op is None: + return None + + if op_name not in OP_NODE_FIELDS: + return {"error": f"Unknown op type: {op_name}"} + + module = __import__( + f"executorch.backends.mlx.serialization._generated.mlx_delegate.{op_name}", + fromlist=[op_name], + ) + node_class = getattr(module, op_name) + node = node_class() + node.Init(op.Bytes, op.Pos) + + result: Dict[str, Any] = {} + for field_name, accessor_name, kind in OP_NODE_FIELDS[op_name]: + try: + result[field_name] = _extract_field(node, accessor_name, kind) + except Exception as e: + result[field_name] = {"error": str(e)} + + result = {k: v for k, v in result.items() if v is not None} + return result if result else None + + except Exception as e: + return {"parse_error": str(e), "traceback": traceback.format_exc()} + + +def _extract_vid_or_tid(obj) -> Optional[Dict[str, Any]]: + """Extract a VidOrTid FlatBuffer object into a dict. + + VidOrTid has: .IsVid() -> bool, .Vid() -> Vid|None, .Tid() -> Tid|None. + Same pattern as IntOrVid but references value/tensor slots instead of + holding a literal. + """ + if obj is None: + return None + if obj.IsVid(): + v = obj.Vid() + return {"vid": v.Idx()} if v else None + t = obj.Tid() + return {"tid": t.Idx()} if t else None + + +def _extract_field(node, accessor_name: str, kind: str) -> Any: # noqa: C901 + """Extract a single field from a FlatBuffer op node based on its *kind*.""" + if kind == "tid": + t = getattr(node, accessor_name)() + return {"tid": t.Idx()} if t else None + + if kind == "vid": + v = getattr(node, accessor_name)() + return {"vid": v.Idx()} if v else None + + if kind == "vid_or_tid": + return _extract_vid_or_tid(getattr(node, accessor_name)()) + + if kind == "int_or_vid_or_tid": + ivt = getattr(node, accessor_name)() + if ivt is None: + return None + k = ivt.Kind() + if k == 0: # literal int + return {"literal": ivt.Literal()} + elif k == 1: # Vid + v = ivt.Vid() + return {"vid": v.Idx()} if v else None + elif k == 2: # Tid + t = ivt.Tid() + return {"tid": t.Idx()} if t else None + return {"kind": k} + + if kind == "int_or_vid": + iov = getattr(node, accessor_name)() + if iov is None: + return None + if iov.IsVid(): + v = iov.Vid() + return {"vid": v.Idx()} if v else None + return {"literal": iov.Literal()} + + if kind == "float_or_vid": + fov = getattr(node, accessor_name)() + if fov is None: + return None + if fov.IsVid(): + v = fov.Vid() + return {"vid": v.Idx()} if v else None + return {"literal": fov.Literal()} + + if kind == "int_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + return [getter(i) for i in range(length)] + + if kind == "tid_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + items = [] + for i in range(length): + s = getter(i) + items.append(f"tid {s.Idx()}" if s else None) + return items + + if kind == "int_or_vid_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + items = [] + for i in range(length): + iov = getter(i) + if iov is None: + items.append(None) + elif iov.IsVid(): + v = iov.Vid() + items.append({"vid": v.Idx()} if v else None) + else: + items.append({"literal": iov.Literal()}) + return items + + if kind == "string": + val = getattr(node, accessor_name)() + return val.decode("utf-8") if val else None + + # scalar (default) + return getattr(node, accessor_name)() + + +def parse_mlx_payload(payload: bytes) -> Dict[str, Any]: + """Parse raw MLX delegate payload bytes into a dict. + + This is the public entry point for callers that already have the raw + delegate payload (e.g. from ``extract_delegate_payload``). + """ + header = MLXHeader.from_bytes(payload) + + if not header.is_valid(): + return { + "error": f"Invalid MLX magic: {header.magic!r}", + "header": header.to_dict(), + } + + fb_data = payload[MLX_HEADER_LENGTH : header.data_segment_offset] + result: Dict[str, Any] = { + "header": header.to_dict(), + "flatbuffer_size": len(fb_data), + "graph": parse_mlx_flatbuffer(fb_data), + } + + if header.data_segment_size > 0: + result["constant_data_size"] = header.data_segment_size + + return result + + +def parse_executorch_program(pte_data: bytes) -> Dict[str, Any]: # noqa: C901 + result: Dict[str, Any] = {} + + if len(pte_data) < 8: + raise ValueError("File too small to be a valid PTE file") + + fb_magic = pte_data[4:8] + result["flatbuffer_magic"] = fb_magic.decode("utf-8", errors="replace") + + extended_header_offset = 8 + if len(pte_data) > extended_header_offset + 32: + try: + header = _ExtendedHeader.from_bytes(pte_data[extended_header_offset:]) + if header.is_valid(): + result["extended_header"] = { + "magic": header.magic.decode("utf-8", errors="replace"), + "length": header.length, + "program_size": header.program_size, + "segment_base_offset": header.segment_base_offset, + "segment_data_size": header.segment_data_size, + } + fb_start = extended_header_offset + header.length + result["flatbuffer_offset"] = fb_start + result["flatbuffer_size"] = header.program_size + result["segment_offset"] = header.segment_base_offset + result["segment_size"] = header.segment_data_size + except Exception as e: + result["header_parse_error"] = str(e) + + try: + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + program_data = json.loads(_program_flatbuffer_to_json(pte_data)) + result["program"] = program_data + + if "execution_plan" in program_data: + delegates = [] + for plan in program_data["execution_plan"]: + if "delegates" in plan: + for delegate in plan["delegates"]: + delegate_info: Dict[str, Any] = { + "id": delegate.get("id"), + "processed_type": delegate.get("processed", {}).get( + "location" + ), + } + processed = delegate.get("processed", {}) + if "data" in processed: + delegate_info["inline_data_size"] = len(processed["data"]) + if "location" in processed: + delegate_info["location"] = processed["location"] + delegates.append(delegate_info) + result["delegates"] = delegates + + except ImportError: + result["program_parse_error"] = "ExecuTorch FlatBuffer parsing not available" + except Exception as e: + result["program_parse_error"] = str(e) + + return result + + +def _slot_type_display(slot_type: int, style: str = "full") -> str: + """Return display string for a slot type. + + *style* controls the format: + - ``"full"``: "Tensor", "Int", etc. (for summary tables) + - ``"short"``: "tid", "vid" (for instruction I/O lists) + """ + if style == "short": + return "tid" if slot_type == 0 else "vid" + return _SLOT_TYPE_NAMES.get(slot_type, "Unknown") + + +def _print_slot_map(label: str, slots: List[Dict]) -> None: + """Print a list of slot-variant dicts with their type names.""" + if not slots: + return + print(f"\n {label}:") + for i, slot in enumerate(slots): + type_name = _slot_type_display(slot.get("slot_type", 0)) + print(f" [{i}]: idx={slot.get('idx')}, type={type_name}") + + +def show_mlx_summary(pte_data: bytes) -> None: # noqa: C901 + try: + mlx_delegates = _find_mlx_delegates(pte_data) + if not mlx_delegates: + print("No MLX delegates found in this PTE file.") + return + + print(f"\n{'='*70}") + print("MLX DELEGATE SUMMARY") + print(f"{'='*70}") + print(f"File contains {len(mlx_delegates)} MLX delegate(s)\n") + + for idx, (delegate_idx, delegate) in enumerate(mlx_delegates): + print(f"\n--- Delegate {idx} (plan index {delegate_idx}) ---") + print(f"ID: {delegate.get('id', 'unknown')}") + + try: + mlx = _load_mlx_payload(pte_data, delegate_index=idx) + except ValueError as e: + print(f" {e}") + continue + + graph_info = parse_mlx_flatbuffer(mlx.fb_data) + + print("\nMLX Graph Info:") + for key in ( + "num_constant_tensors", + "num_input_tensors", + "num_output_tensors", + "num_mutable_buffer_tensors", + "num_temp_tensors", + "num_values", + "num_instruction_chains", + ): + label = f" {key + ':':<29}" + print(f"{label}{graph_info.get(key, '?')}") + + main_idx = graph_info.get("main_chain_idx", 0) + chains = graph_info.get("instruction_chains", []) + main_num = "?" + if main_idx < len(chains): + main_num = len(chains[main_idx].get("instructions", [])) + print(f" {'main_chain_idx:':<29}{main_idx} ({main_num} instructions)") + print(f" {'init_chain_idx:':<29}{graph_info.get('init_chain_idx', '?')}") + + print("\nI/O Maps:") + print( + f" {'input_map length:':<29}{graph_info.get('input_map_length', '?')}" + ) + print( + f" {'output_map length:':<29}{graph_info.get('output_map_length', '?')}" + ) + print( + f" {'mutable_buffer_map length:':<29}{graph_info.get('mutable_buffer_map_length', '?')}" + ) + + input_len = graph_info.get("input_map_length", 0) + mutable_len = graph_info.get("mutable_buffer_map_length", 0) + if input_len and mutable_len is not None: + print( + f" => regular inputs expected: {input_len - mutable_len} (input_map - mutable_buffer_map)" + ) + + _print_slot_map("Input Map Details", graph_info.get("input_map", [])) + if graph_info.get("mutable_buffer_map"): + _print_slot_map( + "Mutable Buffer Map Details", + graph_info["mutable_buffer_map"], + ) + _print_slot_map("Output Map Details", graph_info.get("output_map", [])) + + if mlx.header.data_segment_size > 0: + print(f"\n Constant data size: {mlx.header.data_segment_size:,} bytes") + + print(f"\n{'='*70}\n") + + except Exception as e: + print(f"Error showing MLX summary: {e}", file=sys.stderr) + traceback.print_exc() + + +def show_mlx_instructions(pte_data: bytes) -> None: # noqa: C901 + try: + mlx_delegates = _find_mlx_delegates(pte_data) + if not mlx_delegates: + print("No MLX delegates found in this PTE file.", file=sys.stderr) + sys.exit(1) + + if len(mlx_delegates) > 1: + print( + f"Found {len(mlx_delegates)} MLX delegate(s) in PTE file\n", + file=sys.stderr, + ) + + for idx, (delegate_idx, _delegate) in enumerate(mlx_delegates): + try: + mlx = _load_mlx_payload(pte_data, delegate_index=idx) + except ValueError as e: + print(f"\nError: {e}", file=sys.stderr) + continue + + graph = parse_mlx_flatbuffer(mlx.fb_data) + if "error" in graph: + print( + f"\nError parsing delegate {idx}: {graph['error']}", + file=sys.stderr, + ) + continue + + # Print delegate header + if len(mlx_delegates) > 1: + print("\n" + "=" * 70) + print(f"MLX DELEGATE {idx} (plan index {delegate_idx})") + print("=" * 70) + else: + print("\n" + "=" * 70) + print("MLX Graph Summary") + print("=" * 70) + + # Basic info + print(f"Version: {graph.get('version', 'unknown')}") + print(f"Constant tensors: {graph.get('num_constant_tensors', 0)}") + print(f"Input tensors: {graph.get('num_input_tensors', 0)}") + print(f"Output tensors: {graph.get('num_output_tensors', 0)}") + print( + f"Mutable buffer tensors: {graph.get('num_mutable_buffer_tensors', 0)}" + ) + print(f"Temp tensors: {graph.get('num_temp_tensors', 0)}") + print(f"Values: {graph.get('num_values', 0)}") + num_chains = graph.get("num_instruction_chains", 0) + main_idx = graph.get("main_chain_idx", 0) + init_idx = graph.get("init_chain_idx", -1) + print(f"Instruction chains: {num_chains}") + print(f"Main chain idx: {main_idx}") + if init_idx >= 0: + print(f"Init chain idx: {init_idx}") + + constant_seg = graph.get("constant_segment", {}) + if constant_seg: + print(f"Constant data: {constant_seg.get('size', 0):,} bytes") + + # Instruction chains + for chain_info in graph.get("instruction_chains", []): + chain_idx = chain_info.get("chain_index", "?") + label = "" + if chain_idx == main_idx: + label = " (main)" + elif chain_idx == init_idx: + label = " (init)" + instructions = chain_info.get("instructions", []) + print(f"\nChain {chain_idx}{label} ({len(instructions)} instructions):") + for instr in instructions: + op_name = instr.get("op_name", f"op_{instr.get('op_type', '?')}") + print(f" [{instr.get('instr_idx', '?')}] {op_name}") + + for key, value in instr.items(): + if key in ("instr_idx", "op_type", "op_name"): + continue + if isinstance(value, dict): + if "tid" in value: + print(f" {key}: tid {value['tid']}") + elif "vid" in value: + print(f" {key}: vid {value['vid']}") + else: + print(f" {key}: {value}") + elif value is not None: + print(f" {key}: {value}") + + # Named slots + named_slots = graph.get("named_slots", []) + if named_slots: + print("\nNamed Slots:") + for slot in named_slots: + slot_type = _slot_type_display( + slot.get("slot_type", 0), style="short" + ) + print( + f" [{slot.get('slot_idx', '?')}] {slot.get('name', '?')} ({slot_type})" + ) + + # Input/Output maps + input_map = graph.get("input_map", []) + output_map = graph.get("output_map", []) + + if input_map: + print("\nInputs:") + for inp in input_map: + slot_type = _slot_type_display( + inp.get("slot_type", 0), style="short" + ) + print(f" {slot_type} {inp.get('idx', '?')}") + + if output_map: + print("\nOutputs:") + for out in output_map: + slot_type = _slot_type_display( + out.get("slot_type", 0), style="short" + ) + print(f" {slot_type} {out.get('idx', '?')}") + + print("=" * 70 + "\n") + + except Exception as e: + print(f"Error showing MLX instructions: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + +def main(): # noqa: C901 + parser = argparse.ArgumentParser( + description="Inspect ExecuTorch .pte files and extract data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +MLX-Specific Options: + --mlx-summary Show high-level summary (tensor counts, I/O maps) + --mlx-instructions Show detailed instruction list with operation parameters + (use this to verify quantization, inspect ops, etc.) + +Examples: + # Basic PTE file inspection + python -m executorch.backends.mlx.pte_inspector model.pte + + # Show high-level MLX delegate summary + python -m executorch.backends.mlx.pte_inspector model.pte --mlx-summary + + # Show detailed MLX instructions (verify quantization, inspect operations) + python -m executorch.backends.mlx.pte_inspector model.pte --mlx-instructions + + # Extract raw delegate payload to binary file + python -m executorch.backends.mlx.pte_inspector model.pte \\ + --extract-delegate MLXBackend -o delegate.bin + """, + ) + parser.add_argument("pte_file", type=Path, help="Path to the .pte file") + parser.add_argument( + "--output", "-o", type=Path, help="Output file (default: stdout)" + ) + parser.add_argument( + "--extract-delegate", + type=str, + metavar="ID", + help="Extract delegate payload by ID (e.g., 'mlx')", + ) + parser.add_argument( + "--delegate-index", + type=int, + default=None, + metavar="N", + help="Index of delegate to extract (0-based). If not specified, extracts first matching delegate.", + ) + parser.add_argument( + "--parse-mlx", + action="store_true", + help="Parse extracted MLX payload (use with --extract-delegate mlx)", + ) + parser.add_argument( + "--mlx-summary", + action="store_true", + help="Show summary of all MLX delegates (input/output/mutable buffer counts)", + ) + parser.add_argument( + "--mlx-instructions", + action="store_true", + help="Show detailed MLX instruction list with operands and quantization details", + ) + parser.add_argument( + "--format", + choices=["json", "summary"], + default="json", + help="Output format (default: json)", + ) + parser.add_argument( + "--indent", + type=int, + default=2, + help="JSON indentation (default: 2)", + ) + + args = parser.parse_args() + + if not args.pte_file.exists(): + print(f"Error: File not found: {args.pte_file}", file=sys.stderr) + sys.exit(1) + + pte_data = args.pte_file.read_bytes() + print(f"Loaded {len(pte_data)} bytes from {args.pte_file}", file=sys.stderr) + + if args.mlx_instructions: + show_mlx_instructions(pte_data) + return + + if args.mlx_summary: + show_mlx_summary(pte_data) + return + + if args.extract_delegate: + payload = extract_delegate_payload( + pte_data, args.extract_delegate, delegate_index=args.delegate_index + ) + if payload is None: + print( + f"Error: Delegate '{args.extract_delegate}' not found", file=sys.stderr + ) + sys.exit(1) + + if args.parse_mlx and args.extract_delegate.lower() == "mlx": + result = parse_mlx_payload(payload) + + output = json.dumps(result, indent=args.indent, default=str) + + if args.output: + args.output.write_text(output) + print(f"Wrote parsed MLX data to {args.output}", file=sys.stderr) + else: + print(output) + else: + if args.output: + args.output.write_bytes(payload) + print(f"Wrote {len(payload)} bytes to {args.output}", file=sys.stderr) + else: + print(f"Delegate payload: {len(payload)} bytes", file=sys.stderr) + if len(payload) >= MLX_HEADER_LENGTH: + header = MLXHeader.from_bytes(payload) + print(f" Magic: {header.magic!r}", file=sys.stderr) + print( + f" Data offset: {header.data_segment_offset}", file=sys.stderr + ) + print(f" Data size: {header.data_segment_size}", file=sys.stderr) + return + + result = parse_executorch_program(pte_data) + result["file_size"] = len(pte_data) + result["file_path"] = str(args.pte_file) + + if args.format == "summary": + print(f"PTE File: {args.pte_file}") + print(f" Size: {len(pte_data):,} bytes") + if "extended_header" in result: + h = result["extended_header"] + print(f" Program size: {h['program_size']:,} bytes") + print(f" Segment offset: {h['segment_base_offset']:,}") + print(f" Segment size: {h['segment_data_size']:,} bytes") + if "delegates" in result: + print(f" Delegates: {len(result['delegates'])}") + for d in result["delegates"]: + print(f" - {d.get('id', 'unknown')}") + else: + output = json.dumps(result, indent=args.indent, default=str) + + if args.output: + args.output.write_text(output) + print(f"Wrote JSON to {args.output}", file=sys.stderr) + else: + print(output) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp new file mode 100644 index 00000000000..38dff189935 --- /dev/null +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -0,0 +1,419 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#include "MLXExecutor.h" +#include "MLXInterpreter.h" +#include "MLXLoader.h" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// Note: We use fully qualified executorch::aten::Tensor because MLXExecutor.h +// defines Tensor as mlx::core::array in the executorch::backends::mlx +// namespace. +using ETTensor = ::executorch::aten::Tensor; +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::Backend; +using ::executorch::runtime::BackendExecutionContext; +using ::executorch::runtime::BackendInitContext; +using ::executorch::runtime::CompileSpec; +using ::executorch::runtime::DelegateHandle; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::FreeableBuffer; +using ::executorch::runtime::Result; +using ::executorch::runtime::Span; + +using ::mlx::core::array; +using ::mlx::core::Dtype; +using ::mlx::core::eval; + +namespace { + +array tensor_to_mlx( + const ETTensor& t, + const std::optional& expected_meta = std::nullopt) { + if (!executorch::runtime::tensor_is_contiguous(t)) { + throw std::runtime_error("tensor_to_mlx: input tensor is not contiguous"); + } + + Dtype dtype = + resolve_dtype(static_cast(t.scalar_type())); + + if (expected_meta.has_value()) { + Dtype expected_dtype = resolve_dtype(expected_meta->scalar_type); + if (dtype != expected_dtype) { + throw std::runtime_error( + std::string("tensor_to_mlx: dtype mismatch - input tensor has ") + + ExecutionState::dtype_str(dtype) + " but model expects " + + ExecutionState::dtype_str(expected_dtype)); + } + } + + ::mlx::core::Shape shape; + for (int i = 0; i < t.dim(); ++i) { + auto dim_size = t.size(i); + if (dim_size > std::numeric_limits::max() || + dim_size < std::numeric_limits::min()) { + throw std::runtime_error( + "tensor_to_mlx: dimension " + std::to_string(i) + " size " + + std::to_string(dim_size) + " exceeds int range"); + } + shape.push_back(static_cast(dim_size)); + } + + // SAFETY: MLX reads this data during async_eval() Metal command encoding, + // which completes before the lock is released. The ET tensor must remain + // valid until async_eval returns. + const void* cptr = t.const_data_ptr(); + if (!cptr) { + throw std::runtime_error("tensor_to_mlx: tensor has null data pointer"); + } + void* data_ptr = const_cast(cptr); + auto deleter = [](void*) {}; + return array(data_ptr, shape, dtype, deleter); +} + +// Build the contiguous + dtype conversion pipeline for an output array. +// Returns a lazy array (not yet evaluated) ready for async_eval. +array prepare_output( + const array& arr, + Dtype expected_dtype, + const ::mlx::core::Stream& stream) { + array result = + ::mlx::core::contiguous(arr, /*allow_col_major=*/false, stream); + if (result.dtype() != expected_dtype) { + result = ::mlx::core::astype(result, expected_dtype, stream); + } + return result; +} + +// Wait for a prepared output array and copy its data to an ET tensor. +// The array must have been submitted via async_eval before calling this. +void write_output(array& arr, ETTensor& out) { + arr.wait(); + + // Resize output tensor if shape doesn't match (dynamic shapes) + const auto& mlx_shape = arr.shape(); + auto out_sizes = out.sizes(); + + bool shape_matches = (mlx_shape.size() == static_cast(out.dim())); + if (shape_matches) { + for (size_t i = 0; i < mlx_shape.size(); ++i) { + if (static_cast(mlx_shape[i]) != + static_cast(out_sizes[i])) { + shape_matches = false; + break; + } + } + } + + if (!shape_matches) { + std::vector new_sizes; + new_sizes.reserve(mlx_shape.size()); + for (auto d : mlx_shape) { + new_sizes.push_back(static_cast(d)); + } + auto err = resize_tensor( + out, + ArrayRef( + new_sizes.data(), new_sizes.size())); + if (err != Error::Ok) { + throw std::runtime_error("write_output: failed to resize output tensor"); + } + } + + size_t mlx_nbytes = arr.nbytes(); + size_t out_nbytes = out.nbytes(); + if (mlx_nbytes != out_nbytes) { + throw std::runtime_error( + "write_output: size mismatch - MLX has " + std::to_string(mlx_nbytes) + + " bytes, output has " + std::to_string(out_nbytes) + " bytes"); + } + + const void* src = arr.data(); + if (!src) { + throw std::runtime_error( + "write_output: arr.data() is null after wait()"); + } + std::memcpy(out.mutable_data_ptr(), src, out_nbytes); +} + +} // namespace + +struct MLXHandle { + MLXProgram program; + ConstantData constants; + MutableBufferData mutable_buffers; + ExecutionState state; // Reusable execution state + Interpreter interpreter; + ::mlx::core::Stream stream; // Dedicated GPU stream for this handle + + // Keep the constant buffers alive for zero-copy constants + // Each FreeableBuffer must outlive the MLX arrays that reference it + std::vector constant_buffers; + + MLXHandle() : stream(::mlx::core::new_stream(::mlx::core::Device::gpu)) {} + ~MLXHandle() = default; + + MLXHandle(const MLXHandle&) = delete; + MLXHandle& operator=(const MLXHandle&) = delete; +}; + +// MLX is not thread-safe: its computation graph is global shared state. +// A global mutex serializes graph construction and command submission +// across all handles. GPU execution and output copies can proceed +// without the lock (see execute() for the async pipeline design). +static std::mutex& mlx_global_mutex() { + static std::mutex m; + return m; +} + +class MLXBackend final : public ::executorch::runtime::BackendInterface { + public: + ~MLXBackend() override = default; + + bool is_available() const override { + return ::mlx::core::metal::is_available(); + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + std::lock_guard lock(mlx_global_mutex()); + auto* handle = + context.get_runtime_allocator()->allocateInstance(); + if (handle == nullptr) { + return Error::MemoryAllocationFailed; + } + + try { + new (handle) MLXHandle(); + + if (!processed || !processed->data() || processed->size() == 0) { + throw std::runtime_error("init: null or empty delegate payload"); + } + + handle->program = loader::load_program( + static_cast(processed->data()), processed->size()); + + // Validate schema version + if (handle->program.version != "1") { + throw std::runtime_error( + "Unsupported MLX schema version '" + handle->program.version + + "' (expected '1'). Rebuild the .pte with a matching SDK version."); + } + + // Load constants from named_data_map + // Constants are stored by name in the .pte file and provided by ET at + // runtime + const runtime::NamedDataMap* named_data_map = + context.get_named_data_map(); + load_constants( + handle->program, + named_data_map, + handle->constants, + handle->constant_buffers); + + // Delegate payload no longer needed after constants are loaded + processed->Free(); + processed = nullptr; + + // Load mutable buffers (e.g., KV cache) + load_mutable_buffers(handle->program, handle->mutable_buffers); + + // Bind execution state (reused across execute() calls) + handle->state.bind( + handle->program, handle->constants, handle->mutable_buffers); + + // Run init chain if present. + // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the + // static_cast cannot produce UINT32_MAX from a -1 sentinel. + if (handle->program.init_chain_idx >= 0) { + handle->interpreter.run_chain( + handle->program, + static_cast(handle->program.init_chain_idx), + handle->state, + handle->stream); + } + + } catch (const std::exception& e) { + ET_LOG(Error, "Failed to load MLX program: %s", e.what()); + handle->~MLXHandle(); + if (processed != nullptr) { + processed->Free(); + } + return Error::InvalidProgram; + } + + return handle; + } + + Error execute( + ET_UNUSED BackendExecutionContext& context, + DelegateHandle* handle, + Span args) const override { + try { + std::vector prepared_outputs; + struct OutputInfo { + size_t arg_idx; + size_t prepared_idx; + }; + + std::vector tensor_output_info; + size_t arg_idx = 0; + + auto* h = static_cast(handle); + const auto& program = h->program; + + // Graph construction + async GPU dispatch (locked) + { + std::lock_guard lock(mlx_global_mutex()); + + h->state.reset(); + + const size_t n_inputs = program.input_map.size(); + const size_t n_outputs = program.output_map.size(); + if (n_inputs > SIZE_MAX - n_outputs) { + throw std::runtime_error("execute: input + output count overflow"); + } + const size_t expected_args = n_inputs + n_outputs; + if (args.size() != expected_args) { + ET_LOG( + Error, "Expected %zu args, got %zu", expected_args, args.size()); + return Error::InvalidArgument; + } + + // Bind inputs + for (const auto& slot : program.input_map) { + if (arg_idx >= args.size()) { + throw std::runtime_error( + "execute: arg_idx " + std::to_string(arg_idx) + + " out of bounds (args.size()=" + std::to_string(args.size()) + + ")"); + } + if (slot.slot_type == SlotType::TensorSlot) { + const ETTensor& tensor = args[arg_idx++]->toTensor(); + Tid tid{slot.idx}; + std::optional expected_meta = std::nullopt; + if (tid.idx < program.tensor_meta.size()) { + expected_meta = program.tensor_meta[tid.idx]; + } + h->state.set_tensor(tid, tensor_to_mlx(tensor, expected_meta)); + } else if (slot.slot_type == SlotType::IntValueSlot) { + int64_t val = args[arg_idx]->toInt(); + arg_idx++; + if (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) { + ET_LOG( + Error, + "Int input value %lld exceeds int32 range", + static_cast(val)); + return Error::InvalidArgument; + } + h->state.set_value(Vid{slot.idx}, static_cast(val)); + } else { + throw std::runtime_error( + "Unhandled input slot type: " + + std::to_string(static_cast(slot.slot_type))); + } + } + + // Run the MLX program (builds lazy computation graph) + h->interpreter.run(program, h->state, h->stream); + + // Prepare output pipeline and collect int outputs + // Build contiguous + dtype conversion lazily for each tensor output, + // and extract int outputs (which don't need GPU) while still locked. + prepared_outputs.reserve(program.num_output_tensors); + + for (const auto& slot : program.output_map) { + if (slot.slot_type == SlotType::TensorSlot) { + ETTensor& out_tensor = args[arg_idx]->toTensor(); + Dtype expected_dtype = + resolve_dtype(static_cast( + out_tensor.scalar_type())); + array out_arr = prepare_output( + h->state.const_tensor_ref(Tid{slot.idx}), + expected_dtype, + h->stream); + tensor_output_info.push_back({arg_idx, prepared_outputs.size()}); + prepared_outputs.push_back(std::move(out_arr)); + arg_idx++; + } else if (slot.slot_type == SlotType::IntValueSlot) { + Vid vid{slot.idx}; + int64_t int_val = + static_cast(h->state.const_value_ref(vid)); + *args[arg_idx] = EValue(int_val); + arg_idx++; + } else { + throw std::runtime_error( + "Unhandled output slot type: " + + std::to_string(static_cast(slot.slot_type))); + } + } + + // Submit all output work to GPU asynchronously + // async_eval encodes Metal commands and returns immediately. + // The GPU will signal events on completion. + if (!prepared_outputs.empty()) { + ::mlx::core::async_eval(prepared_outputs); + } + + } // Lock released — GPU is still executing + + for (auto& info : tensor_output_info) { + ETTensor& out_tensor = args[info.arg_idx]->toTensor(); + + // write_output waits on arr to be ready + write_output(prepared_outputs[info.prepared_idx], out_tensor); + } + + h->state.reset(); // Release temp GPU buffers back to MLX cache + + return Error::Ok; + } catch (const std::exception& e) { + ET_LOG(Error, "MLX execute failed: %s", e.what()); + return Error::Internal; + } + } + + void destroy(DelegateHandle* handle) const override { + std::lock_guard lock(mlx_global_mutex()); + if (handle != nullptr) { + auto* mlx_handle = static_cast(handle); + mlx_handle->~MLXHandle(); + } + } +}; + +namespace { +auto cls = MLXBackend(); +Backend backend{"MLXBackend", &cls}; +static auto success_with_compiler = register_backend(backend); +} // namespace + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXExecutor.h b/backends/mlx/runtime/MLXExecutor.h new file mode 100644 index 00000000000..32d623790ab --- /dev/null +++ b/backends/mlx/runtime/MLXExecutor.h @@ -0,0 +1,878 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#pragma once + +#include "MLXLoader.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================= +// Op Logging - compile-time gate + runtime env var check +// +// Compile flag (CMake: -DET_MLX_ENABLE_OP_LOGGING=1) controls whether logging +// code is compiled in at all. When off, all logging is stripped (zero +// overhead). When on, the env var ET_MLX_ENABLE_OP_LOGGING=1 must also be set +// at runtime to actually produce output. +// ============================================================================= +#ifndef ET_MLX_ENABLE_OP_LOGGING +#define ET_MLX_ENABLE_OP_LOGGING 0 +#endif + +// ============================================================================= +// Constant Zero-Copy - Enable via CMake: -DET_MLX_ENABLE_CONSTANT_ZERO_COPY=1 +// When enabled, attempts to load model constants (weights) using zero-copy +// on Apple Silicon's unified memory. Falls back to copying if zero-copy fails. +// Disable if you want predictable memory usage (always copies). +// ============================================================================= +#ifndef ET_MLX_ENABLE_CONSTANT_ZERO_COPY +#define ET_MLX_ENABLE_CONSTANT_ZERO_COPY 1 // Enabled by default +#endif + +namespace executorch { +namespace backends { +namespace mlx { + +/// Multiply two unsigned values, throw on overflow. +template +inline T safe_mul(T a, T b, const char* context) { + static_assert(std::is_unsigned::value, "safe_mul requires unsigned type"); + T result; + if (__builtin_mul_overflow(a, b, &result)) { + throw std::runtime_error(std::string(context) + ": unsigned mul overflow"); + } + return result; +} + +// Runtime check for op logging (only callable when compiled in) +#if ET_MLX_ENABLE_OP_LOGGING +inline bool isOpLoggingEnabled() { + static const bool enabled = []() { + const char* val = std::getenv("ET_MLX_ENABLE_OP_LOGGING"); + return val != nullptr && std::string(val) == "1"; + }(); + return enabled; +} +#else +constexpr bool isOpLoggingEnabled() { + return false; +} +#endif + +// Compile-time constant zero-copy flag +constexpr bool kEnableConstantZeroCopy = ET_MLX_ENABLE_CONSTANT_ZERO_COPY; + +using Tensor = ::mlx::core::array; +using Value = std::variant; +using StreamOrDevice = ::mlx::core::StreamOrDevice; + +struct ConstantData { + std::vector tensors; + + inline const Tensor& get(Tid id) const { + if (id.idx >= tensors.size()) { + throw std::out_of_range("ConstantData::get: id out of range"); + } + return tensors[id.idx]; + } + + inline void add(Tensor t) { + tensors.push_back(std::move(t)); + } +}; + +struct MutableBufferData { + // Maps tensor slot idx to MLX array + // Using vector of optional since mlx::array has no default constructor + std::vector> tensors; + + inline void resize(size_t n) { + tensors.resize(n, std::nullopt); + } + + inline bool has(Tid id) const { + return id.idx < tensors.size() && tensors[id.idx].has_value(); + } + + inline Tensor& get(Tid id) { + if (id.idx >= tensors.size() || !tensors[id.idx].has_value()) { + throw std::out_of_range("MutableBufferData::get: id not found or unset"); + } + return *tensors[id.idx]; + } + + inline const Tensor& get(Tid id) const { + if (id.idx >= tensors.size() || !tensors[id.idx].has_value()) { + throw std::out_of_range("MutableBufferData::get: id not found or unset"); + } + return *tensors[id.idx]; + } + + inline void set(Tid id, Tensor t) { + if (id.idx >= tensors.size()) { + throw std::out_of_range("MutableBufferData::set: id out of range"); + } + tensors[id.idx] = std::move(t); + } + + inline void clear() { + tensors.clear(); + } +}; + +struct ExecutionState { + const MLXProgram* program{nullptr}; + const ConstantData* constants{nullptr}; // Shared, read-only + MutableBufferData* mutable_buffers{nullptr}; // Per-handle, persistent + + // Per-execution tensors: inputs, outputs, temps (NOT constants or mutable + // buffers) + std::vector> tensors; + + // Non-constant values (SymInt, etc.) + std::vector> values; + + // Logging context + size_t current_op_idx{0}; + const char* current_op_name{nullptr}; + + // Tensor ID range boundaries for O(1) type lookup (computed at bind time) + uint32_t num_constants{0}; + uint32_t input_end{0}; + uint32_t output_end{0}; + uint32_t mutable_buffer_end{0}; + + void bind( + const MLXProgram& prog, + const ConstantData& const_data, + MutableBufferData& mut_bufs) { + program = &prog; + constants = &const_data; + mutable_buffers = &mut_bufs; + + // Allocate space for inputs, outputs, and temps only (not constants or + // mutable buffers) + uint64_t num_per_execution_tensors = + static_cast(prog.num_input_tensors) + + prog.num_output_tensors + prog.num_temp_tensors; + if (num_per_execution_tensors > 1'000'000) { + throw std::runtime_error( + "bind: num_per_execution_tensors " + + std::to_string(num_per_execution_tensors) + " exceeds limit"); + } + tensors.assign( + static_cast(num_per_execution_tensors), std::nullopt); + if (prog.num_values > 1'000'000) { + throw std::runtime_error( + "bind: num_values " + std::to_string(prog.num_values) + + " exceeds limit"); + } + values.assign(prog.num_values, std::nullopt); + + // Compute tensor ID range boundaries for fast type lookup + // ID assignment order: Constant -> Input -> Output -> MutableBuffer -> Temp + num_constants = prog.num_constant_tensors; + uint64_t ie = static_cast(num_constants) + prog.num_input_tensors; + uint64_t oe = ie + prog.num_output_tensors; + uint64_t me = oe + prog.num_mutable_buffer_tensors; + if (me > std::numeric_limits::max()) { + throw std::runtime_error("bind: tensor ID range overflow"); + } + input_end = static_cast(ie); + output_end = static_cast(oe); + mutable_buffer_end = static_cast(me); + } + + // Check if a tensor ID is a mutable buffer + inline bool is_mutable_buffer(Tid id) const { + return id.idx >= output_end && id.idx < mutable_buffer_end; + } + + // Convert tensor ID to index in the tensors vector + // Accounts for constants and mutable buffers not being in the vector + inline uint32_t tensor_index(Tid id) const { + if (id.idx < num_constants) { + throw std::runtime_error( + "tensor_index: called with constant tensor id " + + std::to_string(id.idx)); + } + if (is_mutable_buffer(id)) { + throw std::runtime_error( + "tensor_index: called with mutable buffer tensor id " + + std::to_string(id.idx)); + } + uint32_t idx = id.idx - num_constants; + // If this ID is after mutable buffer range, subtract mutable buffer count + if (id.idx >= mutable_buffer_end) { + if (idx < program->num_mutable_buffer_tensors) { + throw std::runtime_error( + "tensor_index: underflow for tensor id " + std::to_string(id.idx)); + } + idx -= program->num_mutable_buffer_tensors; + } + if (idx >= tensors.size()) { + throw std::out_of_range( + "tensor_index: computed index " + std::to_string(idx) + + " out of range (size=" + std::to_string(tensors.size()) + + ") for tensor id " + std::to_string(id.idx)); + } + return idx; + } + + void reset() { + // Clear per-execution tensors (inputs, outputs, temps) + // Constants and mutable buffers are not in this vector + for (auto& t : tensors) { + t = std::nullopt; + } + for (auto& v : values) { + v = std::nullopt; + } + } + + static inline const char* dtype_str(::mlx::core::Dtype dtype) { + using namespace ::mlx::core; + switch (dtype.val()) { + case float32.val(): + return "f32"; + case float16.val(): + return "f16"; + case bfloat16.val(): + return "bf16"; + case int32.val(): + return "i32"; + case int64.val(): + return "i64"; + case int16.val(): + return "i16"; + case int8.val(): + return "i8"; + case uint32.val(): + return "u32"; + case uint8.val(): + return "u8"; + case bool_.val(): + return "bool"; + default: + return "?"; + } + } + + static inline std::string format_tensor_info(const Tensor& t) { + std::ostringstream ss; + ss << dtype_str(t.dtype()); + ss << "("; + const auto& shape = t.shape(); + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) + ss << ","; + ss << shape[i]; + } + ss << ")"; + return ss.str(); + } + + // Compute tensor stats: min, max, mean, nan_count + // Uses MLX ops for GPU-accelerated computation + static inline std::string format_tensor_stats(const Tensor& t) { + using namespace ::mlx::core; + + try { + std::ostringstream ss; + + size_t numel = t.size(); + if (numel == 0) { + ss << "[empty]"; + return ss.str(); + } + + // Cast to float32 for stats computation (handles bf16/fp16/int/bool) + Tensor t_float = astype(t, float32); + + // Use MLX ops for efficient GPU-based stats + Tensor nan_mask = isnan(t_float); + Tensor inf_mask = isinf(t_float); + Tensor nan_count_arr = sum(astype(nan_mask, int32)); + Tensor inf_count_arr = sum(astype(inf_mask, int32)); + + // For min/max/mean, we need to handle NaN/Inf - replace with 0 + Tensor valid_mask = logical_not(logical_or(nan_mask, inf_mask)); + Tensor t_valid = where(valid_mask, t_float, zeros_like(t_float)); + + Tensor min_arr = min(t_valid); + Tensor max_arr = max(t_valid); + Tensor mean_arr = mean(t_valid); + + // Evaluate all at once + eval({nan_count_arr, inf_count_arr, min_arr, max_arr, mean_arr}); + + int nan_count = nan_count_arr.item(); + int inf_count = inf_count_arr.item(); + float min_val = min_arr.item(); + float max_val = max_arr.item(); + float mean_val = mean_arr.item(); + + ss << std::fixed << std::setprecision(4); + ss << "[min=" << min_val << " max=" << max_val << " mean=" << mean_val; + if (nan_count > 0) { + ss << " NaN=" << nan_count; + } + if (inf_count > 0) { + ss << " Inf=" << inf_count; + } + ss << "]"; + return ss.str(); + } catch (const std::exception& e) { + return std::string("[stats error: ") + e.what() + "]"; + } catch (...) { + return "[stats error: unknown]"; + } + } + + // Get tensor type prefix for logging: "c", "i", "o", "b", "t" + inline const char* tensor_type_prefix(Tid id) const { + if (!program) + return "?"; + + uint32_t tid = id.idx; + + // Check each range in order (mutually exclusive ranges) + if (tid < program->num_constant_tensors) + return "c"; // Constant + if (tid < input_end) + return "i"; // User Input + if (tid < output_end) + return "o"; // User Output + if (tid < mutable_buffer_end) + return "b"; // Mutable Buffer + return "t"; // Temp + } + + inline void begin_op(size_t idx, const char* name) { + current_op_idx = idx; + current_op_name = name; + if (isOpLoggingEnabled()) { + std::cout << "[" << idx << "] " << name << std::endl; + } + } + + inline void end_op() { + if (isOpLoggingEnabled()) { + std::cout << "----\n"; + } + } + + inline Tensor& tensor_ref(Tid id) { + if (isOpLoggingEnabled()) { + std::cout << " ref " << tensor_type_prefix(id) << id.idx << std::flush; + } + if (!program) { + throw std::runtime_error("tensor_ref: Program not bound"); + } + if (id.idx >= program->num_tensors()) { + throw std::out_of_range("tensor_ref: id out of range"); + } + if (program->is_constant_tensor(id)) { + throw std::runtime_error("tensor_ref: cannot mutate constant tensor"); + } + // Route to mutable buffers or per-execution tensors + Tensor* t = nullptr; + if (is_mutable_buffer(id)) { + if (!mutable_buffers) { + throw std::runtime_error("tensor_ref: mutable_buffers not bound"); + } + t = &mutable_buffers->get(id); + } else { + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("tensor_ref: tensor idx out of range"); + } + auto& opt = tensors[idx]; + if (!opt) { + throw std::runtime_error( + "tensor_ref: uninitialized tensor idx=" + std::to_string(id.idx)); + } + t = &*opt; + } + if (isOpLoggingEnabled()) { + std::cout << " " << format_tensor_info(*t) << "\n"; + } + return *t; + } + + inline const Tensor& const_tensor_ref(Tid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in " << tensor_type_prefix(id) << id.idx << std::flush; + } + if (!program) { + throw std::runtime_error("const_tensor_ref: Program not bound"); + } + if (id.idx >= program->num_tensors()) { + throw std::out_of_range("const_tensor_ref: id out of range"); + } + + const Tensor* t = nullptr; + if (program->is_constant_tensor(id)) { + // Route to constants + if (!constants) { + throw std::runtime_error("const_tensor_ref: constants not bound"); + } + t = &constants->get(id); + } else if (is_mutable_buffer(id)) { + // Route to mutable buffers + if (!mutable_buffers) { + throw std::runtime_error("const_tensor_ref: mutable_buffers not bound"); + } + t = &mutable_buffers->get(id); + } else { + // Route to per-execution tensors + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("const_tensor_ref: tensor idx out of range"); + } + const auto& opt = tensors[idx]; + if (!opt) { + throw std::runtime_error( + "const_tensor_ref: uninitialized tensor idx=" + + std::to_string(id.idx)); + } + t = &*opt; + } + + if (isOpLoggingEnabled()) { + std::cout << " " << format_tensor_info(*t) << " " + << format_tensor_stats(*t) << "\n"; + } + return *t; + } + + // Set a tensor output + inline void set_tensor(Tid id, Tensor arr) { + if (isOpLoggingEnabled()) { + std::cout << " out " << tensor_type_prefix(id) << id.idx << " " + << format_tensor_info(arr) << " " << format_tensor_stats(arr) + << "\n"; + } + if (!program) { + throw std::runtime_error("set_tensor: Program not bound"); + } + if (id.idx < program->num_constant_tensors) { + throw std::runtime_error("set_tensor: cannot write to constant tensor"); + } + // Route to mutable buffers or per-execution tensors + if (is_mutable_buffer(id)) { + if (!mutable_buffers) { + throw std::runtime_error("set_tensor: mutable_buffers not bound"); + } + mutable_buffers->set(id, std::move(arr)); + } else { + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("set_tensor: tensor idx out of range"); + } + tensors[idx] = std::move(arr); + } + } + + template + inline T& value_ref(Vid id) { + if (isOpLoggingEnabled()) { + std::cout << " ref v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("value_ref: id out of range"); + } + auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "value_ref: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::cout << " " << std::get(*opt) << "\n"; + } + return std::get(*opt); + } + + template + inline const T& const_value_ref(Vid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("const_value_ref: id out of range"); + } + const auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "const_value_ref: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::cout << " " << std::get(*opt) << "\n"; + } + return std::get(*opt); + } + + inline const Value& const_value(Vid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("const_value: id out of range"); + } + const auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "const_value: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::visit([](auto&& arg) { std::cout << " " << arg << "\n"; }, *opt); + } + return *opt; + } + + template + inline void set_value(Vid id, T val) { + if (isOpLoggingEnabled()) { + std::cout << " out v" << id.idx << " " << val << "\n"; + } + if (id.idx >= values.size()) { + throw std::out_of_range("set_value: id out of range"); + } + values[id.idx] = val; + } +}; + +inline ::mlx::core::Dtype resolve_dtype(ScalarType d) { + using namespace ::mlx::core; + switch (d) { + case ScalarType::Half: + return float16; + case ScalarType::Float: + return float32; + case ScalarType::BFloat16: + return bfloat16; + case ScalarType::Int: + return int32; + case ScalarType::Short: + return int16; + case ScalarType::Long: + return int64; + case ScalarType::UInt32: + return uint32; + case ScalarType::Byte: + return uint8; + case ScalarType::Bool: + return bool_; + case ScalarType::Char: + return int8; + default: + throw std::runtime_error( + "Unsupported ScalarType: " + std::to_string(static_cast(d))); + } +} + +inline ::mlx::core::Dtype resolve_dtype(int8_t d) { + return resolve_dtype(static_cast(d)); +} + +// Maximum allocation size for any single tensor created from untrusted data. +// This bounds GPU memory allocation from malformed payloads. +constexpr size_t kMaxAllocationBytes = + static_cast(4) * 1024 * 1024 * 1024; // 4 GB + +/// Validate that a tensor with the given shape and dtype does not exceed +/// kMaxAllocationBytes. Throws std::runtime_error on invalid dimensions +/// or if the total size exceeds the limit. +inline void check_allocation_bounded( + const ::mlx::core::Shape& shape, + ::mlx::core::Dtype dtype, + const char* context) { + size_t elem_size = ::mlx::core::size_of(dtype); + size_t numel = 1; + for (auto d : shape) { + if (d <= 0) { + throw std::runtime_error( + std::string(context) + ": invalid dimension " + std::to_string(d)); + } + numel = safe_mul(numel, static_cast(d), context); + } + size_t total_bytes = safe_mul(numel, elem_size, context); + if (total_bytes > kMaxAllocationBytes) { + throw std::runtime_error( + std::string(context) + ": allocation exceeds 4GB limit"); + } +} + +inline int32_t clamp_to_int32(int64_t val64) { + // INT64_MAX is commonly used as a sentinel for "slice to end". + // Non-sentinel large values are silently clamped, which may change + // slice semantics — but this matches PyTorch behavior. + if (val64 >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } else if ( + val64 <= static_cast(std::numeric_limits::min())) { + return std::numeric_limits::min(); + } + return static_cast(val64); +} + +inline int32_t resolve_int( + const std::variant& v, + const ExecutionState& st) { + if (std::holds_alternative(v)) { + return clamp_to_int32(std::get(v)); + } + return st.const_value_ref(std::get(v)); +} + +inline std::vector resolve_ints( + const std::vector>& v, + const ExecutionState& st) { + std::vector out; + out.reserve(v.size()); + for (const auto& elem : v) { + out.push_back(resolve_int(elem, st)); + } + return out; +} + +inline float resolve_float( + const std::variant& v, + const ExecutionState& st) { + if (std::holds_alternative(v)) { + return static_cast(std::get(v)); + } + // The value may be stored as int32_t (from SymInt computations) or float. + const auto& val = st.const_value(std::get(v)); + return std::visit( + [](auto&& arg) -> float { return static_cast(arg); }, val); +} + +inline ::mlx::core::Shape to_shape( + const std::vector>& dims, + const ExecutionState& st) { + auto resolved = resolve_ints(dims, st); + return ::mlx::core::Shape(resolved.begin(), resolved.end()); +} + +inline ::mlx::core::Shape to_shape(const std::vector& dims) { + return ::mlx::core::Shape(dims.begin(), dims.end()); +} + +// Overload for static shapes (used when loading constants where all dims must +// be literals) +// Convert ShapeDim vector to MLX Shape (for constants and mutable buffers). +// Only static dimensions are allowed — dynamic dims (value == -1) are rejected. +inline ::mlx::core::Shape to_shape(const std::vector& dims) { + ::mlx::core::Shape out; + out.reserve(dims.size()); + for (const auto& d : dims) { + if (d.is_dynamic()) { + throw std::runtime_error( + "to_shape: expected static shape but found dynamic dimension"); + } + out.push_back(d.value); + } + return out; +} + +// Load constants from ExecuTorch's NamedDataMap. +// Constants are stored by name in the .pte file and loaded via the +// named_data_map interface. This allows ExecuTorch to own the constant data and +// enables zero-copy on Apple Silicon unified memory. +// +// Parameters: +// program: The loaded MLXProgram containing tensor metadata and named_slots +// named_data_map: ExecuTorch's interface for accessing named data +// store: Output storage for loaded constant tensors +// constant_buffers: Vector to store FreeableBuffers (must outlive store for +// zero-copy) +inline void load_constants( + const MLXProgram& program, + const runtime::NamedDataMap* named_data_map, + ConstantData& store, + std::vector& constant_buffers) { + using namespace ::mlx::core; + + store.tensors.clear(); + constant_buffers.clear(); + + if (program.num_constant_tensors == 0) { + return; + } + + store.tensors.reserve(program.num_constant_tensors); + constant_buffers.reserve(program.num_constant_tensors); + + // Build tid -> name map for O(1) lookup + std::unordered_map tid_to_name; + tid_to_name.reserve(program.named_slots.size()); + for (const auto& ns : program.named_slots) { + if (ns.slot.slot_type == SlotType::TensorSlot) { + tid_to_name[ns.slot.idx] = &ns.name; + } + } + + // Load each constant tensor by name + for (uint32_t tid = 0; tid < program.num_constant_tensors; ++tid) { + // Get tensor metadata + if (tid >= program.tensor_meta.size() || !program.tensor_meta[tid]) { + throw std::runtime_error( + "load_constants: missing metadata for constant " + + std::to_string(tid)); + } + + // Find the name for this tensor ID + auto it = tid_to_name.find(tid); + const std::string* name = (it != tid_to_name.end()) ? it->second : nullptr; + if (!name) { + throw std::runtime_error( + "load_constants: no name found for constant tensor " + + std::to_string(tid)); + } + + // Get data from named_data_map + if (named_data_map == nullptr) { + throw std::runtime_error( + "load_constants: named_data_map is null but program has constants"); + } + + auto data_result = named_data_map->get_data(name->c_str()); + if (!data_result.ok()) { + throw std::runtime_error( + "load_constants: failed to get data for constant '" + *name + + "': error " + std::to_string(static_cast(data_result.error()))); + } + + // Move the buffer into our storage (keeps it alive for zero-copy) + constant_buffers.push_back(std::move(data_result.get())); + runtime::FreeableBuffer& buffer = constant_buffers.back(); + + const auto& meta = *program.tensor_meta[tid]; + Shape shape = to_shape(meta.shape); + Dtype dtype = resolve_dtype(meta.scalar_type); + + // Create MLX array with zero-copy when enabled. + // SAFETY: Constants are read-only; the program builder ensures no in-place + // ops target constant tensors. The const_cast is required by MLX's array + // constructor but the data will not be mutated + void* data_ptr = const_cast(buffer.data()); + + if constexpr (kEnableConstantZeroCopy) { + // Zero-copy: wrap pointer directly with no-op deleter + // The FreeableBuffer in constant_buffers keeps the data alive + auto deleter = [](void*) { + // Data lifetime managed by FreeableBuffer in + // MLXHandle::constant_buffers + }; + array arr = array(data_ptr, shape, dtype, deleter); + store.add(std::move(arr)); + } else { + // No deleter = MLX copies the data into its own memory + store.add(array(static_cast(data_ptr), shape, dtype)); + } + } + + // Evaluate all constants immediately to prepare Metal buffers + // This trades init time for faster first inference + eval(store.tensors); +} + +inline void load_mutable_buffers( + const MLXProgram& program, + MutableBufferData& store) { + using namespace ::mlx::core; + + store.clear(); + + if (program.mutable_buffer_map.empty()) { + return; + } + + // Pre-size the storage to fit all tensor IDs + // Mutable buffer IDs are in the global tensor ID space + uint32_t max_tid = 0; + for (const auto& slot : program.mutable_buffer_map) { + if (slot.idx > max_tid) { + max_tid = slot.idx; + } + } + if (max_tid >= 1'000'000) { + throw std::runtime_error( + "load_mutable_buffers: max_tid " + std::to_string(max_tid) + + " exceeds limit"); + } + store.resize(max_tid + 1); + + for (const auto& slot : program.mutable_buffer_map) { + if (slot.slot_type != SlotType::TensorSlot) { + throw std::runtime_error( + "load_mutable_buffers: unexpected slot type " + + std::to_string(static_cast(slot.slot_type))); + } + + Tid tid{slot.idx}; + + // Get metadata for this tensor + if (tid.idx >= program.tensor_meta.size()) { + ET_LOG( + Error, + "load_mutable_buffers: tid %u >= tensor_meta.size() %zu", + tid.idx, + program.tensor_meta.size()); + throw std::runtime_error( + "load_mutable_buffers: tensor index out of range for tensor " + + std::to_string(tid.idx)); + } + + if (!program.tensor_meta[tid.idx]) { + ET_LOG( + Error, + "load_mutable_buffers: missing metadata for tensor %u", + tid.idx); + throw std::runtime_error( + "load_mutable_buffers: missing metadata for tensor " + + std::to_string(tid.idx)); + } + + const auto& meta = *program.tensor_meta[tid.idx]; + auto shape = to_shape(meta.shape); + auto dtype = resolve_dtype(meta.scalar_type); + + check_allocation_bounded(shape, dtype, "load_mutable_buffers"); + + // Initialize mutable buffer to zeros + // This matches the typical initialization of KV cache buffers + auto arr = zeros(shape, dtype); + + // Evaluate immediately to allocate in GPU memory + eval(arr); + + store.set(tid, std::move(arr)); + } +} + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h new file mode 100644 index 00000000000..f3b6e9b720f --- /dev/null +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -0,0 +1,169 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#pragma once + +#include "MLXExecutor.h" + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +namespace ops { + +using namespace ::mlx::core; + +/** + * Normalize axis to be in range [0, rank) and validate. + * @param axis The axis value (can be negative) + * @param rank The tensor rank + * @param op_name Name of the operation for error messages + * @return Normalized axis in range [0, rank) + * @throws std::out_of_range if axis is out of range + */ +inline int normalize_axis(int axis, int rank, const char* op_name) { + if (axis < -rank || axis >= rank) { + throw std::out_of_range(std::string(op_name) + ": axis out of range"); + } + if (axis < 0) + axis += rank; + return axis; +} + +/** + * Infers dimensions with -1 in a reshape-like operation. + * + * PyTorch allows -1 in shapes to mean "infer this dimension from total size". + * MLX requires concrete positive integers, so we must resolve -1 values. + * + * @param shape The shape to resolve (may contain -1) + * @param input_size Total number of elements in the input tensor + * @return Resolved shape with all positive integers + * @throws std::runtime_error if shape has multiple -1 or incompatible sizes + */ +inline std::vector infer_shape_with_minus_one( + const std::vector& shape, + size_t input_size) { + std::vector resolved_shape = shape; + int neg_one_idx = -1; + int64_t known_size = 1; // Use int64_t to avoid overflow + + // Find -1 dimension and compute product of known dimensions + for (size_t i = 0; i < resolved_shape.size(); i++) { + if (resolved_shape[i] == -1) { + if (neg_one_idx != -1) { + throw std::runtime_error("infer_shape: only one dimension can be -1"); + } + neg_one_idx = static_cast(i); + } else { + known_size *= static_cast(resolved_shape[i]); + } + } + + // Infer the -1 dimension if present + if (neg_one_idx != -1) { + if (known_size == 0) { + throw std::runtime_error( + "infer_shape: cannot infer -1 dimension when known product is 0"); + } + int64_t input_size_i64 = static_cast(input_size); + if (input_size_i64 % known_size != 0) { + throw std::runtime_error( + "infer_shape: cannot infer dimension - size mismatch"); + } + int64_t inferred_dim = input_size_i64 / known_size; + + // Check that inferred dimension fits in int + if (inferred_dim > std::numeric_limits::max()) { + throw std::runtime_error( + "infer_shape: inferred dimension exceeds int max"); + } + + resolved_shape[static_cast(neg_one_idx)] = + static_cast(inferred_dim); + } + + return resolved_shape; +} + +inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} + +inline void +exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + + st.set_tensor(n.out, std::move(Y)); +} + +} // namespace ops + +class Interpreter { + public: + void run( + const MLXProgram& prog, + ExecutionState& st, + StreamOrDevice stream = {}) const { + run_chain(prog, prog.main_chain_idx, st, stream); + } + + void run_chain( + const MLXProgram& prog, + uint32_t chain_idx, + ExecutionState& st, + StreamOrDevice stream = {}) const { + if (chain_idx >= prog.instruction_chains.size()) { + throw std::runtime_error( + "run_chain: chain_idx " + std::to_string(chain_idx) + + " out of range (num_chains=" + + std::to_string(prog.instruction_chains.size()) + ")"); + } + const auto& chain = prog.instruction_chains[chain_idx]; + size_t idx = 0; + for (const auto& instr : chain) { + st.begin_op(idx, op_name(instr.op)); + dispatch(instr, st, stream); + st.end_op(); + ++idx; + } + } + + private: + void dispatch(const Instruction& instr, ExecutionState& st, StreamOrDevice s) + const { + switch (instr.op) { + case OpCode::NOOP: + ops::exec_noop(std::get(instr.node), st, s); + break; + case OpCode::ADDMM: + ops::exec_addmm(std::get(instr.node), st, s); + break; + default: + throw std::runtime_error( + "Unknown opcode: " + std::to_string(static_cast(instr.op))); + } + } +}; + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/MLXLoader.cpp.tmpl b/backends/mlx/serialization/MLXLoader.cpp.tmpl new file mode 100644 index 00000000000..aa4716d7a4a --- /dev/null +++ b/backends/mlx/serialization/MLXLoader.cpp.tmpl @@ -0,0 +1,324 @@ +// -*- c++ -*- + +#include "MLXLoader.h" + +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { +namespace loader { + +namespace { + +// Header structure for MLX payload +constexpr size_t kHeaderSize = 24; +constexpr uint32_t kMagic = 0x30584C4D; // "MLX0" in little-endian + +struct MLXHeader { + uint32_t padding; + uint32_t magic; + uint64_t data_offset; + uint64_t data_size; +}; +static_assert(sizeof(MLXHeader) == kHeaderSize, "MLXHeader size mismatch"); + +bool parse_header(const void* data, size_t size, MLXHeader& header) { + if (size < kHeaderSize) { + return false; + } + std::memcpy(&header, data, sizeof(MLXHeader)); + if (header.magic != kMagic) { + return false; + } + // Validate data_offset: must be strictly greater than kHeaderSize (so the + // FlatBuffer region is non-empty) and must not exceed the total buffer size. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + return false; + } + return true; +} + +// Helper to convert FlatBuffer vectors to std::vector. +// Caps size to prevent unbounded allocations from malformed payloads. +template +std::vector to_vector(const flatbuffers::Vector* fb_vec) { + if (!fb_vec) { + return {}; + } + constexpr size_t kMaxVectorSize = 1'000'000; + if (fb_vec->size() > kMaxVectorSize) { + throw std::runtime_error( + "FlatBuffer vector size " + std::to_string(fb_vec->size()) + + " exceeds maximum of " + std::to_string(kMaxVectorSize)); + } + return std::vector(fb_vec->begin(), fb_vec->end()); +} + +} // namespace + +// ============================================================================= +// load_instruction - AUTO-GENERATED switch statement +// ============================================================================= + +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr) { + Instruction instr; + + if (!fb_instr || !fb_instr->op()) { + instr.op = OpCode::NOOP; + instr.node = NoopNode{}; + return instr; + } + + auto op_type = fb_instr->op_type(); + + switch (op_type) { +{{LOAD_INSTRUCTION_CASES}} + default: + throw std::runtime_error( + "Unknown op_type in load_instruction: " + + std::to_string(static_cast(op_type)) + + ". The .pte was built with a newer schema than this binary. " + "Rebuild with the latest runtime."); + } + + return instr; +} + +// ============================================================================= +// load_program +// ============================================================================= + +MLXProgram load_program(const void* data, size_t size) { + MLXHeader header; + if (!parse_header(data, size, header)) { + throw std::runtime_error("Invalid MLX header"); + } + + // Defense-in-depth: parse_header already validates this, but guard the + // unsigned subtraction against underflow in case the call site ever changes. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + throw std::runtime_error("data_offset out of range"); + } + const uint8_t* fb_data = static_cast(data) + kHeaderSize; + size_t fb_size = header.data_offset - kHeaderSize; + + flatbuffers::Verifier verifier(fb_data, fb_size); + if (!mlx_delegate::VerifyMLXGraphBuffer(verifier)) { + throw std::runtime_error("Invalid FlatBuffer data"); + } + + const auto* fb_graph = mlx_delegate::GetMLXGraph(fb_data); + if (!fb_graph) { + throw std::runtime_error("Failed to parse MLXGraph"); + } + + MLXProgram program; + + if (fb_graph->version()) { + program.version = fb_graph->version()->str(); + } + + program.num_constant_tensors = fb_graph->num_constant_tensors(); + program.num_input_tensors = fb_graph->num_input_tensors(); + program.num_output_tensors = fb_graph->num_output_tensors(); + program.num_mutable_buffer_tensors = fb_graph->num_mutable_buffer_tensors(); + program.num_temp_tensors = fb_graph->num_temp_tensors(); + program.num_values = fb_graph->num_values(); + + // Cap all counts/collection sizes to prevent unbounded allocations from + // malformed FlatBuffer payloads + constexpr size_t kMaxCollectionSize = 1'000'000; + auto check_collection_size = [](size_t sz, const char* name) { + if (sz > kMaxCollectionSize) { + throw std::runtime_error( + std::string("Malformed program: ") + name + " size " + + std::to_string(sz) + " exceeds maximum of " + + std::to_string(kMaxCollectionSize)); + } + }; + + check_collection_size(program.num_tensors(), "num_tensors()"); + check_collection_size(program.num_values, "num_values"); + + if (fb_graph->instruction_chains()) { + check_collection_size(fb_graph->instruction_chains()->size(), "instruction_chains"); + program.instruction_chains.reserve(fb_graph->instruction_chains()->size()); + for (size_t c = 0; c < fb_graph->instruction_chains()->size(); ++c) { + const auto* fb_chain = fb_graph->instruction_chains()->Get(static_cast(c)); + std::vector chain; + if (fb_chain && fb_chain->instructions()) { + check_collection_size(fb_chain->instructions()->size(), "instructions in chain"); + chain.reserve(fb_chain->instructions()->size()); + for (size_t i = 0; i < fb_chain->instructions()->size(); ++i) { + chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)))); + } + } + program.instruction_chains.push_back(std::move(chain)); + } + } + + program.main_chain_idx = fb_graph->main_chain_idx(); + program.init_chain_idx = fb_graph->init_chain_idx(); + + // Validate chain indices against actual instruction_chains size. + if (program.main_chain_idx >= program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid main_chain_idx " + + std::to_string(program.main_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + if (program.init_chain_idx >= 0 && + static_cast(program.init_chain_idx) >= + program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid init_chain_idx " + + std::to_string(program.init_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + + if (fb_graph->input_map()) { + check_collection_size(fb_graph->input_map()->size(), "input_map"); + for (size_t i = 0; i < fb_graph->input_map()->size(); ++i) { + const auto* slot = fb_graph->input_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "input_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.input_map.push_back(sv); + } + } + + if (fb_graph->output_map()) { + check_collection_size(fb_graph->output_map()->size(), "output_map"); + for (size_t i = 0; i < fb_graph->output_map()->size(); ++i) { + const auto* slot = fb_graph->output_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "output_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.output_map.push_back(sv); + } + } + + if (fb_graph->mutable_buffer_map()) { + check_collection_size(fb_graph->mutable_buffer_map()->size(), "mutable_buffer_map"); + for (size_t i = 0; i < fb_graph->mutable_buffer_map()->size(); ++i) { + const auto* slot = fb_graph->mutable_buffer_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "mutable_buffer_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.mutable_buffer_map.push_back(sv); + } + } + + if (fb_graph->named_slots()) { + check_collection_size(fb_graph->named_slots()->size(), "named_slots"); + for (size_t i = 0; i < fb_graph->named_slots()->size(); ++i) { + const auto* fb_slot = fb_graph->named_slots()->Get(static_cast(i)); + if (!fb_slot || !fb_slot->name()) { + throw std::runtime_error( + "Malformed program: named_slot at index " + std::to_string(i) + + " is null or has null name"); + } + NamedSlot slot; + slot.name = fb_slot->name()->str(); + slot.slot = convert_slot_variant(fb_slot->slot()); + program.named_slots.push_back(std::move(slot)); + } + } + + if (fb_graph->tensor_meta()) { + check_collection_size(fb_graph->tensor_meta()->size(), "tensor_meta"); + for (size_t i = 0; i < fb_graph->tensor_meta()->size(); ++i) { + const auto* fb_meta = fb_graph->tensor_meta()->Get(static_cast(i)); + if (fb_meta) { + TensorMeta meta; + if (fb_meta->shape()) { + // Validate tensor rank against kTensorDimensionLimit to prevent + // stack overflows from unchecked rank + constexpr size_t kTensorDimensionLimit = 16; + if (fb_meta->shape()->size() > kTensorDimensionLimit) { + throw std::runtime_error( + "Tensor at index " + std::to_string(i) + + " has rank " + std::to_string(fb_meta->shape()->size()) + + " exceeding kTensorDimensionLimit (" + + std::to_string(kTensorDimensionLimit) + ")"); + } + for (size_t j = 0; j < fb_meta->shape()->size(); ++j) { + const auto* fb_dim = fb_meta->shape()->Get(static_cast(j)); + if (!fb_dim) { + throw std::runtime_error( + "Null ShapeDim at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + ShapeDim dim; + dim.value = fb_dim->value(); + dim.min_value = fb_dim->min_value(); + dim.max_value = fb_dim->max_value(); + if (dim.value < -1) { + throw std::runtime_error( + "Invalid ShapeDim value " + std::to_string(dim.value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.is_dynamic()) { + if (dim.min_value < 0) { + throw std::runtime_error( + "Invalid ShapeDim min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.max_value != -1 && dim.max_value < dim.min_value) { + throw std::runtime_error( + "ShapeDim max_value " + std::to_string(dim.max_value) + + " < min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + } + meta.shape.push_back(dim); + } + } + auto raw_scalar_type = fb_meta->scalar_type(); + if (raw_scalar_type < 0 || + raw_scalar_type >= + static_cast(ScalarType::NumOptions)) { + throw std::runtime_error( + "Invalid scalar_type " + std::to_string(raw_scalar_type) + + " in tensor_meta at index " + std::to_string(i)); + } + meta.scalar_type = static_cast(raw_scalar_type); + if (fb_meta->dim_order()) { + meta.dim_order = to_vector(fb_meta->dim_order()); + } + program.tensor_meta.push_back(std::move(meta)); + } else { + program.tensor_meta.push_back(std::nullopt); + } + } + } + + return program; +} + +} // namespace loader +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/MLXLoader.h.tmpl b/backends/mlx/serialization/MLXLoader.h.tmpl new file mode 100644 index 00000000000..0930d5e00e1 --- /dev/null +++ b/backends/mlx/serialization/MLXLoader.h.tmpl @@ -0,0 +1,343 @@ +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "schema_generated.h" + +// ExecuTorch scalar type for dtype representation +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// ============================================================================= +// Core types matching the Python side +// ============================================================================= + +struct Tid { + uint32_t idx{}; +}; + +struct Vid { + uint32_t idx{}; +}; + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Import ScalarType from ExecuTorch +using ScalarType = ::executorch::runtime::etensor::ScalarType; + +struct ShapeDim { + int32_t value{-1}; // Static dim (>= 0), or -1 for dynamic + int32_t min_value{0}; // Lower bound (when value == -1) + int32_t max_value{-1}; // Upper bound (-1 = unbounded, when value == -1) + + bool is_dynamic() const { return value < 0; } +}; + +struct TensorMeta { + std::vector shape; + ScalarType scalar_type{ScalarType::Float}; // ET ScalarType + std::vector dim_order; +}; + +// VidOrTid: either a scalar value (Vid) or a tensor (Tid) +struct VidOrTid { + Vid vid{}; + Tid tid{}; + bool is_vid{false}; // false = use tid, true = use vid +}; + +// IntOrVidOrTid: a literal int, a runtime Vid, or a tensor (Tid) +struct IntOrVidOrTid { + int64_t literal{0}; + Vid vid{}; + Tid tid{}; + uint8_t kind{0}; // 0 = literal int, 1 = vid, 2 = tid +}; + +// ============================================================================= +// Op node types (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +{{OP_NODE_STRUCTS}} + +// ============================================================================= +// OpCode enum (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +enum class OpCode : uint8_t { +{{OPCODE_ENUM_VALUES}} +}; + +// OpCode to string conversion (for logging) +inline const char* op_name(OpCode op) { + switch (op) { +{{OP_NAME_CASES}} + } + return "UNKNOWN"; +} + +// ============================================================================= +// NodeVariant for type-erased op storage (AUTO-GENERATED) +// ============================================================================= + +using NodeVariant = std::variant< +{{NODE_VARIANT_TYPES}} +>; + +// ============================================================================= +// Instruction +// ============================================================================= + +struct Instruction { + OpCode op{OpCode::NOOP}; + NodeVariant node; + + template + T& get() { + return std::get(node); + } + + template + const T& get() const { + return std::get(node); + } +}; + +// ============================================================================= +// SlotVariant for I/O mapping +// ============================================================================= + +enum class SlotType : uint8_t { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3, +}; + +struct SlotVariant { + uint32_t idx; + SlotType slot_type; +}; + +// ============================================================================= +// Named slot (name -> slot mapping) +// ============================================================================= + +struct NamedSlot { + std::string name; + SlotVariant slot; +}; + +// ============================================================================= +// MLXProgram - the loaded program ready for execution +// ============================================================================= + +struct MLXProgram { + std::string version; + + // Tensor/value slot counts (in Tid assignment order) + uint32_t num_constant_tensors{0}; + uint32_t num_input_tensors{0}; + uint32_t num_output_tensors{0}; + uint32_t num_mutable_buffer_tensors{0}; + uint32_t num_temp_tensors{0}; + uint32_t num_values{0}; + + // Instruction chains + std::vector> instruction_chains; + uint32_t main_chain_idx{0}; + int32_t init_chain_idx{-1}; // -1 = no init chain + + // I/O mappings + std::vector input_map; + std::vector output_map; + std::vector mutable_buffer_map; + + // Name to slot lookup + std::vector named_slots; + + // Tensor metadata + std::vector> tensor_meta; + + // Helper methods + inline uint64_t num_tensors() const { + return static_cast(num_constant_tensors) + + num_input_tensors + num_output_tensors + + num_mutable_buffer_tensors + num_temp_tensors; + } + + inline bool is_constant_tensor(Tid id) const { + return id.idx < num_constant_tensors; + } + + inline size_t num_inputs() const { + return input_map.size(); + } + + inline size_t num_outputs() const { + return output_map.size(); + } +}; + +// ============================================================================= +// FlatBuffer loading functions +// ============================================================================= + +namespace loader { + +// Convert FlatBuffer SlotType to our SlotType +inline SlotType convert_slot_type(mlx_delegate::SlotType fb_type) { + switch (fb_type) { + case mlx_delegate::SlotType_TensorSlot: + return SlotType::TensorSlot; + case mlx_delegate::SlotType_IntValueSlot: + return SlotType::IntValueSlot; + case mlx_delegate::SlotType_FloatValueSlot: + return SlotType::FloatValueSlot; + case mlx_delegate::SlotType_BoolValueSlot: + return SlotType::BoolValueSlot; + default: + throw std::runtime_error("Unknown SlotType: " + + std::to_string(static_cast(fb_type))); + } +} + +// Convert FlatBuffer Tid +inline Tid convert_tid(const mlx_delegate::Tid* fb_tid) { + if (!fb_tid) { + throw std::runtime_error("Null Tid in FlatBuffer"); + } + return Tid{fb_tid->idx()}; +} + +// Convert FlatBuffer Vid +inline Vid convert_vid(const mlx_delegate::Vid* fb_vid) { + if (!fb_vid) { + throw std::runtime_error("Null Vid in FlatBuffer"); + } + return Vid{fb_vid->idx()}; +} + +// Convert FlatBuffer IntOrVid +inline std::variant convert_int_or_vid( + const mlx_delegate::IntOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("IntOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer FloatOrVid +inline std::variant convert_float_or_vid( + const mlx_delegate::FloatOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null FloatOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("FloatOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer VidOrTid (scalar value or tensor) +inline VidOrTid convert_vid_or_tid( + const mlx_delegate::VidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null VidOrTid in FlatBuffer"); + } + VidOrTid result; + result.is_vid = fb->is_vid(); + if (result.is_vid) { + if (!fb->vid()) { + throw std::runtime_error("VidOrTid has is_vid=true but vid pointer is null"); + } + result.vid = Vid{fb->vid()->idx()}; + } else { + if (!fb->tid()) { + throw std::runtime_error("VidOrTid has is_vid=false but tid pointer is null"); + } + result.tid = Tid{fb->tid()->idx()}; + } + return result; +} + +// Convert FlatBuffer IntOrVidOrTid (literal int, Vid, or Tid) +inline IntOrVidOrTid convert_int_or_vid_or_tid( + const mlx_delegate::IntOrVidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVidOrTid in FlatBuffer"); + } + IntOrVidOrTid result; + result.kind = fb->kind(); + switch (result.kind) { + case 0: // literal int + result.literal = fb->literal(); + break; + case 1: { // Vid + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=1 (Vid) but vid pointer is null"); + } + result.vid = Vid{vid_ptr->idx()}; + break; + } + case 2: { // Tid + const auto* tid_ptr = fb->tid(); + if (!tid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=2 (Tid) but tid pointer is null"); + } + result.tid = Tid{tid_ptr->idx()}; + break; + } + default: + throw std::runtime_error( + "IntOrVidOrTid has invalid kind: " + std::to_string(result.kind)); + } + return result; +} + +// Convert FlatBuffer SlotVariant +inline SlotVariant convert_slot_variant(const mlx_delegate::SlotVariant* fb) { + if (!fb) { + throw std::runtime_error("Null SlotVariant in FlatBuffer"); + } + return SlotVariant{fb->idx(), convert_slot_type(fb->slot_type())}; +} + +// Load an instruction from FlatBuffer +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr); + +// Load the full MLXProgram from FlatBuffer data +MLXProgram load_program(const void* data, size_t size); + +} // namespace loader + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/README.md b/backends/mlx/serialization/README.md new file mode 100644 index 00000000000..f2c022d0c80 --- /dev/null +++ b/backends/mlx/serialization/README.md @@ -0,0 +1,130 @@ +# MLX Delegate Serialization + +This directory contains the serialization code for the MLX delegate, which converts +Python graph representations to FlatBuffer format for execution on Apple Silicon. + +## Single Source of Truth: `schema.fbs` + +The FlatBuffer schema file `schema.fbs` is the **single source of truth** for all +serialization-related code. When you need to add a new op or modify existing types, +edit `schema.fbs` and regenerate all derived files. + +## Code Generator + +The `generate.py` script parses `schema.fbs` and generates: + +| Generated File | Description | +|----------------|-------------| +| `mlx_graph_schema.py` | Python dataclasses for all schema types | +| `_generated_serializers.py` | Python FlatBuffer serialization methods | +| `_generated/` | Python FlatBuffer reader classes (via `flatc`) | +| `../runtime/MLXLoader.h` | C++ structs, OpCode enum, NodeVariant | +| `../runtime/MLXLoader.cpp` | C++ `load_instruction()` switch statement | +| `../runtime/schema_generated.h` | C++ FlatBuffer reader classes (via `flatc`) | + +## Usage + +### Regenerate all files + +From the executorch root directory: + +```bash +python backends/mlx/serialization/generate.py +``` + +Or with explicit flatc path: + +```bash +python backends/mlx/serialization/generate.py --flatc /path/to/flatc +``` + +### Options + +``` +--flatc PATH Path to flatc compiler (default: "flatc") +--skip-flatc Skip running flatc (use existing FlatBuffer bindings) +--dry-run Print what would be generated without writing files +``` + +## File Structure + +``` +serialization/ +├── README.md # This file +├── schema.fbs # SOURCE OF TRUTH - FlatBuffer schema +├── generate.py # Code generator script +├── mlx_graph_schema.py # [GENERATED] Python dataclasses +├── mlx_graph_serialize.py # Main serializer (uses generated code) +├── _generated_serializers.py # [GENERATED] Op serialization methods +└── _generated/ # [GENERATED] FlatBuffer Python bindings + └── mlx_delegate/ + ├── *.py # One file per table/enum + +runtime/ +├── MLXLoader.h # [GENERATED] C++ types and loader decls +├── MLXLoader.cpp # [GENERATED] C++ loader implementation +├── schema_generated.h # [GENERATED] FlatBuffer C++ bindings +├── MLXInterpreter.h # C++ executor (manual) +├── MLXExecutor.h # C++ executor interface (manual) +└── MLXBackend.cpp # ExecuTorch backend integration (manual) +``` + +## Schema Design Notes + +### Field Types + +- `Tid` - Tensor slot identifier (indexes into tensor array) +- `Vid` - Value slot identifier (indexes into values array for scalars) +- `IntOrVid` - Either a literal int64 or a Vid (for dynamic shapes) +- `FloatOrVid` - Either a literal double or a Vid +- `DTypeId` - Data type enum (f16, f32, bf16, i32, etc.) + +### Optional Fields + +FlatBuffer fields without `(required)` are optional. In the generated Python +dataclasses, these become `Optional[T]` with default `None`. + +For optional scalar fields that need a sentinel (to distinguish None from 0), +use the `= null` default: + +```flatbuffers +table MyNode { + value: float = null; // None by default, distinguishes None from 0.0 +} +``` + +This requires FlatBuffers 2.0+ (ExecuTorch uses 24.3.25). The generated Python +dataclass will have `value: Optional[float] = None`. + +## Troubleshooting + +### flatc not found + +Install FlatBuffers or specify the path: + +```bash +# macOS +brew install flatbuffers + +# Or specify path +python generate.py --flatc /usr/local/bin/flatc +``` + +### Import errors after regeneration + +Make sure you're running from the correct environment: + +```bash +conda run -n et-mlx python backends/mlx/serialization/generate.py +``` + +### Generated code doesn't match schema + +Delete all generated files and regenerate: + +```bash +rm -rf backends/mlx/serialization/_generated +rm backends/mlx/serialization/mlx_graph_schema.py +rm backends/mlx/serialization/_generated_serializers.py +python backends/mlx/serialization/generate.py +``` diff --git a/backends/mlx/serialization/__init__.py b/backends/mlx/serialization/__init__.py new file mode 100644 index 00000000000..35a4f0cef8a --- /dev/null +++ b/backends/mlx/serialization/__init__.py @@ -0,0 +1,32 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Serialization utilities for MLX delegate.""" + +from pathlib import Path + + +_schema_py = Path(__file__).parent / "mlx_graph_schema.py" +if not _schema_py.exists(): + raise ImportError( + "MLX delegate generated files not found. " + "Run 'python install_executorch.py' first." + ) + +# Export serialization functions for convenience +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( # noqa: F401, E501 + deserialize_to_json, + parse_header, + serialize_mlx_graph, +) + +__all__ = [ + "deserialize_to_json", + "parse_header", + "serialize_mlx_graph", +] diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py new file mode 100755 index 00000000000..d12743906db --- /dev/null +++ b/backends/mlx/serialization/generate.py @@ -0,0 +1,1437 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +""" +Code generator for MLX delegate. + +This is the SINGLE SOURCE OF TRUTH generator. Edit schema.fbs, then run: + python generate.py + +Generates: +1. FlatBuffer bindings (via flatc): + - _generated/ (Python) + - ../runtime/schema_generated.h (C++) +2. mlx_graph_schema.py (Python dataclasses) +3. _generated_serializers.py (Python serialization code) +4. ../runtime/MLXLoader.h (C++ structs, enums) - PARTIAL +5. ../runtime/MLXLoader.cpp (C++ loader switch) - PARTIAL + +Usage: + python generate.py [--flatc PATH_TO_FLATC] [--skip-flatc] +""" + +from __future__ import annotations + +import argparse +import re +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple + + +SCRIPT_DIR = Path(__file__).parent +SCHEMA_FBS = SCRIPT_DIR / "schema.fbs" +GENERATED_DIR = SCRIPT_DIR / "_generated" +GENERATED_SERIALIZERS = SCRIPT_DIR / "_generated_serializers.py" +GENERATED_SCHEMA_PY = SCRIPT_DIR / "mlx_graph_schema.py" +GENERATED_INSPECTOR = SCRIPT_DIR.parent / "_generated_inspector.py" +RUNTIME_DIR = SCRIPT_DIR.parent / "runtime" +LOADER_H_TMPL = SCRIPT_DIR / "MLXLoader.h.tmpl" +LOADER_CPP_TMPL = SCRIPT_DIR / "MLXLoader.cpp.tmpl" +LOADER_H = RUNTIME_DIR / "MLXLoader.h" +LOADER_CPP = RUNTIME_DIR / "MLXLoader.cpp" + + +@dataclass +class FBSEnum: + name: str + base_type: str # e.g., "byte" + values: List[Tuple[str, Optional[int]]] # (name, explicit_value or None) + + +@dataclass +class FBSField: + name: str + type_str: str + required: bool + default: Optional[str] + + +# FBS integer types (signed and unsigned) +FBS_INTEGER_TYPES = frozenset( + { + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + } +) + +# FBS float types +FBS_FLOAT_TYPES = frozenset({"float", "double"}) + +# All FBS primitive scalar types (numbers + bool) +FBS_SCALAR_TYPES = FBS_INTEGER_TYPES | FBS_FLOAT_TYPES | frozenset({"bool"}) + +# Compound "or" types that wrap a literal + Vid +FBS_COMPOUND_TYPES = frozenset({"IntOrVid", "FloatOrVid", "VidOrTid", "IntOrVidOrTid"}) + +# Python type mapping for FBS primitives +FBS_TO_PYTHON = { + "int8": "int", + "int16": "int", + "int32": "int", + "int64": "int", + "uint8": "int", + "uint16": "int", + "uint32": "int", + "uint64": "int", + "float": "float", + "double": "float", + "bool": "bool", + "string": "str", + "byte": "int", +} + +# C++ type mapping for FBS primitives +FBS_TO_CPP = { + "int8": "int8_t", + "int16": "int16_t", + "int32": "int32_t", + "int64": "int64_t", + "uint8": "uint8_t", + "uint16": "uint16_t", + "uint32": "uint32_t", + "uint64": "uint64_t", + "float": "float", + "double": "double", + "bool": "bool", + "string": "std::string", + "byte": "uint8_t", + "Tid": "Tid", + "Vid": "Vid", + "IntOrVid": "std::variant", + "FloatOrVid": "std::variant", +} + + +def _section_header(comment: str, title: str) -> List[str]: + """Generate a section-header banner for generated output.""" + sep = f"{comment} {'=' * 76}" + return [sep, f"{comment} {title}", sep, ""] + + +def _file_header(comment: str, description: str = "") -> List[str]: + """Generate a standard auto-generated file header. + + Args: + comment: Comment prefix, e.g. '#' for Python or '//' for C++. + description: Optional description appended after the banner. + """ + sep = f"{comment} {'=' * 76}" + lines = [ + f"{comment}", + f"{comment} Copyright (c) Meta Platforms, Inc. and affiliates.", + f"{comment} All rights reserved.", + f"{comment}", + f"{comment} This source code is licensed under the BSD-style license found in the", + f"{comment} LICENSE file in the root directory of this source tree.", + f"{comment}", + sep, + f"{comment} AUTO-GENERATED FILE - DO NOT EDIT MANUALLY", + sep, + f"{comment}", + f"{comment} This file was generated from schema.fbs by the MLX delegate code generator.", + f"{comment}", + f"{comment} Source: backends/mlx/serialization/schema.fbs", + f"{comment} Generator: backends/mlx/serialization/generate.py", + f"{comment}", + f"{comment} To regenerate, run from the executorch root:", + f"{comment} python backends/mlx/serialization/generate.py", + f"{comment}", + sep, + ] + if description: + lines.append(f"{comment}") + lines.append(f"{comment} {description}") + return lines + + +@dataclass +class FBSStruct: + name: str + fields: List[FBSField] + + +@dataclass +class FBSTable: + name: str + fields: List[FBSField] + + +@dataclass +class FBSUnion: + name: str + types: List[str] + + +@dataclass +class FBSSchema: + namespace: str + enums: List[FBSEnum] + structs: List[FBSStruct] + tables: List[FBSTable] + unions: List[FBSUnion] + + def get_op_nodes(self) -> List[FBSTable]: + """Get all tables that are part of the OpNode union.""" + op_union = next((u for u in self.unions if u.name == "OpNode"), None) + if not op_union: + return [] + op_names = set(op_union.types) + return [t for t in self.tables if t.name in op_names] + + +def parse_fbs(fbs_path: Path) -> FBSSchema: + """Parse a FlatBuffer schema file.""" + with open(fbs_path) as f: + content = f.read() + + # Remove comments + content = re.sub(r"//.*$", "", content, flags=re.MULTILINE) + + namespace = "" + enums: List[FBSEnum] = [] + structs: List[FBSStruct] = [] + tables: List[FBSTable] = [] + unions: List[FBSUnion] = [] + + # Parse namespace + ns_match = re.search(r"namespace\s+(\w+)\s*;", content) + if ns_match: + namespace = ns_match.group(1) + + # Parse enums + for match in re.finditer(r"enum\s+(\w+)\s*:\s*(\w+)\s*\{([^}]+)\}", content): + enum_name = match.group(1) + base_type = match.group(2) + body = match.group(3) + values = [] + for val_match in re.finditer(r"(\w+)\s*(?:=\s*(\d+))?", body): + name = val_match.group(1) + explicit_val = int(val_match.group(2)) if val_match.group(2) else None + values.append((name, explicit_val)) + enums.append(FBSEnum(enum_name, base_type, values)) + + # Parse structs + for match in re.finditer(r"struct\s+(\w+)\s*\{([^}]+)\}", content): + struct_name = match.group(1) + body = match.group(2) + fields = _parse_fields(body) + structs.append(FBSStruct(struct_name, fields)) + + # Parse tables + for match in re.finditer(r"table\s+(\w+)\s*\{([^}]*)\}", content): + table_name = match.group(1) + body = match.group(2) + fields = _parse_fields(body) + tables.append(FBSTable(table_name, fields)) + + # Parse unions + for match in re.finditer(r"union\s+(\w+)\s*\{([^}]+)\}", content): + union_name = match.group(1) + body = match.group(2) + types = [t.strip() for t in body.split(",") if t.strip()] + unions.append(FBSUnion(union_name, types)) + + return FBSSchema(namespace, enums, structs, tables, unions) + + +def _parse_fields(body: str) -> List[FBSField]: + """Parse fields from a struct/table body.""" + fields = [] + for line in body.split(";"): + line = line.strip() + if not line: + continue + + # Parse: name: type (attributes) = default + match = re.match( + r"(\w+)\s*:\s*(\[?\w+\]?)\s*(?:\(([^)]*)\))?\s*(?:=\s*([^;]+))?", line + ) + if match: + name = match.group(1) + type_str = match.group(2) + attrs = match.group(3) or "" + default = match.group(4).strip() if match.group(4) else None + required = "required" in attrs + fields.append(FBSField(name, type_str, required, default)) + + return fields + + +# Config for compound type factory methods. +# Maps compound type name -> (primary_field_name, primary_python_type, description) +_COMPOUND_TYPE_CONFIG = { + "IntOrVid": ("literal", "int", "a literal integer"), + "FloatOrVid": ("literal", "float", "a literal float"), + "VidOrTid": ("tid", "Tid", "a tensor reference"), + "IntOrVidOrTid": ("literal", "int", "a literal integer"), +} + + +def _generate_compound_type(table: FBSTable) -> List[str]: # noqa: C901 + """Generate a Python dataclass for a compound type (IntOrVid, etc.) from schema.""" + name = table.name + config = _COMPOUND_TYPE_CONFIG.get(name) + if not config: + raise ValueError(f"No compound type config for '{name}'") + + primary_field, primary_py_type, primary_desc = config + + # Build the docstring from the schema structure + lines = [ + "@dataclass", + f"class {name}:", + ] + + # Docstring: describe the two alternatives + lines.append( + f' """Represents either {primary_desc} or a runtime Vid reference."""' + ) + + # Dataclass fields from the parsed schema + for fld in table.fields: + if fld.default == "false": + default = "False" + elif fld.default == "true": + default = "True" + elif fld.type_str in ("Tid", "Vid"): + default = "None" + elif fld.default is not None: + default = fld.default + elif fld.type_str in FBS_INTEGER_TYPES: + default = "0" + elif fld.type_str in FBS_FLOAT_TYPES: + default = "0.0" + else: + default = "None" + truly_required = default != "None" + py_type = _fbs_type_to_python(fld.type_str, truly_required) + lines.append(f" {fld.name}: {py_type} = {default}") + + # Check if this is a 3-way discriminator (IntOrVidOrTid uses 'kind') + has_kind = any(fld.name == "kind" for fld in table.fields) + has_tid = any(fld.name == "tid" for fld in table.fields) + + # Factory: from_primary (e.g. from_literal, from_tid) + lines.append("") + lines.append(" @classmethod") + lines.append( + f' def from_{primary_field}(cls, value: {primary_py_type}) -> "{name}":' + ) + lines.append(f' """Create a {name} from {primary_desc}."""') + if has_kind: + lines.append(f" return cls({primary_field}=value, kind=0)") + else: + lines.append(f" return cls({primary_field}=value, is_vid=False)") + + # Factory: from_vid + lines.append("") + lines.append(" @classmethod") + lines.append(f' def from_vid(cls, vid: Vid) -> "{name}":') + lines.append(f' """Create a {name} from a Vid reference."""') + if has_kind: + lines.append(" return cls(vid=vid, kind=1)") + else: + lines.append(" return cls(vid=vid, is_vid=True)") + + # Factory: from_tid (only for types with a tid field) + if has_tid: + lines.append("") + lines.append(" @classmethod") + lines.append(f' def from_tid(cls, tid: Tid) -> "{name}":') + lines.append(f' """Create a {name} from a Tid tensor reference."""') + if has_kind: + lines.append(" return cls(tid=tid, kind=2)") + else: + lines.append(" return cls(tid=tid, is_vid=False)") + + lines.append("") + return lines + + +def _generate_dataclass(table: FBSTable) -> List[str]: + """Generate a Python @dataclass from a parsed FBS table. + + Handles field ordering (required/defaulted before optional), skips + _is_set sentinel fields, and emits proper type annotations with defaults. + """ + lines = ["@dataclass", f"class {table.name}:"] + fields = [f for f in table.fields if not f.name.endswith("_is_set")] + if not fields: + lines.append(" pass") + else: + required_fields = [f for f in fields if f.required or f.default is not None] + optional_fields = [f for f in fields if not f.required and f.default is None] + + for fld in required_fields: + py_type = _fbs_type_to_python(fld.type_str, True) + default = _fbs_default_to_python(fld.default, fld.type_str) + if default is not None: + lines.append(f" {fld.name}: {py_type} = {default}") + else: + lines.append(f" {fld.name}: {py_type}") + + for fld in optional_fields: + py_type = _fbs_type_to_python(fld.type_str, fld.required) + lines.append(f" {fld.name}: {py_type} = None") + + lines.extend(["", ""]) + return lines + + +def generate_python_schema(schema: FBSSchema) -> str: # noqa: C901 + """Generate mlx_graph_schema.py from parsed FBS.""" + lines = _file_header("#") + lines.extend( + [ + "", + "from __future__ import annotations", + "", + "from dataclasses import dataclass, field", + "from enum import IntEnum", + "from typing import List, Optional, Union", + "", + "", + *_section_header("#", "Enums"), + ] + ) + + # Generate enums + for enum in schema.enums: + lines.append(f"class {enum.name}(IntEnum):") + val = 0 + for name, explicit_val in enum.values: + if explicit_val is not None: + val = explicit_val + lines.append(f" {name} = {val}") + val += 1 + lines.append("") + lines.append("") + + lines.extend(_section_header("#", "Core types")) + + # Generate structs (Tid, Vid) + for struct in schema.structs: + lines.append("@dataclass") + lines.append(f"class {struct.name}:") + for fld in struct.fields: + py_type = _fbs_type_to_python(fld.type_str, fld.required) + default = _fbs_default_to_python(fld.default, fld.type_str) + if default: + lines.append(f" {fld.name}: {py_type} = {default}") + else: + lines.append(f" {fld.name}: {py_type}") + lines.append("") + lines.append("") + + # Generate compound types (IntOrVid, FloatOrVid, TidOrVid) from schema + for type_name in sorted(FBS_COMPOUND_TYPES): + table = next((t for t in schema.tables if t.name == type_name), None) + if table: + lines.extend(_generate_compound_type(table)) + lines.append("") + + # Generate ShapeDim, SlotVariant, NamedSlot, TensorMeta (but not Instruction/MLXGraph yet - they reference OpNode) + other_tables = ["ShapeDim", "SlotVariant", "NamedSlot", "TensorMeta"] + for table_name in other_tables: + table = next((t for t in schema.tables if t.name == table_name), None) + if table: + lines.extend(_generate_dataclass(table)) + + lines.extend(_section_header("#", "Op nodes")) + + # Generate op node dataclasses + op_nodes = schema.get_op_nodes() + for table in op_nodes: + lines.extend(_generate_dataclass(table)) + + # Generate OpNodeUnion type alias + op_names = [t.name for t in op_nodes] + lines.append("# Union of all op types") + lines.append("OpNodeUnion = Union[") + for name in op_names: + lines.append(f" {name},") + lines.append("]") + lines.append("") + + # Generate Instruction and MLXGraph (these reference OpNode so must come after) + lines.extend( + [ + *_section_header("#", "Container types (reference OpNodeUnion)"), + "@dataclass", + "class Instruction:", + " op: OpNodeUnion", + "", + "", + "@dataclass", + "class InstructionChain:", + " instructions: List[Instruction]", + "", + "", + "@dataclass", + "class MLXGraph:", + " instruction_chains: List[InstructionChain]", + " version: Optional[str] = None", + " num_constant_tensors: int = 0", + " num_input_tensors: int = 0", + " num_output_tensors: int = 0", + " num_mutable_buffer_tensors: int = 0", + " num_temp_tensors: int = 0", + " num_values: int = 0", + " main_chain_idx: int = 0", + " init_chain_idx: int = -1", + " input_map: Optional[List[SlotVariant]] = None", + " output_map: Optional[List[SlotVariant]] = None", + " mutable_buffer_map: Optional[List[SlotVariant]] = None", + " named_slots: Optional[List[NamedSlot]] = None", + " tensor_meta: Optional[List[TensorMeta]] = None", + "", + ] + ) + + return "\n".join(lines) + + +def _fbs_type_to_python(fbs_type: str, required: bool) -> str: + """Convert FBS type to Python type annotation. + + When required=False, the result is wrapped in Optional[…] for all types + (scalars, lists, and reference types alike). + """ + # Handle arrays + if fbs_type.startswith("[") and fbs_type.endswith("]"): + inner = fbs_type[1:-1] + inner_py = _fbs_type_to_python(inner, True) + base = f"List[{inner_py}]" + return base if required else f"Optional[{base}]" + + py_type = FBS_TO_PYTHON.get(fbs_type, fbs_type) + + if not required: + return f"Optional[{py_type}]" + + return py_type + + +def _fbs_default_to_python(default: Optional[str], fbs_type: str) -> Optional[str]: + """Convert FBS default value to Python.""" + if default is None: + return None + + if default == "false": + return "False" + if default == "true": + return "True" + if default == "null": + return "None" + + # Handle enum defaults like 'TensorSlot' + if fbs_type == "SlotType": + return f"SlotType.{default}" + + # Numeric defaults + return default + + +def generate_python_serializers(schema: FBSSchema) -> str: + """Generate _generated_serializers.py from parsed FBS.""" + op_nodes = schema.get_op_nodes() + op_union = next((u for u in schema.unions if u.name == "OpNode"), None) + + header = _file_header( + "#", + "This file contains auto-generated serializer methods for all op types.", + ) + + # Imports and module-level code + op_imports = ",\n".join(f" {t.name}" for t in op_nodes) + lines = [ + *header, + "", + "from __future__ import annotations", + "", + "from typing import List, Tuple, Dict", + "", + "import flatbuffers", + "", + ] + + # Generate op type names dict from union order + lines.append( + "# FlatBuffer union indices: 0 = NONE, then 1-indexed from union order" + ) + lines.append("MLX_OP_TYPE_NAMES = {") + lines.append(' 0: "NONE",') + if op_union: + for i, type_name in enumerate(op_union.types, start=1): + lines.append(f' {i}: "{type_name}",') + lines.append("}") + lines.append("") + + lines.extend( + [ + "from executorch.backends.mlx.serialization.mlx_graph_schema import (", + f"{op_imports},", + " IntOrVid,", + " FloatOrVid,", + " VidOrTid,", + " IntOrVidOrTid,", + " Tid,", + " Vid,", + ")", + "", + "", + "def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int:", + ' """Build a vector of int32."""', + " builder.StartVector(4, len(vec), 4)", + " for v in reversed(vec):", + " builder.PrependInt32(v)", + " return builder.EndVector()", + "", + "", + "class GeneratedOpBuilders:", + ' """Mixin class with auto-generated op builder methods."""', + "", + " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> int:", + ' """Build an IntOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVid as FBIntOrVidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBIntOrVidModule.Start(builder)", + " FBIntOrVidModule.AddLiteral(builder, iov.literal)", + " FBIntOrVidModule.AddIsVid(builder, iov.is_vid)", + " if iov.vid is not None:", + " # Vid is an inline struct - must be added last for proper FlatBuffer layout", + " FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx))", + " return FBIntOrVidModule.End(builder)", + "", + " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> int:", + ' """Build a FloatOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import FloatOrVid as FBFloatOrVidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBFloatOrVidModule.Start(builder)", + " FBFloatOrVidModule.AddLiteral(builder, fov.literal)", + " FBFloatOrVidModule.AddIsVid(builder, fov.is_vid)", + " if fov.vid is not None:", + " FBFloatOrVidModule.AddVid(builder, CreateVid(builder, fov.vid.idx))", + " return FBFloatOrVidModule.End(builder)", + "", + " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> int:", + ' """Build a TidOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import VidOrTid as FBVidOrTidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBVidOrTidModule.Start(builder)", + " FBVidOrTidModule.AddIsVid(builder, vot.is_vid)", + " if vot.tid is not None:", + " FBVidOrTidModule.AddTid(builder, CreateTid(builder, vot.tid.idx))", + " if vot.vid is not None:", + " FBVidOrTidModule.AddVid(builder, CreateVid(builder, vot.vid.idx))", + " return FBVidOrTidModule.End(builder)", + "", + " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> int:", + ' """Build an IntOrVidOrTid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVidOrTid as FBIntOrVidOrTidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBIntOrVidOrTidModule.Start(builder)", + " FBIntOrVidOrTidModule.AddLiteral(builder, ivt.literal)", + " FBIntOrVidOrTidModule.AddKind(builder, ivt.kind)", + " if ivt.tid is not None:", + " FBIntOrVidOrTidModule.AddTid(builder, CreateTid(builder, ivt.tid.idx))", + " if ivt.vid is not None:", + " FBIntOrVidOrTidModule.AddVid(builder, CreateVid(builder, ivt.vid.idx))", + " return FBIntOrVidOrTidModule.End(builder)", + "", + " def _build_int_or_vid_vector(", + " self, builder: flatbuffers.Builder, vec: List[IntOrVid]", + " ) -> int:", + ' """Build a vector of IntOrVid tables."""', + " offsets = []", + " for iov in vec:", + " offsets.append(self._build_int_or_vid(builder, iov))", + " builder.StartVector(4, len(offsets), 4)", + " for off in reversed(offsets):", + " builder.PrependUOffsetTRelative(off)", + " return builder.EndVector()", + "", + " def _build_tid_vector(", + " self, builder: flatbuffers.Builder, vec: List[Tid]", + " ) -> int:", + ' """Build a vector of Tid structs."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + "", + " # For vectors of structs, we need to build the vector differently", + " # Each Tid struct is 4 bytes (uint32), so we manually write them", + " builder.StartVector(4, len(vec), 4)", + " for tid in reversed(vec):", + " builder.Prep(4, 0) # Align for struct", + " builder.PrependUint32(tid.idx)", + " return builder.EndVector()", + "", + ] + ) + + # Generate builder methods for each op + for table in op_nodes: + lines.append(_generate_op_builder_method(table)) + + return "\n".join(lines) + + +def _generate_op_builder_method(table: FBSTable) -> str: + """Generate a _build_XxxNode method for the serializer class.""" + class_name = table.name + fb_module_name = f"FB{class_name}Module" + + lines = [ + f" def _build_{class_name}(", + f" self, builder: flatbuffers.Builder, op: {class_name}", + " ) -> Tuple[int, int]:", + f' """Auto-generated builder for {class_name}."""', + " # Import the MODULE (not class) to access builder functions like Start(), Add*(), End()", + f" from executorch.backends.mlx.serialization._generated.mlx_delegate import {class_name} as {fb_module_name}", + " from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + ] + + # Pre-build any strings or vectors (must be done before Start) + prebuild_lines = [] + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + kind = _get_field_kind(fld, table) + pb = _emit_py_prebuild(kind, fld) + if pb: + prebuild_lines.extend(pb) + + if prebuild_lines: + lines.extend(prebuild_lines) + lines.append("") + + # Start the FlatBuffer table + lines.append(f" {fb_module_name}.Start(builder)") + + # Add each field + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + fb_field_name = _to_pascal_case(fld.name) + kind = _get_field_kind(fld, table) + add_lines = _emit_py_add(kind, fld, fb_module_name, fb_field_name) + if add_lines is None: + raise ValueError( + f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _emit_py_add()." + ) + lines.extend(add_lines) + + # End the FlatBuffer table and return offset + union type + lines.append(f" offset = {fb_module_name}.End(builder)") + lines.append(f" return offset, FBOpNodeModule.OpNode.{class_name}") + lines.append("") + + return "\n".join(lines) + + +# Prebuild emitters: return list of lines or None if no prebuild needed. +# These build offsets/vectors that must be created before FlatBuffer Start(). + +_PY_PREBUILD_VECTOR = { + "list_int": "_build_int_vector(builder, op.{name})", + "list_int_or_vid": "self._build_int_or_vid_vector(builder, op.{name})", + "list_tid": "self._build_tid_vector(builder, op.{name})", +} + +_PY_PREBUILD_OFFSET = { + "str": "builder.CreateString(op.{name})", + "int_or_vid": "self._build_int_or_vid(builder, op.{name})", + "float_or_vid": "self._build_float_or_vid(builder, op.{name})", + "vid_or_tid": "self._build_vid_or_tid(builder, op.{name})", + "int_or_vid_or_tid": "self._build_int_or_vid_or_tid(builder, op.{name})", + "optional_str": "builder.CreateString(op.{name}) if op.{name} is not None else None", +} + + +def _emit_py_prebuild(kind: str, fld: FBSField) -> List[str]: + """Emit prebuild lines for a field kind, or empty list if none needed.""" + n = fld.name + if kind in _PY_PREBUILD_VECTOR: + expr = _PY_PREBUILD_VECTOR[kind].format(name=n) + if fld.required: + return [f" {n}_vec = {expr}"] + else: + return [f" {n}_vec = {expr} if op.{n} is not None else None"] + if kind in _PY_PREBUILD_OFFSET: + suffix = "_off" + expr = _PY_PREBUILD_OFFSET[kind].format(name=n) + return [f" {n}{suffix} = {expr}"] + return [] + + +# Maps struct kinds to their Python Create function name +_PY_STRUCT_CREATOR = {"tid": "CreateTid", "vid": "CreateVid"} + + +def _emit_py_add( + kind: str, fld: FBSField, mod: str, fb_name: str +) -> "List[str] | None": + """Emit Add lines for a field kind, or None if kind is unrecognized.""" + n = fld.name + add = f"{mod}.Add{fb_name}" + + # Required struct via inline Create call + if kind in _PY_STRUCT_CREATOR: + creator = _PY_STRUCT_CREATOR[kind] + return [f" {add}(builder, {creator}(builder, op.{n}.idx))"] + # Scalars (direct value) + if kind in ("int", "float", "bool"): + return [f" {add}(builder, op.{n})"] + # Pre-built offsets (string, compound types) + if kind in ("str", "int_or_vid", "float_or_vid", "vid_or_tid", "int_or_vid_or_tid"): + return [f" {add}(builder, {n}_off)"] + # Pre-built vectors (required vs optional) + if kind in ("list_int", "list_int_or_vid", "list_tid"): + if fld.required: + return [f" {add}(builder, {n}_vec)"] + return [ + f" if {n}_vec is not None:", + f" {add}(builder, {n}_vec)", + ] + # Optional struct via inline Create call + if kind in ("optional_tid", "optional_vid"): + creator = _PY_STRUCT_CREATOR[kind.removeprefix("optional_")] + return [ + f" if op.{n} is not None:", + f" {add}(builder, {creator}(builder, op.{n}.idx))", + ] + # Optional scalars + if kind in ("optional_float", "optional_int"): + return [ + f" if op.{n} is not None:", + f" {add}(builder, op.{n})", + ] + # Optional string offset + if kind == "optional_str": + return [ + f" if {n}_off is not None:", + f" {add}(builder, {n}_off)", + ] + return None + + +def _get_field_kind(fld: FBSField, table: FBSTable) -> str: # noqa: C901 + """Classify a field into a canonical kind string. + + This is the single source of truth for field classification, used by all + generators (Python builder, C++ loader, and inspector via _INSPECTOR_KIND_MAP). + """ + t = fld.type_str + + # Handle arrays + if t.startswith("[") and t.endswith("]"): + inner = t[1:-1] + if inner in FBS_INTEGER_TYPES: + return "list_int" + if inner == "IntOrVid": + return "list_int_or_vid" + if inner == "Tid": + return "list_tid" + raise ValueError( + f"Unrecognized array element type '{inner}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _get_field_kind()." + ) + + # Handle basic types + if t == "Tid": + return "optional_tid" if not fld.required else "tid" + if t == "Vid": + return "optional_vid" if not fld.required else "vid" + if t == "IntOrVid": + return "int_or_vid" + if t == "FloatOrVid": + return "float_or_vid" + if t == "VidOrTid": + return "vid_or_tid" + if t == "IntOrVidOrTid": + return "int_or_vid_or_tid" + if t in FBS_INTEGER_TYPES: + if fld.default == "null": + return "optional_int" + return "int" + if t in FBS_FLOAT_TYPES: + # Check if this is optional (has = null default) + if fld.default == "null": + return "optional_float" + return "float" + if t == "bool": + return "bool" + if t == "string": + return "optional_str" if not fld.required else "str" + + raise ValueError( + f"Unrecognized field type '{t}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _get_field_kind()." + ) + + +def _to_pascal_case(name: str) -> str: + """Convert snake_case to PascalCase.""" + # Handle special cases + if name == "table_": + return "Table_" + parts = name.split("_") + return "".join(p.capitalize() for p in parts) + + +def generate_cpp_loader_h(schema: FBSSchema) -> str: + """Generate MLXLoader.h from parsed FBS using template.""" + op_nodes = schema.get_op_nodes() + + struct_lines = [] + for table in op_nodes: + struct_lines.append(f"struct {table.name} {{") + if not table.fields: + struct_lines.append("};") + else: + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + cpp_type = _fbs_type_to_cpp(fld.type_str, fld.required, table, fld) + struct_lines.append(f" {cpp_type} {fld.name};") + struct_lines.append("};") + struct_lines.append("") + + enum_lines = [] + for table in op_nodes: + enum_lines.append(f" {_table_name_to_opcode(table.name)},") + + name_lines = [] + for table in op_nodes: + op_code = _table_name_to_opcode(table.name) + name_lines.append(f" case OpCode::{op_code}:") + name_lines.append(f' return "{op_code}";') + + variant_lines = [] + for i, table in enumerate(op_nodes): + comma = "," if i < len(op_nodes) - 1 else "" + variant_lines.append(f" {table.name}{comma}") + + # Read template and fill placeholders + header = "\n".join(_file_header("//")) + "\n//\n" + tmpl = LOADER_H_TMPL.read_text() + result = tmpl.replace("{{OP_NODE_STRUCTS}}", "\n".join(struct_lines)) + result = result.replace("{{OPCODE_ENUM_VALUES}}", "\n".join(enum_lines)) + result = result.replace("{{OP_NAME_CASES}}", "\n".join(name_lines)) + result = result.replace("{{NODE_VARIANT_TYPES}}", "\n".join(variant_lines)) + return header + result + + +def _fbs_type_to_cpp( + fbs_type: str, + required: bool, + table: Optional["FBSTable"] = None, + fld: Optional["FBSField"] = None, +) -> str: + """Convert FBS type to C++ type. + + Args: + fbs_type: The FlatBuffer type string + required: Whether the field is required + table: Optional table context for type inference + fld: Optional field context for the current field + + Note: Most scalar types (float, int, etc.) are never optional in C++. + The Python serialization layer is responsible for ensuring scalar fields + have values (using defaults if user doesn't provide them). + Reference types (Tid, Vid) and DTypeId with '= null' default can be optional. + """ + # Handle arrays + if fbs_type.startswith("[") and fbs_type.endswith("]"): + inner = fbs_type[1:-1] + inner_cpp = _fbs_type_to_cpp(inner, True) + return f"std::vector<{inner_cpp}>" + + cpp_type = FBS_TO_CPP.get(fbs_type, fbs_type) + + # Handle optional types + if not required: + if fbs_type == "Tid": + return "std::optional" + if fbs_type == "Vid": + return "std::optional" + if fld is not None and fld.default == "null" and fbs_type in FBS_TO_CPP: + return f"std::optional<{cpp_type}>" + + return cpp_type + + +_OPCODE_OVERRIDES = { + "ARange": "ARANGE", + "AsType": "ASTYPE", + "Conv1D": "CONV1D", + "Conv2D": "CONV2D", + "Conv3D": "CONV3D", + "ConvTranspose1D": "CONV_TRANSPOSE1D", + "ConvTranspose2D": "CONV_TRANSPOSE2D", + "ConvTranspose3D": "CONV_TRANSPOSE3D", +} + + +def _table_name_to_opcode(name: str) -> str: + """Convert table name like 'LinearNode' to opcode like 'LINEAR'. + + Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small + override dict for names whose conventional opcode doesn't follow the + normal camelCase splitting rules (e.g. Conv1D → CONV1D, not CONV1_D). + """ + name = name.removesuffix("Node") + if name in _OPCODE_OVERRIDES: + return _OPCODE_OVERRIDES[name] + s = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name) + s = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", s) + return s.upper() + + +def generate_cpp_loader_cpp(schema: FBSSchema) -> str: + """Generate MLXLoader.cpp from parsed FBS using template.""" + op_nodes = schema.get_op_nodes() + + case_lines = [] + for table in op_nodes: + case_lines.extend(_generate_loader_case(table)) + + # Read template and fill placeholders + header = "\n".join(_file_header("//")) + "\n" + tmpl = LOADER_CPP_TMPL.read_text() + result = tmpl.replace("{{LOAD_INSTRUCTION_CASES}}", "\n".join(case_lines)) + return header + result + + +def _generate_loader_case(table: FBSTable) -> List[str]: + """Generate a switch case for loading an op node.""" + class_name = table.name + op_code = _table_name_to_opcode(class_name) + + lines = [ + f" case mlx_delegate::OpNode_{class_name}: {{", + ] + + if not table.fields: + # NoopNode case + lines.extend( + [ + f" instr.op = OpCode::{op_code};", + f" instr.node = {class_name}{{}};", + " break;", + " }", + "", + ] + ) + return lines + + lines.append(f" auto fb = fb_instr->op_as_{class_name}();") + lines.append(" if (!fb) {{") + lines.append( + ' throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}");' + ) + lines.append(" }}") + lines.append(f" {class_name} node;") + + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + + fb_field_name = fld.name + kind = _get_field_kind(fld, table) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name) + if load_lines is None: + raise ValueError( + f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _emit_cpp_load()." + ) + lines.extend(load_lines) + + lines.extend( + [ + f" instr.op = OpCode::{op_code};", + " instr.node = std::move(node);", + " break;", + " }", + "", + ] + ) + + return lines + + +# Maps kinds to their C++ converter function name +_CPP_CONVERTER = { + "tid": "convert_tid", + "vid": "convert_vid", + "int_or_vid": "convert_int_or_vid", + "float_or_vid": "convert_float_or_vid", + "vid_or_tid": "convert_vid_or_tid", + "int_or_vid_or_tid": "convert_int_or_vid_or_tid", +} + + +def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": + """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" + # Required struct / compound via converter + if kind in _CPP_CONVERTER: + conv = _CPP_CONVERTER[kind] + return [f" node.{name} = {conv}(fb->{fb_name}());"] + # Scalars (direct value) + if kind in ("int", "float", "bool"): + return [f" node.{name} = fb->{fb_name}();"] + # Required string + if kind == "str": + return [f' node.{name} = fb->{fb_name}() ? fb->{fb_name}()->str() : "";'] + # Optional struct / compound via guarded converter + base_kind = kind.removeprefix("optional_") + if kind.startswith("optional_") and base_kind in _CPP_CONVERTER: + conv = _CPP_CONVERTER[base_kind] + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = {conv}(fb->{fb_name}());", + " }", + ] + # Optional scalar (FlatBuffers returns flatbuffers::Optional) + if kind in ("optional_float", "optional_int"): + return [ + f" auto {fb_name}_opt = fb->{fb_name}();", + f" if ({fb_name}_opt.has_value()) {{", + f" node.{name} = {fb_name}_opt.value();", + " }", + ] + # Optional string + if kind == "optional_str": + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = fb->{fb_name}()->str();", + " }", + ] + # Integer/bool vector via to_vector + if kind == "list_int": + return [f" node.{name} = to_vector(fb->{fb_name}());"] + # Int-or-vid vector (indexed access) + if kind == "list_int_or_vid": + return [ + f" if (fb->{fb_name}()) {{", + f" for (size_t i = 0; i < fb->{fb_name}()->size(); ++i) {{", + f" node.{name}.push_back(convert_int_or_vid(fb->{fb_name}()->Get(static_cast(i))));", + " }", + " }", + ] + # Tid vector (range-based iteration) + if kind == "list_tid": + return [ + f" if (fb->{fb_name}()) {{", + f" for (auto fb_tid : *fb->{fb_name}()) {{", + f" node.{name}.push_back(convert_tid(fb_tid));", + " }", + " }", + ] + return None + + +def run_flatc(flatc_path: str = "flatc") -> bool: + """Run flatc to generate Python and C++ bindings.""" + print(f"Running flatc on {SCHEMA_FBS}...") + + # Create output directories + GENERATED_DIR.mkdir(parents=True, exist_ok=True) + + success = True + + # Generate Python bindings + cmd_py = [ + flatc_path, + "--python", + "-o", + str(GENERATED_DIR), + str(SCHEMA_FBS), + ] + try: + result = subprocess.run(cmd_py, capture_output=True, text=True) + if result.returncode != 0: + print(f"flatc (Python) failed: {result.stderr}") + success = False + else: + print(f"Generated FlatBuffer Python bindings in {GENERATED_DIR}") + except FileNotFoundError: + print(f"flatc not found at '{flatc_path}'. Skipping FlatBuffer generation.") + success = False + + # Generate C++ bindings + cmd_cpp = [ + flatc_path, + "--cpp", + "-o", + str(RUNTIME_DIR), + str(SCHEMA_FBS), + ] + try: + result = subprocess.run(cmd_cpp, capture_output=True, text=True) + if result.returncode != 0: + print(f"flatc (C++) failed: {result.stderr}") + success = False + else: + print(f"Generated FlatBuffer C++ bindings in {RUNTIME_DIR}") + except FileNotFoundError: + success = False + + return success + + +_FLATC_IMPORT_PREFIX = "executorch.backends.mlx.serialization._generated." + + +def _fixup_flatc_imports() -> None: + """Rewrite bare ``from mlx_delegate.X`` imports in generated FlatBuffer code. + + ``flatc --python`` emits lazy imports like ``from mlx_delegate.Tid import Tid`` + inside accessor methods. These only resolve if the ``_generated/`` directory is + on ``sys.path``. We rewrite them to fully-qualified imports so no ``sys.path`` + manipulation is needed at runtime. + """ + fb_dir = GENERATED_DIR / "mlx_delegate" + if not fb_dir.exists(): + return + + count = 0 + for py_file in fb_dir.glob("*.py"): + content = py_file.read_text() + if "from mlx_delegate." not in content: + continue + new_content = content.replace( + "from mlx_delegate.", f"from {_FLATC_IMPORT_PREFIX}mlx_delegate." + ) + if new_content != content: + py_file.write_text(new_content) + count += 1 + + if count: + print(f"Fixed bare imports in {count} generated FlatBuffer file(s)") + + +# Mapping from fine-grained field kinds (from _get_field_kind) to inspector +# display kinds. The inspector uses coarser categories: optional/required +# distinctions collapse, and int/float/bool all map to "scalar". +_INSPECTOR_KIND_MAP = { + "tid": "tid", + "optional_tid": "tid", + "vid": "vid", + "optional_vid": "vid", + "int_or_vid": "int_or_vid", + "float_or_vid": "float_or_vid", + "vid_or_tid": "vid_or_tid", + "int_or_vid_or_tid": "int_or_vid_or_tid", + "list_int": "int_list", + "list_int_or_vid": "int_or_vid_list", + "list_tid": "tid_list", + "int": "scalar", + "optional_int": "scalar", + "float": "scalar", + "optional_float": "scalar", + "bool": "scalar", + "str": "string", + "optional_str": "string", +} + + +def generate_inspector(schema: "Schema") -> str: # noqa: F821 + """Generate the inspector field mappings file.""" + lines = _file_header("#") + lines.extend( + [ + "", + '"""', + "Auto-generated inspector field mappings for MLX delegate.", + "", + "This module provides field metadata for each op node type, enabling", + "the pte_inspector to parse FlatBuffer op nodes without manually", + "maintaining field mappings.", + '"""', + "", + "from __future__ import annotations", + "", + "from typing import Dict, List, Tuple", + "", + "", + "# Field kinds and their extractors", + "# Each field is a tuple of (display_name, accessor_name, kind)", + "# where kind is one of: 'tid', 'vid', 'int_or_vid', 'float_or_vid',", + "# 'int_list', 'int_or_vid_list', 'tid_list', 'scalar', 'string'", + "", + "FieldSpec = Tuple[str, str, str] # (display_name, accessor_name, kind)", + "", + "", + "# Mapping from op node name to list of field specs", + "OP_NODE_FIELDS: Dict[str, List[FieldSpec]] = {", + ] + ) + + op_nodes = schema.get_op_nodes() + + for table in op_nodes: + lines.append(f' "{table.name}": [') + for fld in table.fields: + # Skip fields ending in _is_set (legacy pattern) + if fld.name.endswith("_is_set"): + continue + + kind = _get_field_kind(fld, table) + inspector_kind = _INSPECTOR_KIND_MAP.get(kind) + if inspector_kind is None: + raise ValueError( + f"No inspector mapping for field kind '{kind}' " + f"(field '{fld.name}' in table '{table.name}'). " + f"Add a mapping in _INSPECTOR_KIND_MAP." + ) + accessor = _to_pascal_case(fld.name) + lines.append(f' ("{fld.name}", "{accessor}", "{inspector_kind}"),') + lines.append(" ],") + + lines.append("}") + lines.append("") + lines.append("") + + # Add the list of op node names for import generation + lines.append("# List of all op node names (for dynamic imports)") + lines.append("OP_NODE_NAMES: List[str] = [") + for table in op_nodes: + lines.append(f' "{table.name}",') + lines.append("]") + lines.append("") + + return "\n".join(lines) + + +def main(): # noqa: C901 + parser = argparse.ArgumentParser( + description="Generate MLX delegate code from schema.fbs" + ) + parser.add_argument( + "--flatc", + default="flatc", + help="Path to flatc compiler", + ) + parser.add_argument( + "--skip-flatc", + action="store_true", + help="Skip running flatc (use existing generated files)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be generated without writing files", + ) + args = parser.parse_args() + + print(f"Parsing {SCHEMA_FBS}...") + schema = parse_fbs(SCHEMA_FBS) + print( + f" Found {len(schema.enums)} enums, {len(schema.structs)} structs, " + f"{len(schema.tables)} tables, {len(schema.unions)} unions" + ) + print(f" Op nodes: {len(schema.get_op_nodes())}") + + # Run flatc + if not args.skip_flatc: + run_flatc(args.flatc) + _fixup_flatc_imports() + + # Generate all code files + generators = [ + (generate_python_schema, GENERATED_SCHEMA_PY, "mlx_graph_schema.py"), + ( + generate_python_serializers, + GENERATED_SERIALIZERS, + "_generated_serializers.py", + ), + (generate_cpp_loader_h, LOADER_H, "MLXLoader.h"), + (generate_cpp_loader_cpp, LOADER_CPP, "MLXLoader.cpp"), + (generate_inspector, GENERATED_INSPECTOR, "_generated_inspector.py"), + ] + for gen_fn, output_path, label in generators: + print(f"Generating {output_path}...") + content = gen_fn(schema) + if args.dry_run: + print(f"--- {label} (first 50 lines) ---") + print("\n".join(content.split("\n")[:50])) + else: + with open(output_path, "w") as f: + f.write(content) + + # Create __init__.py for _generated package that re-exports from mlx_delegate + init_file = GENERATED_DIR / "__init__.py" + if not args.dry_run: + init_file.parent.mkdir(parents=True, exist_ok=True) + + # Get all the exports from mlx_delegate (tables, enums, structs, and unions) + exports = [] + for table in schema.tables: + exports.append(table.name) + for enum in schema.enums: + exports.append(enum.name) + for struct in schema.structs: + exports.append(struct.name) + for union in schema.unions: + exports.append(union.name) + + # Create __init__.py with re-exports + init_content = """# Auto-generated FlatBuffer bindings +# Re-exports from mlx_delegate namespace for convenient imports + +""" + # Add imports from mlx_delegate + for export in sorted(exports): + init_content += f"from executorch.backends.mlx.serialization._generated.mlx_delegate.{export} import {export}\n" + + init_content += f"\n__all__ = {sorted(exports)!r}\n" + init_file.write_text(init_content) + + print("Done!") + print("") + print("Generated files:") + print(f" - {GENERATED_SCHEMA_PY}") + print(f" - {GENERATED_SERIALIZERS}") + print(f" - {GENERATED_INSPECTOR}") + print(f" - {LOADER_H}") + print(f" - {LOADER_CPP}") + if not args.skip_flatc: + print(f" - {GENERATED_DIR}/ (FlatBuffer Python bindings)") + print(f" - {RUNTIME_DIR}/schema_generated.h (FlatBuffer C++ bindings)") + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/serialization/mlx_graph_serialize.py b/backends/mlx/serialization/mlx_graph_serialize.py new file mode 100644 index 00000000000..db5acc9048f --- /dev/null +++ b/backends/mlx/serialization/mlx_graph_serialize.py @@ -0,0 +1,416 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Serialization utilities for MLX delegate. + +Converts MLXGraph dataclasses to FlatBuffer binary format. + +Constants are NOT embedded in the delegate payload - they are provided by +ExecuTorch via named_data_map at runtime. + +Layout: + [Header: 24 bytes] + - Padding: 4 bytes (zeros) + - Magic: 4 bytes ("MLX0") + - Reserved: 16 bytes (zeros, for future use) + [FlatBuffer payload] +""" + +from __future__ import annotations + +import struct +from typing import Any, List, Tuple + +import flatbuffers + +# Import auto-generated serializers +from executorch.backends.mlx.serialization._generated_serializers import ( + GeneratedOpBuilders, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( # noqa: F401 + FloatOrVid, + Instruction, + IntOrVid, + MLXGraph, + NamedSlot, + OpNodeUnion, + SlotType, + SlotVariant, + TensorMeta, + Tid, + Vid, +) +from executorch.exir._serialize._program import Cord + +HEADER_LENGTH = 24 +MAGIC = b"MLX0" +ALIGNMENT = 16 + + +def _padding_required(offset: int, alignment: int) -> int: + remainder = offset % alignment + return (alignment - remainder) % alignment + + +def _build_tid(builder: flatbuffers.Builder, tid: Tid) -> int: + return tid.idx + + +def _build_vid(builder: flatbuffers.Builder, vid: Vid) -> int: + return vid.idx + + +def _build_int_or_vid(builder: flatbuffers.Builder, iov: IntOrVid) -> int: + # Import the MODULE (not class) to access builder functions + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + IntOrVid as FBIntOrVidModule, + ) + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import ( + CreateVid, + ) + + FBIntOrVidModule.Start(builder) + FBIntOrVidModule.AddLiteral(builder, iov.literal) + FBIntOrVidModule.AddIsVid(builder, iov.is_vid) + if iov.vid is not None: + # Vid is an inline struct - must be added last for proper FlatBuffer layout + FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx)) + return FBIntOrVidModule.End(builder) + + +def _build_string(builder: flatbuffers.Builder, s: str) -> int: + return builder.CreateString(s) + + +def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: + # FlatBuffers vectors must be created before the table that contains them + builder.StartVector(4, len(vec), 4) # elem_size=4, num_elems, alignment + for v in reversed(vec): + builder.PrependInt32(v) + return builder.EndVector() + + +class MLXGraphSerializer(GeneratedOpBuilders): + """ + Serializes MLXGraph to bytes with separate constant data segment. + + Inherits auto-generated op builders from GeneratedOpBuilders mixin. + """ + + def __init__(self, graph: MLXGraph, constant_data: bytes = b""): + self.graph = graph + self.constant_data = constant_data + + def serialize(self) -> bytes: + """ + Serialize the graph to bytes. + + Returns: + Complete serialized payload with header, flatbuffer, and data segment. + """ + # Build FlatBuffer + fb_bytes = self._build_flatbuffer() + + # Calculate offsets + data_segment_offset = HEADER_LENGTH + len(fb_bytes) + padding_len = _padding_required(data_segment_offset, ALIGNMENT) + data_segment_offset += padding_len + data_segment_size = len(self.constant_data) + + # Build header + header = ( + b"\x00\x00\x00\x00" # 4 bytes padding + + MAGIC # 4 bytes magic + + struct.pack(" 0: + result.append(b"\x00" * padding_len) + result.append(self.constant_data) + + return bytes(result) + + def _build_flatbuffer(self) -> bytes: + builder = flatbuffers.Builder(4096) + + # Build all components bottom-up (FlatBuffers requirement) + + # 1. Build instruction chains + chain_offsets = [] + for chain in self.graph.instruction_chains: + instr_offsets = [] + for instr in chain.instructions: + instr_offsets.append(self._build_instruction(builder, instr)) + instr_vec = self._build_offset_vector(builder, instr_offsets) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + InstructionChain as FBInstructionChainModule, + ) + + FBInstructionChainModule.Start(builder) + FBInstructionChainModule.AddInstructions(builder, instr_vec) + chain_offsets.append(FBInstructionChainModule.End(builder)) + + chains_vec = self._build_offset_vector(builder, chain_offsets) + + # 2. Build I/O maps + input_map_vec = self._build_slot_variant_vector(builder, self.graph.input_map) + output_map_vec = self._build_slot_variant_vector(builder, self.graph.output_map) + mutable_buffer_map_vec = self._build_slot_variant_vector( + builder, self.graph.mutable_buffer_map + ) + + # 3. Build named slots + named_slots_offsets = [] + for ns in self.graph.named_slots: + named_slots_offsets.append(self._build_named_slot(builder, ns)) + named_slots_vec = self._build_offset_vector(builder, named_slots_offsets) + + # 4. Build tensor metadata + tensor_meta_offsets = [] + for tm in self.graph.tensor_meta: + if tm is not None: + tensor_meta_offsets.append(self._build_tensor_meta(builder, tm)) + else: + tensor_meta_offsets.append(0) # null + tensor_meta_vec = self._build_offset_vector(builder, tensor_meta_offsets) + + # 5. Build version string (must be created before the table that uses it) + version_off = builder.CreateString(self.graph.version) + + # 6. Build the root MLXGraph table + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + MLXGraph as FBMLXGraphModule, + ) + + FBMLXGraphModule.Start(builder) + FBMLXGraphModule.AddVersion(builder, version_off) + FBMLXGraphModule.AddNumConstantTensors(builder, self.graph.num_constant_tensors) + FBMLXGraphModule.AddNumInputTensors(builder, self.graph.num_input_tensors) + FBMLXGraphModule.AddNumOutputTensors(builder, self.graph.num_output_tensors) + FBMLXGraphModule.AddNumMutableBufferTensors( + builder, self.graph.num_mutable_buffer_tensors + ) + FBMLXGraphModule.AddNumTempTensors(builder, self.graph.num_temp_tensors) + FBMLXGraphModule.AddNumValues(builder, self.graph.num_values) + FBMLXGraphModule.AddInstructionChains(builder, chains_vec) + FBMLXGraphModule.AddMainChainIdx(builder, self.graph.main_chain_idx) + FBMLXGraphModule.AddInitChainIdx(builder, self.graph.init_chain_idx) + FBMLXGraphModule.AddInputMap(builder, input_map_vec) + FBMLXGraphModule.AddOutputMap(builder, output_map_vec) + FBMLXGraphModule.AddMutableBufferMap(builder, mutable_buffer_map_vec) + FBMLXGraphModule.AddNamedSlots(builder, named_slots_vec) + FBMLXGraphModule.AddTensorMeta(builder, tensor_meta_vec) + root = FBMLXGraphModule.End(builder) + + builder.Finish(root) + return bytes(builder.Output()) + + def _build_instruction( + self, builder: flatbuffers.Builder, instr: Instruction + ) -> int: + op_offset, op_type = self._build_op_node(builder, instr.op) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + Instruction as FBInstructionModule, + ) + + FBInstructionModule.Start(builder) + FBInstructionModule.AddOpType(builder, op_type) + FBInstructionModule.AddOp(builder, op_offset) + return FBInstructionModule.End(builder) + + def _build_op_node( + self, builder: flatbuffers.Builder, op: OpNodeUnion + ) -> Tuple[int, int]: + """ + Build an op node and return (offset, union_type). + + This is the main dispatch for all op types. + """ + # Map Python class to FlatBuffer union type and builder + # This would ideally be auto-generated + + op_type = type(op).__name__ + builder_method = getattr(self, f"_build_{op_type}", None) + + if builder_method is None: + raise NotImplementedError(f"No builder for op type: {op_type}") + + return builder_method(builder, op) + + def _build_offset_vector( + self, builder: flatbuffers.Builder, offsets: List[int] + ) -> int: + builder.StartVector(4, len(offsets), 4) + for off in reversed(offsets): + builder.PrependUOffsetTRelative(off) + return builder.EndVector() + + def _build_slot_variant_vector( + self, builder: flatbuffers.Builder, slots: List[SlotVariant] + ) -> int: + offsets = [] + for slot in slots: + offsets.append(self._build_slot_variant(builder, slot)) + return self._build_offset_vector(builder, offsets) + + def _build_slot_variant( + self, builder: flatbuffers.Builder, slot: SlotVariant + ) -> int: + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + SlotVariant as FBSlotVariantModule, + ) + + FBSlotVariantModule.Start(builder) + FBSlotVariantModule.AddIdx(builder, slot.idx) + FBSlotVariantModule.AddSlotType(builder, slot.slot_type) + return FBSlotVariantModule.End(builder) + + def _build_named_slot(self, builder: flatbuffers.Builder, ns: NamedSlot) -> int: + name_off = builder.CreateString(ns.name) + slot_off = self._build_slot_variant(builder, ns.slot) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + NamedSlot as FBNamedSlotModule, + ) + + FBNamedSlotModule.Start(builder) + FBNamedSlotModule.AddName(builder, name_off) + FBNamedSlotModule.AddSlot(builder, slot_off) + return FBNamedSlotModule.End(builder) + + def _build_tensor_meta(self, builder: flatbuffers.Builder, tm: TensorMeta) -> int: + # Shape is a vector of ShapeDim tables + shape_offsets = [] + for dim in tm.shape: + shape_offsets.append(self._build_shape_dim(builder, dim)) + shape_vec = self._build_offset_vector(builder, shape_offsets) + + # Build dim_order vector (uint8) + dim_order_vec = 0 + if tm.dim_order: + builder.StartVector(1, len(tm.dim_order), 1) # elem_size=1 for uint8 + for d in reversed(tm.dim_order): + builder.PrependUint8(d) + dim_order_vec = builder.EndVector() + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + TensorMeta as FBTensorMetaModule, + ) + + FBTensorMetaModule.Start(builder) + FBTensorMetaModule.AddShape(builder, shape_vec) + if tm.scalar_type is not None: + FBTensorMetaModule.AddScalarType(builder, tm.scalar_type) + if dim_order_vec: + FBTensorMetaModule.AddDimOrder(builder, dim_order_vec) + return FBTensorMetaModule.End(builder) + + def _build_shape_dim(self, builder: flatbuffers.Builder, dim) -> int: + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + ShapeDim as FBShapeDimModule, + ) + + FBShapeDimModule.Start(builder) + FBShapeDimModule.AddValue(builder, dim.value) + FBShapeDimModule.AddMinValue(builder, dim.min_value) + FBShapeDimModule.AddMaxValue(builder, dim.max_value) + return FBShapeDimModule.End(builder) + + +def serialize_mlx_graph(graph: MLXGraph, constant_data: bytes = b"") -> bytes: + """ + Serialize an MLXGraph to bytes. + + Args: + graph: The MLXGraph to serialize. + constant_data: Raw bytes for constant tensors. + + Returns: + Serialized bytes with header, flatbuffer, and data segment. + """ + serializer = MLXGraphSerializer(graph, constant_data) + return serializer.serialize() + + +def parse_header(data: bytes) -> Tuple[int, int, int, int]: + """ + Parse the MLX delegate header. + + Returns: + (flatbuffer_offset, flatbuffer_size, data_segment_offset, data_segment_size) + """ + if len(data) < HEADER_LENGTH: + raise ValueError(f"Data too short: {len(data)} < {HEADER_LENGTH}") + + magic = data[4:8] + if magic != MAGIC: + raise ValueError(f"Invalid magic: {magic!r} (expected {MAGIC!r})") + + data_segment_offset = struct.unpack(" dict: + """ + Deserialize MLX delegate payload to a JSON-compatible dict. + + Useful for debugging - extracts the FlatBuffer and dumps it as JSON. + """ + fb_off, fb_size, ds_off, ds_size = parse_header(data) + + # Extract FlatBuffer portion + fb_data = data[fb_off : fb_off + fb_size] + + # Parse using generated FlatBuffer code + from executorch.backends.mlx.serialization._generated.mlx_delegate.MLXGraph import ( + MLXGraph as FBMLXGraphClass, + ) + + graph = FBMLXGraphClass.GetRootAs(fb_data, 0) + + # Convert to dict (recursive) + result = _fb_to_dict(graph) + result["_constant_segment_size"] = ds_size + + return result + + +def _fb_to_dict(obj: Any) -> Any: + if obj is None: + return None + if isinstance(obj, (int, float, str, bool, bytes)): + return obj + if isinstance(obj, (list, tuple)): + return [_fb_to_dict(item) for item in obj] + + # FlatBuffer object - extract fields + result = {} + for attr in dir(obj): + if attr.startswith("_") or attr[0].islower(): + continue + try: + value = getattr(obj, attr)() + result[attr] = _fb_to_dict(value) + except (TypeError, AttributeError): + pass + + return result diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs new file mode 100644 index 00000000000..945186ebef8 --- /dev/null +++ b/backends/mlx/serialization/schema.fbs @@ -0,0 +1,192 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// FlatBuffer schema for MLX delegate - THIS IS THE SOURCE OF TRUTH +// Defines the IR that gets serialized into the .pte file and executed by MLX runtime +// +// After editing this file, regenerate dependent files with: +// python backends/mlx/serialization/generate.py +// +// BACKWARD COMPATIBILITY RULES: +// - New fields in tables: APPEND ONLY (add at the end, with a default value) +// - New union members: APPEND ONLY (add at the end of the union) +// - New tables: Safe to add freely +// - New enum values: APPEND ONLY +// - NEVER remove, reorder, or change the type of existing fields/members + +namespace mlx_delegate; + +// ============================================================================= +// Core types +// ============================================================================= + +// We use ET's ScalarType (int8) directly. +// See runtime/core/portable_type/scalar_type.h for ScalarType values. + +// Tensor slot identifier - indexes into tensors array +struct Tid { + idx: uint32; +} + +// Value slot identifier - indexes into values array +// Values are stored as variant at runtime +struct Vid { + idx: uint32; +} + +// NOTE: These compound types use tables with manual discriminators rather than +// FlatBuffers unions because IntOrVid is used in vectors ([IntOrVid]), and +// FlatBuffers does not support vectors of unions. + +// For fields that can be either a literal int or a runtime Vid +table IntOrVid { + literal: int64; // widened to int64 for future-proofing + vid: Vid; + is_vid: bool = false; +} + +// For fields that can be either a literal float or a runtime Vid +table FloatOrVid { + literal: double; // widened to double for future-proofing + vid: Vid; + is_vid: bool = false; +} + +// For fields that can be either a tensor (Tid) or a scalar value (Vid) +table VidOrTid { + vid: Vid; + tid: Tid; + is_vid: bool = false; // false = use tid, true = use vid +} + +// For fields that can be a literal int, a runtime Vid, or a tensor (Tid) +table IntOrVidOrTid { + literal: int64; + vid: Vid; + tid: Tid; + kind: uint8 = 0; // 0 = literal int, 1 = vid, 2 = tid +} + +// ============================================================================= +// Op nodes - mirrors ops_schema.py dataclasses +// ============================================================================= + +table NoopNode {} + +table AddmmNode { + mat1: Tid (required); // First matrix + mat2: Tid (required); // Second matrix + out: Tid (required); + bias: Tid; // optional - added to result + alpha: float = 1.0; // Scalar multiplier for mat1 @ mat2 + beta: float = 1.0; // Scalar multiplier for bias +} + +// ============================================================================= +// Union of all op types +// ============================================================================= + +// BC: APPEND ONLY — new op nodes must be added at the end of this union. +// Reordering or removing members changes numeric type IDs and breaks existing .pte files. +union OpNode { + NoopNode, + AddmmNode + // BC: Add new op nodes here (append only) +} + +// ============================================================================= +// Instruction wrapper +// ============================================================================= + +table Instruction { + op: OpNode (required); +} + +// ============================================================================= +// Instruction chain (basic block of sequential instructions) +// ============================================================================= + +table InstructionChain { + instructions: [Instruction] (required); + // BC: New fields must be appended here with a default value +} + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Shape dimension: static value, or dynamic with optional bounds +table ShapeDim { + value: int32 = -1; // Static dim (>= 0), or -1 for dynamic + min_value: int32 = 0; // Lower bound (only when value == -1) + max_value: int32 = -1; // Upper bound (-1 = unbounded, only when value == -1) +} + +table TensorMeta { + shape: [ShapeDim] (required); // Dimension info with static/dynamic distinction + scalar_type: int8; // ET ScalarType value (see runtime/core/portable_type/scalar_type.h) + dim_order: [uint8]; // Memory layout order (matches TensorLayout.dim_order, DimOrderType = uint8_t) +} + +// ============================================================================= +// Slot variant for I/O mapping +// ============================================================================= + +enum SlotType : byte { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3 +} + +table SlotVariant { + idx: uint32; + slot_type: SlotType = TensorSlot; +} + +// ============================================================================= +// Name to slot mapping entry +// ============================================================================= + +table NamedSlot { + name: string (required); + slot: SlotVariant (required); +} + +// ============================================================================= +// Root type: MLX Graph +// ============================================================================= + +// BC: New fields must be appended at the end of this table with a default value. +table MLXGraph { + // Version for compatibility + version: string; + + // Tensor slot counts + + num_constant_tensors: uint32; + num_input_tensors: uint32; + num_output_tensors: uint32; + num_mutable_buffer_tensors: uint32; + num_temp_tensors: uint32; + num_values: uint32; + + // Instruction chains (basic blocks of sequential instructions) + instruction_chains: [InstructionChain] (required); + main_chain_idx: uint32 = 0; // Chain to run every execute() call + init_chain_idx: int32 = -1; // Chain to run once at init(), -1 = none + + // I/O mappings + input_map: [SlotVariant]; + output_map: [SlotVariant]; + mutable_buffer_map: [SlotVariant]; + + // Name to slot lookup (used for constant/mutable buffer keys in named_data_map) + named_slots: [NamedSlot]; + + // Tensor metadata (for non-temp tensors), indexed by Tid + tensor_meta: [TensorMeta]; + + // BC: New fields must be appended here with a default value +} + +root_type MLXGraph; diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt new file mode 100644 index 00000000000..2a709a63412 --- /dev/null +++ b/backends/mlx/test/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# MLX backend tests + +# Strict compiler flags for the test runner — mlxdelegate uses PRIVATE so these +# don't propagate to downstream consumers +set(_mlx_test_compile_options -Wall -Werror -Wconversion -Wsign-conversion + -Wshorten-64-to-32 +) + +# Sanitizers are inherited from parent via EXECUTORCH_MLX_ENABLE_SANITIZERS +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + list(APPEND _mlx_test_compile_options -fsanitize=address,undefined + -fno-omit-frame-pointer + ) +endif() + +# Op test runner - generic test binary for testing individual ops +add_executable(op_test_runner op_test_runner.cpp) + +target_compile_options(op_test_runner PRIVATE ${_mlx_test_compile_options}) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options(op_test_runner PRIVATE ${_mlx_sanitizer_link_options}) +endif() + +target_link_libraries( + op_test_runner PRIVATE extension_module extension_tensor executorch + mlxdelegate +) + +# -------------------------------------------------------------------------- +# Compile-only strict warnings test for delegate headers +# +# Verifies MLXExecutor.h, MLXInterpreter.h, MLXLoader.h compile cleanly under +# -Wconversion -Wsign-conversion -Wshorten-64-to-32 -Werror. ExecuTorch and MLX +# headers are suppressed via pragma in the source file. This target is never +# linked or run — a successful compile is the test. +# -------------------------------------------------------------------------- +add_library(strict_compile_test OBJECT strict_compile_test.cpp) +target_compile_options(strict_compile_test PRIVATE ${_mlx_test_compile_options}) +target_include_directories( + strict_compile_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../runtime +) +target_link_libraries( + strict_compile_test PRIVATE mlx_schema executorch_core mlx +) +add_dependencies(op_test_runner strict_compile_test) diff --git a/backends/mlx/test/README.md b/backends/mlx/test/README.md new file mode 100644 index 00000000000..6d90d513fec --- /dev/null +++ b/backends/mlx/test/README.md @@ -0,0 +1,164 @@ +# MLX Backend Tests + +This directory contains end-to-end tests for the MLX backend. Each test verifies that a specific op or pattern is correctly lowered to MLX and produces matching outputs between PyTorch and the MLX runtime. + +## Setup + +### 1. Install ExecuTorch Python package (if not already installed) + +```bash +python install_executorch.py --editable +``` + +### 2. Configure CMake with MLX preset + +From the ExecuTorch root directory: + +```bash +cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON +``` + +This configures the build with MLX delegate support and test targets. Build files are generated in `cmake-out/`. + +### 3. Build the test runner + +```bash +cmake --build cmake-out --target op_test_runner +``` + +This builds the `op_test_runner` binary that executes `.pte` models using the MLX runtime. + + + +## Prerequisites + +1. **Python environment**: Tests must be run in an environment where the `executorch` Python package is installed +2. **Built C++ runtime**: The `op_test_runner` binary must be built (see Setup above) + +## Running Tests + +### Run All Tests + +To run all registered tests: + +```bash +python -m executorch.backends.mlx.test.run_all_tests -j4 --clean-after +``` + +### Options + +| Flag | Description | +|------|-------------| +| `-j N` / `--parallel N` | Run tests in parallel with N workers | +| `--clean-after` | Clean up generated test files after running | +| `--clean` | Clean up generated test files and exit | +| `--rebuild` | Rebuild the C++ test runner before running | +| `--list` | List available tests and exit | +| `-v` / `--verbose` | Verbose output | +| `--timeout SECS` | Timeout per test in seconds (default: 300) | + +### Memory Management Options + +Running many tests can accumulate memory (torch/MLX/Metal allocations). These flags help manage memory: + +| Flag | Description | +|------|-------------| +| `--isolate` | Run each test in a separate subprocess (sequential mode only). Provides full memory isolation but is slower due to Python/torch import overhead per test. | +| `--max-tasks-per-worker N` | Recycle parallel workers after N tests (parallel mode only). Workers are terminated and replaced after completing N tests, releasing accumulated memory. | + +**Comparison:** + +| Mode | Memory Isolation | Speed | +|------|------------------|-------| +| `-j 4` | None (workers reused) | Fastest | +| `-j 4 --max-tasks-per-worker 10` | Bounded (recycled every 10 tests) | Fast | +| `-j 4 --max-tasks-per-worker 1` | Full (new process per test) | Slower | +| `--isolate` | Full (subprocess per test) | Slowest (sequential) | + +**Recommended for CI with memory constraints:** + +```bash +python -m executorch.backends.mlx.test.run_all_tests -j4 --max-tasks-per-worker 10 --clean-after +``` + +### Run a Specific Test + +To run a specific test by name (e.g., `linear`): + +```bash +python -m executorch.backends.mlx.test.run_all_tests linear +``` + +With verbose output: + +```bash +python -m executorch.backends.mlx.test.run_all_tests -v linear +``` + +### List Available Tests + +```bash +python -m executorch.backends.mlx.test.run_all_tests --list +``` + +## Test Architecture + +All tests are defined in `test_ops.py`. Each test follows a common pattern: + +1. **Define a model** - A simple `nn.Module` that uses the op being tested +2. **Create test inputs** - Generate random input tensors +3. **Export and lower** - Export the model and lower it to the MLX backend +4. **Run C++ binary** - Execute the lowered model using `op_test_runner` +5. **Compare outputs** - Verify PyTorch and MLX outputs match within tolerance + +### Test Class Structure + +Tests inherit from `OpTestCase` and implement: + +```python +@register_test +class MyTest(OpTestCase): + name = "my_test" # Test name (used for output directory) + rtol = 1e-5 # Relative tolerance for comparison + atol = 1e-5 # Absolute tolerance for comparison + + def create_model(self) -> nn.Module: + """Return the model to test.""" + ... + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Return input tensors for export.""" + ... + + def get_dynamic_shapes(self) -> Optional[Dict]: + """Return dynamic shape specs, or None for static shapes.""" + ... + + @classmethod + def get_test_configs(cls) -> List["MyTest"]: + """Return list of test configurations to run.""" + ... +``` + +## Test Output + +Test artifacts are saved to `op_tests//`: +- `model.pte` - Exported ExecuTorch model +- `input.bin` - Serialized input tensors +- `expected_output.bin` - PyTorch reference output +- `actual_output.bin` - MLX runtime output + +## Adding a New Test + +1. Add a new model class and `OpTestCase` subclass to `test_ops.py` +2. Use the `@register_test` decorator on the test class +3. Implement `create_model()`, `create_inputs()`, and `get_test_configs()` +4. Run the test to verify it works E2E + +## Test harness + +MLX also plugs into the ExecuTorch test harness for even more coverage. To run, use the following command from the ExecuTorch root directory: + +```bash +pytest -c /dev/null backends/test/suite/operators/ -m flow_mlx +``` diff --git a/backends/mlx/test/__init__.py b/backends/mlx/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/test/op_test_runner.cpp b/backends/mlx/test/op_test_runner.cpp new file mode 100644 index 00000000000..6bed13d7a56 --- /dev/null +++ b/backends/mlx/test/op_test_runner.cpp @@ -0,0 +1,395 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Generic op test runner for MLX delegate. + * + * Loads a .pte file, reads inputs from .bin files, runs the model, + * and writes outputs to .bin files. + * + * Build: + * cd cmake-out-mlx && cmake --build . --target op_test_runner + * + * Usage: + * ./cmake-out-mlx/backends/mlx/test/op_test_runner \ + * --pte \ + * --input \ + * --output + * + * Binary file format: + * - 4 bytes: number of tensors (uint32_t) + * For each tensor: + * - 4 bytes: dtype (0=float32, 1=float16, 2=int32, 3=int64, 4=bfloat16, + * 5=bool) + * - 4 bytes: number of dimensions (uint32_t) + * - 4 bytes * ndim: shape (int32_t each) + * - N bytes: data (size = product of shape * sizeof(dtype)) + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wsign-conversion" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#pragma clang diagnostic ignored "-Wimplicit-float-conversion" +#include +#include +#pragma clang diagnostic pop + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::executorch::extension; +using namespace ::executorch::runtime; + +enum class DType : uint32_t { + Float32 = 0, + Float16 = 1, + Int32 = 2, + Int64 = 3, + BFloat16 = 4, + Bool = 5, +}; + +size_t dtype_size(DType dtype) { + switch (dtype) { + case DType::Float32: + return 4; + case DType::Float16: + return 2; + case DType::Int32: + return 4; + case DType::Int64: + return 8; + case DType::BFloat16: + return 2; + case DType::Bool: + return 1; + default: + return 4; + } +} + +exec_aten::ScalarType dtype_to_scalar_type(DType dtype) { + switch (dtype) { + case DType::Float32: + return exec_aten::ScalarType::Float; + case DType::Float16: + return exec_aten::ScalarType::Half; + case DType::Int32: + return exec_aten::ScalarType::Int; + case DType::Int64: + return exec_aten::ScalarType::Long; + case DType::BFloat16: + return exec_aten::ScalarType::BFloat16; + case DType::Bool: + return exec_aten::ScalarType::Bool; + default: + return exec_aten::ScalarType::Float; + } +} + +DType scalar_type_to_dtype(exec_aten::ScalarType stype) { + switch (stype) { + case exec_aten::ScalarType::Float: + return DType::Float32; + case exec_aten::ScalarType::Half: + return DType::Float16; + case exec_aten::ScalarType::Int: + return DType::Int32; + case exec_aten::ScalarType::Long: + return DType::Int64; + case exec_aten::ScalarType::BFloat16: + return DType::BFloat16; + case exec_aten::ScalarType::Bool: + return DType::Bool; + default: + return DType::Float32; + } +} + +struct TensorData { + DType dtype; + std::vector shape; + std::vector data; +}; + +std::vector read_tensors_from_bin(const std::string& path) { + std::ifstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open input file: " + path); + } + + uint32_t num_tensors; + file.read(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + std::vector tensors; + tensors.reserve(num_tensors); + + for (uint32_t i = 0; i < num_tensors; ++i) { + TensorData t; + + uint32_t dtype_val; + file.read(reinterpret_cast(&dtype_val), sizeof(dtype_val)); + t.dtype = static_cast(dtype_val); + + uint32_t ndim; + file.read(reinterpret_cast(&ndim), sizeof(ndim)); + + t.shape.resize(ndim); + file.read(reinterpret_cast(t.shape.data()), ndim * sizeof(int32_t)); + + size_t numel = 1; + for (int32_t s : t.shape) { + numel *= static_cast(s); + } + size_t data_size = numel * dtype_size(t.dtype); + + t.data.resize(data_size); + file.read( + reinterpret_cast(t.data.data()), + static_cast(data_size)); + + tensors.push_back(std::move(t)); + } + + return tensors; +} + +void write_tensors_to_bin( + const std::string& path, + const std::vector& tensors) { + std::ofstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open output file: " + path); + } + + uint32_t num_tensors = static_cast(tensors.size()); + file.write(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + for (const auto& t : tensors) { + uint32_t dtype_val = static_cast(t.dtype); + file.write(reinterpret_cast(&dtype_val), sizeof(dtype_val)); + + uint32_t ndim = static_cast(t.shape.size()); + file.write(reinterpret_cast(&ndim), sizeof(ndim)); + + file.write( + reinterpret_cast(t.shape.data()), ndim * sizeof(int32_t)); + + file.write( + reinterpret_cast(t.data.data()), + static_cast(t.data.size())); + } +} + +void print_usage(const char* prog_name) { + std::cerr << "Usage: " << prog_name << " [options]\n" + << "Options:\n" + << " --pte Path to .pte model file (required)\n" + << " --input Path to input .bin file (required)\n" + << " --output Path to output .bin file (required)\n" + << " --verbose Print verbose output\n" + << std::endl; +} + +int main(int argc, char* argv[]) { + std::string pte_path; + std::string input_path; + std::string output_path; + bool verbose = false; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--pte" && i + 1 < argc) { + pte_path = argv[++i]; + } else if (arg == "--input" && i + 1 < argc) { + input_path = argv[++i]; + } else if (arg == "--output" && i + 1 < argc) { + output_path = argv[++i]; + } else if (arg == "--verbose") { + verbose = true; + } else if (arg == "--help" || arg == "-h") { + print_usage(argv[0]); + return 0; + } else { + std::cerr << "Unknown argument: " << arg << std::endl; + print_usage(argv[0]); + return 1; + } + } + + if (pte_path.empty() || input_path.empty() || output_path.empty()) { + std::cerr << "Error: --pte, --input, and --output are required\n"; + print_usage(argv[0]); + return 1; + } + + try { + if (verbose) { + std::cout << "Loading model from: " << pte_path << std::endl; + } + + Module module(pte_path); + auto load_error = module.load(); + if (load_error != Error::Ok) { + std::cerr << "Failed to load model: " << static_cast(load_error) + << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Model loaded successfully" << std::endl; + } + + auto load_method_error = module.load_method("forward"); + if (load_method_error != Error::Ok) { + std::cerr << "Failed to load forward method: " + << static_cast(load_method_error) << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Reading inputs from: " << input_path << std::endl; + } + + auto input_tensors = read_tensors_from_bin(input_path); + + if (verbose) { + std::cout << "Read " << input_tensors.size() << " input tensors" + << std::endl; + for (size_t i = 0; i < input_tensors.size(); ++i) { + std::cout << " Input " << i + << ": dtype=" << static_cast(input_tensors[i].dtype) + << ", shape=["; + for (size_t j = 0; j < input_tensors[i].shape.size(); ++j) { + std::cout << input_tensors[i].shape[j]; + if (j < input_tensors[i].shape.size() - 1) + std::cout << ", "; + } + std::cout << "]" << std::endl; + } + } + + std::vector tensor_ptrs; + std::vector inputs; + tensor_ptrs.reserve(input_tensors.size()); + inputs.reserve(input_tensors.size()); + + for (const auto& t : input_tensors) { + std::vector sizes(t.shape.begin(), t.shape.end()); + + TensorPtr tensor_ptr; + if (t.dtype == DType::Float32) { + std::vector data(t.data.size() / sizeof(float)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Float16) { + std::vector data( + t.data.size() / sizeof(exec_aten::Half)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::BFloat16) { + std::vector data( + t.data.size() / sizeof(exec_aten::BFloat16)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Int32) { + std::vector data(t.data.size() / sizeof(int32_t)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Int64) { + std::vector data(t.data.size() / sizeof(int64_t)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Bool) { + std::vector data(t.data.size()); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr( + sizes, std::move(data), {}, {}, exec_aten::ScalarType::Bool); + } else { + std::cerr << "Unsupported dtype: " << static_cast(t.dtype) + << std::endl; + return 1; + } + + tensor_ptrs.push_back(tensor_ptr); + inputs.push_back(tensor_ptr); + } + + if (verbose) { + std::cout << "Executing forward..." << std::endl; + } + + auto result = module.forward(inputs); + if (result.error() != Error::Ok) { + std::cerr << "Execution failed: " << static_cast(result.error()) + << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Execution succeeded, " << result->size() << " outputs" + << std::endl; + } + + std::vector output_tensors; + output_tensors.reserve(result->size()); + + for (size_t i = 0; i < result->size(); ++i) { + const auto& evalue = result->at(i); + if (!evalue.isTensor()) { + std::cerr << "Output " << i << " is not a tensor" << std::endl; + return 1; + } + + const auto& tensor = evalue.toTensor(); + TensorData t; + t.dtype = scalar_type_to_dtype(tensor.scalar_type()); + + t.shape.resize(static_cast(tensor.dim())); + for (size_t d = 0; d < static_cast(tensor.dim()); ++d) { + t.shape[d] = static_cast(tensor.size(static_cast(d))); + } + + size_t data_size = tensor.nbytes(); + t.data.resize(data_size); + std::memcpy(t.data.data(), tensor.const_data_ptr(), data_size); + + if (verbose) { + std::cout << " Output " << i << ": dtype=" << static_cast(t.dtype) + << ", shape=["; + for (size_t j = 0; j < t.shape.size(); ++j) { + std::cout << t.shape[j]; + if (j < t.shape.size() - 1) + std::cout << ", "; + } + std::cout << "]" << std::endl; + } + + output_tensors.push_back(std::move(t)); + } + + if (verbose) { + std::cout << "Writing outputs to: " << output_path << std::endl; + } + + write_tensors_to_bin(output_path, output_tensors); + + std::cout << "OK" << std::endl; + return 0; + + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } +} diff --git a/backends/mlx/test/run_all_tests.py b/backends/mlx/test/run_all_tests.py new file mode 100644 index 00000000000..3cda35da275 --- /dev/null +++ b/backends/mlx/test/run_all_tests.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run all MLX delegate op tests. + +Usage: + # Run all tests (all configurations): + python -m executorch.backends.mlx.test.run_all_tests + + # Run specific test (all its configurations): + python -m executorch.backends.mlx.test.run_all_tests add + + # Run specific test configuration: + python -m executorch.backends.mlx.test.run_all_tests add_scalar + + # List available tests: + python -m executorch.backends.mlx.test.run_all_tests --list + + # Rebuild C++ runner before running: + python -m executorch.backends.mlx.test.run_all_tests --rebuild + + # Run tests in parallel: + python -m executorch.backends.mlx.test.run_all_tests -j 4 + + # Run with custom timeout: + python -m executorch.backends.mlx.test.run_all_tests --timeout 60 +""" + +import argparse +import importlib +import multiprocessing +import subprocess +import sys +from multiprocessing import Pool +from typing import List, Optional, Tuple + +from .test_utils import ( + clean_test_outputs, + DEFAULT_TEST_TIMEOUT, + get_all_test_configs, + get_registered_tests, + get_test_output_size, + rebuild_op_test_runner, +) + + +def discover_and_import_tests(): + """ + Import test_ops.py module which contains all test definitions. + This triggers registration of all tests. + """ + importlib.import_module(".test_ops", package=__package__) + + +def _run_single_test( + test_class_name: str, + config_name: str, + config_kwargs: dict, + verbose: bool, + timeout: int, +) -> Tuple[str, bool, Optional[str]]: + """ + Run a single test configuration in a subprocess. + + Called via multiprocessing.Pool.starmap for parallel execution. + Recreates the test instance from the class name and kwargs. + + Args: + test_class_name: Name of the test class module.path + config_name: Name of this configuration + config_kwargs: Kwargs to recreate the test instance + verbose: Whether to print verbose output + timeout: Timeout in seconds + + Returns: + (config_name, passed, error_message) + """ + try: + # Re-discover and import tests in this subprocess + discover_and_import_tests() + + # Find the test config by name + all_configs = get_all_test_configs() + test_instance = None + for name, instance in all_configs: + if name == config_name: + test_instance = instance + break + + if test_instance is None: + return (config_name, False, f"Could not find test config: {config_name}") + + # Run the test + passed = test_instance.run_test(verbose=verbose, timeout=timeout) + return (config_name, passed, None) + + except Exception as e: + import traceback + + return (config_name, False, f"Exception: {e}\n{traceback.format_exc()}") + + +def run_tests_sequential( + configs_to_run: List[Tuple[str, object]], + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, + clean_after_each: bool = False, + isolate: bool = False, +) -> Tuple[int, int, List[str]]: + """ + Run tests sequentially. + + Args: + configs_to_run: List of (config_name, test_instance) tuples. + verbose: Whether to print verbose output. + timeout: Timeout in seconds per test. + clean_after_each: Whether to clean up test outputs after each test. + isolate: Whether to run each test in a subprocess to prevent memory + accumulation across tests (torch/MLX/Metal allocations). + + Returns: + (passed_count, failed_count, failed_test_names) + """ + passed = 0 + failed = 0 + failed_tests = [] + + for config_name, test in configs_to_run: + if isolate: + test_passed = _run_test_in_subprocess( + config_name, verbose=verbose, timeout=timeout + ) + else: + try: + test_passed = test.run_test(verbose=verbose, timeout=timeout) + except Exception as e: + print(f"✗ FAILED: {config_name} - Exception: {e}") + import traceback + + traceback.print_exc() + test_passed = False + + if test_passed: + passed += 1 + else: + failed += 1 + failed_tests.append(config_name) + + if clean_after_each: + clean_test_outputs([config_name], verbose=False) + + return passed, failed, failed_tests + + +def _run_test_in_subprocess( + config_name: str, + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, +) -> bool: + """ + Run a single test in an isolated subprocess. + + Each test gets its own Python interpreter so torch/MLX/Metal memory is + fully released between tests, preventing OOM on CI runners. + + Args: + config_name: Name of the test configuration to run. + verbose: Whether to print verbose output. + timeout: Timeout in seconds. + + Returns: + True if test passed, False otherwise. + """ + cmd = [ + sys.executable, + "-m", + "executorch.backends.mlx.test.test_utils", + config_name, + "run", + ] + if verbose: + cmd.append("--verbose") + + try: + result = subprocess.run( + cmd, + timeout=timeout, + capture_output=False, + ) + return result.returncode == 0 + except subprocess.TimeoutExpired: + print(f"✗ FAILED: {config_name} - Timeout after {timeout}s") + return False + except Exception as e: + print(f"✗ FAILED: {config_name} - Subprocess error: {e}") + return False + + +def run_tests_parallel( + configs_to_run: List[Tuple[str, object]], + num_workers: int, + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, + max_tasks_per_worker: Optional[int] = None, +) -> Tuple[int, int, List[str]]: + """ + Run tests in parallel using multiprocessing.Pool. + + Args: + configs_to_run: List of (config_name, test_instance) tuples. + num_workers: Number of parallel workers. + verbose: Whether to print verbose output. + timeout: Timeout in seconds per test. + max_tasks_per_worker: Maximum tasks per worker before recycling. + When set, worker processes are terminated and replaced after + completing this many tests, which releases accumulated memory + (torch/MLX/Metal allocations). None means workers are never recycled. + + Returns: + (passed_count, failed_count, failed_test_names) + """ + passed = 0 + failed = 0 + failed_tests = [] + + # Prepare test args for parallel execution + # We pass config names and let subprocesses recreate the test instances + test_args = [("", name, {}, verbose, timeout) for name, _ in configs_to_run] + + recycle_msg = "" + if max_tasks_per_worker is not None: + recycle_msg = f", recycling workers every {max_tasks_per_worker} tests" + print( + f"\nRunning {len(test_args)} tests with {num_workers} workers{recycle_msg}...\n" + ) + + with Pool(processes=num_workers, maxtasksperchild=max_tasks_per_worker) as pool: + results = pool.starmap(_run_single_test, test_args) + + for result_name, result_passed, error_msg in results: + if result_passed: + print(f"✓ PASSED: {result_name}") + passed += 1 + else: + if error_msg: + print(f"✗ FAILED: {result_name} - {error_msg}") + else: + print(f"✗ FAILED: {result_name}") + failed += 1 + failed_tests.append(result_name) + + return passed, failed, failed_tests + + +def run_tests( + test_filter: List[str], + verbose: bool = False, + parallel: int = 1, + timeout: int = DEFAULT_TEST_TIMEOUT, + clean_after_each: bool = False, + isolate: bool = False, + max_tasks_per_worker: Optional[int] = None, +) -> Tuple[int, int, List[str]]: + """ + Run tests matching the filter. + + Args: + test_filter: List of test names/patterns to run. If empty, runs all tests. + Can match either base test name (e.g., "add") or config name (e.g., "add_scalar"). + verbose: Whether to print verbose output. + parallel: Number of parallel workers (1 = sequential). + timeout: Timeout in seconds per test. + clean_after_each: Whether to clean up test outputs after each test (sequential only). + isolate: Whether to run each test in a subprocess (sequential only). + max_tasks_per_worker: Maximum tasks per worker before recycling (parallel only). + + Returns: + (passed_count, failed_count, failed_test_names) + """ + all_configs = get_all_test_configs() + registry = get_registered_tests() + + # Determine which configs to run + configs_to_run = [] + if not test_filter: + # Run all + configs_to_run = all_configs + else: + for pattern in test_filter: + matched = False + + # Check if pattern matches a base test name + if pattern in registry: + configs_to_run.extend(registry[pattern]) + matched = True + else: + # Check if pattern matches a config name + for config_name, config in all_configs: + if config_name == pattern: + configs_to_run.append((config_name, config)) + matched = True + + if not matched: + print(f"Warning: No test matching '{pattern}', skipping") + + if not configs_to_run: + print("No tests to run.") + return 0, 0, [] + + # Run tests + if parallel > 1: + return run_tests_parallel( + configs_to_run, parallel, verbose, timeout, max_tasks_per_worker + ) + else: + return run_tests_sequential( + configs_to_run, verbose, timeout, clean_after_each, isolate + ) + + +def main(): # noqa: C901 + # Get CPU count for default parallel workers + cpu_count = multiprocessing.cpu_count() + + parser = argparse.ArgumentParser(description="Run all MLX delegate op tests") + parser.add_argument( + "tests", + nargs="*", + help="Specific tests to run (default: all). Can be base name (e.g., 'add') or config name (e.g., 'add_scalar')", + ) + parser.add_argument( + "--list", + action="store_true", + help="List available tests and exit", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Verbose output", + ) + parser.add_argument( + "--rebuild", + action="store_true", + help="Rebuild the C++ test runner before running", + ) + parser.add_argument( + "--clean", + action="store_true", + help="Clean up generated test files and exit", + ) + parser.add_argument( + "--clean-after", + action="store_true", + help="Clean up generated test files after running tests", + ) + parser.add_argument( + "--isolate", + action="store_true", + help="Run each test in a separate subprocess to prevent memory accumulation", + ) + parser.add_argument( + "-j", + "--parallel", + type=int, + default=1, + metavar="N", + help=f"Run tests in parallel with N workers (default: 1, max: {cpu_count})", + ) + parser.add_argument( + "--timeout", + type=int, + default=DEFAULT_TEST_TIMEOUT, + metavar="SECS", + help=f"Timeout per test in seconds (default: {DEFAULT_TEST_TIMEOUT})", + ) + parser.add_argument( + "--max-tasks-per-worker", + type=int, + default=None, + metavar="N", + help="Recycle parallel workers after N tests to release memory (default: no recycling)", + ) + args = parser.parse_args() + + # Validate parallel workers + if args.parallel < 1: + args.parallel = 1 + elif args.parallel > cpu_count: + print( + f"Warning: --parallel {args.parallel} exceeds CPU count ({cpu_count}), using {cpu_count}" + ) + args.parallel = cpu_count + + # Auto-discover and import all test modules + discover_and_import_tests() + + # Handle --clean flag + if args.clean: + # Determine which tests to clean + test_names = None + if args.tests: + # Get config names for the specified tests + registry = get_registered_tests() + test_names = [] + for pattern in args.tests: + if pattern in registry: + test_names.extend(cfg_name for cfg_name, _ in registry[pattern]) + else: + test_names.append(pattern) + + # Show current size + current_size = get_test_output_size(test_names) + if current_size > 0: + print(f"Current test output size: {current_size / 1024 / 1024:.2f} MB") + + # Clean + files_removed = clean_test_outputs(test_names, verbose=args.verbose) + if files_removed > 0: + print(f"Removed {files_removed} files") + else: + print("No files to clean") + sys.exit(0) + + # List tests + if args.list: + registry = get_registered_tests() + print("Available tests:") + for base_name in sorted(registry.keys()): + configs = registry[base_name] + if len(configs) == 1 and configs[0][0] == base_name: + # Single config with same name as base + print(f" {base_name}") + else: + # Multiple configs or different name + print(f" {base_name}:") + for config_name, _ in configs: + print(f" - {config_name}") + sys.exit(0) + + # Rebuild if requested + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + # Run tests + passed, failed, failed_tests = run_tests( + args.tests, + verbose=args.verbose, + parallel=args.parallel, + timeout=args.timeout, + clean_after_each=args.clean_after, + isolate=args.isolate, + max_tasks_per_worker=args.max_tasks_per_worker, + ) + + # Print summary + print("\n" + "=" * 60) + print("TEST SUMMARY") + print("=" * 60) + print(f"Passed: {passed}") + print(f"Failed: {failed}") + if failed_tests: + print(f"Failed tests: {', '.join(failed_tests)}") + print("=" * 60) + + # Clean up after tests if requested + if args.clean_after: + # Determine which tests to clean (same logic as --clean) + test_names = None + if args.tests: + registry = get_registered_tests() + test_names = [] + for pattern in args.tests: + if pattern in registry: + test_names.extend(cfg_name for cfg_name, _ in registry[pattern]) + else: + test_names.append(pattern) + + current_size = get_test_output_size(test_names) + files_removed = clean_test_outputs(test_names, verbose=args.verbose) + if files_removed > 0: + print( + f"\nCleaned up {files_removed} files ({current_size / 1024 / 1024:.2f} MB)" + ) + + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/test/strict_compile_test.cpp b/backends/mlx/test/strict_compile_test.cpp new file mode 100644 index 00000000000..28df78a7d5a --- /dev/null +++ b/backends/mlx/test/strict_compile_test.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Compile-only test to verify MLX delegate headers are clean under strict + * warnings (-Wconversion, -Wsign-conversion, -Wshorten-64-to-32, -Werror). + * + * This file includes the delegate headers and instantiates key types to ensure + * template code is also checked. It is never linked or executed — a successful + * compilation is the test. + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wsign-conversion" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#include +#include +#include +#include +#pragma clang diagnostic pop + +// These are the headers we want to verify under strict warnings +#include "MLXExecutor.h" +#include "MLXInterpreter.h" +#include "MLXLoader.h" + +// Instantiate key types to ensure template code is checked +namespace { +[[maybe_unused]] void force_instantiation() { + using namespace executorch::backends::mlx; + + // Force safe_mul template instantiation + (void)safe_mul(0, 0, "test"); + + // Force check_allocation_bounded instantiation + ::mlx::core::Shape shape = {1, 2, 3}; + check_allocation_bounded(shape, ::mlx::core::float32, "test"); +} +} // namespace diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py new file mode 100644 index 00000000000..01286f75f16 --- /dev/null +++ b/backends/mlx/test/test_ops.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Consolidated op tests for the MLX delegate. + +This file contains all op tests organized by category. Each test class inherits +from OpTestCase and can be run via the run_all_tests.py script. + +Usage: + # Run all tests (with 4 parallel workers, cleanup after) + python -m executorch.backends.mlx.test.run_all_tests -j4 --clean-after + + # Run specific test + python -m executorch.backends.mlx.test.run_all_tests add + + # List available tests + python -m executorch.backends.mlx.test.run_all_tests --list + +See README.md in this directory for full documentation. +""" + +from typing import List, Tuple + +import torch +import torch.nn as nn + +# Import custom ops for RoPE and KV cache tests +from executorch.backends.mlx import ( # noqa: F401 - registers mlx ops # noqa: F401 - registers mlx.rope + custom_ops, + ops, +) + +from .test_utils import OpTestCase, register_test + + +class BmmModel(nn.Module): + """Model that performs batch matrix multiplication.""" + + def __init__(self, batch_size: int, n: int, m: int, p: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bmm(x, self.weight) + + +@register_test +class BmmTest(OpTestCase): + """Test case for bmm (batch matrix multiplication).""" + + name = "bmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 4, + n: int = 8, + m: int = 16, + p: int = 32, + ): + self.batch_size = batch_size + self.n = n + self.m = m + self.p = p + self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + @classmethod + def get_test_configs(cls) -> List["BmmTest"]: + return [ + cls(batch_size=4, n=8, m=16, p=32), + cls(batch_size=2, n=64, m=64, p=32), + ] + + def create_model(self) -> nn.Module: + return BmmModel(self.batch_size, self.n, self.m, self.p) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.n, self.m) + return (x,) + + +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) + + +@register_test +class AddmmTest(OpTestCase): + """Test case for addmm.""" + + name = "addmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 2, + in_features: int = 64, + out_features: int = 32, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + self.batch_size = batch_size + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.alpha = alpha + self.beta = beta + + # Build unique test name + if not bias: + name = f"addmm_{in_features}x{out_features}_no_bias" + elif alpha != 1.0 or beta != 1.0: + name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" + else: + name = f"addmm_{in_features}x{out_features}" + self.name = name + + @classmethod + def get_test_configs(cls) -> List["AddmmTest"]: + return [ + cls( + batch_size=2, in_features=64, out_features=32 + ), # with bias, default alpha/beta + cls( + batch_size=2, in_features=64, out_features=32, bias=False + ), # without bias + cls(batch_size=4, in_features=128, out_features=64), # larger size + cls( + batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 + ), # custom alpha/beta + ] + + def create_model(self) -> nn.Module: + return AddmmModel( + self.in_features, + self.out_features, + bias=self.bias, + alpha=self.alpha, + beta=self.beta, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features) + return (x,) diff --git a/backends/mlx/test/test_partitioner.py b/backends/mlx/test/test_partitioner.py new file mode 100644 index 00000000000..4a5833aa656 --- /dev/null +++ b/backends/mlx/test/test_partitioner.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the MLX partitioner. +""" + +import unittest + +import torch +import torch.nn as nn +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.exir import EdgeCompileConfig, to_edge +from torch.export import export + + +class TestMLXPartitionerRejectsToEdge(unittest.TestCase): + """MLXPartitioner must only be used via to_edge_transform_and_lower.""" + + def test_to_edge_then_to_backend_raises(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + ep = export(M(), (torch.randn(4),), strict=False) + edge = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + + with self.assertRaises(RuntimeError) as ctx: + edge.to_backend(MLXPartitioner()) + + self.assertIn("to_edge_transform_and_lower", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_passes.py b/backends/mlx/test/test_passes.py new file mode 100644 index 00000000000..a9fdb3b996b --- /dev/null +++ b/backends/mlx/test/test_passes.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/test/test_pattern_utils.py b/backends/mlx/test/test_pattern_utils.py new file mode 100644 index 00000000000..48495a469d7 --- /dev/null +++ b/backends/mlx/test/test_pattern_utils.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for pattern_utils.py - shared pattern matching utilities. +""" + +import unittest + +import torch +from torch.export import export + + +def get_exported_graph(module, example_inputs): + """Export a module and return the graph with ATen ops.""" + ep = export(module, example_inputs) + return ep.graph_module.graph + + +def find_node_by_target(graph, target_name): + """Find first call_function node whose target contains target_name.""" + for node in graph.nodes: + if node.op == "call_function" and target_name in str(node.target): + return node + return None + + +def find_all_nodes_by_target(graph, target_name): + """Find all call_function nodes whose target contains target_name.""" + return [ + node + for node in graph.nodes + if node.op == "call_function" and target_name in str(node.target) + ] + + +class TestMatchTarget(unittest.TestCase): + """Tests for match_target function.""" + + def test_match_target_basic(self): + """Test basic op matching.""" + from executorch.backends.mlx.pattern_utils import match_target + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + self.assertIsNotNone(rsqrt_node) + self.assertTrue(match_target(rsqrt_node, torch.ops.aten.rsqrt.default)) + self.assertFalse(match_target(rsqrt_node, torch.ops.aten.add.Tensor)) + + def test_match_target_non_call_function(self): + """Test that non-call_function nodes don't match.""" + from executorch.backends.mlx.pattern_utils import match_target + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + + # Find a placeholder node + placeholder_node = None + for node in graph.nodes: + if node.op == "placeholder": + placeholder_node = node + break + + self.assertIsNotNone(placeholder_node) + self.assertFalse(match_target(placeholder_node, torch.ops.aten.rsqrt.default)) + + +class TestHasSingleUser(unittest.TestCase): + """Tests for has_single_user function.""" + + def test_single_user(self): + """Test node with single user.""" + from executorch.backends.mlx.pattern_utils import has_single_user + + class SingleUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Single use + return y + 1 + + graph = get_exported_graph(SingleUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertTrue(has_single_user(neg_node)) + + def test_multiple_users(self): + """Test node with multiple users.""" + from executorch.backends.mlx.pattern_utils import has_single_user + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertFalse(has_single_user(neg_node)) + + +class TestHasNoUsers(unittest.TestCase): + """Tests for has_no_users function.""" + + def test_has_users(self): + """Test node that has users.""" + from executorch.backends.mlx.pattern_utils import has_no_users + + class SimpleModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + return y + 1 + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertFalse(has_no_users(neg_node)) + + def test_no_users_after_removal(self): + """Test has_no_users returns True for orphaned nodes.""" + from executorch.backends.mlx.pattern_utils import has_no_users + + class SimpleModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Initially neg has a user (rsqrt) + self.assertFalse(has_no_users(neg_node)) + + # Replace rsqrt's input with placeholder to orphan neg + placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholder = node + break + rsqrt_node.replace_input_with(neg_node, placeholder) + + # Now neg has no users + self.assertTrue(has_no_users(neg_node)) + + +class TestOpStep(unittest.TestCase): + """Tests for OpStep dataclass.""" + + def test_matches_with_op(self): + """Test OpStep.matches with op field.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + step = OpStep(op=torch.ops.aten.rsqrt.default) + self.assertTrue(step.matches(rsqrt_node)) + + step_wrong = OpStep(op=torch.ops.aten.neg.default) + self.assertFalse(step_wrong.matches(rsqrt_node)) + + def test_matches_with_predicate(self): + """Test OpStep.matches with predicate field.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Predicate that always returns True + step_true = OpStep(predicate=lambda n: True) + self.assertTrue(step_true.matches(rsqrt_node)) + + # Predicate that always returns False + step_false = OpStep(predicate=lambda n: False) + self.assertFalse(step_false.matches(rsqrt_node)) + + def test_matches_no_op_no_predicate(self): + """Test OpStep.matches returns False when neither op nor predicate set.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + step_empty = OpStep() + self.assertFalse(step_empty.matches(rsqrt_node)) + + def test_matches_require_single_user_true(self): + """Test OpStep.matches with require_single_user=True (default).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + # Default require_single_user=True, neg has multiple users + step = OpStep(op=torch.ops.aten.neg.default) + self.assertFalse(step.matches(neg_node)) + + def test_matches_require_single_user_false(self): + """Test OpStep.matches with require_single_user=False.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + # With require_single_user=False, should match despite multiple users + step = OpStep(op=torch.ops.aten.neg.default, require_single_user=False) + self.assertTrue(step.matches(neg_node)) + + def test_matches_nargs_int(self): + """Test OpStep.matches with nargs as int (minimum).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # nargs=1 should match (rsqrt has 1 arg) + step = OpStep(op=torch.ops.aten.rsqrt.default, nargs=1) + self.assertTrue(step.matches(rsqrt_node)) + + # nargs=2 should fail (rsqrt only has 1 arg) + step_too_many = OpStep(op=torch.ops.aten.rsqrt.default, nargs=2) + self.assertFalse(step_too_many.matches(rsqrt_node)) + + def test_matches_nargs_tuple(self): + """Test OpStep.matches with nargs as tuple (range).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # nargs=(1, 3) should match (rsqrt has 1 arg, in range) + step = OpStep(op=torch.ops.aten.rsqrt.default, nargs=(1, 3)) + self.assertTrue(step.matches(rsqrt_node)) + + # nargs=(2, 4) should fail (rsqrt has 1 arg, not in range) + step_out_of_range = OpStep(op=torch.ops.aten.rsqrt.default, nargs=(2, 4)) + self.assertFalse(step_out_of_range.matches(rsqrt_node)) + + def test_matches_kwargs_empty(self): + """Test OpStep.matches with empty kwargs (node must have no kwargs).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # No kwargs + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Empty kwargs set() means node must have no kwargs (default) + step = OpStep(op=torch.ops.aten.rsqrt.default, kwargs=set()) + self.assertTrue(step.matches(rsqrt_node)) + + # Default is also empty set (strict checking) + step_default = OpStep(op=torch.ops.aten.rsqrt.default) + self.assertTrue(step_default.matches(rsqrt_node)) + + def test_matches_kwargs_declared(self): + """Test OpStep.matches with declared kwargs.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class CastModule(torch.nn.Module): + def forward(self, x): + return x.to(torch.float16) + + graph = get_exported_graph(CastModule(), (torch.randn(4, 4),)) + to_copy_node = find_node_by_target(graph, "_to_copy") + + if to_copy_node is not None: + # Check what kwargs exist + node_kwargs = set(to_copy_node.kwargs.keys()) + + # If we declare all kwargs, should match + step_all = OpStep( + op=torch.ops.aten._to_copy.default, + kwargs=node_kwargs, + ) + self.assertTrue(step_all.matches(to_copy_node)) + + # If we don't declare some kwargs, should fail + if node_kwargs: + step_missing = OpStep( + op=torch.ops.aten._to_copy.default, + kwargs=set(), # Empty, but node has kwargs + ) + self.assertFalse(step_missing.matches(to_copy_node)) + + def test_matches_arg_index(self): + """Test OpStep.matches validates arg_index is accessible.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # arg_index=0 should work (rsqrt has 1 arg) + step = OpStep(op=torch.ops.aten.rsqrt.default, arg_index=0) + self.assertTrue(step.matches(rsqrt_node)) + + # arg_index=1 should fail (rsqrt only has 1 arg, can't access args[1]) + step_bad_index = OpStep(op=torch.ops.aten.rsqrt.default, arg_index=1) + self.assertFalse(step_bad_index.matches(rsqrt_node)) + + +class TestWalkBack(unittest.TestCase): + """Tests for walk_back function.""" + + def test_walk_back_single_step(self): + """Test walk_back with a single step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + result = walk_back(rsqrt_node, [OpStep(op=torch.ops.aten.rsqrt.default)]) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0], rsqrt_node) + # base_node should be the input to rsqrt + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_chain(self): + """Test walk_back with multiple steps in a chain.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Match rsqrt -> neg chain + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.neg.default), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_no_match(self): + """Test walk_back returns None when pattern doesn't match.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Try to match neg which isn't there + result = walk_back(rsqrt_node, [OpStep(op=torch.ops.aten.neg.default)]) + + self.assertIsNone(result) + + def test_walk_back_optional_step(self): + """Test walk_back with optional step that doesn't match.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Match rsqrt, skip optional neg (not present) + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.neg.default, optional=True), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # One for each step + self.assertIsNotNone(entries[0]) # rsqrt matched + self.assertIsNone(entries[1]) # neg is None (optional, not matched) + + def test_walk_back_repeat_step(self): + """Test walk_back with repeat step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class RepeatModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.neg(y) + w = torch.neg(z) + return w + + graph = get_exported_graph(RepeatModule(), (torch.randn(4, 4),)) + + # Find the last neg node (output of the chain) + neg_nodes = find_all_nodes_by_target(graph, "neg") + self.assertEqual(len(neg_nodes), 3) + last_neg = neg_nodes[-1] + + # Match chain of neg ops + result = walk_back( + last_neg, + [OpStep(op=torch.ops.aten.neg.default, repeat=True)], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 1) # One entry for the repeat step + self.assertIsInstance(entries[0], list) # Repeat returns list + self.assertEqual(len(entries[0]), 3) # Three neg nodes matched + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_repeat_zero_matches(self): + """Test walk_back with repeat step matching zero times then another step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Try to match neg (repeat, 0 matches) then rsqrt + # neg doesn't exist at rsqrt, so 0 matches, then we match rsqrt + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.neg.default, repeat=True), + OpStep(op=torch.ops.aten.rsqrt.default), + ], + ) + + # This should match: neg repeat matches 0 times, rsqrt matches + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # One for each step + self.assertIsInstance(entries[0], list) # Repeat returns list + self.assertEqual(len(entries[0]), 0) # Zero neg nodes matched + self.assertIsNotNone(entries[1]) # rsqrt matched + + def test_walk_back_arg_index(self): + """Test walk_back with arg_index to follow non-first argument.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class BinaryModule(torch.nn.Module): + def forward(self, x): + y = torch.rsqrt(x) + return x * y # mul(x, rsqrt(x)) + + graph = get_exported_graph(BinaryModule(), (torch.randn(4, 4),)) + mul_node = find_node_by_target(graph, "mul") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + self.assertIsNotNone(mul_node) + self.assertIsNotNone(rsqrt_node) + + # Follow args[1] (rsqrt) instead of args[0] (placeholder) + result = walk_back( + mul_node, + [ + OpStep(op=torch.ops.aten.mul.Tensor, nargs=2, arg_index=1), + OpStep(op=torch.ops.aten.rsqrt.default), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # mul and rsqrt + self.assertEqual(entries[0], mul_node) + self.assertEqual(entries[1], rsqrt_node) + # base_node should be the input to rsqrt (placeholder) + self.assertEqual(base_node.op, "placeholder") + + +class TestPatternMatch(unittest.TestCase): + """Tests for PatternMatch base class.""" + + def test_all_nodes(self): + """Test all_nodes returns head + body.""" + from executorch.backends.mlx.pattern_utils import PatternMatch + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + match = PatternMatch(head=rsqrt_node, body=[neg_node]) + self.assertEqual(match.all_nodes(), [rsqrt_node, neg_node]) + + def test_remove_body_nodes(self): + """Test remove_body_nodes removes unused nodes.""" + from executorch.backends.mlx.pattern_utils import PatternMatch + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # To test remove_body_nodes, we'd need to first replace rsqrt's uses + # and then call remove_body_nodes. For this test, just verify the + # method exists and doesn't crash when nodes have users. + match = PatternMatch(head=rsqrt_node, body=[neg_node]) + + # This won't remove neg because it still has a user (rsqrt) + match.remove_body_nodes(graph) + + # neg should still exist because it has a user + self.assertIn(neg_node, graph.nodes) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py new file mode 100644 index 00000000000..090bceabf08 --- /dev/null +++ b/backends/mlx/test/test_utils.py @@ -0,0 +1,1122 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for MLX delegate op testing. + +This module provides functions to: +1. Save/load tensors to/from binary files (compatible with C++ op_test_runner) +2. Export simple models to .pte files +3. Compare expected vs actual outputs +4. Run the C++ op_test_runner binary +""" + +import json +import os +import struct +import subprocess +import tempfile +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + + +DEFAULT_TEST_TIMEOUT = 300 # 5 minutes default timeout + + +class TestTimeoutError(Exception): + """Raised when a test exceeds its timeout.""" + + pass + + +# DType enum values matching C++ op_test_runner +DTYPE_FLOAT32 = 0 +DTYPE_FLOAT16 = 1 +DTYPE_INT32 = 2 +DTYPE_INT64 = 3 +DTYPE_BFLOAT16 = 4 +DTYPE_BOOL = 5 + + +# Default tolerance presets for different data types. +# These are based on the precision characteristics of each dtype: +# - FP32: ~7 decimal digits of precision +# - FP16: ~3-4 decimal digits of precision +# - BF16: ~2-3 decimal digits of precision (same exponent range as FP32) +TOLERANCE_PRESETS = { + torch.float32: {"rtol": 1e-5, "atol": 1e-5}, + torch.float16: {"rtol": 1e-3, "atol": 1e-3}, + torch.bfloat16: {"rtol": 1e-2, "atol": 1e-2}, + # Integer types should match exactly + torch.int32: {"rtol": 0, "atol": 0}, + torch.int64: {"rtol": 0, "atol": 0}, +} + + +def get_tolerance_for_dtype(dtype: torch.dtype) -> Tuple[float, float]: + """ + Get appropriate (rtol, atol) tolerances for a given dtype. + + Args: + dtype: The torch dtype to get tolerances for. + + Returns: + (rtol, atol) tuple with appropriate tolerances for the dtype. + """ + if dtype in TOLERANCE_PRESETS: + preset = TOLERANCE_PRESETS[dtype] + return preset["rtol"], preset["atol"] + # Default to FP32 tolerances for unknown types + return 1e-5, 1e-5 + + +def get_tolerance_for_dtypes(dtypes: List[torch.dtype]) -> Tuple[float, float]: + """ + Get tolerances that work for a list of dtypes (uses the loosest tolerances). + + Args: + dtypes: List of torch dtypes. + + Returns: + (rtol, atol) tuple with tolerances that accommodate all dtypes. + """ + if not dtypes: + return 1e-5, 1e-5 + + max_rtol = 0.0 + max_atol = 0.0 + for dtype in dtypes: + rtol, atol = get_tolerance_for_dtype(dtype) + max_rtol = max(max_rtol, rtol) + max_atol = max(max_atol, atol) + + return max_rtol, max_atol + + +def torch_dtype_to_bin_dtype(dtype: torch.dtype) -> int: + """Convert torch dtype to binary file dtype enum value.""" + mapping = { + torch.float32: DTYPE_FLOAT32, + torch.float16: DTYPE_FLOAT16, + torch.int32: DTYPE_INT32, + torch.int64: DTYPE_INT64, + torch.bfloat16: DTYPE_BFLOAT16, + torch.bool: DTYPE_BOOL, + } + if dtype not in mapping: + raise ValueError(f"Unsupported dtype: {dtype}") + return mapping[dtype] + + +def bin_dtype_to_torch_dtype(dtype_val: int) -> torch.dtype: + """Convert binary file dtype enum value to torch dtype.""" + mapping = { + DTYPE_FLOAT32: torch.float32, + DTYPE_FLOAT16: torch.float16, + DTYPE_INT32: torch.int32, + DTYPE_INT64: torch.int64, + DTYPE_BFLOAT16: torch.bfloat16, + DTYPE_BOOL: torch.bool, + } + if dtype_val not in mapping: + raise ValueError(f"Unknown dtype value: {dtype_val}") + return mapping[dtype_val] + + +def _atomic_write_binary(path: Path, data: bytes) -> None: + """ + Atomically write binary data to a file. + + Writes to a temporary file in the same directory, then atomically replaces + the target path. This prevents race conditions when multiple parallel + workers write to the same ``op_tests/`` tree. + """ + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp") + closed = False + try: + os.write(fd, data) + os.close(fd) + closed = True + os.replace(tmp, path) + except BaseException: + if not closed: + os.close(fd) + if os.path.exists(tmp): + os.unlink(tmp) + raise + + +def save_tensors_to_bin(tensors: List[torch.Tensor], path: Union[str, Path]) -> None: + """ + Save a list of tensors to a binary file. + + Binary format: + - 4 bytes: number of tensors (uint32) + For each tensor: + - 4 bytes: dtype enum (uint32) + - 4 bytes: number of dimensions (uint32) + - 4 bytes * ndim: shape (int32 each) + - N bytes: tensor data + """ + path = Path(path) + + buf = bytearray() + # Write number of tensors + buf += struct.pack("I", len(tensors)) + + for tensor in tensors: + # Ensure contiguous + tensor = tensor.contiguous() + + # Write dtype + dtype_val = torch_dtype_to_bin_dtype(tensor.dtype) + buf += struct.pack("I", dtype_val) + + # Write ndim + buf += struct.pack("I", tensor.dim()) + + # Write shape + for s in tensor.shape: + buf += struct.pack("i", s) + + # Write data - bf16 needs special handling since numpy doesn't support it + if tensor.dtype == torch.bfloat16: + # View bf16 as uint16 to preserve raw bytes + buf += tensor.view(torch.uint16).numpy().tobytes() + else: + buf += tensor.numpy().tobytes() + + _atomic_write_binary(path, bytes(buf)) + + +def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: + path = Path(path) + + # Mapping from torch dtype to numpy dtype + np_dtype_map = { + torch.float32: np.float32, + torch.float16: np.float16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.bool: np.bool_, + # bfloat16 needs special handling - read as uint16 + } + + # Element size for each dtype + elem_size_map = { + torch.float32: 4, + torch.float16: 2, + torch.int32: 4, + torch.int64: 8, + torch.bfloat16: 2, + torch.bool: 1, + } + + tensors = [] + with open(path, "rb") as f: + # Read number of tensors + num_tensors = struct.unpack("I", f.read(4))[0] + + for _ in range(num_tensors): + # Read dtype + dtype_val = struct.unpack("I", f.read(4))[0] + dtype = bin_dtype_to_torch_dtype(dtype_val) + + # Read ndim + ndim = struct.unpack("I", f.read(4))[0] + + # Read shape + shape = [] + for _ in range(ndim): + shape.append(struct.unpack("i", f.read(4))[0]) + + # Read data + numel = 1 + for s in shape: + numel *= s + + elem_size = elem_size_map[dtype] + data_bytes = f.read(numel * elem_size) + + # Convert to tensor + if dtype == torch.bfloat16: + # Read as uint16 and view as bfloat16 + arr = np.frombuffer(data_bytes, dtype=np.uint16).reshape(shape) + tensor = torch.tensor(arr).view(torch.bfloat16) + else: + arr = np.frombuffer(data_bytes, dtype=np_dtype_map[dtype]).reshape( + shape + ) + tensor = torch.from_numpy(arr.copy()) + + tensors.append(tensor) + + return tensors + + +def export_model_to_pte( + model: torch.nn.Module, + example_inputs: Tuple[torch.Tensor, ...], + output_path: Union[str, Path], + dynamic_shapes: Optional[Dict] = None, + verbose: bool = False, +) -> None: + """ + Export a PyTorch model to a .pte file using the MLX delegate. + + Args: + model: The PyTorch model to export. + example_inputs: Example inputs for tracing. + output_path: Path to save the .pte file. + dynamic_shapes: + dynamic_shapes: Optional dynamic shapes specification for torch.export. + Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. + verbose: Whether to print the exported program for debugging. + """ + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.exir.capture._config import ExecutorchBackendConfig + from torch.export import export + + model = model.eval() + + # Export with torch.export + exported_program = export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + # Print exported program if verbose + if verbose: + print("\n" + "=" * 60) + print("EXPORTED PROGRAM (torch.export)") + print("=" * 60) + print(exported_program) + + # Lower to edge and delegate to MLX + edge_program = exir.to_edge_transform_and_lower( + exported_program, + partitioner=[MLXPartitioner()], + ) + + # Print edge program if verbose + if verbose: + print("\n" + "=" * 60) + print("EDGE PROGRAM (after decomposition)") + print("=" * 60) + print(edge_program.exported_program()) + + # Export to ExecuTorch + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + + # Save to file + output_path = Path(output_path) + _atomic_write_binary(output_path, executorch_program.buffer) + + +def inspect_pte_file(pte_path: Union[str, Path]) -> Dict: + """ + Inspect a PTE file and return the MLX graph information. + + Returns: + Dictionary with MLX graph details + """ + from executorch.backends.mlx.pte_inspector import ( + extract_delegate_payload, + parse_mlx_payload, + ) + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + + # Extract MLX delegate payload + payload = extract_delegate_payload(pte_data, "MLXBackend") + if payload is None: + return {"error": "Could not extract MLX delegate payload"} + + # Parse the MLX payload + mlx_data = parse_mlx_payload(payload) + return mlx_data + + +def print_mlx_graph_summary(pte_path: Union[str, Path]) -> None: + """ + Print a human-readable summary of the MLX graph in a PTE file. + + This function uses the pte_inspector module to display the MLX graph. + """ + from executorch.backends.mlx.pte_inspector import show_mlx_instructions + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + show_mlx_instructions(pte_data) + + +def count_mlx_delegate_segments(pte_path: Union[str, Path]) -> int: + """ + Count the number of MLX delegate segments in a PTE file. + + Args: + pte_path: Path to the .pte file + + Returns: + Number of MLX delegate segments found + """ + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + + try: + program_json = _program_flatbuffer_to_json(pte_data) + program_data = json.loads(program_json) + + # Count all MLX delegates across all execution plans + count = 0 + for plan in program_data.get("execution_plan", []): + for delegate in plan.get("delegates", []): + delegate_name = delegate.get("id", "") + # Match MLXBackend (case-insensitive) + if "mlx" in delegate_name.lower(): + count += 1 + + return count + except Exception as e: + print(f"Error counting MLX segments: {e}") + return 0 + + +def get_mlx_node_counts(pte_path: Union[str, Path]) -> Dict[str, int]: + """ + Get a count of each MLX op node type in a serialized .pte file. + + Args: + pte_path: Path to the .pte file + + Returns: + Dictionary mapping op name (e.g. "SdpaNode", "SliceUpdateNode") to count. + """ + data = inspect_pte_file(pte_path) + graph = data.get("graph", {}) + counts: Dict[str, int] = {} + for chain_info in graph.get("instruction_chains", []): + for instr in chain_info.get("instructions", []): + op_name = instr.get("op_name") + if op_name: + counts[op_name] = counts.get(op_name, 0) + 1 + return counts + + +def compare_outputs( + expected: List[torch.Tensor], + actual: List[torch.Tensor], + rtol: float = 1e-5, + atol: float = 1e-5, +) -> Tuple[bool, str]: + """ + Compare expected and actual outputs using torch.allclose. + + Returns: + (passed, message) tuple + """ + if len(expected) != len(actual): + return ( + False, + f"Output count mismatch: expected {len(expected)}, got {len(actual)}", + ) + + for i, (exp, act) in enumerate(zip(expected, actual)): + if exp.shape != act.shape: + return ( + False, + f"Output {i} shape mismatch: expected {exp.shape}, got {act.shape}", + ) + + if exp.dtype != act.dtype: + # Convert both to float32 for comparison + exp = exp.float() + act = act.float() + + # For bool tensors, use exact comparison + if exp.dtype == torch.bool: + if not torch.equal(exp, act): + mismatches = (exp != act).sum().item() + total = exp.numel() + return False, ( + f"Output {i} values do not match:\n" + f" {mismatches}/{total} elements differ\n" + f" expected[:5]={exp.flatten()[:5].tolist()}\n" + f" actual[:5]={act.flatten()[:5].tolist()}" + ) + elif not torch.allclose(exp, act, rtol=rtol, atol=atol): + diff = (exp - act).abs() + max_diff = diff.max().item() + mean_diff = diff.float().mean().item() + return False, ( + f"Output {i} values do not match:\n" + f" max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}\n" + f" rtol={rtol}, atol={atol}\n" + f" expected[:5]={exp.flatten()[:5].tolist()}\n" + f" actual[:5]={act.flatten()[:5].tolist()}" + ) + + return True, "All outputs match" + + +def find_executorch_root() -> Path: # noqa: C901 + """Find the executorch root directory.""" + test_dir = Path(__file__).parent + + # Walk up to find the executorch root (has CMakeLists.txt and backends dir at root) + executorch_root = test_dir + for _ in range(10): # Max 10 levels up + if (executorch_root / "CMakeLists.txt").exists() and ( + executorch_root / "backends" + ).exists(): + # Check if we're in src/executorch (editable install) + if ( + executorch_root.name == "executorch" + and executorch_root.parent.name == "src" + ): + executorch_root = executorch_root.parent.parent + break + executorch_root = executorch_root.parent + + # If we didn't find a valid root (e.g. running from a pip-installed + # site-packages), fall back to cwd which is typically the repo root. + if not (executorch_root / "CMakeLists.txt").exists(): + cwd = Path.cwd() + if (cwd / "CMakeLists.txt").exists() and (cwd / "backends").exists(): + executorch_root = cwd + + return executorch_root + + +def find_build_dir(): + """Find the cmake build directory containing op_test_runner.""" + executorch_root = find_executorch_root() + + # Check common build locations + candidates = [ + executorch_root / "cmake-out-mlx", + executorch_root / "cmake-out", + executorch_root / "build", + ] + + for candidate in candidates: + runner_path = candidate / "backends" / "mlx" / "test" / "op_test_runner" + if runner_path.exists(): + return candidate + + # Return first candidate that exists as a directory (for rebuild) + for candidate in candidates: + if candidate.is_dir(): + return candidate + + return None + + +def find_op_test_runner() -> Path: + """Find the op_test_runner binary.""" + executorch_root = find_executorch_root() + + # Check common build locations + candidates = [ + executorch_root + / "cmake-out-mlx" + / "backends" + / "mlx" + / "test" + / "op_test_runner", + executorch_root / "cmake-out" / "backends" / "mlx" / "test" / "op_test_runner", + executorch_root / "build" / "backends" / "mlx" / "test" / "op_test_runner", + ] + + for candidate in candidates: + if candidate.exists(): + return candidate + + raise FileNotFoundError( + "Could not find op_test_runner binary. Tried:\n" + + "\n".join(f" - {c}" for c in candidates) + + "\n\nBuild with:\n" + + " cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON\n" + + " cmake --build cmake-out --target op_test_runner" + ) + + +def rebuild_op_test_runner(verbose: bool = False) -> bool: + """ + Rebuild the op_test_runner binary using cmake. + + Args: + verbose: Whether to print build output. + + Returns: + True if build succeeded, False otherwise. + """ + build_dir = find_build_dir() + if build_dir is None: + print("Error: Could not find cmake build directory.") + print("Make sure you have run cmake configuration first.") + return False + + print(f"Rebuilding op_test_runner in {build_dir}...") + + cmd = ["cmake", "--build", str(build_dir), "--target", "op_test_runner", "-j8"] + + if verbose: + print(f"Running: {' '.join(cmd)}") + + result = subprocess.run( + cmd, + capture_output=not verbose, + text=True, + ) + + if result.returncode != 0: + print(f"Build failed with exit code {result.returncode}") + if not verbose and result.stderr: + print(f"stderr: {result.stderr}") + if not verbose and result.stdout: + print(f"stdout: {result.stdout}") + return False + + print("Build succeeded.") + return True + + +def run_cpp_test_runner( + pte_path: Path, + input_path: Path, + output_path: Path, + verbose: bool = False, + timeout: Optional[int] = None, +) -> bool: + """ + Run the C++ op_test_runner binary. + + Args: + pte_path: Path to the .pte model file. + input_path: Path to input .bin file. + output_path: Path to write output .bin file. + verbose: Whether to print verbose output. + timeout: Timeout in seconds. None means use DEFAULT_TEST_TIMEOUT. + + Returns: + True if execution succeeded, False otherwise. + """ + if timeout is None: + timeout = DEFAULT_TEST_TIMEOUT + + runner = find_op_test_runner() + + cmd = [ + str(runner), + "--pte", + str(pte_path), + "--input", + str(input_path), + "--output", + str(output_path), + ] + if verbose: + cmd.append("--verbose") + + print(f"Running: {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + print(f"TIMEOUT: C++ runner exceeded {timeout}s timeout") + return False + + if result.returncode != 0: + print(f"FAILED: {result.stderr}") + print(f"stdout: {result.stdout}") + return False + + print(f"C++ binary output: {result.stdout.strip()}") + return True + + +# Files that are generated during tests and can be safely cleaned up +GENERATED_TEST_FILES = [ + "model.pte", + "input.bin", + "expected_output.bin", + "actual_output.bin", +] + + +def clean_test_outputs( + test_names: Optional[List[str]] = None, verbose: bool = False +) -> int: + """ + Clean up generated test output files. + + Args: + test_names: Optional list of test names to clean. If None, cleans all tests. + verbose: Whether to print verbose output. + + Returns: + Number of files removed. + """ + test_dir = Path(__file__).parent / "op_tests" + if not test_dir.exists(): + if verbose: + print(f"Test directory does not exist: {test_dir}") + return 0 + + files_removed = 0 + + # Get directories to clean + if test_names: + dirs_to_clean = [ + test_dir / name for name in test_names if (test_dir / name).exists() + ] + else: + dirs_to_clean = [d for d in test_dir.iterdir() if d.is_dir()] + + for subdir in dirs_to_clean: + for filename in GENERATED_TEST_FILES: + filepath = subdir / filename + if filepath.exists(): + if verbose: + print(f"Removing: {filepath}") + filepath.unlink() + files_removed += 1 + + # Remove empty directories + if subdir.exists() and not any(subdir.iterdir()): + if verbose: + print(f"Removing empty directory: {subdir}") + subdir.rmdir() + + return files_removed + + +def get_test_output_size(test_names: Optional[List[str]] = None) -> int: + """ + Get total size of generated test output files in bytes. + + Args: + test_names: Optional list of test names to check. If None, checks all tests. + + Returns: + Total size in bytes. + """ + test_dir = Path(__file__).parent / "op_tests" + if not test_dir.exists(): + return 0 + + total_size = 0 + + # Get directories to check + if test_names: + dirs_to_check = [ + test_dir / name for name in test_names if (test_dir / name).exists() + ] + else: + dirs_to_check = [d for d in test_dir.iterdir() if d.is_dir()] + + for subdir in dirs_to_check: + for filename in GENERATED_TEST_FILES: + filepath = subdir / filename + if filepath.exists(): + total_size += filepath.stat().st_size + + return total_size + + +# Global registry: maps base_name -> (test_class, get_test_configs method) +# Tests are instantiated lazily when actually run, not at import time +_TEST_REGISTRY: Dict[str, type] = {} + + +def register_test(test_class: type) -> type: + """ + Class decorator to register a test class. + + The test class must have: + - A class attribute `name` (str) - the base test name + - A class method `get_test_configs()` that returns a list of OpTestCase instances + + Test instances are created LAZILY when tests are actually run, not at import time. + This avoids creating random tensors at import time and keeps memory usage low. + + Example: + @register_test + class AddTest(OpTestCase): + name = "add" + + @classmethod + def get_test_configs(cls) -> List["OpTestCase"]: + return [ + cls(), # default config + cls(scalar=2.5), # scalar variant + ] + """ + if not hasattr(test_class, "name"): + raise ValueError( + f"Test class {test_class.__name__} must have a 'name' attribute" + ) + + base_name = test_class.name + _TEST_REGISTRY[base_name] = test_class + + return test_class + + +def get_registered_tests() -> Dict[str, List[Tuple[str, "OpTestCase"]]]: + """ + Get all registered tests with their configurations. + + Returns dict mapping base_name -> list of (config_name, test_instance). + Test instances are created fresh each time this is called. + """ + result = {} + for base_name, test_class in _TEST_REGISTRY.items(): + if hasattr(test_class, "get_test_configs"): + configs = test_class.get_test_configs() + else: + configs = [test_class()] + result[base_name] = [(cfg.name, cfg) for cfg in configs] + return result + + +def get_test_names() -> List[str]: + """Get list of registered base test names.""" + return list(_TEST_REGISTRY.keys()) + + +def get_all_test_configs() -> List[Tuple[str, "OpTestCase"]]: + """ + Get flat list of all (config_name, test_instance) tuples. + + Test instances are created fresh each time this is called. + """ + result = [] + for _base_name, test_class in _TEST_REGISTRY.items(): + if hasattr(test_class, "get_test_configs"): + configs = test_class.get_test_configs() + else: + configs = [test_class()] + result.extend((cfg.name, cfg) for cfg in configs) + return result + + +class OpTestCase: + """ + Base class for op test cases. + + Subclasses should implement: + - name: str - test name + - create_model() -> nn.Module + - create_inputs() -> Tuple[torch.Tensor, ...] + + Optionally override: + - get_dynamic_shapes() -> Optional[Dict] - for dynamic shape testing + - create_test_inputs() -> Tuple[torch.Tensor, ...] - inputs for testing (may differ from export inputs) + - expected_mlx_segments: int - expected number of MLX delegate segments (default: 1) + """ + + name: str = "base_test" + rtol: float = 1e-5 + atol: float = 1e-5 + seed: int = 42 # Default seed for reproducibility + timeout: int = DEFAULT_TEST_TIMEOUT # Timeout in seconds + skip_comparison: bool = False # Skip output comparison (for pattern-only tests) + skip_comparison_reason: str = "" # Reason for skipping comparison + expected_mlx_segments: int = 1 # Expected number of MLX delegate segments + expected_node_counts: Optional[Dict[str, int]] = ( + None # Expected serialized node counts + ) + + def _set_seed(self) -> None: + """Set random seed for reproducibility.""" + torch.manual_seed(self.seed) + + def create_model(self) -> torch.nn.Module: + raise NotImplementedError + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Create inputs for export (tracing).""" + raise NotImplementedError + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + """Create inputs for testing. Override for dynamic shape tests.""" + return self.create_inputs() + + def get_dynamic_shapes(self) -> Optional[Dict]: + """Return dynamic shapes specification for torch.export, or None for static shapes.""" + return None + + def get_test_dir(self) -> Path: + """Get the directory for this test's files.""" + test_dir = Path(__file__).parent / "op_tests" / self.name + test_dir.mkdir(parents=True, exist_ok=True) + return test_dir + + def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: + """ + Generate .pte, input.bin, and expected_output.bin files. + + Args: + verbose: Whether to print the exported program for debugging. + + Returns: + (pte_path, input_path, expected_output_path) + """ + test_dir = self.get_test_dir() + + pte_path = test_dir / "model.pte" + input_path = test_dir / "input.bin" + expected_path = test_dir / "expected_output.bin" + + # Set seed for reproducibility + self._set_seed() + + # Create model and inputs + model = self.create_model() + export_inputs = self.create_inputs() + + # Set seed again before creating test inputs (in case they differ) + self._set_seed() + test_inputs = self.create_test_inputs() + + # Get expected outputs using test inputs + model.eval() + with torch.no_grad(): + if isinstance(test_inputs, torch.Tensor): + test_inputs = (test_inputs,) + expected_outputs = model(*test_inputs) + if isinstance(expected_outputs, torch.Tensor): + expected_outputs = [expected_outputs] + else: + expected_outputs = list(expected_outputs) + + # Export model with export inputs (and potentially dynamic shapes) + print(f"Exporting model to {pte_path}") + if isinstance(export_inputs, torch.Tensor): + export_inputs = (export_inputs,) + + dynamic_shapes = self.get_dynamic_shapes() + if dynamic_shapes: + print(f" Using dynamic shapes: {dynamic_shapes}") + + export_model_to_pte( + model, + export_inputs, + pte_path, + dynamic_shapes=dynamic_shapes, + verbose=verbose, + ) + + # Save test inputs + print(f"Saving inputs to {input_path}") + if isinstance(test_inputs, torch.Tensor): + test_inputs = [test_inputs] + else: + test_inputs = list(test_inputs) + save_tensors_to_bin(test_inputs, input_path) + + # Save expected outputs + print(f"Saving expected outputs to {expected_path}") + save_tensors_to_bin(expected_outputs, expected_path) + + return pte_path, input_path, expected_path + + def compare_with_actual( + self, actual_output_path: Union[str, Path], use_dtype_tolerances: bool = False + ) -> Tuple[bool, str]: + """ + Compare actual outputs with expected outputs. + + Args: + actual_output_path: Path to the actual output file. + use_dtype_tolerances: If True, uses tolerance presets based on output dtypes + instead of the test's rtol/atol values. + """ + test_dir = self.get_test_dir() + expected_path = test_dir / "expected_output.bin" + + expected = load_tensors_from_bin(expected_path) + actual = load_tensors_from_bin(actual_output_path) + + # Determine tolerances + if use_dtype_tolerances: + # Use dtype-based tolerances (loosest tolerance across all output dtypes) + output_dtypes = [t.dtype for t in expected] + rtol, atol = get_tolerance_for_dtypes(output_dtypes) + else: + rtol, atol = self.rtol, self.atol + + return compare_outputs(expected, actual, rtol=rtol, atol=atol) + + def run_test(self, verbose: bool = False, timeout: Optional[int] = None) -> bool: + """ + Run the full test: generate files, run C++, compare outputs. + + Args: + verbose: Whether to print verbose output. + timeout: Timeout in seconds. None means use self.timeout. + + Returns: + True if test passed, False otherwise. + """ + if timeout is None: + timeout = self.timeout + + print(f"\n{'='*60}") + print(f"Running test: {self.name}") + print(f"{'='*60}\n") + + # Generate test files + print("Step 1: Generating test files...") + pte_path, input_path, expected_path = self.generate_test_files(verbose=verbose) + + # Print MLX graph summary + print_mlx_graph_summary(pte_path) + + # Verify expected number of MLX delegate segments + print("\nStep 2: Verifying MLX delegation...") + actual_segments = count_mlx_delegate_segments(pte_path) + print(f" Expected MLX segments: {self.expected_mlx_segments}") + print(f" Actual MLX segments: {actual_segments}") + + if actual_segments != self.expected_mlx_segments: + print("✗ FAILED: MLX delegation mismatch!") + print( + f" Expected {self.expected_mlx_segments} segment(s), but found {actual_segments}" + ) + return False + print("✓ MLX delegation verified") + + # Verify expected node counts if specified + if self.expected_node_counts is not None: + print("\n Verifying serialized node counts...") + actual_counts = get_mlx_node_counts(pte_path) + for node_name, expected_count in self.expected_node_counts.items(): + actual_count = actual_counts.get(node_name, 0) + if actual_count != expected_count: + print(f"✗ FAILED: Node count mismatch for {node_name}!") + print(f" Expected {expected_count}, got {actual_count}") + print(f" All node counts: {actual_counts}") + return False + print(f" ✓ {node_name}: {actual_count}") + print(" ✓ All node counts verified") + + # Run C++ binary + print("\nStep 3: Running C++ binary...") + actual_path = self.get_test_dir() / "actual_output.bin" + if not run_cpp_test_runner( + pte_path, input_path, actual_path, verbose=verbose, timeout=timeout + ): + return False + + # Compare outputs (or skip if configured) + print("\nStep 4: Comparing outputs...") + if self.skip_comparison: + reason = self.skip_comparison_reason or "skip_comparison=True" + print(f"NOTE: Output comparison skipped ({reason})") + print("✓ PASSED (runtime execution succeeded)") + return True + + passed, message = self.compare_with_actual(actual_path) + + if passed: + print(f"✓ PASSED: {message}") + else: + print(f"✗ FAILED: {message}") + + return passed + + +def run_op_test_main( + test_factory, + description: str, + add_args_fn=None, +): + """ + Common main() function for op tests. + + This handles the common argparse setup, rebuild logic, and generate/compare/run + action handling that is shared across all op tests. + + Args: + test_factory: A callable that takes parsed args (argparse.Namespace) and + returns an OpTestCase instance. + description: Description for the argparse help message. + add_args_fn: Optional callable that takes a parser and adds test-specific + arguments. Signature: add_args_fn(parser) -> None + """ + import argparse + import sys + + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "action", + choices=["generate", "compare", "run"], + help="Action to perform: generate (create test files), compare (compare outputs), run (full test)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + parser.add_argument( + "--rebuild", + action="store_true", + help="Rebuild the C++ test runner before running", + ) + + # Add test-specific arguments + if add_args_fn is not None: + add_args_fn(parser) + + args = parser.parse_args() + + # Rebuild if requested + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + # Create test case from factory + test = test_factory(args) + + if args.action == "generate": + pte_path, input_path, expected_path = test.generate_test_files( + verbose=args.verbose + ) + print("\nGenerated files:") + print(f" PTE: {pte_path}") + print(f" Input: {input_path}") + print(f" Expected: {expected_path}") + print_mlx_graph_summary(pte_path) + + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + if not actual_path.exists(): + print(f"Error: {actual_path} not found. Run the C++ binary first.") + sys.exit(1) + + passed, message = test.compare_with_actual(actual_path) + if passed: + print(f"✓ PASSED: {message}") + else: + print(f"✗ FAILED: {message}") + sys.exit(0 if passed else 1) + + elif args.action == "run": + passed = test.run_test(verbose=args.verbose) + sys.exit(0 if passed else 1) diff --git a/backends/mlx/test/tester.py b/backends/mlx/test/tester.py new file mode 100644 index 00000000000..7a929ea7c3b --- /dev/null +++ b/backends/mlx/test/tester.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, List, Optional, Tuple + +import executorch +import executorch.backends.test.harness.stages as BaseStages +import torch + +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import StageType +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.partitioner import Partitioner + + +def _create_default_partitioner( + compile_specs: List[CompileSpec] | None = None, +) -> MLXPartitioner: + return MLXPartitioner(compile_specs=compile_specs) + + +class Partition(BaseStages.Partition): + def __init__( + self, + partitioner: Optional[Partitioner] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + super().__init__( + partitioner=partitioner or _create_default_partitioner(compile_specs), + ) + + +class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): + def __init__( + self, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + super().__init__( + default_partitioner_cls=lambda: _create_default_partitioner(compile_specs), + partitioners=partitioners, + edge_compile_config=edge_compile_config, + ) + + +class MLXTester(TesterBase): + def __init__( + self, + module: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], + dynamic_shapes: Optional[Tuple[Any]] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + stage_classes = ( + executorch.backends.test.harness.Tester.default_stage_classes() + | { + StageType.PARTITION: functools.partial( + Partition, compile_specs=compile_specs + ), + StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial( + ToEdgeTransformAndLower, compile_specs=compile_specs + ), + } + ) + + super().__init__( + module=module, + stage_classes=stage_classes, + example_inputs=example_inputs, + dynamic_shapes=dynamic_shapes, + ) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx new file mode 160000 index 00000000000..72e94c81e16 --- /dev/null +++ b/backends/mlx/third-party/mlx @@ -0,0 +1 @@ +Subproject commit 72e94c81e1685c90679ef03532c4b8897010abf9 diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index f3c9ee75083..c9142581f33 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -53,7 +53,7 @@ def __str__(self): return self.name -def all_flows() -> dict[str, TestFlow]: +def all_flows() -> dict[str, TestFlow]: # noqa: C901 flows = [] from executorch.backends.test.suite.flows.portable import PORTABLE_TEST_FLOW @@ -147,4 +147,13 @@ def all_flows() -> dict[str, TestFlow]: except Exception as e: logger.info(f"Skipping ARM flow registration: {e}") + try: + from executorch.backends.test.suite.flows.mlx import MLX_TEST_FLOW + + flows += [ + MLX_TEST_FLOW, + ] + except Exception as e: + logger.info(f"Skipping MLX flow registration: {e}") + return {f.name: f for f in flows if f is not None} diff --git a/backends/test/suite/flows/mlx.py b/backends/test/suite/flows/mlx.py new file mode 100644 index 00000000000..d70db46b73c --- /dev/null +++ b/backends/test/suite/flows/mlx.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.mlx.test.tester import MLXTester +from executorch.backends.test.suite.flow import TestFlow + +MLX_TEST_FLOW = TestFlow( + name="mlx", + backend="mlx", + tester_factory=MLXTester, +) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index be7bf0bd56f..e2c31f0c5fc 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -765,3 +765,70 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile: ) return PTEFile(program=program, mutable_data=None, named_data=None) + + +def _extract_delegate_payload( + pte_data: bytes, backend_id: str, delegate_index: int = 0 +) -> Optional[bytes]: + """Extract a delegate payload from a serialized PTE file. + + Parses the PTE file structure, finds the delegate matching the given + backend ID, and returns its raw payload bytes. Handles both inline + delegate data and segment-based storage. + + Args: + pte_data: Raw bytes of the PTE file. + backend_id: ID substring to match (case-insensitive). + For example, 'mlx' matches 'MLXBackend'. + delegate_index: Which matching delegate to extract (0-based). + Defaults to 0 (first match). + + Returns: + Delegate payload bytes, or None if not found. + """ + # Parse the extended header + extended_header = _get_extended_header(pte_data) + + # Determine program size from header or use full data + if extended_header is not None: + program_size = extended_header.program_size + else: + program_size = len(pte_data) + + # Parse the program flatbuffer + program: Program = _json_to_program( + _program_flatbuffer_to_json(pte_data[:program_size]) + ) + + # Search for the matching delegate + match_count = 0 + for plan in program.execution_plan: + for delegate in plan.delegates: + if backend_id.lower() not in delegate.id.lower(): + continue + if match_count != delegate_index: + match_count += 1 + continue + + processed = delegate.processed + + # Inline data + if processed.location == DataLocation.INLINE: + inline_data = program.backend_delegate_data[processed.index] + if inline_data.data: + return bytes(inline_data.data) + return None + + # Segment data + if processed.location == DataLocation.SEGMENT: + if extended_header is None: + return None + + segment = program.segments[processed.index] + offset = extended_header.segment_base_offset + segment.offset + size = segment.size + return pte_data[offset : offset + size] + + return None + + return None diff --git a/setup.py b/setup.py index f05951012e3..d07736128c8 100644 --- a/setup.py +++ b/setup.py @@ -624,6 +624,26 @@ def run(self): # the input file is read-only. self.copy_file(src, dst, preserve_mode=False) + # Copy CMake-generated Python directories that setuptools missed. + # Setuptools discovers packages at configuration time, before CMake + # runs. Directories created by CMake during the build (e.g. by + # generate.py) are not in the package list and must be copied manually. + generated_dirs = [ + "backends/mlx/serialization/_generated", + ] + for rel_dir in generated_dirs: + src_dir = os.path.join("src/executorch", rel_dir) + if not os.path.isdir(src_dir): + continue + dst_dir = os.path.join(dst_root, rel_dir) + for dirpath, _dirnames, filenames in os.walk(src_dir): + for filename in filenames: + src_file = os.path.join(dirpath, filename) + rel_path = os.path.relpath(src_file, src_dir) + dst_file = os.path.join(dst_dir, rel_path) + self.mkpath(os.path.dirname(dst_file)) + self.copy_file(src_file, dst_file, preserve_mode=False) + class Buck2EnvironmentFixer(contextlib.AbstractContextManager): """Removes HOME from the environment when running as root. @@ -786,6 +806,9 @@ def run(self): # noqa C901 if cmake_cache.is_enabled("EXECUTORCH_BUILD_COREML"): cmake_build_args += ["--target", "executorchcoreml"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_MLX"): + cmake_build_args += ["--target", "mlxdelegate"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_KERNELS_LLM_AOT"): cmake_build_args += ["--target", "custom_ops_aot_lib"] cmake_build_args += ["--target", "quantized_ops_aot_lib"] @@ -846,6 +869,16 @@ def run(self): # noqa C901 modpath="executorch.extension.pybindings.data_loader", dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"], ), + # MLX metallib (Metal GPU kernels) must be colocated with _portable_lib.so + # because MLX uses dladdr() to find the directory containing the library, + # then looks for mlx.metallib in that directory at runtime. + # After submodule migration, the path is backends/mlx/mlx/... + BuiltFile( + src_dir="%CMAKE_CACHE_DIR%/backends/mlx/mlx/mlx/backend/metal/kernels/", + src_name="mlx.metallib", + dst="executorch/extension/pybindings/", + dependent_cmake_flags=["EXECUTORCH_BUILD_MLX"], + ), BuiltExtension( src="extension/training/_training_lib.*", # @lint-ignore https://github.com/pytorch/executorch/blob/cb3eba0d7f630bc8cec0a9cc1df8ae2f17af3f7a/scripts/lint_xrefs.sh modpath="executorch.extension.training.pybindings._training_lib", diff --git a/tools/cmake/Utils.cmake b/tools/cmake/Utils.cmake index 74f2be78804..3295036663c 100644 --- a/tools/cmake/Utils.cmake +++ b/tools/cmake/Utils.cmake @@ -178,3 +178,36 @@ function(executorch_add_prefix_to_public_headers targetName prefix) TARGET "${targetName}" PROPERTY PUBLIC_HEADER ${FIXED_PUBLIC_HEADERS} ) endfunction() + +# ----------------------------------------------------------------------------- +# MLX metallib distribution helper +# ----------------------------------------------------------------------------- +# Copies mlx.metallib next to the target executable so MLX can find it at +# runtime. +# +# MLX uses dladdr() to find the directory containing the binary with MLX code, +# then looks for mlx.metallib in that directory. When MLX is statically linked +# into an executable or shared library, this function ensures the metallib is +# colocated with that binary. +# +# Usage: executorch_target_copy_mlx_metallib(my_executable) +# +function(executorch_target_copy_mlx_metallib target) + if(EXECUTORCH_BUILD_MLX) + if(DEFINED MLX_METALLIB_PATH AND EXISTS "${MLX_METALLIB_PATH}") + add_custom_command( + TARGET ${target} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${MLX_METALLIB_PATH}" + "$/mlx.metallib" + COMMENT "Copying mlx.metallib for ${target}" + ) + elseif(DEFINED MLX_METALLIB_PATH) + message( + WARNING + "MLX_METALLIB_PATH is set to ${MLX_METALLIB_PATH} but file does not exist. " + "metallib will not be copied for ${target}." + ) + endif() + endif() +endfunction() diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index dc4d34d8701..524e1be36ec 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -63,6 +63,8 @@ set(optional_lib_list coreml_inmemoryfs coremldelegate mpsdelegate + mlxdelegate + mlx metal_backend neuron_backend qnn_executorch_backend @@ -118,3 +120,46 @@ set_property( TARGET executorch_core PROPERTY INTERFACE_LINK_LIBRARIES ${FIXED_EXECUTORCH_CORE_LINK_LIBRARIES} ) + +# Expose MLX library and metallib path for downstream consumers +if(TARGET mlxdelegate) + # Create imported target for mlx library if not already defined (mlx is built + # by MLX's CMake but we need to expose it for linking) + if(NOT TARGET mlx) + find_library( + _mlx_library mlx + HINTS "${_root}/lib" + CMAKE_FIND_ROOT_PATH_BOTH + ) + if(_mlx_library) + add_library(mlx STATIC IMPORTED) + set_target_properties(mlx PROPERTIES IMPORTED_LOCATION "${_mlx_library}") + # MLX requires Metal and Foundation frameworks on Apple platforms + if(APPLE) + find_library(METAL_FRAMEWORK Metal) + find_library(FOUNDATION_FRAMEWORK Foundation) + if(METAL_FRAMEWORK AND FOUNDATION_FRAMEWORK) + set_target_properties( + mlx PROPERTIES INTERFACE_LINK_LIBRARIES + "${METAL_FRAMEWORK};${FOUNDATION_FRAMEWORK}" + ) + endif() + endif() + message(STATUS "Found mlx library at: ${_mlx_library}") + endif() + endif() + + # Find metallib for runtime distribution + find_file( + _mlx_metallib mlx.metallib + HINTS "${_root}/lib" + CMAKE_FIND_ROOT_PATH_BOTH + ) + if(_mlx_metallib) + set(MLX_METALLIB_PATH + "${_mlx_metallib}" + CACHE FILEPATH "Path to mlx.metallib for runtime distribution" + ) + message(STATUS "Found mlx.metallib at: ${MLX_METALLIB_PATH}") + endif() +endif() diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index 1caf8ea9602..9280d0db915 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -121,6 +121,7 @@ define_overridable_option( EXECUTORCH_BUILD_EXTENSION_APPLE "Build the Apple extension" BOOL OFF ) define_overridable_option(EXECUTORCH_BUILD_MPS "Build the MPS backend" BOOL OFF) +define_overridable_option(EXECUTORCH_BUILD_MLX "Build the MLX backend" BOOL OFF) define_overridable_option( EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" BOOL OFF ) diff --git a/tools/cmake/preset/pybind.cmake b/tools/cmake/preset/pybind.cmake index 699a7c50358..dc60dc7d820 100644 --- a/tools/cmake/preset/pybind.cmake +++ b/tools/cmake/preset/pybind.cmake @@ -31,6 +31,24 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) + # MLX requires Apple Silicon (ARM64) and the Metal compiler (xcrun -sdk macosx + # metal) which is only available with Xcode, not Command Line Tools + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + execute_process( + COMMAND xcrun -sdk macosx --find metal + RESULT_VARIABLE _metal_compiler_result + OUTPUT_QUIET ERROR_QUIET + ) + if(_metal_compiler_result EQUAL 0) + set_overridable_option(EXECUTORCH_BUILD_MLX ON) + set_overridable_option(ET_MLX_ENABLE_OP_LOGGING ON) + else() + message( + STATUS + "Metal compiler not found, disabling MLX backend. Install Xcode to enable MLX." + ) + endif() + endif() elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") set_overridable_option(EXECUTORCH_BUILD_COREML ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) From 0f03f2b8032289b997125356c44727c6d587a052 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:06:21 -0800 Subject: [PATCH 02/34] up --- tools/cmake/preset/mlx.cmake | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tools/cmake/preset/mlx.cmake diff --git a/tools/cmake/preset/mlx.cmake b/tools/cmake/preset/mlx.cmake new file mode 100644 index 00000000000..d8ea7fe237f --- /dev/null +++ b/tools/cmake/preset/mlx.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# MLX delegate preset - builds ExecuTorch with MLX backend for Apple Silicon + +# Core ExecuTorch options +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_MODULE ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TENSOR ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER ON) +set_overridable_option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED ON) + +# Build the MLX delegate +set_overridable_option(EXECUTORCH_BUILD_MLX ON) From bf3673208f279bee45fa8f261e6c5fd9da3b7397 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:21:54 -0800 Subject: [PATCH 03/34] up --- backends/mlx/runtime/MLXInterpreter.h | 8 ++++++++ backends/mlx/serialization/schema.fbs | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index f3b6e9b720f..bfd593c162b 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -98,6 +98,11 @@ inline std::vector infer_shape_with_minus_one( inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} +inline void +exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { + st.set_tensor(n.out, st.const_tensor_ref(n.x)); +} + inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { const auto& mat1 = st.const_tensor_ref(n.mat1); @@ -154,6 +159,9 @@ class Interpreter { case OpCode::NOOP: ops::exec_noop(std::get(instr.node), st, s); break; + case OpCode::ID_COPY: + ops::exec_id_copy(std::get(instr.node), st, s); + break; case OpCode::ADDMM: ops::exec_addmm(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 945186ebef8..8b159314760 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -72,6 +72,11 @@ table IntOrVidOrTid { table NoopNode {} +table IdCopyNode { + x: Tid (required); + out: Tid (required); +} + table AddmmNode { mat1: Tid (required); // First matrix mat2: Tid (required); // Second matrix @@ -89,6 +94,7 @@ table AddmmNode { // Reordering or removing members changes numeric type IDs and breaks existing .pte files. union OpNode { NoopNode, + IdCopyNode, AddmmNode // BC: Add new op nodes here (append only) } From 6a2d4556701c5db003df99c0d6de95f5a50db3ea Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:50:55 -0800 Subject: [PATCH 04/34] up --- backends/mlx/ops.py | 6 +++ backends/mlx/test/test_ops.py | 91 ----------------------------------- 2 files changed, 6 insertions(+), 91 deletions(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 6e8516e86b1..4c9e0d6f796 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -219,6 +219,12 @@ def normalize_reduction_dim( return dim, keepdim +@REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) +def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: + """No-op handler for nodes that don't emit any MLX instructions.""" + return None + + @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 01286f75f16..0ba98b532ad 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -83,94 +83,3 @@ def create_model(self) -> nn.Module: def create_inputs(self) -> Tuple[torch.Tensor, ...]: x = torch.randn(self.batch_size, self.n, self.m) return (x,) - - -class AddmmModel(nn.Module): - """Model that performs addmm: bias + (mat1 @ mat2).""" - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - alpha: float = 1.0, - beta: float = 1.0, - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(out_features, in_features)) - if bias: - self.bias = nn.Parameter(torch.randn(out_features)) - else: - self.bias = None - self.alpha = alpha - self.beta = beta - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm( - self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha - ) - else: - return torch.mm(x, self.weight.t()) - - -@register_test -class AddmmTest(OpTestCase): - """Test case for addmm.""" - - name = "addmm" - rtol = 1e-4 - atol = 1e-4 - - def __init__( - self, - batch_size: int = 2, - in_features: int = 64, - out_features: int = 32, - bias: bool = True, - alpha: float = 1.0, - beta: float = 1.0, - ): - self.batch_size = batch_size - self.in_features = in_features - self.out_features = out_features - self.bias = bias - self.alpha = alpha - self.beta = beta - - # Build unique test name - if not bias: - name = f"addmm_{in_features}x{out_features}_no_bias" - elif alpha != 1.0 or beta != 1.0: - name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" - else: - name = f"addmm_{in_features}x{out_features}" - self.name = name - - @classmethod - def get_test_configs(cls) -> List["AddmmTest"]: - return [ - cls( - batch_size=2, in_features=64, out_features=32 - ), # with bias, default alpha/beta - cls( - batch_size=2, in_features=64, out_features=32, bias=False - ), # without bias - cls(batch_size=4, in_features=128, out_features=64), # larger size - cls( - batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 - ), # custom alpha/beta - ] - - def create_model(self) -> nn.Module: - return AddmmModel( - self.in_features, - self.out_features, - bias=self.bias, - alpha=self.alpha, - beta=self.beta, - ) - - def create_inputs(self) -> Tuple[torch.Tensor, ...]: - x = torch.randn(self.batch_size, self.in_features) - return (x,) From 493d9ea5e76525ccbdc66b39dca624ad29c67a91 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:32:00 -0800 Subject: [PATCH 05/34] up --- backends/mlx/builder/program_builder.py | 19 +++++++++++++----- backends/mlx/runtime/MLXBackend.cpp | 26 ++++++++++++++++++++++--- backends/mlx/runtime/MLXExecutor.h | 20 ++++++++++++++++++- backends/mlx/test/test_utils.py | 10 +++++++++- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 60d5ebbdbfe..2add4f1b7a3 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -27,7 +27,6 @@ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union import torch - from executorch.backends.mlx._logging import logger from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type from executorch.backends.mlx.builder.op_registry import ( @@ -132,7 +131,9 @@ class MLXProgramBuilder: def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): self.ep: ExportedProgram = ep - self._instrs: List[Instruction] = [] + self._chains: List[List[Instruction]] = [[]] # chain 0 = main + self._current_chain: int = 0 + self.init_chain_idx: int = -1 self.extra_constants: Dict[str, torch.Tensor] = {} self.slot_manager = SlotManager() self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) @@ -163,7 +164,13 @@ def _prefix_key(self, name: str) -> str: return name def emit(self, op: OpNodeUnion) -> None: - self._instrs.append(Instruction(op=op)) + self._chains[self._current_chain].append(Instruction(op=op)) + + def emit_init(self, op: OpNodeUnion) -> None: + if self.init_chain_idx == -1: + self.init_chain_idx = len(self._chains) + self._chains.append([]) + self._chains[self.init_chain_idx].append(Instruction(op=op)) def args(self, node: Node) -> Tuple[Any, ...]: return self.slot_map(node.args) @@ -934,9 +941,11 @@ def _build_mlx_graph(self) -> MLXGraph: num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], num_temp_tensors=num_temp_tensors, num_values=num_values_count, - instruction_chains=[InstructionChain(instructions=self._instrs)], + instruction_chains=[ + InstructionChain(instructions=chain) for chain in self._chains + ], main_chain_idx=0, - init_chain_idx=-1, + init_chain_idx=self.init_chain_idx, input_map=input_map, output_map=output_map, mutable_buffer_map=mutable_buffer_map, diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 38dff189935..99e20114ea7 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -219,10 +219,24 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { static_cast(processed->data()), processed->size()); // Validate schema version - if (handle->program.version != "1") { + int schema_version = 1; + if (!handle->program.version.empty()) { + try { + schema_version = std::stoi(handle->program.version); + } catch (...) { + throw std::runtime_error( + "Invalid MLX schema version '" + handle->program.version + + "' (expected integer)"); + } + } + constexpr int kMaxSupportedVersion = 1; + if (schema_version > kMaxSupportedVersion) { throw std::runtime_error( - "Unsupported MLX schema version '" + handle->program.version + - "' (expected '1'). Rebuild the .pte with a matching SDK version."); + "This .pte requires ExecuTorch MLX runtime version " + + std::to_string(schema_version) + + " but this runtime only supports up to version " + + std::to_string(kMaxSupportedVersion) + + ". Upgrade ExecuTorch to a newer version."); } // Load constants from named_data_map @@ -251,11 +265,17 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the // static_cast cannot produce UINT32_MAX from a -1 sentinel. if (handle->program.init_chain_idx >= 0) { + handle->state.is_init_chain = true; handle->interpreter.run_chain( handle->program, static_cast(handle->program.init_chain_idx), handle->state, handle->stream); + handle->state.is_init_chain = false; + + // Evaluate any constants written by the init chain so the first + // execute() doesn't pay the cost of materializing them. + eval(handle->constants.tensors); } } catch (const std::exception& e) { diff --git a/backends/mlx/runtime/MLXExecutor.h b/backends/mlx/runtime/MLXExecutor.h index 32d623790ab..978eaadabba 100644 --- a/backends/mlx/runtime/MLXExecutor.h +++ b/backends/mlx/runtime/MLXExecutor.h @@ -97,6 +97,13 @@ struct ConstantData { return tensors[id.idx]; } + inline void set(Tid id, Tensor t) { + if (id.idx >= tensors.size()) { + throw std::out_of_range("ConstantData::set: id out of range"); + } + tensors[id.idx] = std::move(t); + } + inline void add(Tensor t) { tensors.push_back(std::move(t)); } @@ -153,6 +160,9 @@ struct ExecutionState { // Non-constant values (SymInt, etc.) std::vector> values; + // Init chain flag: when true, set_tensor allows writing to constants + bool is_init_chain{false}; + // Logging context size_t current_op_idx{0}; const char* current_op_name{nullptr}; @@ -478,7 +488,15 @@ struct ExecutionState { throw std::runtime_error("set_tensor: Program not bound"); } if (id.idx < program->num_constant_tensors) { - throw std::runtime_error("set_tensor: cannot write to constant tensor"); + if (!is_init_chain) { + throw std::runtime_error("set_tensor: cannot write to constant tensor"); + } + // Init chain can write over constants + if (!constants) { + throw std::runtime_error("set_tensor: constants not bound"); + } + const_cast(constants)->set(id, std::move(arr)); + return; } // Route to mutable buffers or per-execution tensors if (is_mutable_buffer(id)) { diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 090bceabf08..660968195b7 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import executorch.exir as exir import numpy as np import torch @@ -268,6 +269,7 @@ def export_model_to_pte( output_path: Union[str, Path], dynamic_shapes: Optional[Dict] = None, verbose: bool = False, + edge_compile_config: Optional[exir.EdgeCompileConfig] = None, ) -> None: """ Export a PyTorch model to a .pte file using the MLX delegate. @@ -281,7 +283,6 @@ def export_model_to_pte( Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. verbose: Whether to print the exported program for debugging. """ - import executorch.exir as exir from executorch.backends.mlx import MLXPartitioner from executorch.exir.capture._config import ExecutorchBackendConfig from torch.export import export @@ -301,9 +302,11 @@ def export_model_to_pte( print(exported_program) # Lower to edge and delegate to MLX + compile_config = edge_compile_config or exir.EdgeCompileConfig() edge_program = exir.to_edge_transform_and_lower( exported_program, partitioner=[MLXPartitioner()], + compile_config=compile_config, ) # Print edge program if verbose @@ -865,6 +868,10 @@ def get_dynamic_shapes(self) -> Optional[Dict]: """Return dynamic shapes specification for torch.export, or None for static shapes.""" return None + def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]: + """Return EdgeCompileConfig for export, or None for default.""" + return None + def get_test_dir(self) -> Path: """Get the directory for this test's files.""" test_dir = Path(__file__).parent / "op_tests" / self.name @@ -924,6 +931,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: pte_path, dynamic_shapes=dynamic_shapes, verbose=verbose, + edge_compile_config=self.get_edge_compile_config(), ) # Save test inputs From 5ee8ac41c1be9f4a496c3e7f890fc75caacab7c1 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:33:28 -0800 Subject: [PATCH 06/34] up --- backends/mlx/third-party/mlx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx index 72e94c81e16..365d6f29b47 160000 --- a/backends/mlx/third-party/mlx +++ b/backends/mlx/third-party/mlx @@ -1 +1 @@ -Subproject commit 72e94c81e1685c90679ef03532c4b8897010abf9 +Subproject commit 365d6f29b47686a9f5401f6a9ec5825fee162d69 From 0df21d99285cb1e97c8819bb04ae789e0d1c8c4e Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:44:36 -0800 Subject: [PATCH 07/34] up --- .github/workflows/mlx.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 2e8ca7aa3b7..ea0bce96e1a 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,11 +9,13 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** + - extension/llm/export/** + - extension/audio/** + - examples/models/parakeet/** + - examples/models/voxtral_realtime/** workflow_dispatch: -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true +permissions: {} jobs: test-mlx: From 93afd3e1ac787d5deecb1179df9cb5bac266e89e Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:50:26 -0800 Subject: [PATCH 08/34] up --- .github/workflows/mlx.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index ea0bce96e1a..cc83c90e23e 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,10 +9,6 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** - - extension/llm/export/** - - extension/audio/** - - examples/models/parakeet/** - - examples/models/voxtral_realtime/** workflow_dispatch: permissions: {} From 0adbe8c752f49911d541038bada8b73259b95752 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 11:23:22 -0800 Subject: [PATCH 09/34] up --- backends/mlx/README.md | 123 ++++++++++++++++--------- backends/mlx/serialization/generate.py | 2 +- 2 files changed, 82 insertions(+), 43 deletions(-) diff --git a/backends/mlx/README.md b/backends/mlx/README.md index ebab893385a..eea60fe2d00 100644 --- a/backends/mlx/README.md +++ b/backends/mlx/README.md @@ -193,7 +193,7 @@ ExportedProgram (subgraph) ## How to Add a New Op -This section walks through adding a new op end-to-end, using **`aten.linear`** +This section walks through adding a new op end-to-end, using **`aten.addmm`** as an example. ### Step 1: Add the Node to `schema.fbs` @@ -201,15 +201,15 @@ as an example. Add a new table in the "Op nodes" section and add it to the `OpNode` union: ```fbs -table LinearNode { - x: Tid (required); - weight: Tid (required); +table AddmmNode { + mat1: Tid (required); + mat2: Tid (required); out: Tid (required); bias: Tid; // optional } ``` -Then add `LinearNode` to the `union OpNode { ... }` list. +Then add `AddmmNode` to the `union OpNode { ... }` list. ### Step 2: Run the Code Generator @@ -219,34 +219,40 @@ python backends/mlx/serialization/generate.py This regenerates: -- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass -- `_generated_serializers.py` — adds `_build_LinearNode` serializer -- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader -- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization +- `mlx_graph_schema.py` — adds `AddmmNode` Python dataclass +- `_generated_serializers.py` — adds `_build_AddmmNode` serializer +- `runtime/MLXLoader.h` — adds `AddmmNode` C++ struct, `OpCode::ADDMM`, loader +- `runtime/MLXLoader.cpp` — adds FlatBuffer → `AddmmNode` deserialization - `runtime/schema_generated.h` — FlatBuffer C++ bindings ### Step 3: Add the Python Op Handler (`ops.py`) Register a handler that converts the ATen op to your new node. Make sure to -import `LinearNode` from `mlx_graph_schema`: +import `AddmmNode` from `mlx_graph_schema`: ```python -from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode +from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode -@REGISTRY.register(target=[torch.ops.aten.linear.default]) -def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: +@REGISTRY.register(target=[torch.ops.aten.addmm.default]) +def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: args = P.args(n) - require_args(args, 2, 3, "aten.linear") - require_kwargs(P.kwargs(n), set(), "aten.linear") - x, w = args[0], args[1] - b = args[2] if len(args) > 2 else None + kwargs = P.kwargs(n) + require_args(args, 3, 3, "aten.addmm") + require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm") + bias, mat1, mat2 = args[0], args[1], args[2] + + beta = kwargs.get("beta", 1) + alpha = kwargs.get("alpha", 1) + out = P.make_or_get_slot(n) P.emit( - LinearNode( - x=P.slot_to_tid(x), - weight=P.slot_to_tid(w), + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), out=P.slot_to_tid(out), - bias=P.slot_to_tid(b) if b else None, + bias=P.slot_to_tid(bias), + alpha=float(alpha), + beta=float(beta), ) ) return out @@ -263,21 +269,28 @@ Key APIs: Add an `exec_*` function in the `ops` namespace: ```cpp -inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { - const auto& X = st.const_tensor_ref(n.x); - auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s); - array Y = n.bias - ? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s) - : matmul(X, W, s); +inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + st.set_tensor(n.out, std::move(Y)); } ``` -Then add the dispatch case in `Interpreter::execute_instruction()`: +Then add the dispatch case in `Interpreter::dispatch()`: ```cpp -case OpCode::LINEAR: - ops::exec_linear(std::get(instr.node), st, s); +case OpCode::ADDMM: + ops::exec_addmm(std::get(instr.node), st, s); break; ``` @@ -290,34 +303,60 @@ Each test follows a standard pattern: 3. **Decorate with `@register_test`** to register it with the test runner. ```python -class LinearModel(nn.Module): - def __init__(self, in_features=64, out_features=128, bias=True): +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0): super().__init__() - self.linear = nn.Linear(in_features, out_features, bias=bias) + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) @register_test -class LinearTest(OpTestCase): - name = "linear" +class AddmmTest(OpTestCase): + name = "addmm" rtol = 1e-4 atol = 1e-4 - def __init__(self, in_features=64, out_features=128, bias=True): + def __init__(self, batch_size=2, in_features=64, out_features=32, + bias=True, alpha=1.0, beta=1.0): + self.batch_size = batch_size self.in_features = in_features self.out_features = out_features self.bias = bias + self.alpha = alpha + self.beta = beta + self.name = f"addmm_{in_features}x{out_features}" @classmethod def get_test_configs(cls): - return [cls(), cls(bias=False)] + return [ + cls(batch_size=2, in_features=64, out_features=32), + cls(batch_size=2, in_features=64, out_features=32, bias=False), + cls(batch_size=4, in_features=128, out_features=64), + cls(batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5), + ] def create_model(self): - return LinearModel(self.in_features, self.out_features, bias=self.bias) + return AddmmModel( + self.in_features, self.out_features, + bias=self.bias, alpha=self.alpha, beta=self.beta, + ) def create_inputs(self): - return (torch.randn(2, 16, self.in_features),) + return (torch.randn(self.batch_size, self.in_features),) ``` ### Step 6: Run Tests @@ -327,7 +366,7 @@ outputs against PyTorch reference. Since adding a new op always involves C++ changes, use `--rebuild` to recompile the runtime: ```bash -python -m executorch.backends.mlx.test.run_all_tests --rebuild linear +python -m executorch.backends.mlx.test.run_all_tests --rebuild addmm ``` Run all tests in parallel: @@ -356,7 +395,7 @@ architecture, prerequisites, and the `OpTestCase` API. - [ ] Run `python backends/mlx/serialization/generate.py` - [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class) - [ ] Add `exec_*` function in `runtime/MLXInterpreter.h` -- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()` +- [ ] Add `case OpCode::*` in `Interpreter::dispatch()` - [ ] Add test model + `OpTestCase` in `test/test_ops.py` - [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild ` diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index d12743906db..6f6ee11fe41 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -1006,7 +1006,7 @@ def _fbs_type_to_cpp( def _table_name_to_opcode(name: str) -> str: - """Convert table name like 'LinearNode' to opcode like 'LINEAR'. + """Convert table name like 'AddNode' to opcode like 'ADD'. Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small override dict for names whose conventional opcode doesn't follow the From f0b8e71be86745dd84f1a3186ff34a978f423295 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 17:46:59 -0800 Subject: [PATCH 10/34] up --- .github/workflows/mlx.yml | 1 + backends/mlx/ops.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index cc83c90e23e..cf4e887d898 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,6 +9,7 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** + - extension/llm/export/** workflow_dispatch: permissions: {} diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 4c9e0d6f796..6743943f9e2 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -219,12 +219,15 @@ def normalize_reduction_dim( return dim, keepdim +<<<<<<< HEAD @REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: """No-op handler for nodes that don't emit any MLX instructions.""" return None +======= +>>>>>>> e3b488076f (up) @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). From 5493ea107713e95814b8c3d04ad5b105ca230d32 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:03:02 -0800 Subject: [PATCH 11/34] up --- .github/workflows/mlx.yml | 10 +- backends/mlx/custom_ops.py | 256 + backends/mlx/llm/__init__.py | 6 + backends/mlx/llm/cache.py | 429 ++ backends/mlx/ops.py | 3668 +++++++++- backends/mlx/patterns.py | 939 +++ backends/mlx/runtime/MLXInterpreter.h | 1873 ++++- backends/mlx/serialization/schema.fbs | 945 ++- backends/mlx/test/CMakeLists.txt | 20 + .../test/export_multi_thread_test_model.py | 124 + .../mlx/test/multi_thread_test_runner.cpp | 204 + backends/mlx/test/test_ops.py | 6031 ++++++++++++++++- 12 files changed, 14462 insertions(+), 43 deletions(-) create mode 100644 backends/mlx/llm/__init__.py create mode 100644 backends/mlx/llm/cache.py create mode 100644 backends/mlx/test/export_multi_thread_test_model.py create mode 100644 backends/mlx/test/multi_thread_test_runner.cpp diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index cf4e887d898..53c7b9360cd 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -37,7 +37,7 @@ jobs: ${CONDA_RUN} pip list echo "::group::Build test runners" - ${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) echo "::endgroup::" echo "::group::Run op unit tests" @@ -52,6 +52,14 @@ jobs: -v echo "::endgroup::" + echo "::group::Run multi-thread stress test" + ${CONDA_RUN} python backends/mlx/test/export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte + ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \ + ET_TESTING_NUM_THREADS=50 \ + ET_PREDICTIONS_PER_THREAD=100 \ + ./cmake-out/backends/mlx/test/multi_thread_test_runner + echo "::endgroup::" + backend-tester: strategy: fail-fast: false diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index 81853adbd6d..8ad891e3568 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -13,3 +13,259 @@ These ops are used during model export to represent operations that MLX can execute efficiently but may not have direct PyTorch equivalents. """ + +from typing import Optional + +import torch +from torch import Tensor + + +@torch.library.custom_op("mlx::kv_cache_update", mutates_args=("cache",)) +def kv_cache_update( + cache: Tensor, # [B, H, S_max, D] - mutated in place + new_values: Tensor, # [B, H, S, D] + start_pos: int, + ring_size: int = 0, +) -> Tensor: + """ + Mutating KV cache update that modifies cache in place. + + This op updates the cache at positions [start_pos, start_pos + S) with + new_values. The cache is mutated in place, similar to llama.update_cache. + + Args: + cache: Cache tensor of shape [B, H, S_max, D] (BHSD layout) - mutated + new_values: New values to insert of shape [B, H, S, D] + start_pos: Starting position index for insertion + ring_size: If > 0, treat as ring buffer of this size: write position + is start_pos % ring_size and writes wrap around. If 0 (default), + linear update at start_pos with no wrapping. + + Returns: + A dummy tensor (1,) - the return value is not semantically meaningful + but is required for slot management during export. This follows the + same pattern as llama.update_cache. + + Note: + The BHSD layout matches what torch SDPA expects, avoiding transposition. + """ + seq_len = new_values.size(2) + + if ring_size > 0: + write_pos = start_pos % ring_size + end_pos = write_pos + seq_len + if end_pos <= ring_size: + cache[:, :, write_pos:end_pos, :] = new_values + else: + first_part = ring_size - write_pos + cache[:, :, write_pos:ring_size, :] = new_values[:, :, :first_part, :] + cache[:, :, 0 : seq_len - first_part, :] = new_values[:, :, first_part:, :] + else: + end_pos = start_pos + seq_len + assert end_pos <= cache.size(2), ( + f"kv_cache_update: write [{start_pos}, {end_pos}) exceeds " + f"cache size {cache.size(2)}. Use ring_size > 0 for wrapping." + ) + cache[:, :, start_pos:end_pos, :] = new_values + + return torch.empty((1,), dtype=new_values.dtype, device=new_values.device) + + +@torch.library.register_fake("mlx::kv_cache_update") +def kv_cache_update_fake( + cache: Tensor, + new_values: Tensor, + start_pos: int, + ring_size: int = 0, +) -> Tensor: + """Fake implementation for tracing - returns dummy tensor like llama.update_cache.""" + return torch.empty((1,), dtype=new_values.dtype, device="meta") + + +@torch.library.custom_op("mlx::custom_sdpa", mutates_args=()) +def mlx_custom_sdpa( + query: Tensor, # [B, num_heads, seq_len, head_dim] - BHSD + key: Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (FULL cache) + value: Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (FULL cache) + start_pos: int, # FIRST position in current batch (0-indexed) + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tensor: + """ + MLX custom SDPA with K/V cache slicing. + + This op uses BHSD layout (matching PyTorch SDPA and MLX's SdpaNode). + It receives the FULL K/V cache and slices to [0:stop_pos] before computing + attention, where stop_pos = start_pos + query_seq_len. + + The semantics follow executorch's llama.custom_sdpa: + - start_pos: FIRST position of the current query batch + - For prefill with 7 tokens at positions [0,1,2,3,4,5,6]: start_pos=0, stop_pos=7 + - For decode at position 10: start_pos=10, stop_pos=11 + + Args: + query: Query tensor [B, num_heads, seq_len, head_dim] + key: Key cache [B, num_kv_heads, kv_len, head_dim] - FULL cache + value: Value cache [B, num_kv_heads, kv_len, head_dim] - FULL cache + start_pos: FIRST position in current batch (SymInt) + attn_mask: Optional attention mask (only used when is_causal=False) + dropout_p: Dropout probability (default 0.0) + is_causal: Whether to apply causal masking (default False) + scale: Attention scale factor (default 1/sqrt(head_dim)) + + Returns: + Attention output [B, num_heads, seq_len, head_dim] - BHSD + """ + if scale is None: + scale = query.shape[-1] ** -0.5 + + # Compute stop_pos = start_pos + query_seq_len + # BHSD layout: seq_len is at dim 2 + query_seq_len = query.shape[2] + stop_pos = start_pos + query_seq_len + + # Constrain symbolic shapes so torch.export can resolve guards. + # start_pos is data-dependent (from input_pos), so the slice + # stop_pos > kv_len comparison is unresolvable without these hints. + torch._check(start_pos >= 0) + torch._check(stop_pos <= key.shape[2]) + + # Slice K/V to valid cache entries [0:stop_pos] + key_sliced = key[:, :, :stop_pos, :] + value_sliced = value[:, :, :stop_pos, :] + + # Handle GQA: expand K/V heads to match query heads + num_heads = query.shape[1] + num_kv_heads = key.shape[1] + if num_kv_heads != num_heads: + num_groups = num_heads // num_kv_heads + key_sliced = key_sliced.repeat_interleave(num_groups, dim=1) + value_sliced = value_sliced.repeat_interleave(num_groups, dim=1) + + # Build explicit lower-right aligned causal mask to match MLX's SdpaNode. + # PyTorch's is_causal=True uses upper-left alignment when Q_len != K_len, + # but for KV-cache inference q[i] is at context position (start_pos + i) + # and should attend to all positions 0..start_pos+i (lower-right). + if is_causal: + L, S = query.shape[2], key_sliced.shape[2] + offset = S - L # equals start_pos + mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril( + diagonal=offset + ) + attn_mask = torch.where(mask, 0.0, float("-inf")).to(query.dtype) + + # Compute SDPA - returns BHSD + return torch.nn.functional.scaled_dot_product_attention( + query, + key_sliced, + value_sliced, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=scale, + ) + + +@torch.library.register_fake("mlx::custom_sdpa") +def mlx_custom_sdpa_fake( + query: Tensor, + key: Tensor, + value: Tensor, + start_pos: int, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tensor: + """Fake implementation for tracing - returns BHSD shape (same as query).""" + return query.new_empty(query.shape) + + +@torch.library.custom_op("mlx::rope", mutates_args=()) +def rope( + x: Tensor, # (B, H, T, D) + dims: int, + pos: int, # int, not tensor + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + freqs: Optional[Tensor] = None, +) -> Tensor: + """ + Apply Rotary Position Embedding to a single tensor. + + Args: + x: Input tensor of shape (B, H, T, D) + dims: Number of feature dimensions to rotate. If less than D, + only the first `dims` dimensions are rotated and the rest + are left unchanged. + pos: Starting position index (int, not tensor) + traditional: Whether to use traditional RoPE formulation + base: Base for frequency computation + scale: Scale factor for frequencies + freqs: Optional precomputed frequencies + + Returns: + Rotated tensor of the same shape + """ + Dh = int(dims) + + B, H, T, _ = x.shape + half = Dh // 2 + + if freqs is None: + # [1, 1, 1, half] to broadcast over B,H,T + i = torch.arange(half, device=x.device, dtype=torch.float32) + inv_freq = (base ** (-2.0 * i / Dh)).view(1, 1, 1, half) + + # positions: [1, 1, T, 1] + pos_range = torch.arange( + pos, pos + T, device=x.device, dtype=torch.float32 + ).view(1, 1, T, 1) + + # final angles: [1, 1, T, half] + angles = (pos_range * inv_freq) * float(scale) + else: + # assume freqs is already per-position, just reshape to [1,1,T,half] + angles = freqs.to(torch.float32).view(1, 1, T, half) + + cos = angles.cos().to(x.dtype) # [1,1,T,half] + sin = angles.sin().to(x.dtype) # [1,1,T,half] + + # Split into rotated and unrotated portions + x_rot = x[..., :Dh] + x_pass = x[..., Dh:] + + if traditional: + # Interleaved pairs: (x[0],x[1]), (x[2],x[3]), ... + x1 = x_rot[..., 0::2] # even indices + x2 = x_rot[..., 1::2] # odd indices + xr = x1 * cos - x2 * sin + xi = x1 * sin + x2 * cos + rotated = torch.stack([xr, xi], dim=-1).flatten(-2) + else: + # Split-half: first half paired with second half + x1, x2 = x_rot[..., :half], x_rot[..., half:] + xr = x1 * cos - x2 * sin + xi = x1 * sin + x2 * cos + rotated = torch.cat([xr, xi], dim=-1) + + if x_pass.shape[-1] > 0: + return torch.cat([rotated, x_pass], dim=-1) + return rotated + + +@torch.library.register_fake("mlx::rope") +def rope_fake( + x: Tensor, + dims: int, + pos: int, + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + freqs: Optional[Tensor] = None, +) -> Tensor: + """Fake implementation for tracing.""" + return x.new_empty(x.shape) diff --git a/backends/mlx/llm/__init__.py b/backends/mlx/llm/__init__.py new file mode 100644 index 00000000000..f557ef26c5b --- /dev/null +++ b/backends/mlx/llm/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py new file mode 100644 index 00000000000..9709980689b --- /dev/null +++ b/backends/mlx/llm/cache.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared KV cache utilities for MLX delegate examples. + +Provides reusable KV cache implementations optimized for the MLX backend: +""" + +from typing import Tuple + +import torch +import torch.nn as nn + +# Import MLX custom ops to register mlx::kv_cache_update +from executorch.backends.mlx import custom_ops as _mlx_custom_ops # noqa: F401 + + +class KVCache(nn.Module): + """ + MLX-optimized KV cache with ExecutorTorch llama KVCache interface. + + This class follows the same interface as examples/models/llama/attention.py KVCache, + making it a drop-in replacement, but uses the mlx::kv_cache_update op internally + which is optimized for the MLX delegate. + + The cache uses BHSD layout [B, H, S, D] which matches what torch SDPA expects. + + The ``update`` method accepts ``input_pos`` as either a ``torch.Tensor`` or a + plain ``int`` / SymInt. When a tensor is passed, ``item()`` is called internally + to extract the start position, which introduces an unbacked SymInt during + ``torch.export``. Extracting a SymInt has a cost because it creates a new + symbolic variable and associated constraints in the exported program. In a + multi-layer model, prefer extracting the SymInt once and passing the resulting + int/SymInt to every layer's ``update`` call rather than passing the tensor + repeatedly: + + .. code-block:: python + + # Preferred: extract once, pass to all layers + start_pos = input_pos[0].item() + for layer_cache in caches: + layer_cache.update(start_pos, k_val, v_val) + + # Avoid: each layer re-extracts from the tensor + for layer_cache in caches: + layer_cache.update(input_pos, k_val, v_val) + + Example: + >>> cache = KVCache( + ... max_batch_size=1, + ... max_context_length=4096, + ... n_heads=32, + ... head_dim=128, + ... enable_dynamic_shape=True, + ... ) + >>> # With tensor input_pos + >>> input_pos = torch.tensor([0]) + >>> k_val = torch.randn(1, 32, 10, 128) # [B, H, S, D] + >>> v_val = torch.randn(1, 32, 10, 128) # [B, H, S, D] + >>> k_cache, v_cache = cache.update(input_pos, k_val, v_val) + >>> + >>> # With int/SymInt input_pos (preferred in multi-layer loops) + >>> start_pos = input_pos[0].item() + >>> k_cache, v_cache = cache.update(start_pos, k_val, v_val) + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype: torch.dtype = torch.float32, + ): + """ + Initialize KV cache buffers. + + Args: + max_batch_size: Maximum batch size + max_context_length: Maximum sequence length the cache can hold + n_heads: Number of attention heads (key/value heads for GQA) + head_dim: Dimension per head + enable_dynamic_shape: Whether dynamic shapes are enabled (kept for interface + compatibility, but MLX always uses dynamic-style update) + dtype: Data type for cache buffers + """ + super().__init__() + assert ( + max_batch_size == 1 + ), f"Only max_batch_size=1 is supported, but got {max_batch_size}" + self.max_batch_size = max_batch_size + self.max_context_length = max_context_length + self.n_heads = n_heads + self.head_dim = head_dim + self.enable_dynamic_shape = enable_dynamic_shape + + # Initialize cache buffers [B, H, T_max, D] - BHSD layout + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor | int, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update cache with new K/V states and return FULL cache. + + This method follows the same signature as examples/models/llama/attention.py KVCache. + + Args: + input_pos: Start position — either a position tensor [S] or an int/SymInt + k_val: New key states [B, H, S, D] + v_val: New value states [B, H, S, D] + + Returns: + Tuple of (k_cache, v_cache) - slices of the FULL cache buffers + """ + + if isinstance(input_pos, torch.Tensor): + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(seq_len == v_val.size(2)) + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= self.max_context_length) + else: + start_pos = input_pos + + torch.ops.mlx.kv_cache_update(self.k_cache, k_val, start_pos) + torch.ops.mlx.kv_cache_update(self.v_cache, v_val, start_pos) + + # Return full slices of the cache (creates new tensor nodes in the graph) + # This avoids the issue where the same tensor is both BUFFER_MUTATION and USER_OUTPUT + return self.k_cache[:, :, :, :], self.v_cache[:, :, :, :] + + +class RingBufferKVCache(nn.Module): + """ + Ring buffer KV cache for sliding window attention. + + Instead of a linear cache that fills up and stops, this cache wraps around: + write_pos = start_pos % window_size. When the cache is full, new tokens + overwrite the oldest ones, enabling infinite-length generation. + + The attention mask is computed branchlessly from ``start_pos`` and + ``window_size`` alone using ``torch.where`` — no mutable position-tracking + buffers and no Python if/else that would create torch.export guards. + + Mask creation is NOT done here — following optimum-executorch's pattern, + the attention function creates the mask lazily by accessing the cache + via a closure. This avoids tracing issues with torch.export. + + Layout: BHSD [batch_size, num_heads, window_size, head_dim] + + Example: + >>> cache = RingBufferKVCache( + ... max_batch_size=1, + ... max_context_length=512, + ... n_heads=4, + ... head_dim=256, + ... dtype=torch.bfloat16, + ... ) + >>> k_val = torch.randn(1, 4, 1, 256) + >>> v_val = torch.randn(1, 4, 1, 256) + >>> k_cache, v_cache = cache.update(start_pos=0, k_val=k_val, v_val=v_val) + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + assert ( + max_batch_size == 1 + ), f"Only max_batch_size=1 is supported, but got {max_batch_size}" + self.max_batch_size = max_batch_size + self.max_context_length = max_context_length + self.window_size = max_context_length + self.buffer_size = 2 * max_context_length + self.n_heads = n_heads + self.head_dim = head_dim + + # Cache buffers [B, H, 2*window_size, D] + # 2× buffer ensures multi-token writes never overwrite data that + # earlier queries in the same batch still need (matches ET behavior). + cache_shape = (max_batch_size, n_heads, self.buffer_size, head_dim) + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor | int, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update cache with new K/V states using ring buffer semantics. + + Args: + input_pos: Start position — either a position tensor [S] or an int/SymInt + k_val: New key states [B, H, S, D] + v_val: New value states [B, H, S, D] + + Returns: + Tuple of (k_cache, v_cache) — full ring buffer slices + """ + if isinstance(input_pos, torch.Tensor): + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(seq_len == v_val.size(2)) + torch._check(start_pos >= 0) + torch._check(seq_len <= self.window_size) + else: + start_pos = input_pos + + torch.ops.mlx.kv_cache_update( + self.k_cache, k_val, start_pos, ring_size=self.buffer_size + ) + torch.ops.mlx.kv_cache_update( + self.v_cache, v_val, start_pos, ring_size=self.buffer_size + ) + + return self.k_cache[:, :, :, :], self.v_cache[:, :, :, :] + + def create_sliding_window_mask(self, start_pos: int, seq_len: int) -> torch.Tensor: + """ + Build attention mask for the ring buffer — branchless, no mutable state. + + Reconstructs the slot→position mapping from ``start_pos`` and + ``buffer_size`` alone using ``torch.where``, avoiding both Python + if/else (which creates torch.export guards) and mutable position- + tracking buffers (which require extra kv_cache_update calls and + complicate partitioning). + + Returns: + Additive mask [1, 1, seq_len, buffer_size] in the cache's dtype, + where 0 = attend, -inf = block. + """ + w = self.window_size + b = self.buffer_size + end_pos = start_pos + seq_len + + # Slot indices [buffer_size] + slots = torch.arange(b, dtype=torch.long) + + last_write_slot = (end_pos - 1) % b + current_cycle_base = end_pos - 1 - last_write_slot + pos_current = current_cycle_base + slots + pos_previous = current_cycle_base - b + slots + + cache_pos = torch.where(slots <= last_write_slot, pos_current, pos_previous) + + # Query positions [seq_len, 1] + pos_q = (start_pos + torch.arange(seq_len, dtype=torch.long)).view(-1, 1) + + # Delta from query to each cached position [seq_len, buffer_size] + delta = pos_q - cache_pos + + # A slot is attendable if: filled (pos >= 0), causal (delta >= 0), + # and within the sliding window (delta < w) + attn_mask = (cache_pos >= 0) & (delta >= 0) & (delta < w) + + # Use cache dtype (e.g. bf16) to avoid float32 AsTypeNode casts in SDPA + dtype = self.k_cache.dtype + zero = torch.zeros(1, dtype=dtype) + neg_inf = torch.full((1,), float("-inf"), dtype=dtype) + return torch.where(attn_mask, zero, neg_inf).unsqueeze(0).unsqueeze(0) + + +from transformers.cache_utils import StaticCache + + +class HFStaticCache(StaticCache): + """ + MLX-optimized Static KV Cache that follows HuggingFace's StaticCache interface. + + This cache is designed to be a drop-in replacement for HuggingFace's StaticCache + when exporting models for the MLX backend. It uses mlx::kv_cache_update internally + which is optimized for the MLX delegate. + + The cache supports multi-layer models by maintaining separate K/V buffers per layer, + matching the HF StaticCache behavior where `update()` takes a `layer_idx` argument. + + Layout: BHSD [batch_size, num_heads, max_cache_len, head_dim] + + Example: + >>> from transformers import AutoConfig + >>> config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> cache = HFStaticCache(config, max_batch_size=1, max_cache_len=4096) + >>> # In attention layer: + >>> k_out, v_out = cache.update(k_states, v_states, layer_idx=0, + ... cache_kwargs={"cache_position": pos_tensor}) + """ + + def __init__( + self, + config, + max_batch_size: int = 1, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + ): + """ + Initialize MLX Static Cache. + + Args: + config: HuggingFace model config with num_hidden_layers, num_key_value_heads, + num_attention_heads, hidden_size, and optionally head_dim + max_batch_size: Maximum batch size (default: 1) + max_cache_len: Maximum cache length. If None, uses config.max_position_embeddings + device: Device for cache tensors (default: None = CPU) + dtype: Data type for cache tensors (default: torch.float32) + """ + # Resolve dimensions from config BEFORE calling parent + num_layers = config.num_hidden_layers + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + actual_max_cache_len = max_cache_len or getattr( + config, "max_position_embeddings", 2048 + ) + + # Initialize parent StaticCache with required arguments + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=actual_max_cache_len, + device=device, + dtype=dtype, + ) + # Call early_initialization to ensure parent's layers are fully initialized + self.early_initialization( + batch_size=max_batch_size, + num_heads=num_heads, + head_dim=head_dim, + dtype=dtype, + device=device, + ) + + # Store dimensions as instance attributes + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + + # Create KVCache wrappers for each layer - these use mlx::kv_cache_update + # Named 'kv_cache' to match optimum-executorch's ETCustomStaticCache pattern + self.kv_cache = nn.ModuleList( + [ + KVCache( + max_batch_size=max_batch_size, + max_context_length=actual_max_cache_len, + n_heads=num_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + # Move to device if specified + if device is not None: + self.to(device) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update the cache with new key/value states for a specific layer. + + This method follows HuggingFace's StaticCache.update() signature. + + Args: + key_states: New key states [batch_size, num_heads, seq_len, head_dim] + value_states: New value states [batch_size, num_heads, seq_len, head_dim] + layer_idx: Index of the layer to update + cache_kwargs: Dictionary containing 'cache_position' tensor with start position + + Returns: + Tuple of (key_cache, value_cache) for the full cache after update + """ + assert ( + cache_kwargs is not None + ), "cache_kwargs must be provided with 'cache_position'" + cache_position = cache_kwargs.get("cache_position") + assert ( + cache_position is not None + ), "cache_position must be provided in cache_kwargs" + assert isinstance( + cache_position, torch.Tensor + ), "cache_position must be a tensor" + + # Pass cache_position tensor directly to KVCache.update() + # KVCache extracts start_pos internally via input_pos[0].item() + return self.kv_cache[layer_idx].update(cache_position, key_states, value_states) + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Approximate sequence length (counts non-zero cache positions).""" + k_cache = self.kv_cache[layer_idx].k_cache + # Check if any value in the head_dim is non-zero for each position + return (k_cache[0, 0, :, 0] != 0).sum().item() + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + return self.max_cache_len + + def reset(self): + for layer_cache in self.kv_cache: + layer_cache.k_cache.zero_() + layer_cache.v_cache.zero_() diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 6743943f9e2..0d5f21ebd7d 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -19,10 +19,141 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_lifted_constant, + parse_dequant_node, + to_mlx_qparams, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode +from executorch.backends.mlx.builder.slot_manager import IdType, Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AbsNode, + AddIntNode, + AddmmNode, + AddNode, + AllNode, + AnyNode, + ARangeNode, + ArccoshNode, + ArccosNode, + ArcsinhNode, + ArcsinNode, + ArctanhNode, + ArctanNode, + ArgmaxNode, + ArgminNode, + ArgPartitionNode, + ArgsortNode, + AsStridedNode, + AsTypeNode, + Atan2Node, + BroadcastToNode, + CeilNode, + ClipNode, + ConcatenateNode, + ContiguousNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + CoshNode, + CosNode, + CumsumNode, + DequantizeNode, + DivideNode, + EqualNode, + ErfNode, + ExpandDimsNode, + Expm1Node, + ExpNode, + FloatOrVid, + FloorDivideIntNode, + FloorDivideNode, + FloorNode, + FullLikeNode, + FullNode, + GatherNode, + GeluNode, + GreaterEqualNode, + GreaterNode, + IdCopyNode, + IntOrVid, + IntOrVidOrTid, + ItemIntNode, + LayerNormNode, + LessEqualNode, + LessNode, + Log10Node, + Log1pNode, + Log2Node, + LogAddExpNode, + LogicalAndNode, + LogicalNotNode, + LogicalOrNode, + LogNode, + LogSumExpNode, + MaximumNode, + MaxNode, + MeanNode, + MinimumNode, + MinNode, + ModIntNode, + MultiplyIntNode, + MultiplyNode, + NegNode, + NotEqualNode, + PadNode, + PartitionNode, + PowerNode, + ProdNode, + ReciprocalNode, + RemainderNode, + RepeatNode, + ReshapeNode, + RMSNormNode, + RopeNode, + RoundNode, + RsqrtNode, + SigmoidNode, + SignNode, + SiluNode, + SinhNode, + SinNode, + SliceNode, + SliceUpdateNode, + SoftmaxNode, + SortNode, + SplitNode, + SqrtNode, + SquareNode, + SqueezeNode, + StackNode, + StdNode, + SubtractIntNode, + SubtractNode, + SumNode, + SymSizeNode, + TakeAlongAxisNode, + TakeNode, + TanhNode, + TanNode, + TileNode, + TransposeNode, + TrilNode, + TriuNode, + VarNode, + VidOrTid, + WhereNode, +) + +# The coding style is for handlers to register against aten targets +# The corresponding edge ops are automatically registered +# For ops that are not in aten (e.g., dim order ops), directly register on exir_ops +from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.node import Node @@ -219,15 +350,440 @@ def normalize_reduction_dim( return dim, keepdim -<<<<<<< HEAD +_UNARY_OPS: List[Tuple[Any, Any, str]] = [ + # Activations + (torch.ops.aten.silu.default, SiluNode, "aten.silu"), + (torch.ops.aten.sigmoid.default, SigmoidNode, "aten.sigmoid"), + (torch.ops.aten.tanh.default, TanhNode, "aten.tanh"), + # Reciprocal square root + (torch.ops.aten.rsqrt.default, RsqrtNode, "aten.rsqrt"), + # Rounding + (torch.ops.aten.floor.default, FloorNode, "aten.floor"), + (torch.ops.aten.ceil.default, CeilNode, "aten.ceil"), + # Powers / roots + (torch.ops.aten.square.default, SquareNode, "aten.square"), + (torch.ops.aten.exp.default, ExpNode, "aten.exp"), + (torch.ops.aten.sqrt.default, SqrtNode, "aten.sqrt"), + (torch.ops.aten.reciprocal.default, ReciprocalNode, "aten.reciprocal"), + # Trigonometric + (torch.ops.aten.sin.default, SinNode, "aten.sin"), + (torch.ops.aten.cos.default, CosNode, "aten.cos"), + (torch.ops.aten.tan.default, TanNode, "aten.tan"), + (torch.ops.aten.asin.default, ArcsinNode, "aten.asin"), + (torch.ops.aten.acos.default, ArccosNode, "aten.acos"), + (torch.ops.aten.atan.default, ArctanNode, "aten.atan"), + # Hyperbolic + (torch.ops.aten.sinh.default, SinhNode, "aten.sinh"), + (torch.ops.aten.cosh.default, CoshNode, "aten.cosh"), + (torch.ops.aten.asinh.default, ArcsinhNode, "aten.asinh"), + (torch.ops.aten.acosh.default, ArccoshNode, "aten.acosh"), + (torch.ops.aten.atanh.default, ArctanhNode, "aten.atanh"), + # Logarithmic + (torch.ops.aten.log.default, LogNode, "aten.log"), + (torch.ops.aten.log2.default, Log2Node, "aten.log2"), + (torch.ops.aten.log10.default, Log10Node, "aten.log10"), + (torch.ops.aten.log1p.default, Log1pNode, "aten.log1p"), + # Special + (torch.ops.aten.erf.default, ErfNode, "aten.erf"), + (torch.ops.aten.expm1.default, Expm1Node, "aten.expm1"), + # Sign / magnitude + (torch.ops.aten.abs.default, AbsNode, "aten.abs"), + (torch.ops.aten.neg.default, NegNode, "aten.neg"), + (torch.ops.aten.sign.default, SignNode, "aten.sign"), + # Logical + (torch.ops.aten.logical_not.default, LogicalNotNode, "aten.logical_not"), +] + + +def _make_unary_handler(node_cls: Any, op_name: str): + """Create a handler for a simple unary op: x → node_cls(x, out).""" + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 1, 1, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + x = args[0] + out = P.make_or_get_slot(n) + P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(out))) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven unary op)." + return handler + + +for _target, _node_cls, _op_name in _UNARY_OPS: + REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name)) + + +_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [ + ( + [torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar], + MultiplyNode, + "aten.mul", + True, + ), + ( + [torch.ops.aten.div.Tensor, torch.ops.aten.div.Scalar], + DivideNode, + "aten.div", + True, + ), + ( + [torch.ops.aten.remainder.Tensor, torch.ops.aten.remainder.Scalar], + RemainderNode, + "aten.remainder", + True, + ), + ( + [torch.ops.aten.pow.Tensor_Tensor, torch.ops.aten.pow.Tensor_Scalar], + PowerNode, + "aten.pow", + True, + ), + ( + [torch.ops.aten.floor_divide.default], + FloorDivideNode, + "aten.floor_divide", + False, + ), + ([torch.ops.aten.maximum.default], MaximumNode, "aten.maximum", False), + ([torch.ops.aten.minimum.default], MinimumNode, "aten.minimum", False), + ([torch.ops.aten.atan2.default], Atan2Node, "aten.atan2", False), + ([torch.ops.aten.logaddexp.default], LogAddExpNode, "aten.logaddexp", False), + ([torch.ops.aten.logical_or.default], LogicalOrNode, "aten.logical_or", False), + ( + [torch.ops.aten.lt.Tensor, torch.ops.aten.lt.Scalar], + LessNode, + "aten.lt", + True, + ), + ( + [torch.ops.aten.le.Tensor, torch.ops.aten.le.Scalar], + LessEqualNode, + "aten.le", + True, + ), + ( + [torch.ops.aten.gt.Tensor, torch.ops.aten.gt.Scalar], + GreaterNode, + "aten.gt", + True, + ), + ( + [torch.ops.aten.ge.Tensor, torch.ops.aten.ge.Scalar], + GreaterEqualNode, + "aten.ge", + True, + ), + ( + [torch.ops.aten.eq.Tensor, torch.ops.aten.eq.Scalar], + EqualNode, + "aten.eq", + True, + ), + ( + [torch.ops.aten.ne.Tensor, torch.ops.aten.ne.Scalar], + NotEqualNode, + "aten.ne", + True, + ), +] + + +def _make_binary_handler(node_cls: Any, op_name: str, lift_b: bool): + """Create a handler for a binary op: (a, b) -> node_cls(a, b, out). + + When lift_b is True, scalar b values are lifted to 0-D constant tensors + via emit_lifted_constant, using a's dtype. + """ + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + a, b = args[0], args[1] + if lift_b and (not isinstance(b, Slot) or b.id_type != IdType.Tensor): + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + b = emit_lifted_constant(P, b, dtype) + out = P.make_or_get_slot(n) + P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out))) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven binary op)." + return handler + + +for _targets, _node_cls, _op_name, _lift_b in _BINARY_OPS: + REGISTRY.register(target=_targets)( + _make_binary_handler(_node_cls, _op_name, _lift_b) + ) + + +_SCALAR_INT_OPS: List[Tuple[Any, Any, str]] = [ + (operator.add, AddIntNode, "operator.add"), + (operator.sub, SubtractIntNode, "operator.sub"), + (operator.mul, MultiplyIntNode, "operator.mul"), + (operator.floordiv, FloorDivideIntNode, "operator.floordiv"), + (operator.mod, ModIntNode, "operator.mod"), +] + + +def _make_scalar_int_handler(node_cls: Any, op_name: str): + """Create a handler for a scalar int op: (a, b) -> node_cls(a, b, out).""" + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + a, b = args + out = P.make_or_get_slot(n) + P.emit( + node_cls( + a=P.to_int_or_vid(a), + b=P.to_int_or_vid(b), + out=P.slot_to_vid(out), + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven scalar int op)." + return handler + + +for _target, _node_cls, _op_name in _SCALAR_INT_OPS: + REGISTRY.register(target=[_target])(_make_scalar_int_handler(_node_cls, _op_name)) + + +_REDUCTION_OPS: List[Tuple[List[Any], Any, str, int]] = [ + ( + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.sum.default], + SumNode, + "aten.sum", + 4, + ), + ([torch.ops.aten.mean.dim, torch.ops.aten.mean.default], MeanNode, "aten.mean", 4), + ( + [torch.ops.aten.prod.dim_int, torch.ops.aten.prod.default], + ProdNode, + "aten.prod", + 4, + ), + ([torch.ops.aten.amax.default], MaxNode, "aten.amax", 3), + ([torch.ops.aten.amin.default], MinNode, "aten.amin", 3), + ([torch.ops.aten.any.dim, torch.ops.aten.any.default], AnyNode, "aten.any", 3), + ([torch.ops.aten.all.dim, torch.ops.aten.all.default], AllNode, "aten.all", 3), +] + + +def _make_reduction_handler(node_cls: Any, op_name: str, max_args: int): + """Create a handler for a reduction op: x -> node_cls(x, out, axes, keepdims).""" + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 1, max_args, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + x = args[0] + axes, keepdim = normalize_reduction_dim(args) + out = P.make_or_get_slot(n) + P.emit( + node_cls( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=axes, keepdims=keepdim + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven reduction op)." + return handler + + +for _targets, _node_cls, _op_name, _max_args in _REDUCTION_OPS: + REGISTRY.register(target=_targets)( + _make_reduction_handler(_node_cls, _op_name, _max_args) + ) + + +_FULL_OPS: List[Tuple[List[Any], str, Optional[float]]] = [ + ([torch.ops.aten.full.default], "aten.full", None), + ([torch.ops.aten.zeros.default], "aten.zeros", 0.0), + ([torch.ops.aten.ones.default], "aten.ones", 1.0), +] + + +def _make_full_handler(op_name: str, fixed_fill: Optional[float]): + """Create a handler for full/zeros/ones: shape -> FullNode(shape, v, dtype).""" + + has_fill_arg = fixed_fill is None + n_args = 2 if has_fill_arg else 1 + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, n_args, n_args, op_name) + kwargs = P.kwargs(n) + require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, op_name) + require_contiguous_format(layout=kwargs.get("layout"), op_name=op_name) + + shape = args[0] + shape_iovs = [P.to_int_or_vid(d) for d in shape] + v = ( + P.to_float_or_vid(args[1]) + if has_fill_arg + else FloatOrVid.from_literal(fixed_fill) + ) + dtype = n.kwargs.get("dtype") + if dtype is None: + dtype = torch.float32 + + out = P.make_or_get_slot(n) + P.emit( + FullNode( + out=P.slot_to_tid(out), + shape=shape_iovs, + v=v, + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven full op)." + return handler + + +for _targets, _op_name, _fixed_fill in _FULL_OPS: + REGISTRY.register(target=_targets)(_make_full_handler(_op_name, _fixed_fill)) + + +_FULL_LIKE_OPS: List[Tuple[List[Any], str, Optional[float]]] = [ + ([torch.ops.aten.full_like.default], "aten.full_like", None), + ([torch.ops.aten.zeros_like.default], "aten.zeros_like", 0.0), + ([torch.ops.aten.ones_like.default], "aten.ones_like", 1.0), +] + + +def _make_full_like_handler(op_name: str, fixed_fill: Optional[float]): + """Create a handler for full_like/zeros_like/ones_like: x -> FullLikeNode(x, v, dtype).""" + + has_fill_arg = fixed_fill is None + n_args = 2 if has_fill_arg else 1 + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, n_args, n_args, op_name) + kwargs = P.kwargs(n) + require_kwargs( + kwargs, + {"dtype", "layout", "device", "pin_memory", "memory_format"}, + op_name, + ) + require_contiguous_format( + layout=kwargs.get("layout"), + memory_format=kwargs.get("memory_format"), + op_name=op_name, + ) + + x = args[0] + v = ( + P.to_float_or_vid(args[1]) + if has_fill_arg + else FloatOrVid.from_literal(fixed_fill) + ) + dtype = n.kwargs.get("dtype") + + out = P.make_or_get_slot(n) + P.emit( + FullLikeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + v=v, + scalar_type=( + torch_dtype_to_scalar_type(dtype) if dtype is not None else None + ), + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven full_like op)." + return handler + + +for _targets, _op_name, _fixed_fill in _FULL_LIKE_OPS: + REGISTRY.register(target=_targets)(_make_full_like_handler(_op_name, _fixed_fill)) + + +@REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) +def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: + """No-op handler for nodes that don't emit any MLX instructions.""" + return None + + +# Handler for auto_functionalized_v2 higher-order op +# This handles mutating ops that have been functionalized +@REGISTRY.register(target=[torch.ops.higher_order.auto_functionalized_v2]) +def _auto_functionalized_v2_handler(P: MLXProgramBuilder, n: Node): + """ + Handler for auto_functionalized_v2 higher-order op. + + auto_functionalized_v2 wraps mutating ops after functionalization. + It returns a tuple of (token, mutated_values...). + + This handler emits the actual lowering instructions and returns a tuple + of slots that getitem can index into. + """ + if len(n.args) < 1: + raise ValueError( + f"auto_functionalized_v2 requires at least 1 arg, got {len(n.args)}" + ) + + wrapped_op = n.args[0] + + # Unknown wrapped op - not supported + raise NotImplementedError( + f"auto_functionalized_v2 wrapping '{wrapped_op}' is not supported." + ) + + +@REGISTRY.register(target=[torch.ops.aten.linear.default]) +def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.linear") + require_kwargs(P.kwargs(n), set(), "aten.linear") + x, w = args[0], args[1] + b = args[2] if len(args) > 2 else None + out = P.make_or_get_slot(n) + + # Transpose weight: linear(x, w) = x @ w.T + _, w_t = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(w), + out=P.slot_to_tid(w_t), + perm=[1, 0], + ) + ) + + P.emit( + AddmmNode( + mat1=P.slot_to_tid(x), + mat2=P.slot_to_tid(w_t), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(b) if b else None, + ) + ) + return out + + @REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: """No-op handler for nodes that don't emit any MLX instructions.""" return None -======= ->>>>>>> e3b488076f (up) @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). @@ -301,3 +857,3105 @@ def _mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) ) return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.view.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, + ] +) +def _view_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.view") + require_kwargs(P.kwargs(n), set(), "aten.view") + x, shape = args + out = P.make_or_get_slot(n) + + shape_iovs = [P.to_int_or_vid(s) for s in shape] + P.emit( + ReshapeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + shape=shape_iovs, + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.clone.default, + torch.ops.aten.alias.default, + torch.ops.aten.alias_copy.default, + ] +) +def _clone_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten.clone") + require_kwargs(kwargs, {"memory_format"}, "aten.clone") + require_contiguous_format( + memory_format=kwargs.get("memory_format"), + op_name="aten.clone", + ) + (x,) = args + out = P.make_or_get_slot(n) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.copy.default]) +def _copy_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.copy - copy data from src to self. + + Schema: aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + In functionalized Edge IR, this returns a copy of src (args[1]). + """ + args = P.args(n) + require_args(args, 2, 2, "aten.copy") + require_kwargs(P.kwargs(n), {"non_blocking"}, "aten.copy") + src = args[1] + out = P.make_or_get_slot(n) + P.emit( + ContiguousNode( + x=P.slot_to_tid(src), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[exir_ops.edge.dim_order_ops._clone_dim_order.default]) +def _dim_order_clone_handler(P: MLXProgramBuilder, n: Node) -> Slot: + # dim_order_ops._clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor + # This is essentially a contiguous/clone operation for memory layout + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "dim_order_ops._clone_dim_order") + require_kwargs( + kwargs, {"non_blocking", "dim_order"}, "dim_order_ops._clone_dim_order" + ) + require_contiguous_format( + dim_order=kwargs.get("dim_order"), + op_name="dim_order_ops._clone_dim_order", + ) + x = args[0] + out = P.make_or_get_slot(n) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +# Handle Edge IR's dim_order_ops._to_dim_order_copy (dtype conversion) +# This is what x.to(dtype) becomes after to_edge() transformation +@REGISTRY.register(target=[exir_ops.edge.dim_order_ops._to_dim_order_copy.default]) +def _dim_order_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot: + # dim_order_ops._to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, ...) + # If dtype is specified, this is a dtype conversion (use AsTypeNode) + # If dtype is None/same, this is just a memory layout copy (use ContiguousNode) + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "dim_order_ops._to_dim_order_copy") + require_kwargs( + kwargs, + {"dtype", "device", "layout", "non_blocking", "dim_order"}, + "dim_order_ops._to_dim_order_copy", + ) + require_contiguous_format( + layout=kwargs.get("layout"), + dim_order=kwargs.get("dim_order"), + op_name="dim_order_ops._to_dim_order_copy", + ) + x = args[0] + out = P.make_or_get_slot(n) + + dtype = kwargs.get("dtype") + if dtype is not None: + # Dtype conversion + P.emit( + AsTypeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + else: + # No dtype change, just memory layout (contiguous) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._to_copy.default]) +def _to_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten._to_copy - lower-level dtype/device conversion.""" + # aten._to_copy(Tensor self, *, ScalarType? dtype=None, ...) + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten._to_copy") + require_kwargs( + kwargs, {"dtype", "device", "layout", "memory_format"}, "aten._to_copy" + ) + require_contiguous_format( + layout=kwargs.get("layout"), + memory_format=kwargs.get("memory_format"), + op_name="aten._to_copy", + ) + x = args[0] + out = P.make_or_get_slot(n) + + dtype = kwargs.get("dtype") + if dtype is not None: + # Dtype conversion + P.emit( + AsTypeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + else: + # No dtype change, just copy (use contiguous) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.embedding.default]) +def _embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.embedding") + # "padding_idx", "scale_grad_by_freq", "sparse" are training only args + # and ignored + require_kwargs( + P.kwargs(n), {"padding_idx", "scale_grad_by_freq", "sparse"}, "aten.embedding" + ) + w, x = args[0], args[1] + # padding_idx (args[2] if present) is ignored - only affects gradients + out = P.make_or_get_slot(n) + P.emit( + TakeNode( + x=P.slot_to_tid(w), + index=IntOrVidOrTid.from_tid(P.slot_to_tid(x)), + out=P.slot_to_tid(out), + axis=0, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar]) +def _add_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.add.Tensor: a + alpha * b.""" + args = P.args(n) + require_args(args, 2, 2, "aten.add.Tensor") + require_kwargs(P.kwargs(n), {"alpha"}, "aten.add.Tensor") + a, b = args + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + if not isinstance(b, Slot): + b = emit_lifted_constant(P, b, dtype) + alpha = P.kwargs(n).get("alpha", 1) + if alpha != 1: + alpha_slot = emit_lifted_constant(P, alpha, dtype) + _, tmp = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(b), + b=P.slot_to_tid(alpha_slot), + out=P.slot_to_tid(tmp), + ) + ) + b = tmp + out = P.make_or_get_slot(n) + P.emit( + AddNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.div.Tensor_mode]) +def _div_tensor_mode_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.div.Tensor_mode with rounding mode.""" + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 2, 2, "aten.div.Tensor_mode") + require_kwargs(kwargs, {"rounding_mode"}, "aten.div.Tensor_mode") + out = P.make_or_get_slot(n) + a = args[0] + b = args[1] + rounding_mode = kwargs.get("rounding_mode", None) + + # Handle scalar b by creating a constant tensor + if not isinstance(b, Slot): + b = P.make_or_get_constant( + f"_scalar_{b}", torch.tensor([b], dtype=n.meta["val"].dtype) + ) + + # Handle scalar a + if not isinstance(a, Slot): + a = P.make_or_get_constant( + f"_scalar_{a}", torch.tensor([a], dtype=n.meta["val"].dtype) + ) + + if rounding_mode == "trunc": + raise NotImplementedError( + "aten.div.Tensor_mode with rounding_mode='trunc' is not supported. " + "MLX does not have a truncate operation." + ) + elif rounding_mode == "floor": + P.emit( + FloorDivideNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + else: + # rounding_mode is None - true division + P.emit( + DivideNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._softmax.default]) +def _softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle softmax: computes softmax along the specified dimension. + + aten._softmax(self, dim, half_to_float) computes: + softmax(self, axis=dim) + + The half_to_float parameter is for type conversion and is ignored for MLX. + """ + args = P.args(n) + require_args(args, 3, 3, "aten._softmax") + require_kwargs(P.kwargs(n), set(), "aten._softmax") + x, dim, _ = args[0], args[1], args[2] # half_to_float is unused for MLX + + out = P.make_or_get_slot(n) + + # Emit SoftmaxNode with the specified axis + P.emit( + SoftmaxNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + precise=False, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.gelu.default]) +def _gelu_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten.gelu") + require_kwargs(kwargs, {"approximate"}, "aten.gelu") + (x,) = args + # GELU approximate mode: 'none' (default) or 'tanh' + approximate = kwargs.get("approximate", "none") + out = P.make_or_get_slot(n) + P.emit( + GeluNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + approximate=approximate, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.permute.default, torch.ops.aten.permute_copy.default] +) +def _permute_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.permute") + require_kwargs(P.kwargs(n), set(), "aten.permute") + x, dims = args + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + perm=list(dims), + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.transpose.int, torch.ops.aten.transpose_copy.int] +) +def _transpose_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 3, 3, "aten.transpose") + require_kwargs(P.kwargs(n), set(), "aten.transpose") + x, dim0, dim1 = args + perm = list(range(len(n.meta["val"].shape))) + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + perm=perm, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor] +) +def _slice_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 4, 5, "aten.slice") + require_kwargs(P.kwargs(n), set(), "aten.slice") + x, dim, start, stop = args[0], args[1], args[2], args[3] + step = args[4] if len(args) > 4 else 1 + if start is None: + start = 0 + require_static_int(step, "step", "aten.slice") + assert step >= 1, f"aten.slice: step must be >= 1, got {step}" + out = P.make_or_get_slot(n) + P.emit( + SliceNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=P.to_int_or_vid(dim), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(stop), + step=step, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.narrow.default]) +def _narrow_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """ + Handle narrow(input, dim, start, length) -> slice(input, dim, start, start+length). + + This is needed for KV cache updates with dynamic positions where narrow + is preferred over slice syntax for better torch.export compatibility. + """ + args = P.args(n) + require_args(args, 4, 4, "aten.narrow") + require_kwargs(P.kwargs(n), set(), "aten.narrow") + x, dim, start, length = args + out = P.make_or_get_slot(n) + + # Convert narrow (start, length) to slice (start, end) + # The end is start + length + start_iov = P.to_int_or_vid(start) + length_iov = P.to_int_or_vid(length) + + # For stop = start + length, we need to emit an ADD_SCALAR if either is a Vid + if isinstance(start_iov, IntOrVid) and start_iov.vid is not None: + # start is a Vid, need to add at runtime + if isinstance(length_iov, IntOrVid) and length_iov.vid is not None: + # Both are Vids - emit add to compute stop + _, stop_slot = P.make_tmp_value_slot() + stop_vid = P.slot_to_vid(stop_slot) + P.emit( + AddIntNode( + a=start_iov.vid, + b=length_iov.vid, + out=stop_vid, + ) + ) + stop_iov = IntOrVid(int64=None, vid=stop_vid) + else: + # start is Vid, length is int - emit add scalar + _, stop_slot = P.make_tmp_value_slot() + stop_vid = P.slot_to_vid(stop_slot) + P.emit( + AddIntNode( + a=start_iov.vid, + b=( + length_iov.int64 + if isinstance(length_iov, IntOrVid) + else length_iov + ), + out=stop_vid, + ) + ) + stop_iov = IntOrVid(int64=None, vid=stop_vid) + elif isinstance(length_iov, IntOrVid) and length_iov.vid is not None: + # length is Vid, start is int - emit add scalar + start_val = start_iov.int64 if isinstance(start_iov, IntOrVid) else start_iov + _, stop_slot = P.make_tmp_value_slot() + stop_vid = P.slot_to_vid(stop_slot) + P.emit( + AddIntNode( + a=length_iov.vid, + b=start_val, + out=stop_vid, + ) + ) + stop_iov = IntOrVid(int64=None, vid=stop_vid) + else: + # Both are concrete ints + start_val = start_iov.int64 if isinstance(start_iov, IntOrVid) else start_iov + length_val = ( + length_iov.int64 if isinstance(length_iov, IntOrVid) else length_iov + ) + stop_iov = IntOrVid(int64=start_val + length_val, vid=None) + + P.emit( + SliceNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=P.to_int_or_vid(dim), + start=start_iov, + stop=stop_iov, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default] +) +def _unsqueeze_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.unsqueeze") + require_kwargs(P.kwargs(n), set(), "aten.unsqueeze") + x, dim = args + out = P.make_or_get_slot(n) + P.emit( + ExpandDimsNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims] +) +def _squeeze_dims_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle squeeze operation for specific dimensions. + + Removes dimensions of size 1 from the tensor at specified positions. + """ + args = P.args(n) + require_args(args, 2, 2, "aten.squeeze.dims") + require_kwargs(P.kwargs(n), set(), "aten.squeeze.dims") + x, dims = args + out = P.make_or_get_slot(n) + + dims_list = list(dims) if dims is not None else None + + P.emit( + SqueezeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + dims=dims_list, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.squeeze.default, torch.ops.aten.squeeze_copy.default] +) +def _squeeze_default_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle squeeze operation without specified dimensions. + + Removes all dimensions of size 1 from the tensor. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.squeeze.default") + require_kwargs(P.kwargs(n), set(), "aten.squeeze.default") + (x,) = args + out = P.make_or_get_slot(n) + + P.emit( + SqueezeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + dims=None, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.cat.default]) +def _cat_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle concatenation of a list of tensors. + + Concatenates tensors along a specified dimension. + All tensors must have the same shape except in the concatenating dimension. + """ + args = P.args(n) + require_args(args, 1, 2, "aten.cat") + require_kwargs(P.kwargs(n), set(), "aten.cat") + # aten.cat.default signature: cat(Tensor[] tensors, int dim=0) -> Tensor + # args can be (tensors_list,) or (tensors_list, dim) + tensors_list = args[0] + dim = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + + # Convert list of tensor slots to list of Tids + tensor_tids = [P.slot_to_tid(t) for t in tensors_list] + + # dim is typically an int + axis = dim if dim is not None else 0 + + P.emit( + ConcatenateNode( + tensors=tensor_tids, + out=P.slot_to_tid(out), + axis=axis, + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_with_sizes_copy.default, + ] +) +def _split_with_sizes_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle split_with_sizes operation. + + Splits a tensor into chunks with specified sizes along a dimension. + Returns a tuple of output slots that getitem can extract from. + + PyTorch: split_with_sizes(x, [2, 3, 4], dim=1) + MLX: split(x, indices=[2, 5], axis=1) # indices are cumulative positions + """ + args = P.args(n) + require_args(args, 2, 3, "aten.split_with_sizes") + require_kwargs(P.kwargs(n), set(), "aten.split_with_sizes") + x = args[0] + sizes = args[1] + dim = args[2] if len(args) > 2 else 0 # dim has default value of 0 + + # Convert sizes to IntOrVid (supports both static ints and dynamic Vids) + sizes_int_or_vid = [P.to_int_or_vid(s) for s in sizes] + + axis = dim if dim is not None else 0 + + # Create output slots for multi-output operation + # make_or_get_slots automatically creates slots based on node.meta["val"] + output_slots = P.make_or_get_slots(n) + + # Emit SplitNode with all output slots + P.emit( + SplitNode( + x=P.slot_to_tid(x), + outs=[P.slot_to_tid(s) for s in output_slots], + sizes=sizes_int_or_vid, + axis=axis, + ) + ) + + # Return tuple of slots - getitem will extract individual elements + return output_slots + + +@REGISTRY.register( + target=[torch.ops.aten.split.Tensor, torch.ops.aten.split_copy.Tensor] +) +def _split_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle split operation with uniform chunk size. + + Splits a tensor into chunks of a given size along a dimension. + The last chunk may be smaller if the dimension does not divide evenly. + + PyTorch: split(x, split_size, dim=0) + + We pass [split_size] to the interpreter, which computes the actual + chunk sizes based on the tensor dimension. + """ + args = P.args(n) + require_args(args, 2, 3, "aten.split") + require_kwargs(P.kwargs(n), set(), "aten.split") + x = args[0] + split_size = args[1] + dim = args[2] if len(args) > 2 else 0 + + axis = dim if dim is not None else 0 + if axis < 0: + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise RuntimeError("split: missing tensor metadata for negative axis") + axis += len(x_meta.shape) + + # Create output slots for multi-output operation + output_slots = P.make_or_get_slots(n) + + # Emit SplitNode - interpreter computes actual chunk sizes from split_size + P.emit( + SplitNode( + x=P.slot_to_tid(x), + outs=[P.slot_to_tid(s) for s in output_slots], + sizes=[P.to_int_or_vid(split_size)], + axis=axis, + ) + ) + + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.repeat.default]) +def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.repeat") + require_kwargs(P.kwargs(n), set(), "aten.repeat") + x, reps = args + + # Convert reps to IntOrVid (supports both static ints and dynamic Vids) + reps_int_or_vid = [P.to_int_or_vid(r) for r in reps] + + out = P.make_or_get_slot(n) + P.emit( + TileNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + reps=reps_int_or_vid, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.index.Tensor]) +def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.index.Tensor") + require_kwargs(P.kwargs(n), set(), "aten.index.Tensor") + x, idx_list = args + if not isinstance(idx_list, list) or len(idx_list) == 0: + raise ValueError( + f"aten.index.Tensor requires a list of index tensors, " + f"got {type(idx_list)}" + ) + + x_meta = n.args[0].meta.get("val") + x_ndim = len(x_meta.shape) if x_meta is not None else None + + # Filter out None indices and track which axes they correspond to + non_none = [(i, idx) for i, idx in enumerate(idx_list) if idx is not None] + + if len(non_none) == 0: + raise ValueError("aten.index.Tensor: all indices are None") + + if len(non_none) == 1: + axis, idx = non_none[0] + idx_meta = n.args[1][axis].meta.get("val") + ndim_match = ( + x_meta is not None + and idx_meta is not None + and len(x_meta.shape) == len(idx_meta.shape) + ) + out = P.make_or_get_slot(n) + if ndim_match: + # Same ndim: use TakeAlongAxisNode (element-wise gather) + P.emit( + TakeAlongAxisNode( + x=P.slot_to_tid(x), + indices=P.slot_to_tid(idx), + out=P.slot_to_tid(out), + axis=axis, + ) + ) + else: + # Different ndim (e.g. 1D indices into 3D tensor): use TakeNode + P.emit( + TakeNode( + x=P.slot_to_tid(x), + index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)), + out=P.slot_to_tid(out), + axis=axis, + ) + ) + return out + + # Multi-index: use GatherNode (maps to mlx::gather) + if x_meta is None or x_ndim is None: + raise ValueError( + "aten.index.Tensor with multiple indices requires input shape metadata" + ) + + indices = [P.slot_to_tid(idx) for _, idx in non_none] + axes = [i for i, _ in non_none] + + # slice_sizes: 1 for indexed axes, full dim size for non-indexed axes + # Use int() to handle SymInt values from dynamic shapes + indexed_axes = set(axes) + slice_sizes = [] + for dim in range(x_ndim): + if dim in indexed_axes: + slice_sizes.append(1) + else: + dim_size = x_meta.shape[dim] + if not isinstance(dim_size, int): + raise ValueError( + f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size " + f"{dim_size}, which is not supported with multi-index gather" + ) + slice_sizes.append(dim_size) + + # Emit gather — output shape is broadcast(indices).shape + slice_sizes + _, gather_slot = P.make_tmp_slot() + P.emit( + GatherNode( + x=P.slot_to_tid(x), + indices=indices, + out=P.slot_to_tid(gather_slot), + axes=axes, + slice_sizes=slice_sizes, + ) + ) + + # Reshape to match aten.index.Tensor output shape, which strips the + # trailing dimensions introduced by gather's slice_sizes + out_meta = n.meta.get("val") + if out_meta is None: + raise ValueError( + "aten.index.Tensor: output shape metadata required for reshape after gather" + ) + out_shape = [P.to_int_or_vid(int(d)) for d in out_meta.shape] + + out = P.make_or_get_slot(n) + P.emit( + ReshapeNode( + x=P.slot_to_tid(gather_slot), + out=P.slot_to_tid(out), + shape=out_shape, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.index_select.default]) +def _index_select_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.index_select: select elements along an axis using a 1D index tensor. + + index_select(input, dim, index) returns input.take(index, axis=dim). + Unlike select (which takes a scalar index and removes the dim), + index_select takes a tensor of indices and preserves the dim. + """ + args = P.args(n) + require_args(args, 3, 3, "aten.index_select") + require_kwargs(P.kwargs(n), set(), "aten.index_select") + x, dim, indices = args + out = P.make_or_get_slot(n) + P.emit( + TakeNode( + x=P.slot_to_tid(x), + index=IntOrVidOrTid.from_tid(P.slot_to_tid(indices)), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.slice_scatter.default]) +def _slice_scatter_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.slice_scatter: return a copy of self with self[dim][start:end:step] = src.""" + args = P.args(n) + require_args(args, 2, 6, "aten.slice_scatter") + require_kwargs(P.kwargs(n), set(), "aten.slice_scatter") + self_tensor = args[0] + src = args[1] + dim = args[2] if len(args) > 2 else 0 + start = args[3] if len(args) > 3 else 0 + end = args[4] if len(args) > 4 else None + step = args[5] if len(args) > 5 else 1 + + # If end is None, default to dim size + if end is None: + input_meta = n.args[0].meta.get("val") + if input_meta is not None: + end = input_meta.shape[dim] + else: + raise ValueError( + "aten.slice_scatter: end=None requires input shape metadata" + ) + + require_static_int(step, "step", "aten.slice_scatter") + assert step >= 1, f"aten.slice_scatter: step must be >= 1, got {step}" + + out = P.make_or_get_slot(n) + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(self_tensor), + update=P.slot_to_tid(src), + out=P.slot_to_tid(out), + axis=P.to_int_or_vid(dim), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(end), + step=step, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.select.int, torch.ops.aten.select_copy.int]) +def _select_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """ + Handle aten.select_copy.int - select a single index along a dimension. + + select_copy(input, dim, index) returns input[..., index, ...] where the + indexing happens at dimension `dim`. The selected dimension is removed. + + Maps to MLX's take(array, int index, axis) which also removes the dimension. + """ + args = P.args(n) + require_args(args, 3, 3, "aten.select_copy.int") + require_kwargs(P.kwargs(n), set(), "aten.select_copy.int") + x, dim, index = args + out = P.make_or_get_slot(n) + P.emit( + TakeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + index=P.to_int_or_vid_or_tid(index), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.sym_size.int]) +def _sym_size_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.sym_size.int") + require_kwargs(P.kwargs(n), set(), "aten.sym_size.int") + a, dim = args + out = P.make_or_get_slot(n) + P.emit( + SymSizeNode( + a=P.slot_to_tid(a), + dim=dim, + out=P.slot_to_vid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.item.default]) +def _item_handler(P: MLXProgramBuilder, n: Node) -> Slot: + if not isinstance(n.meta["val"], torch.SymInt): + raise ValueError("item only supported if it returns a SymInt") + args = P.args(n) + require_args(args, 1, 1, "aten.item") + require_kwargs(P.kwargs(n), set(), "aten.item") + (x,) = args + out = P.make_or_get_slot(n) + P.emit( + ItemIntNode( + x=P.slot_to_tid(x), + out=P.slot_to_vid(out), + ) + ) + return out + + +@REGISTRY.register(target=[operator.getitem]) +def _getitem_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """ + Handle getitem(tuple, idx) - extracts element from a tuple of slots. + + The source tuple comes from ops that return multiple values (like + auto_functionalized_v2). Those handlers return tuples of slots, + and we just ID_COPY the selected element to a new output slot. + """ + args = P.args(n) + require_args(args, 2, 2, "operator.getitem") + require_kwargs(P.kwargs(n), set(), "operator.getitem") + a, idx = args + out = P.make_or_get_slot(n) + P.emit( + IdCopyNode( + x=P.slot_to_tid(a[idx]), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.layer_norm.default]) +def _layer_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 5, "aten.layer_norm") + require_kwargs(P.kwargs(n), set(), "aten.layer_norm") + x, shape = args[0:2] + if len(shape) > 1: + raise ValueError( + "LayerNorm is only supported when normalizing over the last dimension" + ) + w = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + out = P.make_or_get_slot(n) + P.emit( + LayerNormNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + weight=P.slot_to_tid(w) if w else None, + bias=P.slot_to_tid(bias) if bias else None, + eps=eps, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.native_layer_norm.default]) +def _native_layer_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle native_layer_norm which returns (output, mean, rstd). + + Only the normalized output (index 0) is computed via fast::layer_norm; + mean and rstd (indices 1 and 2) are needed only for backward. + """ + # Verify mean/rstd outputs are unused — we only compute the normalized output. + unsupported = used_getitem_indices(n) & {1, 2} + if unsupported: + raise ValueError( + f"native_layer_norm outputs {unsupported} (mean/rstd) are used, " + "but only the normalized output (index 0) is supported" + ) + + args = P.args(n) + require_args(args, 2, 5, "aten.native_layer_norm") + require_kwargs(P.kwargs(n), set(), "aten.native_layer_norm") + x, shape = args[0:2] + if len(shape) > 1: + raise ValueError( + "LayerNorm is only supported when normalizing over the last dimension" + ) + w = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # native_layer_norm returns (output, mean, rstd) — allocate all 3 slots + output_slots = P.make_or_get_slots(n) + + P.emit( + LayerNormNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(output_slots[0]), + weight=P.slot_to_tid(w) if w else None, + bias=P.slot_to_tid(bias) if bias else None, + eps=eps, + ) + ) + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.arange.default]) +def _arange_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle arange with just stop, or (start, stop) or (start, stop, step). + + Supports both static (literal int) and dynamic (Slot from item()) values. + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 3, "aten.arange") + require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.arange") + require_contiguous_format( + layout=kwargs.get("layout"), + op_name="aten.arange", + ) + if len(args) == 1: + start = 0 + stop = args[0] + else: + start, stop = args[0:2] + step = args[2] if len(args) > 2 else 1 + + # arange defaults to int64 when dtype is not specified (like torch.arange) + dtype = kwargs.get("dtype", torch.int64) + scalar_type_val = torch_dtype_to_scalar_type(dtype) + + out = P.make_or_get_slot(n) + P.emit( + ARangeNode( + out=P.slot_to_tid(out), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(stop), + step=P.to_int_or_vid(step), + scalar_type=scalar_type_val, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.arange.start_step]) +def _arange_start_step_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle arange with start, end, and step arguments. + + Supports both static (literal int) and dynamic (Slot from item()) start/stop/step. + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 2, 3, "aten.arange.start_step") + require_kwargs( + kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.arange.start_step" + ) + require_contiguous_format( + layout=kwargs.get("layout"), + op_name="aten.arange.start_step", + ) + start = args[0] + stop = args[1] + step = args[2] if len(args) > 2 else 1 + + # arange defaults to int64 when dtype is not specified (like torch.arange) + dtype = kwargs.get("dtype", torch.int64) + scalar_type_val = torch_dtype_to_scalar_type(dtype) + + out = P.make_or_get_slot(n) + P.emit( + ARangeNode( + out=P.slot_to_tid(out), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(stop), + step=P.to_int_or_vid(step), + scalar_type=scalar_type_val, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.rms_norm.default]) +def _aten_rms_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 4, "aten.rms_norm") + require_kwargs(P.kwargs(n), set(), "aten.rms_norm") + x, normalized_shape = args[0], args[1] + if len(normalized_shape) > 1: + raise ValueError( + "RMSNorm is only supported when normalizing over the last dimension" + ) + w = args[2] if len(args) > 2 else None + eps = args[3] if len(args) > 3 else 1e-5 + out = P.make_or_get_slot(n) + P.emit( + RMSNormNode( + x=P.slot_to_tid(x), + weight=P.slot_to_tid(w) if w else None, + out=P.slot_to_tid(out), + eps=eps, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.mlx.rope.default]) +def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 3, 7, "mlx.rope") + require_kwargs(P.kwargs(n), set(), "mlx.rope") + x, dims, pos = args[0], args[1], args[2] + traditional = args[3] if len(args) > 3 else False + base = args[4] if len(args) > 4 else 500000.0 + scale = args[5] if len(args) > 5 else 1.0 + freqs = args[6] if len(args) > 6 else None + out = P.make_or_get_slot(n) + + # pos must be a Slot (SymInt) from input_pos.item() during tracing + # The schema supports both Vid (scalar) and Tid (tensor) for offset + if not isinstance(pos, Slot): + raise ValueError( + f"RopeNode.offset must be a SymInt (traced via tensor.item()), got {type(pos)}. " + "Make sure input_pos is a tensor and you call input_pos.item() to get a SymInt." + ) + + P.emit( + RopeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + dims=dims, + offset=VidOrTid.from_vid(P.slot_to_vid(pos)), + freqs=P.slot_to_tid(freqs) if freqs else None, + traditional=traditional, + base=base, + scale=scale, + ) + ) + + return out + + +def _emit_channel_last_weight(P: MLXProgramBuilder, w_node: Node, perm: list) -> Slot: + """Get convolution weight in channel-last format. + + If the weight is a placeholder (static parameter), permute at compile time + and store as a constant. If it comes from another node (e.g. dequantize + output), emit a runtime TransposeNode instead. + """ + if w_node.op == "placeholder": + w_target, w_tensor = P.get_placeholder_target_and_tensor(w_node) + return P.make_or_get_constant( + f"{w_target}_channel_last", w_tensor.permute(perm).contiguous() + ) + else: + w_slot = P.slot_map([w_node])[0] + _, w = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(w_slot), + out=P.slot_to_tid(w), + perm=perm, + ) + ) + return w + + +def _emit_conv_transpose_weight( + P: MLXProgramBuilder, w_node: Node, groups: int, ndim: int +) -> Slot: + """Get conv_transpose weight in MLX format, handling grouped convolutions. + + PyTorch conv_transpose weight shape: [C_in, C_out/G, *K] + MLX expects: [C_out, *K, C_in/G] + + For groups=1, a simple permute suffices (C_in==C_in/G, C_out/G==C_out). + For groups>1, we need reshape-permute-reshape to rearrange the group dim: + [C_in, C_out/G, *K] -> [G, C_in/G, C_out/G, *K] + -> [G, C_out/G, *K, C_in/G] + -> [C_out, *K, C_in/G] + """ + if groups == 1: + # Simple permute: [C_in, C_out, *K] -> [C_out, *K, C_in] + # e.g. 1D: [1, 2, 0], 2D: [1, 2, 3, 0], 3D: [1, 2, 3, 4, 0] + perm = list(range(1, ndim + 2)) + [0] + return _emit_channel_last_weight(P, w_node, perm) + + # Grouped: need reshape-permute-reshape at compile time + if w_node.op != "placeholder": + raise ValueError( + f"conv_transpose with groups > 1 requires static weights, " + f"got dynamic weight from {w_node.op}" + ) + + w_target, w_tensor = P.get_placeholder_target_and_tensor(w_node) + c_in = w_tensor.shape[0] + c_out_per_g = w_tensor.shape[1] + kernel_shape = list(w_tensor.shape[2:]) + c_in_per_g = c_in // groups + + # [C_in, C_out/G, *K] -> [G, C_in/G, C_out/G, *K] + w = w_tensor.reshape([groups, c_in_per_g, c_out_per_g] + kernel_shape) + # [G, C_in/G, C_out/G, *K] -> [G, C_out/G, *K, C_in/G] + # perm: [0, 2, 3, ..., ndim+1, 1] + perm = [0, 2] + list(range(3, ndim + 3)) + [1] + w = w.permute(perm).contiguous() + # [G, C_out/G, *K, C_in/G] -> [C_out, *K, C_in/G] + c_out = groups * c_out_per_g + w = w.reshape([c_out] + kernel_shape + [c_in_per_g]) + + return P.make_or_get_constant(f"{w_target}_channel_last", w) + + +def _emit_conv_bias( + P: MLXProgramBuilder, bias: Optional[Slot], tmp: Slot, ndim: int +) -> None: + """Reshape conv bias to channel-last broadcast shape and add to tmp in-place. + + After the convolution the activation is in channel-last layout, so the bias + (shape ``[C_out]``) must be reshaped to ``[1, …, 1, -1]`` with *ndim* + dimensions before being added. Does nothing when *bias* is ``None``. + """ + if bias is None: + return + _, tmp2 = P.make_tmp_slot() + shape = [IntOrVid.from_literal(1)] * (ndim - 1) + [IntOrVid.from_literal(-1)] + P.emit( + ReshapeNode( + x=P.slot_to_tid(bias), + out=P.slot_to_tid(tmp2), + shape=shape, + ) + ) + P.emit( + AddNode( + a=P.slot_to_tid(tmp), + b=P.slot_to_tid(tmp2), + out=P.slot_to_tid(tmp), + ) + ) + + +def _emit_conv( + P: MLXProgramBuilder, + n: Node, + x_node: Node, + w_node: Node, + bias_node, + stride: list, + padding: list, + dilation: list, + groups: int, + ndim: int, +) -> Slot: + """Shared logic for regular convolution emission. + + Handles weight transform, input/output transposition, bias, and node emission + for all spatial dimensions (1D, 2D, 3D). + + Weight: [C_out, C_in/G, *K] -> [C_out, *K, C_in/G] + Input: (N, C, *spatial) -> (N, *spatial, C) + Output: (N, *spatial, C) -> (N, C, *spatial) + """ + if ndim == 3 and groups != 1: + raise ValueError( + "conv3d with groups != 1 is not supported by MLX. " f"Got groups={groups}." + ) + + # Permutation: channels-first [N, C, *spatial] <-> channels-last [N, *spatial, C] + ch_first_to_last = [0] + list(range(2, ndim + 2)) + [1] + ch_last_to_first = [0, ndim + 1] + list(range(1, ndim + 1)) + + # Weight: [C_out, C_in/G, *K] -> [C_out, *K, C_in/G] (same permutation) + w = _emit_channel_last_weight(P, w_node, ch_first_to_last) + + x, bias = P.slot_map([x_node, bias_node]) + + _, tmp = P.make_tmp_slot() + P.emit( + TransposeNode(x=P.slot_to_tid(x), out=P.slot_to_tid(tmp), perm=ch_first_to_last) + ) + + if ndim == 1: + P.emit( + Conv1DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups, + ) + ) + elif ndim == 2: + P.emit( + Conv2DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + groups=groups, + ) + ) + elif ndim == 3: + P.emit( + Conv3DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_d=stride[0], + stride_h=stride[1], + stride_w=stride[2], + padding_d=padding[0], + padding_h=padding[1], + padding_w=padding[2], + dilation_d=dilation[0], + dilation_h=dilation[1], + dilation_w=dilation[2], + groups=groups, + ) + ) + + _emit_conv_bias(P, bias, tmp, ndim=ndim + 2) + + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(tmp), out=P.slot_to_tid(out), perm=ch_last_to_first + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.conv1d.default]) +def _conv1d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv1d: (input, weight, bias, stride, padding, dilation, groups).""" + require_args(n.args, 2, 7, "aten.conv1d") + require_kwargs(P.kwargs(n), set(), "aten.conv1d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else 1, 1, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else 0, 1, 0) + dilation = _normalize_conv_param(n.args[5] if len(n.args) > 5 else 1, 1, 1) + return _emit_conv( + P, n, x_node, w_node, bias_node, stride, padding, dilation, groups, ndim=1 + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv2d.default]) +def _conv2d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv2d: (input, weight, bias, stride, padding, dilation, groups).""" + require_args(n.args, 2, 7, "aten.conv2d") + require_kwargs(P.kwargs(n), set(), "aten.conv2d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1], 2, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0], 2, 0) + dilation = _normalize_conv_param(n.args[5] if len(n.args) > 5 else [1, 1], 2, 1) + return _emit_conv( + P, n, x_node, w_node, bias_node, stride, padding, dilation, groups, ndim=2 + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv3d.default]) +def _conv3d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv3d: (input, weight, bias, stride, padding, dilation, groups).""" + require_args(n.args, 2, 7, "aten.conv3d") + require_kwargs(P.kwargs(n), set(), "aten.conv3d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1, 1], 3, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0, 0], 3, 0) + dilation = _normalize_conv_param(n.args[5] if len(n.args) > 5 else [1, 1, 1], 3, 1) + return _emit_conv( + P, n, x_node, w_node, bias_node, stride, padding, dilation, groups, ndim=3 + ) + + +def _emit_conv_transpose( + P: MLXProgramBuilder, + n: Node, + x_node: Node, + w_node: Node, + bias_node, + stride: list, + padding: list, + dilation: list, + output_padding: list, + groups: int, + ndim: int, +) -> Slot: + """Shared logic for transposed convolution emission. + + Handles weight transform, input/output transposition, bias, and node emission + for all spatial dimensions. Called by both the specific conv_transpose handlers + and the unified aten.convolution.default handler. + """ + if ndim == 3 and groups != 1: + raise ValueError( + "conv_transpose with groups != 1 is not supported for 3D by MLX" + ) + + w = _emit_conv_transpose_weight(P, w_node, groups, ndim=ndim) + x, bias = P.slot_map([x_node, bias_node]) + + # Transpose input: channels-first -> channels-last + ch_first_to_last = list(range(ndim + 2)) + ch_first_to_last = [0] + list(range(2, ndim + 2)) + [1] + ch_last_to_first = [0, ndim + 1] + list(range(1, ndim + 1)) + + _, tmp = P.make_tmp_slot() + P.emit( + TransposeNode(x=P.slot_to_tid(x), out=P.slot_to_tid(tmp), perm=ch_first_to_last) + ) + + if ndim == 1: + P.emit( + ConvTranspose1DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + output_padding=output_padding[0], + groups=groups, + ) + ) + elif ndim == 2: + P.emit( + ConvTranspose2DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + output_padding_h=output_padding[0], + output_padding_w=output_padding[1], + groups=groups, + ) + ) + elif ndim == 3: + P.emit( + ConvTranspose3DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_d=stride[0], + stride_h=stride[1], + stride_w=stride[2], + padding_d=padding[0], + padding_h=padding[1], + padding_w=padding[2], + dilation_d=dilation[0], + dilation_h=dilation[1], + dilation_w=dilation[2], + output_padding_d=output_padding[0], + output_padding_h=output_padding[1], + output_padding_w=output_padding[2], + groups=groups, + ) + ) + + _emit_conv_bias(P, bias, tmp, ndim=ndim + 2) + + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(tmp), out=P.slot_to_tid(out), perm=ch_last_to_first + ) + ) + return out + + +def _normalize_conv_param(val, ndim, default=0): + """Normalize a conv parameter (stride/padding/etc.) to a list of length ndim.""" + if isinstance(val, int): + return [val] * ndim + if isinstance(val, list): + if len(val) == 1: + return val * ndim + return val + return [default] * ndim + + +@REGISTRY.register(target=[torch.ops.aten.convolution.default]) +def _convolution_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.convolution.default — the unified convolution op. + + Args layout: convolution(input, weight, bias, stride, padding, dilation, + transposed, output_padding, groups) + + This op appears when PyTorch doesn't decompose to specific conv ops + (e.g. grouped conv_transpose). + """ + raw_args = n.args + x_node, w_node = raw_args[0], raw_args[1] + bias_node = raw_args[2] if len(raw_args) > 2 else None + transposed = raw_args[6] if len(raw_args) > 6 else False + groups = raw_args[8] if len(raw_args) > 8 else 1 + + if not transposed: + raise ValueError( + "aten.convolution with transposed=False: use aten.conv{1,2,3}d instead" + ) + + x_meta = x_node.meta.get("val") + if x_meta is None: + raise ValueError("aten.convolution: input shape metadata required") + ndim = len(x_meta.shape) - 2 + + stride = _normalize_conv_param(raw_args[3] if len(raw_args) > 3 else 1, ndim, 1) + padding = _normalize_conv_param(raw_args[4] if len(raw_args) > 4 else 0, ndim, 0) + dilation = _normalize_conv_param(raw_args[5] if len(raw_args) > 5 else 1, ndim, 1) + output_padding = _normalize_conv_param( + raw_args[7] if len(raw_args) > 7 else 0, ndim, 0 + ) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim, + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv_transpose1d.default]) +def _conv_transpose1d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv_transpose1d: (input, weight, bias, stride, padding, output_padding, groups, dilation).""" + require_args(n.args, 2, 8, "aten.conv_transpose1d") + require_kwargs(P.kwargs(n), set(), "aten.conv_transpose1d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else 1, 1, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else 0, 1, 0) + output_padding = _normalize_conv_param(n.args[5] if len(n.args) > 5 else 0, 1, 0) + dilation = _normalize_conv_param(n.args[7] if len(n.args) > 7 else 1, 1, 1) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim=1, + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv_transpose2d.input]) +def _conv_transpose2d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv_transpose2d: (input, weight, bias, stride, padding, output_padding, groups, dilation).""" + require_args(n.args, 2, 8, "aten.conv_transpose2d") + require_kwargs(P.kwargs(n), set(), "aten.conv_transpose2d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1], 2, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0], 2, 0) + output_padding = _normalize_conv_param( + n.args[5] if len(n.args) > 5 else [0, 0], 2, 0 + ) + dilation = _normalize_conv_param(n.args[7] if len(n.args) > 7 else [1, 1], 2, 1) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim=2, + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv_transpose3d.input]) +def _conv_transpose3d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv_transpose3d: (input, weight, bias, stride, padding, output_padding, groups, dilation).""" + require_args(n.args, 2, 8, "aten.conv_transpose3d") + require_kwargs(P.kwargs(n), set(), "aten.conv_transpose3d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1, 1], 3, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0, 0], 3, 0) + output_padding = _normalize_conv_param( + n.args[5] if len(n.args) > 5 else [0, 0, 0], 3, 0 + ) + dilation = _normalize_conv_param(n.args[7] if len(n.args) > 7 else [1, 1, 1], 3, 1) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim=3, + ) + + +@REGISTRY.register(target=[torch.ops.aten.sub.Tensor, torch.ops.aten.sub.Scalar]) +def _sub_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.sub.Tensor: a - alpha * b.""" + args = P.args(n) + require_args(args, 2, 2, "aten.sub.Tensor") + require_kwargs(P.kwargs(n), {"alpha"}, "aten.sub.Tensor") + a, b = args + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + if not isinstance(b, Slot): + b = emit_lifted_constant(P, b, dtype) + alpha = P.kwargs(n).get("alpha", 1) + if alpha != 1: + alpha_slot = emit_lifted_constant(P, alpha, dtype) + _, tmp = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(b), + b=P.slot_to_tid(alpha_slot), + out=P.slot_to_tid(tmp), + ) + ) + b = tmp + out = P.make_or_get_slot(n) + P.emit( + SubtractNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.relu.default]) +def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.relu.default - rectified linear unit. + + ReLU(x) = max(x, 0), implemented using MaximumNode with a scalar zero. + Uses broadcasting in maximum operation for efficiency. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.relu") + require_kwargs(P.kwargs(n), set(), "aten.relu") + (x,) = args # x is already a Slot + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for relu") + dtype = x_meta.dtype + + zero_slot = emit_lifted_constant(P, 0.0, dtype) + + out = P.make_or_get_slot(n) + P.emit( + MaximumNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(zero_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._log_softmax.default]) +def _log_softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten._log_softmax.default - log of softmax. + + LogSoftmax(x, dim) = x - logsumexp(x, dim, keepdims=True) + + This is numerically stable because it avoids computing softmax + (which can underflow to 0) followed by log (which gives -inf for 0). + """ + args = P.args(n) + require_args(args, 3, 3, "aten._log_softmax") + require_kwargs(P.kwargs(n), set(), "aten._log_softmax") + x, dim, _half_to_float = args # x is already a Slot + + # Create temporary slot for logsumexp output + _, logsumexp_slot = P.make_tmp_slot() + + # Emit LogSumExpNode with keepdims=True + P.emit( + LogSumExpNode( + x=P.slot_to_tid(x), + axes=[dim], + keepdims=True, + out=P.slot_to_tid(logsumexp_slot), + ) + ) + + # Emit SubtractNode: x - logsumexp(x) + out = P.make_or_get_slot(n) + P.emit( + SubtractNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(logsumexp_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.constant_pad_nd.default]) +def _constant_pad_nd_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.constant_pad_nd - pad with a constant value. + + PyTorch pad format: [left_0, right_0, left_1, right_1, ...] + MLX pad_width format: [(before_0, after_0), (before_1, after_1), ...] + + Note: PyTorch pads in reverse order (last dimensions first). + """ + args = P.args(n) + require_args(args, 2, 3, "aten.constant_pad_nd") + require_kwargs(P.kwargs(n), set(), "aten.constant_pad_nd") + x_node, pad = args[0], args[1] + value = args[2] if len(args) > 2 else 0 + + if not isinstance(value, (int, float)): + raise ValueError( + f"aten.constant_pad_nd: constant value must be a scalar, got {type(value)}" + ) + + # Convert PyTorch pad format to MLX pad_width format + # PyTorch: [left_D, right_D, left_D-1, right_D-1, ...] + # MLX: [(before_0, after_0), (before_1, after_1), ..., (before_D, after_D)] + if len(pad) % 2 != 0: + raise ValueError( + f"aten.constant_pad_nd: pad length must be even, got {len(pad)}" + ) + + x = P.slot_map([x_node])[0] + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for constant_pad_nd") + + ndim = len(x_meta.shape) + num_pad_dims = len(pad) // 2 + + if num_pad_dims > ndim: + raise ValueError( + f"aten.constant_pad_nd: trying to pad {num_pad_dims} dimensions " + f"but input has only {ndim} dimensions" + ) + + # Build MLX pad_width: start with zeros for non-padded dims + pad_width = [] + for _ in range(ndim - num_pad_dims): + pad_width.extend([0, 0]) # No padding for these dimensions + + # Add padding for the padded dimensions (reverse order) + for i in range(num_pad_dims - 1, -1, -1): + left = pad[i * 2] + right = pad[i * 2 + 1] + pad_width.extend([left, right]) + + out = P.make_or_get_slot(n) + P.emit( + PadNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + pad_width=[P.to_int_or_vid(v) for v in pad_width], + mode="constant", + constant_value=float(value), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor]) +def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.clamp - clamp values to [min, max] range. + + clamp(input, min=None, max=None) -> Tensor + + Clamps all elements in input into the range [min, max]. + If min is None, there is no lower bound. If max is None, there is no upper bound. + """ + args = P.args(n) + require_args(args, 1, 3, "aten.clamp") + require_kwargs(P.kwargs(n), set(), "aten.clamp") + + x = args[0] + min_val = args[1] if len(args) > 1 else None + max_val = args[2] if len(args) > 2 else None + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for clamp") + dtype = x_meta.dtype + + out = P.make_or_get_slot(n) + + # Lift scalar bounds to 0-D constant tensors + a_min_tid = None + a_max_tid = None + if min_val is not None: + if isinstance(min_val, Slot) and min_val.id_type == IdType.Tensor: + a_min_tid = P.slot_to_tid(min_val) + else: + a_min_tid = P.slot_to_tid(emit_lifted_constant(P, float(min_val), dtype)) + if max_val is not None: + if isinstance(max_val, Slot) and max_val.id_type == IdType.Tensor: + a_max_tid = P.slot_to_tid(max_val) + else: + a_max_tid = P.slot_to_tid(emit_lifted_constant(P, float(max_val), dtype)) + + P.emit( + ClipNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + a_min=a_min_tid, + a_max=a_max_tid, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.expand.default, torch.ops.aten.expand_copy.default] +) +def _expand_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle expand: broadcasts dimensions of size 1 to larger sizes.""" + args = P.args(n) + require_args(args, 2, 2, "aten.expand") + require_kwargs(P.kwargs(n), set(), "aten.expand") + x, size = args + out = P.make_or_get_slot(n) + + shape_iovs = [P.to_int_or_vid(s) for s in size] + P.emit( + BroadcastToNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + shape=shape_iovs, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._native_batch_norm_legit_no_training.default]) +def _native_batch_norm_legit_no_training_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle batch norm inference (no training). + + Formula: output = (input - mean) / sqrt(var + eps) * weight + bias + + Args: + input: [N, C, ...] tensor + weight: [C] gamma parameter + bias: [C] beta parameter + running_mean: [C] + running_var: [C] + momentum: float (unused in inference) + eps: float + + Returns: + Tuple of (output, empty, empty) - save_mean and save_invstd are empty for no_training + """ + args = P.args(n) + require_args(args, 7, 7, "aten._native_batch_norm_legit_no_training") + require_kwargs(P.kwargs(n), set(), "aten._native_batch_norm_legit_no_training") + x = args[0] + weight = args[1] # gamma [C] - optional (None if affine=False) + bias = args[2] # beta [C] - optional (None if affine=False) + mean = args[3] # running_mean [C] + var = args[4] # running_var [C] + # momentum = args[5] - not used in inference + eps = args[6] # epsilon + + # Get output slots (3 outputs: normalized, save_mean, save_invstd) + output_slots = P.make_or_get_slots(n) + out = output_slots[0] # Main output + + # Get input ndim to determine reshape dimensions + # For BatchNorm1d: input is [N, C, L] -> reshape params to [1, C, 1] + # For BatchNorm2d: input is [N, C, H, W] -> reshape params to [1, C, 1, 1] + input_node = n.args[0] + input_ndim = len(input_node.meta["val"].shape) + + # Validate input dimensionality (only 3D and 4D supported) + if input_ndim not in (3, 4): + raise NotImplementedError( + f"MLX batch norm handler only supports 3D (BatchNorm1d) and 4D (BatchNorm2d) inputs. " + f"Got {input_ndim}D input." + ) + + def reshape_for_broadcast(slot, name_suffix): + """Reshape a [C] tensor for broadcasting with input.""" + _, reshaped = P.make_tmp_slot() + # Build shape: [1, -1] + [1] * (ndim - 2) + shape = [P.to_int_or_vid(1), P.to_int_or_vid(-1)] + for _ in range(input_ndim - 2): + shape.append(P.to_int_or_vid(1)) + P.emit( + ReshapeNode( + x=P.slot_to_tid(slot), + shape=shape, + out=P.slot_to_tid(reshaped), + ) + ) + return reshaped + + mean_reshaped = reshape_for_broadcast(mean, "mean") + var_reshaped = reshape_for_broadcast(var, "var") + + # Step 1: x_centered = x - mean + _, tmp_centered = P.make_tmp_slot() + P.emit( + SubtractNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(mean_reshaped), + out=P.slot_to_tid(tmp_centered), + ) + ) + + # Step 2: var_eps = var + eps + eps_slot = emit_lifted_constant(P, float(eps), torch.float32) + _, tmp_var_eps = P.make_tmp_slot() + P.emit( + AddNode( + a=P.slot_to_tid(var_reshaped), + b=P.slot_to_tid(eps_slot), + out=P.slot_to_tid(tmp_var_eps), + ) + ) + + # Step 3: inv_std = rsqrt(var_eps) + _, tmp_inv_std = P.make_tmp_slot() + P.emit(RsqrtNode(x=P.slot_to_tid(tmp_var_eps), out=P.slot_to_tid(tmp_inv_std))) + + # Step 4: x_normalized = x_centered * inv_std + _, tmp_normalized = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(tmp_centered), + b=P.slot_to_tid(tmp_inv_std), + out=P.slot_to_tid(tmp_normalized), + ) + ) + + # Step 5: x_scaled = x_normalized * weight (skip if weight is None, i.e. affine=False) + if weight is not None: + weight_reshaped = reshape_for_broadcast(weight, "weight") + _, tmp_scaled = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(tmp_normalized), + b=P.slot_to_tid(weight_reshaped), + out=P.slot_to_tid(tmp_scaled), + ) + ) + current_result = tmp_scaled + else: + current_result = tmp_normalized + + # Step 6: out = current_result + bias (skip if bias is None, i.e. affine=False) + if bias is not None: + bias_reshaped = reshape_for_broadcast(bias, "bias") + P.emit( + AddNode( + a=P.slot_to_tid(current_result), + b=P.slot_to_tid(bias_reshaped), + out=P.slot_to_tid(out), + ) + ) + else: + # No bias - just copy the result to output + P.emit( + IdCopyNode( + x=P.slot_to_tid(current_result), + out=P.slot_to_tid(out), + ) + ) + + # For no_training mode, outputs 1 and 2 (save_mean, save_invstd) are empty + # They should already be allocated by make_or_get_slots but we don't write to them + # PyTorch returns empty tensors for these in no_training mode + + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.where.self]) +def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle where: select from x or y according to condition. + + where(condition, x, y) returns elements from x where condition is True, + and elements from y where condition is False. + """ + args = P.args(n) + require_args(args, 3, 3, "aten.where") + require_kwargs(P.kwargs(n), set(), "aten.where") + condition, x, y = args + out = P.make_or_get_slot(n) + + P.emit( + WhereNode( + condition=P.slot_to_tid(condition), + x=P.slot_to_tid(x), + y=P.slot_to_tid(y), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.bitwise_not.default]) +def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not.""" + args = P.args(n) + require_args(args, 1, 1, "aten.bitwise_not") + require_kwargs(P.kwargs(n), set(), "aten.bitwise_not") + x_meta = n.args[0].meta.get("val") + + if x_meta is not None and x_meta.dtype == torch.bool: + # For boolean tensors, bitwise_not is equivalent to logical_not + out = P.make_or_get_slot(n) + P.emit( + LogicalNotNode( + x=P.slot_to_tid(args[0]), + out=P.slot_to_tid(out), + ) + ) + return out + else: + raise NotImplementedError( + f"aten.bitwise_not is only supported for boolean tensors. " + f"Got dtype={x_meta.dtype if x_meta else 'unknown'}" + ) + + +@REGISTRY.register( + target=[torch.ops.aten.logical_and.default, torch.ops.aten.bitwise_and.Tensor] +) +def _logical_and_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.logical_and / aten.bitwise_and on bool tensors.""" + args = P.args(n) + require_args(args, 2, 2, "aten.logical_and/bitwise_and") + require_kwargs(P.kwargs(n), set(), "aten.logical_and/bitwise_and") + + # bitwise_and is only equivalent to logical_and for bool tensors. + if n.target == torch.ops.aten.bitwise_and.Tensor: + dtype = n.args[0].meta.get("val", None) + if dtype is not None and hasattr(dtype, "dtype") and dtype.dtype != torch.bool: + raise ValueError( + f"aten.bitwise_and on non-bool dtype {dtype.dtype} is not supported; " + "only bool tensors can be lowered via LogicalAndNode" + ) + out = P.make_or_get_slot(n) + P.emit( + LogicalAndNode( + a=P.slot_to_tid(args[0]), + b=P.slot_to_tid(args[1]), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.scalar_tensor.default]) +def _scalar_tensor_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """This is equivalent to torch.full([], scalar, dtype=dtype).""" + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten.scalar_tensor") + require_kwargs( + kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.scalar_tensor" + ) + require_contiguous_format( + layout=kwargs.get("layout"), + op_name="aten.scalar_tensor", + ) + scalar_value = args[0] + + out = P.make_or_get_slot(n) + + # Get dtype from kwargs, default to float32 + dtype = n.kwargs.get("dtype") + if dtype is None: + # Infer dtype from scalar type + if isinstance(scalar_value, bool): + dtype = torch.bool + elif isinstance(scalar_value, int): + dtype = torch.int64 + else: + dtype = torch.float32 + + P.emit( + FullNode( + out=P.slot_to_tid(out), + shape=[], # 0-D tensor (scalar) + v=P.to_float_or_vid(scalar_value), + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.tril.default]) +def _tril_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.tril - extract lower triangular part of matrix. + + tril(input, diagonal=0) -> Tensor + + Returns the lower triangular part of the matrix, with all elements above + the diagonal set to zero. The diagonal parameter controls which diagonal + to consider: 0 = main diagonal, positive = above main, negative = below main. + """ + args = P.args(n) + require_args(args, 1, 2, "aten.tril") + require_kwargs(P.kwargs(n), set(), "aten.tril") + x = args[0] + diagonal = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + P.emit( + TrilNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + k=diagonal, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.triu.default]) +def _triu_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.triu - extract upper triangular part of matrix. + + triu(input, diagonal=0) -> Tensor + + Returns the upper triangular part of the matrix, with all elements below + the diagonal set to zero. The diagonal parameter controls which diagonal + to consider: 0 = main diagonal, positive = above main, negative = below main. + """ + args = P.args(n) + require_args(args, 1, 2, "aten.triu") + require_kwargs(P.kwargs(n), set(), "aten.triu") + x = args[0] + diagonal = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + P.emit( + TriuNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + k=diagonal, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.round.default]) +def _round_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.round - round elements to nearest integer. + + Note: round.decimals variant is not supported as it's not in Core ATen. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.round") + require_kwargs(P.kwargs(n), set(), "aten.round") + x = args[0] + out = P.make_or_get_slot(n) + P.emit(RoundNode(x=P.slot_to_tid(x), out=P.slot_to_tid(out), decimals=0)) + return out + + +@REGISTRY.register(target=[torch.ops.aten.logsumexp.default]) +def _logsumexp_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.logsumexp - log(sum(exp(x))) along axes.""" + args = P.args(n) + require_args(args, 1, 3, "aten.logsumexp") + require_kwargs(P.kwargs(n), set(), "aten.logsumexp") + x = args[0] + dim = args[1] if len(args) > 1 else None + keepdim = args[2] if len(args) > 2 else False + + # Normalize dim to list + if dim is None: + axes = [] + elif isinstance(dim, int): + axes = [dim] + else: + axes = list(dim) + + out = P.make_or_get_slot(n) + P.emit( + LogSumExpNode( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=axes, keepdims=keepdim + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.var.correction, torch.ops.aten.var.dim]) +def _var_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.var - variance of elements along axes.""" + args = P.args(n) + require_args(args, 1, 2, "aten.var") + require_kwargs(P.kwargs(n), {"correction", "keepdim"}, "aten.var") + x = args[0] + axes, _ = normalize_reduction_dim(args) + + # Get correction/ddof and keepdim from kwargs + correction = n.kwargs.get("correction", None) + keepdim = n.kwargs.get("keepdim", False) + ddof = int(correction) if correction is not None else 1 + + out = P.make_or_get_slot(n) + P.emit( + VarNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axes=axes, + keepdims=keepdim, + ddof=ddof, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.std.correction]) +def _std_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.std - standard deviation of elements along axes.""" + args = P.args(n) + require_args(args, 1, 2, "aten.std") + require_kwargs(P.kwargs(n), {"correction", "keepdim"}, "aten.std") + x = args[0] + axes, _ = normalize_reduction_dim(args) + + correction = n.kwargs.get("correction", None) + keepdim = n.kwargs.get("keepdim", False) + ddof = int(correction) if correction is not None else 1 + + out = P.make_or_get_slot(n) + P.emit( + StdNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axes=axes, + keepdims=keepdim, + ddof=ddof, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.max.default]) +def _max_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.max.default - global max (reduce all axes).""" + args = P.args(n) + require_args(args, 1, 1, "aten.max") + require_kwargs(P.kwargs(n), set(), "aten.max") + x = args[0] + + out = P.make_or_get_slot(n) + P.emit(MaxNode(x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=[], keepdims=False)) + return out + + +@REGISTRY.register(target=[torch.ops.aten.min.default]) +def _min_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.min.default - global min (reduce all axes).""" + args = P.args(n) + require_args(args, 1, 1, "aten.min") + require_kwargs(P.kwargs(n), set(), "aten.min") + x = args[0] + + out = P.make_or_get_slot(n) + P.emit(MinNode(x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=[], keepdims=False)) + return out + + +@REGISTRY.register(target=[torch.ops.aten.argmax.default]) +def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.argmax - index of max element along axis.""" + args = P.args(n) + require_args(args, 1, 3, "aten.argmax") + require_kwargs(P.kwargs(n), set(), "aten.argmax") + x = args[0] + dim = args[1] if len(args) > 1 else None + keepdim = args[2] if len(args) > 2 else False + + out = P.make_or_get_slot(n) + + if dim is None: + # argmax without dim: flatten tensor to 1D, then argmax over axis 0 + # Result is a scalar index into the flattened tensor + _, flat_slot = P.make_tmp_slot() + + # Get total number of elements from input shape + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for argmax") + numel = x_meta.numel() + + P.emit( + ReshapeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(flat_slot), + shape=[P.to_int_or_vid(numel)], + ) + ) + P.emit( + ArgmaxNode( + x=P.slot_to_tid(flat_slot), + out=P.slot_to_tid(out), + axis=0, + keepdims=False, + ) + ) + else: + P.emit( + ArgmaxNode( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axis=dim, keepdims=keepdim + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.argmin.default]) +def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.argmin - index of min element along axis.""" + args = P.args(n) + require_args(args, 1, 3, "aten.argmin") + require_kwargs(P.kwargs(n), set(), "aten.argmin") + x = args[0] + dim = args[1] if len(args) > 1 else None + keepdim = args[2] if len(args) > 2 else False + + out = P.make_or_get_slot(n) + + if dim is None: + # argmin without dim: flatten tensor to 1D, then argmin over axis 0 + # Result is a scalar index into the flattened tensor + _, flat_slot = P.make_tmp_slot() + + # Get total number of elements from input shape + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for argmin") + numel = x_meta.numel() + + P.emit( + ReshapeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(flat_slot), + shape=[P.to_int_or_vid(numel)], + ) + ) + P.emit( + ArgminNode( + x=P.slot_to_tid(flat_slot), + out=P.slot_to_tid(out), + axis=0, + keepdims=False, + ) + ) + else: + P.emit( + ArgminNode( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axis=dim, keepdims=keepdim + ) + ) + return out + + +def _parse_pool_args(args, ndim, op_name, is_avg_pool=False): # noqa: C901 + """Parse pooling op arguments, normalizing scalars to lists. + + ATen pooling signatures: + max_pool{N}d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + avg_pool{N}d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + Extra args beyond (input, kernel_size, stride, padding) are accepted only + when they match safe defaults: + max_pool: dilation=1, ceil_mode=False + avg_pool: ceil_mode=False, count_include_pad=True, divisor_override=None + + Returns (kernel_size, stride, padding) as lists of length ndim. + """ + if is_avg_pool: + require_args(args, 2, 7, op_name) + # args[4] = ceil_mode (must be False) + if len(args) > 4 and args[4]: + raise ValueError(f"{op_name}: ceil_mode=True is not supported.") + # args[5] = count_include_pad (must be True) + if len(args) > 5 and not args[5]: + raise ValueError(f"{op_name}: count_include_pad=False is not supported.") + # args[6] = divisor_override (must be None) + if len(args) > 6 and args[6] is not None: + raise ValueError(f"{op_name}: divisor_override is not supported.") + else: + require_args(args, 2, 6, op_name) + # args[4] = dilation (must be 1) + if len(args) > 4: + dilation = args[4] + if isinstance(dilation, list): + if any(d != 1 for d in dilation): + raise ValueError( + f"{op_name}: dilation != 1 is not supported, got {dilation}." + ) + elif dilation != 1: + raise ValueError( + f"{op_name}: dilation != 1 is not supported, got {dilation}." + ) + # args[5] = ceil_mode (must be False) + if len(args) > 5 and args[5]: + raise ValueError(f"{op_name}: ceil_mode=True is not supported.") + + kernel_size = args[1] + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * ndim + + stride = args[2] if len(args) > 2 and args[2] else kernel_size + if isinstance(stride, int): + stride = [stride] * ndim + if not stride: # empty list means default to kernel_size + stride = list(kernel_size) + + padding = args[3] if len(args) > 3 else [0] * ndim + if isinstance(padding, int): + padding = [padding] * ndim + + return list(kernel_size), list(stride), list(padding) + + +def _emit_pool_nd( + P: MLXProgramBuilder, + n: Node, + ndim: int, + reduce_node_cls: type, + padding_value: float, + kernel_size: List[int], + stride: List[int], + padding: List[int], +) -> Slot: + """Emit IR nodes for N-dimensional pooling. + + Decomposes pooling into: + Transpose (channels-first -> channels-last) + -> Pad (if needed) + -> Reshape+Transpose (fast path) or AsStrided (general path) + -> Max/Mean reduction over kernel dims + -> Transpose (channels-last -> channels-first) + + Works for 1D, 2D, and 3D pooling uniformly. + + Args: + P: Program builder. + n: FX graph node for the pooling op. + ndim: Spatial dimensionality (1, 2, or 3). + reduce_node_cls: MaxNode or MeanNode. + padding_value: Padding fill value (-inf for max, 0 for avg). + kernel_size: Kernel size per spatial dim, length ndim. + stride: Stride per spatial dim, length ndim. + padding: Padding per spatial dim, length ndim. + + Returns: + Output Slot with shape [N, C, *out_spatial]. + """ + x_node = P.args(n)[0] + (x,) = P.slot_map([x_node]) + x_meta = n.args[0].meta["val"] + shape = list(x_meta.shape) # [N, C, *spatial] + + N = shape[0] + C = shape[1] + spatial = shape[2:] # length == ndim + + # 1. Transpose: channels-first [N, C, *spatial] -> channels-last [N, *spatial, C] + to_cl = [0] + list(range(2, ndim + 2)) + [1] + _, cur = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(cur), + perm=to_cl, + ) + ) + + # 2. Pad spatial dims if needed + spatial_padded = [s + 2 * p for s, p in zip(spatial, padding)] + if any(p > 0 for p in padding): + pad_width = [0, 0] # batch dim: no pad + for p in padding: + pad_width += [p, p] + pad_width += [0, 0] # channel dim: no pad + P.emit( + PadNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + pad_width=[P.to_int_or_vid(v) for v in pad_width], + mode="constant", + constant_value=padding_value, + ) + ) + + # 3. Sliding windows -> [N, *out_spatial, *kernel_size, C] + out_spatial = [ + (sp - k) // s + 1 for sp, k, s in zip(spatial_padded, kernel_size, stride) + ] + + can_fast_path = all( + k == s and sp % k == 0 for k, s, sp in zip(kernel_size, stride, spatial_padded) + ) + + if can_fast_path: + # Fast path: reshape + transpose (no AsStridedNode needed). + # [N, *spatial_padded, C] + # -> reshape [N, sp0//k0, k0, sp1//k1, k1, ..., C] + # -> transpose to gather output-spatial dims, then kernel dims, then C + reshape_shape = [N] + for sp, k in zip(spatial_padded, kernel_size): + reshape_shape += [sp // k, k] + reshape_shape += [C] + + P.emit( + ReshapeNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + shape=[IntOrVid.from_literal(d) for d in reshape_shape], + ) + ) + + # Transpose: gather output-spatial (odd indices), then kernel (even indices after batch) + # Reshaped tensor axes: [0=batch, 1=out0, 2=k0, 3=out1, 4=k1, ..., last=C] + last = 2 * ndim + 1 + out_spatial_axes = list(range(1, last, 2)) # [1, 3, 5, ...] + kernel_axes = list(range(2, last, 2)) # [2, 4, 6, ...] + perm = [0] + out_spatial_axes + kernel_axes + [last] + + P.emit( + TransposeNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + perm=perm, + ) + ) + else: + # General path: as_strided to create sliding window view. + # Input layout: [N, *spatial_padded, C] (channels-last, row-major) + dims = [N] + spatial_padded + [C] + elem_strides = [] + acc = 1 + for d in reversed(dims): + elem_strides.append(acc) + acc *= d + elem_strides.reverse() + + # as_strided shape: [N, *out_spatial, *kernel_size, C] + as_shape = [N] + out_spatial + kernel_size + [C] + + # as_strided strides: + # batch: elem_strides[0] + # out_spatial[i]: elem_strides[i+1] * stride[i] (skip by pool stride) + # kernel[i]: elem_strides[i+1] (consecutive rows/cols) + # channel: 1 + as_strides = [elem_strides[0]] + for i in range(ndim): + as_strides.append(elem_strides[i + 1] * stride[i]) + for i in range(ndim): + as_strides.append(elem_strides[i + 1]) + as_strides.append(1) + + P.emit( + AsStridedNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + shape=[IntOrVid.from_literal(d) for d in as_shape], + strides=[IntOrVid.from_literal(d) for d in as_strides], + offset=0, + ) + ) + + # 4. Reduce over kernel dims (axes [ndim+1 .. 2*ndim]) + reduce_axes = list(range(ndim + 1, 2 * ndim + 1)) + _, reduced = P.make_tmp_slot() + P.emit( + reduce_node_cls( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(reduced), + axes=reduce_axes, + keepdims=False, + ) + ) + + # 5. Transpose: channels-last [N, *out_spatial, C] -> channels-first [N, C, *out_spatial] + to_cf = [0, ndim + 1] + list(range(1, ndim + 1)) + output_slots = P.make_or_get_slots(n) + out = output_slots[0] + P.emit( + TransposeNode( + x=P.slot_to_tid(reduced), + out=P.slot_to_tid(out), + perm=to_cf, + ) + ) + return out + + +_POOL_OPS: List[Tuple[Any, int, type, float, str, bool]] = [ + # (target, ndim, reduce_cls, pad_value, op_name, returns_indices) + ( + torch.ops.aten.max_pool1d.default, + 1, + MaxNode, + float("-inf"), + "aten.max_pool1d", + False, + ), + ( + torch.ops.aten.max_pool1d_with_indices.default, + 1, + MaxNode, + float("-inf"), + "aten.max_pool1d_with_indices", + True, + ), + ( + torch.ops.aten.max_pool2d_with_indices.default, + 2, + MaxNode, + float("-inf"), + "aten.max_pool2d_with_indices", + True, + ), + ( + torch.ops.aten.max_pool3d_with_indices.default, + 3, + MaxNode, + float("-inf"), + "aten.max_pool3d_with_indices", + True, + ), + (torch.ops.aten.avg_pool1d.default, 1, MeanNode, 0.0, "aten.avg_pool1d", False), + (torch.ops.aten.avg_pool2d.default, 2, MeanNode, 0.0, "aten.avg_pool2d", False), + (torch.ops.aten.avg_pool3d.default, 3, MeanNode, 0.0, "aten.avg_pool3d", False), +] + + +def _make_pool_handler( + ndim: int, + reduce_node_cls: type, + padding_value: float, + op_name: str, + returns_indices: bool, +): + """Create a handler for an N-dimensional pooling op.""" + + is_avg = reduce_node_cls is MeanNode + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + kernel_size, stride, padding = _parse_pool_args( + args, ndim, op_name, is_avg_pool=is_avg + ) + result = _emit_pool_nd( + P, n, ndim, reduce_node_cls, padding_value, kernel_size, stride, padding + ) + if not returns_indices: + return result + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven pool op)." + return handler + + +for _target, _ndim, _cls, _pad, _name, _indices in _POOL_OPS: + REGISTRY.register(target=[_target])( + _make_pool_handler(_ndim, _cls, _pad, _name, _indices) + ) + + +@REGISTRY.register(target=[torch.ops.torchao.dequantize_affine.default]) +def _dequantize_affine_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle standalone torchao.dequantize_affine (not fused with linear/embedding). + + MLX's dequantize always operates along the last axis. When the quantized + dimension is not last (e.g. Conv2d with block_size=[1,32,1,1]), we permute + the constant weight/scale/zero_point tensors at compile time so the + quantized dim becomes last, emit the DequantizeNode, then emit a + TransposeNode with the inverse permutation to restore the original layout. + """ + parsed = parse_dequant_node(n) + if parsed is None: + raise NotImplementedError( + f"dequantize_affine: unsupported quantization config at {n}" + ) + ( + qdata_node, + scale_node, + zero_point_node, + group_size, + bits, + out_dtype, + quantized_dim, + ) = parsed + + qdata_target, qdata = P.get_placeholder_target_and_tensor(qdata_node) + zero_point_target, zero_point = P.get_placeholder_target_and_tensor(zero_point_node) + scale_target, scale = P.get_placeholder_target_and_tensor(scale_node) + + if out_dtype is None: + out_dtype = scale_node.meta["val"].dtype + out_scalar_type = torch_dtype_to_scalar_type(out_dtype) + + ndim = qdata.ndim + needs_permute = quantized_dim != ndim - 1 + + if needs_permute: + perm = list(range(ndim)) + perm.remove(quantized_dim) + perm.append(quantized_dim) + qdata = qdata.permute(perm).contiguous() + scale = scale.permute(perm).contiguous() + zero_point = zero_point.permute(perm).contiguous() + + # to_mlx_qparams expects 2D tensors; flatten N-D to 2D for packing, + # then restore the (possibly permuted) leading dimensions afterward. + permuted_shape = qdata.shape + qdata_2d = qdata.reshape(-1, qdata.shape[-1]) + scale_2d = scale.reshape(-1, scale.shape[-1]) + zero_point_2d = zero_point.reshape(-1, zero_point.shape[-1]) + + Q, B = to_mlx_qparams(qdata_2d, scale_2d, zero_point_2d, bits) + + leading_dims = permuted_shape[:-1] + Q = Q.reshape(*leading_dims, Q.shape[-1]) + scale_nd = scale_2d.reshape(*leading_dims, scale_2d.shape[-1]) + if B is not None: + B = B.reshape(*leading_dims, B.shape[-1]) + + w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B) + scale_const = P.make_or_get_constant(f"{scale_target}_scale", scale_nd) + + if needs_permute: + _, dequant_tmp = P.make_tmp_slot() + else: + dequant_tmp = P.make_or_get_slot(n) + + P.emit( + DequantizeNode( + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_const), + out=P.slot_to_tid(dequant_tmp), + biases=P.slot_to_tid(biases), + group_size=group_size, + bits=bits, + mode="affine", + dtype=out_scalar_type, + ) + ) + + if needs_permute: + inv_perm = [0] * ndim + for i, p in enumerate(perm): + inv_perm[p] = i + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(dequant_tmp), + out=P.slot_to_tid(out), + perm=inv_perm, + ) + ) + else: + out = dequant_tmp + + return out + + +@REGISTRY.register(target=[torch.ops.aten.cumsum.default]) +def _cumsum_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.cumsum - cumulative sum along an axis.""" + args = P.args(n) + require_args(args, 2, 3, "aten.cumsum") + require_kwargs(P.kwargs(n), {"dtype"}, "aten.cumsum") + x = args[0] + dim = args[1] + + out = P.make_or_get_slot(n) + P.emit( + CumsumNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.stack.default]) +def _stack_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.stack - stack tensors along a new axis.""" + args = P.args(n) + require_args(args, 1, 2, "aten.stack") + require_kwargs(P.kwargs(n), set(), "aten.stack") + tensors_list = args[0] + dim = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + tensor_tids = [P.slot_to_tid(t) for t in tensors_list] + P.emit( + StackNode( + tensors=tensor_tids, + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.repeat_interleave.self_int]) +def _repeat_interleave_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.repeat_interleave - repeat each element along an axis.""" + args = P.args(n) + require_args(args, 2, 4, "aten.repeat_interleave") + require_kwargs(P.kwargs(n), {"output_size"}, "aten.repeat_interleave") + x = args[0] + repeats = args[1] + dim = args[2] if len(args) > 2 else 0 + + out = P.make_or_get_slot(n) + P.emit( + RepeatNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + repeats=P.to_int_or_vid(repeats), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.sort.default]) +def _sort_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.sort - sort elements along an axis. + + Returns (values, indices) as a tuple of output slots. + """ + args = P.args(n) + require_args(args, 1, 3, "aten.sort") + require_kwargs(P.kwargs(n), set(), "aten.sort") + x = args[0] + dim = args[1] if len(args) > 1 else -1 + + # torch.sort returns (values, indices) - 2 outputs + output_slots = P.make_or_get_slots(n) + values_slot, indices_slot = output_slots + + used = used_getitem_indices(n) + + if 0 in used: + P.emit( + SortNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(values_slot), + axis=dim, + ) + ) + if 1 in used: + P.emit( + ArgsortNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(indices_slot), + axis=dim, + ) + ) + + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.argsort.default]) +def _argsort_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.argsort - indices that sort elements along an axis.""" + args = P.args(n) + require_args(args, 1, 3, "aten.argsort") + require_kwargs(P.kwargs(n), set(), "aten.argsort") + x = args[0] + dim = args[1] if len(args) > 1 else -1 + + out = P.make_or_get_slot(n) + P.emit( + ArgsortNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.topk.default]) +def _topk_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.topk - top-k elements along an axis. + + Decomposes into: partition → slice → sort → reverse (for values) + argpartition → slice → gather → argsort → reverse → reorder (for indices) + + torch.topk returns (values, indices) sorted descending. + """ + args = P.args(n) + require_args(args, 2, 5, "aten.topk") + require_kwargs(P.kwargs(n), set(), "aten.topk") + x = args[0] + k = args[1] + dim = args[2] if len(args) > 2 else -1 + + output_slots = P.make_or_get_slots(n) + values_slot, indices_slot = output_slots + + used = used_getitem_indices(n) + + # Get dim size from input metadata for forward slice stop + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for topk") + norm_axis = dim if dim >= 0 else dim + len(x_meta.shape) + dim_size = x_meta.shape[norm_axis] + + # Compute -k for partition index and forward slice start + if isinstance(k, int): + neg_k = P.to_int_or_vid(-k) + # Reverse slice: start=k-1, stop=-(k+1) on the k-sized sliced tensor + rev_start = P.to_int_or_vid(k - 1) + rev_stop = P.to_int_or_vid(-(k + 1)) + else: + # k is dynamic — emit neg_k = k * -1 at runtime + _, neg_k_slot = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=P.to_int_or_vid(k), + b=IntOrVid.from_literal(-1), + out=P.slot_to_vid(neg_k_slot), + ) + ) + neg_k = P.to_int_or_vid(neg_k_slot) + # rev_start = k - 1 + _, rev_start_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(k), + b=IntOrVid.from_literal(-1), + out=P.slot_to_vid(rev_start_slot), + ) + ) + rev_start = P.to_int_or_vid(rev_start_slot) + # rev_stop = -(k + 1) = neg_k - 1 + _, rev_stop_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=neg_k, + b=IntOrVid.from_literal(-1), + out=P.slot_to_vid(rev_stop_slot), + ) + ) + rev_stop = P.to_int_or_vid(rev_stop_slot) + + stop_val = P.to_int_or_vid(dim_size) + + def emit_partition_and_slice(node_cls): + """Emit partition/argpartition → slice last k elements.""" + _, part_tmp = P.make_tmp_slot() + P.emit( + node_cls( + x=P.slot_to_tid(x), + out=P.slot_to_tid(part_tmp), + kth=neg_k, + axis=dim, + ) + ) + _, slice_tmp = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(part_tmp), + out=P.slot_to_tid(slice_tmp), + axis=P.to_int_or_vid(dim), + start=neg_k, + stop=stop_val, + step=1, + ) + ) + return slice_tmp + + def emit_reverse(in_slot, out_slot): + """Reverse a tensor along dim using slice with step=-1.""" + P.emit( + SliceNode( + x=P.slot_to_tid(in_slot), + out=P.slot_to_tid(out_slot), + axis=P.to_int_or_vid(dim), + start=rev_start, + stop=rev_stop, + step=-1, + ) + ) + + if 0 in used: + # partition → slice last k → sort ascending → reverse to descending + slice_tmp = emit_partition_and_slice(PartitionNode) + _, sort_tmp = P.make_tmp_slot() + P.emit( + SortNode( + x=P.slot_to_tid(slice_tmp), + out=P.slot_to_tid(sort_tmp), + axis=dim, + ) + ) + emit_reverse(sort_tmp, values_slot) + + if 1 in used: + # argpartition → slice last k → gather values → argsort → reverse → reorder + idx_slice_tmp = emit_partition_and_slice(ArgPartitionNode) + # Gather original values at the partitioned indices + _, gathered_tmp = P.make_tmp_slot() + P.emit( + TakeAlongAxisNode( + x=P.slot_to_tid(x), + indices=P.slot_to_tid(idx_slice_tmp), + out=P.slot_to_tid(gathered_tmp), + axis=dim, + ) + ) + # Argsort gathered values ascending → reverse → descending order + _, order_tmp = P.make_tmp_slot() + P.emit( + ArgsortNode( + x=P.slot_to_tid(gathered_tmp), + out=P.slot_to_tid(order_tmp), + axis=dim, + ) + ) + _, rev_order_tmp = P.make_tmp_slot() + emit_reverse(order_tmp, rev_order_tmp) + # Apply descending order to indices + P.emit( + TakeAlongAxisNode( + x=P.slot_to_tid(idx_slice_tmp), + indices=P.slot_to_tid(rev_order_tmp), + out=P.slot_to_tid(indices_slot), + axis=dim, + ) + ) + + return output_slots diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py index c8bef1f91ca..908fa52b448 100644 --- a/backends/mlx/patterns.py +++ b/backends/mlx/patterns.py @@ -12,3 +12,942 @@ This module contains pattern handlers that match multi-node subgraphs and lower them to optimized MLX operations. """ + +from __future__ import annotations + +from typing import Any, List, Optional, Tuple + +import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_stop_position, + parse_dequant_node, + to_mlx_qparams, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.pattern_utils import ( + has_single_user, + match_target, + OpStep, + walk_back, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + DequantizeNode, + IndexCopyNode, + IntOrVid, + IntOrVidOrTid, + ModIntNode, + QuantizedLinearNode, + SdpaNode, + SliceNode, + SliceUpdateNode, + SubtractIntNode, + SymSizeNode, + TakeNode, +) +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node + +# When True, always serialize the biases tensor for quantized ops (existing behavior). +# When False, use scale_only=True optimization when zero_point is all zeros, +# which avoids serializing the biases tensor (C++ runtime computes: biases = -scales * 2^(bits-1)). +QUANTIZED_SERIALIZE_BIASES = True + + +@REGISTRY.register_pattern(name="INDEX_COPY") +class IndexCopyHandler(PatternHandler): + """ + Pattern for index-based updates on mutable buffers. + """ + + def __init__( + self, + head: Node, + body: List[Node], + dst: Node, + update: Node, + indices: Node, + axis: int, + ): + super().__init__(head, body) + self.dst = dst + self.update = update + self.indices = indices + self.axis = axis + + @classmethod + def maybe_create( # noqa: C901 + cls, ep: ExportedProgram, head: Node + ) -> Optional["IndexCopyHandler"]: + index_copy_node = head + if not match_target(index_copy_node, torch.ops.aten.index_copy.default): + return None + + # index_copy should write to a mutable input/buffer to be an index update. + if (index_copy_node.name not in ep.graph_signature.buffers_to_mutate) and ( + index_copy_node.name not in ep.graph_signature.user_inputs_to_mutate + ): + return None + + # index_copy(dst, axis, indices, update) + if len(index_copy_node.args) != 4: + return None + dst, axis, indices, update = index_copy_node.args + + # axis must be a literal int + if not isinstance(axis, int): + return None + + return cls( + head=index_copy_node, + body=[], + dst=dst, + update=update, + indices=indices, + axis=axis, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + dst, update, indices = P.slot_map([self.dst, self.update, self.indices]) + + P.emit( + IndexCopyNode( + dst=P.slot_to_tid(dst), + update=P.slot_to_tid(update), + indices=P.slot_to_tid(indices), + out=P.slot_to_tid(dst), + axis=self.axis, + ) + ) + + P.set_slot(n, dst) + return dst + + +@REGISTRY.register_pattern(name="ET_KV_CACHE_UPDATE") +class ETKVCacheUpdateHandler(PatternHandler): + """ + Pattern for KV cache updates using torch.ops.mlx.kv_cache_update. + + Matches: auto_functionalized → getitem[1] + HEAD = getitem[1] (no alias_copy required) + + Graph structure: + auto_func = auto_functionalized_v2(mlx.kv_cache_update, new_values=k_val, ...) + getitem_1 = getitem(auto_func, 1) # HEAD - updated cache + """ + + def __init__( + self, + head: Node, + body: List[Node], + cache: Node, + update: Node, + start_pos: Any, + ring_size: int = 0, + ): + super().__init__(head, body) + self.cache = cache + self.update = update + self.start_pos = start_pos + self.ring_size = ring_size + + @staticmethod + def _is_auto_func_et_kv_cache_update(node: Node) -> bool: + """Check if a node is auto_functionalized_v2 wrapping mlx.kv_cache_update.""" + if node.op != "call_function": + return False + target_str = str(node.target) + if "auto_functionalized" not in target_str: + return False + if len(node.args) < 1: + return False + func_arg = node.args[0] + func_str = str(func_arg) if func_arg else "" + return "kv_cache_update" in func_str and "mlx" in func_str + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["ETKVCacheUpdateHandler"]: + """ + Match the ET_KV_CACHE_UPDATE pattern. + + Pattern (HEAD = getitem): + auto_func = auto_functionalized_v2(mlx.kv_cache_update, ...) + getitem_1 = getitem(auto_func, 1) # HEAD + """ + + # HEAD must be getitem with idx=1 + if head.op != "call_function" or "getitem" not in str(head.target): + return None + + if len(head.args) < 2 or head.args[1] != 1: + return None + + # getitem's source should be auto_functionalized_v2 wrapping mlx.kv_cache_update + if not isinstance(head.args[0], Node): + return None + + auto_func_node = head.args[0] + if not cls._is_auto_func_et_kv_cache_update(auto_func_node): + return None + + # Extract info from auto_functionalized_v2 kwargs + kwargs = auto_func_node.kwargs + new_values_node = kwargs.get("new_values") + start_pos_node = kwargs.get("start_pos") + all_bases = kwargs.get("_all_bases", []) + + if not new_values_node or not all_bases: + return None + + cache_node = all_bases[0] + + body = [auto_func_node] + + return cls( + head=head, + body=body, + cache=cache_node, + update=new_values_node, + start_pos=start_pos_node, + ring_size=kwargs.get("ring_size", 0), + ) + + def __call__(self, P: "MLXProgramBuilder", n: Node) -> Slot: + assert n == self.head + + cache_slot, update_slot, start_slot = P.slot_map( + [self.cache, self.update, self.start_pos] + ) + + if self.ring_size > 0: + self._emit_ring_buffer(P, cache_slot, update_slot, start_slot) + else: + self._emit_linear(P, cache_slot, update_slot, start_slot) + + P.set_slot(n, cache_slot) + return cache_slot + + def _emit_linear(self, P: "MLXProgramBuilder", cache_slot, update_slot, start_slot): + """Emit a single SliceUpdate for linear (non-ring) cache.""" + update_meta = self.update.meta.get("val") + stop_slot = emit_stop_position( + P, + start=start_slot, + length_tensor=update_slot, + length_dim=2, # S_step is dim 2 in [B, H, S_step, D] + length_meta=update_meta, + ) + + # This updates cache[:, :, start:stop, :] = update + # SliceUpdateNode on axis=2 + # cache is [B, H, S, D], update is [B, H, S_step, D] + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(cache_slot), + update=P.slot_to_tid(update_slot), + out=P.slot_to_tid(cache_slot), + axis=IntOrVid.from_literal(2), # S dimension in [B, H, S, D] + start=P.to_int_or_vid(start_slot), + stop=P.to_int_or_vid(stop_slot), + ) + ) + + def _emit_ring_buffer( + self, P: "MLXProgramBuilder", cache_slot, update_slot, start_slot + ): + """ + Emit two unconditional SliceUpdates for ring buffer wrapping. + + write_pos = start_pos % ring_size + first_len = ring_size - write_pos + first_chunk = update[:, :, :first_len, :] (Slice clamps to seq_len) + actual_first = first_chunk.shape[2] (min(first_len, seq_len)) + rest_chunk = update[:, :, actual_first:seq_len, :] + overflow = seq_len - actual_first + SliceUpdate(cache, first_chunk, write_pos, write_pos + actual_first) + SliceUpdate(cache, rest_chunk, 0, overflow) + + When no wrap: actual_first == seq_len, rest_chunk is zero-length, + second SliceUpdate is a no-op (guarded in exec_slice_update). + """ + ring_size = self.ring_size + + # write_pos = start_pos % ring_size + _, write_pos_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + ModIntNode( + a=P.to_int_or_vid(start_slot), + b=IntOrVid.from_literal(ring_size), + out=P.slot_to_vid(write_pos_slot), + ) + ) + + # seq_len = update.shape[2] + _, seq_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(update_slot), + dim=2, + out=P.slot_to_vid(seq_len_slot), + ) + ) + + # first_len = ring_size - write_pos (may be > seq_len) + _, first_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SubtractIntNode( + a=IntOrVid.from_literal(ring_size), + b=P.to_int_or_vid(write_pos_slot), + out=P.slot_to_vid(first_len_slot), + ) + ) + + # first_chunk = update[:, :, :first_len, :] (Slice clamps to seq_len) + _, first_chunk_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(update_slot), + out=P.slot_to_tid(first_chunk_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=P.to_int_or_vid(first_len_slot), + ) + ) + + # actual_first = first_chunk.shape[2] (= min(first_len, seq_len)) + _, actual_first_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(first_chunk_slot), + dim=2, + out=P.slot_to_vid(actual_first_slot), + ) + ) + + # rest_chunk = update[:, :, actual_first:seq_len, :] + _, rest_chunk_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(update_slot), + out=P.slot_to_tid(rest_chunk_slot), + axis=IntOrVid.from_literal(2), + start=P.to_int_or_vid(actual_first_slot), + stop=P.to_int_or_vid(seq_len_slot), + ) + ) + + # stop1 = write_pos + actual_first + _, stop1_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(write_pos_slot), + b=P.to_int_or_vid(actual_first_slot), + out=P.slot_to_vid(stop1_slot), + ) + ) + + # overflow = seq_len - actual_first + _, overflow_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SubtractIntNode( + a=P.to_int_or_vid(seq_len_slot), + b=P.to_int_or_vid(actual_first_slot), + out=P.slot_to_vid(overflow_slot), + ) + ) + + # SliceUpdate 1: cache[:, :, write_pos:stop1, :] = first_chunk + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(cache_slot), + update=P.slot_to_tid(first_chunk_slot), + out=P.slot_to_tid(cache_slot), + axis=IntOrVid.from_literal(2), + start=P.to_int_or_vid(write_pos_slot), + stop=P.to_int_or_vid(stop1_slot), + ) + ) + + # SliceUpdate 2: cache[:, :, 0:overflow, :] = rest_chunk + # Zero-length no-op when no wrap (overflow=0) + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(cache_slot), + update=P.slot_to_tid(rest_chunk_slot), + out=P.slot_to_tid(cache_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=P.to_int_or_vid(overflow_slot), + ) + ) + + +@REGISTRY.register_pattern(name="SDPA") +class SDPAHandler(PatternHandler): + """ + Pattern for Scaled Dot Product Attention with optional GQA. + + Matches: scaled_dot_product_attention + Optionally with repeat_interleave for grouped query attention. + """ + + def __init__( + self, + head: Node, + body: List[Node], + q_node: Node, + k_node: Node, + v_node: Node, + ): + super().__init__(head, body) + self.q_node = q_node + self.k_node = k_node + self.v_node = v_node + + @classmethod + def _parse_sdpa_args_and_kwargs(cls, sdpa_node: Node): + q, k, v = sdpa_node.args[0:3] + attn_mask = sdpa_node.args[3] if len(sdpa_node.args) > 3 else None + dropout_p = sdpa_node.args[4] if len(sdpa_node.args) > 4 else 0.0 + is_causal = sdpa_node.args[5] if len(sdpa_node.args) > 5 else False + enable_gqa = sdpa_node.args[6] if len(sdpa_node.args) > 6 else False + scale = sdpa_node.kwargs.get("scale", None) + return q, k, v, attn_mask, dropout_p, is_causal, scale, enable_gqa + + @classmethod + def _try_unwrap_repeat_kv(cls, node: Node) -> Optional[Tuple[Node, List[Node]]]: + """Try to unwrap a HuggingFace repeat_kv pattern. + + HuggingFace's repeat_kv expands KV heads for grouped query attention: + hidden_states[:, :, None, :, :].expand(B, n_kv, n_rep, T, D) + .clone().reshape(B, n_heads, T, D) + + In Edge IR this becomes: + unsqueeze_copy(x, 2) → expand_copy → clone → view_copy + + Returns: + (base_node, body_nodes) if pattern matches, else None. + base_node is the original [B, n_kv, T, D] tensor. + body_nodes are the intermediate nodes to absorb. + """ + result = walk_back( + node, + [ + OpStep(op=torch.ops.aten.view.default, nargs=2), + OpStep(op=torch.ops.aten.clone.default, optional=True), + OpStep(op=torch.ops.aten.expand.default, nargs=2), + OpStep(op=torch.ops.aten.unsqueeze.default, nargs=2), + ], + ) + if result is None: + return None + + base, entries = result + _view, _clone, _expand, unsqueeze = entries + + # unsqueeze must be on dim=2 + if unsqueeze.args[1] != 2: + return None + + body = [e for e in entries if e is not None] + return base, body + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node) -> Optional["SDPAHandler"]: + sdpa_node = head + if not match_target( + sdpa_node, torch.ops.aten.scaled_dot_product_attention.default + ): + return None + + q, k, v, _, _, _, _, _ = cls._parse_sdpa_args_and_kwargs(sdpa_node) + + # Detect grouped kv attention pattern with repeat_interleave before SDPA + is_grouped_kv = False + k_base = k + v_base = v + body: List[Node] = [] + if ( + match_target(k, torch.ops.aten.repeat_interleave.self_int) + and has_single_user(k) + and (len(k.args) == 3) + and (len(k.kwargs) == 0) + and match_target(v, torch.ops.aten.repeat_interleave.self_int) + and has_single_user(v) + and (len(v.args) == 3) + and (len(v.kwargs) == 0) + ): + k_unrepeated, k_reps, k_dim = k.args + v_unrepeated, v_reps, v_dim = v.args + + if (k_dim == 1 and v_dim == 1) and (k_reps == v_reps): + is_grouped_kv = True + k_base = k_unrepeated + v_base = v_unrepeated + body = [k, v] + + # Detect HuggingFace repeat_kv pattern: + # unsqueeze(dim=2) → expand → clone → view + if not is_grouped_kv: + k_unwrap = cls._try_unwrap_repeat_kv(k) + v_unwrap = cls._try_unwrap_repeat_kv(v) + if k_unwrap is not None and v_unwrap is not None: + k_base, k_body = k_unwrap + v_base, v_body = v_unwrap + is_grouped_kv = True + body = k_body + v_body + + head = sdpa_node + if not is_grouped_kv: + body = [] + return SDPAHandler(head, body, q_node=q, k_node=k_base, v_node=v_base) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + q, k, v, attn_mask, dropout_p, is_causal, scale, enable_gqa = ( + SDPAHandler._parse_sdpa_args_and_kwargs(n) + ) + head_dim = q.meta["val"].shape[-1] + if scale is None: + scale = head_dim**-0.5 + + q = self.q_node + k = self.k_node + v = self.v_node + + assert dropout_p == 0.0, "SDPA with dropout is not supported" + + q, k, v, attn_mask = P.slot_map([q, k, v, attn_mask]) + out = P.make_or_get_slot(n) + P.emit( + SdpaNode( + q=P.slot_to_tid(q), + k=P.slot_to_tid(k), + v=P.slot_to_tid(v), + out=P.slot_to_tid(out), + scale=scale, + mask=P.slot_to_tid(attn_mask) if attn_mask else None, + causal=is_causal, + ) + ) + return out + + +@REGISTRY.register_pattern(name="MLX_CUSTOM_SDPA") +class MLXCustomSdpaHandler(PatternHandler): + """ + Pattern handler for mlx::custom_sdpa custom op. + + This op follows the optimum-executorch pattern: + - Input: Q, K, V in BHSD format [B, num_heads, seq_len, head_dim] + - start_pos: FIRST position of current query batch (not last!) + - stop_pos: computed as start_pos + query_seq_len + - K/V are FULL cache, sliced internally to [:, :, :stop_pos, :] + + For prefill with 7 tokens at positions [0,1,2,3,4,5,6]: start_pos=0, stop_pos=7 + For decode at position 10: start_pos=10, stop_pos=11 + + Decomposes into: + - SliceNode (K): slice to [:, :, :stop_pos, :] + - SliceNode (V): slice to [:, :, :stop_pos, :] + - SdpaNode: scaled dot-product attention (handles GQA internally) + """ + + def __init__( + self, + head: Node, + body: List[Node], + query: Node, + key: Node, + value: Node, + start_pos: Any, # int or Node (SymInt) + attn_mask: Optional[Node], + scale: Optional[float], + is_causal: bool, + ): + super().__init__(head, body) + self.query = query + self.key = key + self.value = value + self.start_pos = start_pos + self.attn_mask = attn_mask + self.scale = scale + self.is_causal = is_causal + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["MLXCustomSdpaHandler"]: + """Match the mlx::custom_sdpa custom op.""" + if head.op != "call_function": + return None + + target_str = str(head.target) + if "custom_sdpa" not in target_str or "mlx" not in target_str: + return None + + # Op signature: custom_sdpa(query, key, value, start_pos, attn_mask, dropout_p, is_causal, scale) + # start_pos is a SymInt (int), not a Tensor + args = head.args + kwargs = head.kwargs + + if len(args) < 4: + return None + + query = args[0] + key = args[1] + value = args[2] + start_pos = args[3] # int or SymInt (Node) + + # Get optional args + attn_mask = args[4] if len(args) > 4 else kwargs.get("attn_mask", None) + dropout_p = args[5] if len(args) > 5 else kwargs.get("dropout_p", 0.0) + is_causal = args[6] if len(args) > 6 else kwargs.get("is_causal", False) + scale = args[7] if len(args) > 7 else kwargs.get("scale", None) + + if dropout_p != 0.0: + return None + + return MLXCustomSdpaHandler( + head=head, + body=[], + query=query, + key=key, + value=value, + start_pos=start_pos, + attn_mask=attn_mask, + scale=scale, + is_causal=is_causal, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + SdpaNode, + SliceNode, + ) + + assert n == self.head + + # Get slots for Q, K, V + q_slot, k_slot, v_slot = P.slot_map([self.query, self.key, self.value]) + + # Get scale from metadata if not provided + q_meta = self.query.meta.get("val") + head_dim = q_meta.shape[-1] + scale = self.scale if self.scale is not None else head_dim**-0.5 + + # Resolve start_pos to int or Slot (same pattern as KVCacheUpdateHandler) + if isinstance(self.start_pos, Node): + start_slot = P.slot_map([self.start_pos])[0] + else: + start_slot = self.start_pos + + # Compute stop = start_pos + seq_len using emit_stop_position, + # which handles static/dynamic seq_len (SymInt) and start_pos correctly. + # BHSD layout: q is [B, num_heads, seq_len, head_dim], seq_len is dim 2. + stop = emit_stop_position( + P, + start=start_slot, + length_tensor=q_slot, + length_dim=2, + length_meta=q_meta, + ) + slice_stop = P.to_int_or_vid(stop) + + # Step 1: Slice K to [:, :, :stop_pos, :] where stop_pos = start_pos + query_seq_len + _, k_sliced_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(k_slot), + out=P.slot_to_tid(k_sliced_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=slice_stop, + ) + ) + + # Step 2: Slice V to [:, :, :stop_pos, :] where stop_pos = start_pos + query_seq_len + _, v_sliced_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(v_slot), + out=P.slot_to_tid(v_sliced_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=slice_stop, + ) + ) + + # Step 3: SDPA (handles GQA internally) - outputs BHSD + out_slot = P.make_or_get_slot(n) + P.emit( + SdpaNode( + q=P.slot_to_tid(q_slot), + k=P.slot_to_tid(k_sliced_slot), + v=P.slot_to_tid(v_sliced_slot), + out=P.slot_to_tid(out_slot), + scale=scale, + mask=( + P.slot_to_tid(P.slot_map([self.attn_mask])[0]) + if self.attn_mask is not None + else None + ), + causal=self.is_causal, + ) + ) + + return out_slot + + +@REGISTRY.register_pattern(name="QUANTIZED_LINEAR") +class QuantizedLinearHandler(PatternHandler): + """ + Pattern for quantized linear: dequantize_affine + linear. + """ + + def __init__( + self, + head: Node, + body: List[Node], + qdata: Node, + scale: Node, + zero_point: Node, + group_size: int, + bits: int, + out_dtype: torch.dtype, + ): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.bits = bits + self.out_dtype = out_dtype + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["QuantizedLinearHandler"]: + linear_node = head + if not match_target(linear_node, torch.ops.aten.linear.default): + return None + + x, w = linear_node.args[0:2] + dequant_node = w + if not match_target(dequant_node, torch.ops.torchao.dequantize_affine.default): + return None + if not has_single_user(dequant_node): + return None + + parsed = parse_dequant_node(dequant_node) + if parsed is None: + return None + qdata, scale, zero_point, group_size, bits, out_dtype, _quantized_dim = parsed + out_dtype = x.meta["val"].dtype if out_dtype is None else out_dtype + + head = linear_node + body = [dequant_node] + return QuantizedLinearHandler( + head, + body, + qdata=qdata, + scale=scale, + zero_point=zero_point, + group_size=group_size, + bits=bits, + out_dtype=out_dtype, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + + x, w = n.args[0:2] + b = n.args[2] if len(n.args) > 2 else None + + qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) + zero_point_target, zero_point = P.get_placeholder_target_and_tensor( + self.zero_point + ) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) + + # Check if we can use scale_only optimization: + # When zero_point is all zeros, biases = -scales * 2^(bits-1) + # which can be computed at runtime instead of serialized. + # Note: During partitioning, tensors are FakeTensors so we skip the check. + # The optimization is only applied during preprocess when we have real tensors. + use_scale_only = False + if not QUANTIZED_SERIALIZE_BIASES: + from torch._subclasses.fake_tensor import FakeTensor + + if not isinstance(zero_point, FakeTensor): + if torch.sum(torch.abs(zero_point)).item() == 0: + use_scale_only = True + + Q, B = to_mlx_qparams( + qdata, scale, zero_point, self.bits, compute_biases=not use_scale_only + ) + w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + + if use_scale_only: + biases_tid = None + else: + biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B) + biases_tid = P.slot_to_tid(biases) + + x, scale_slot, b = P.slot_map([x, self.scale, b]) + out = P.make_or_get_slot(n) + P.emit( + QuantizedLinearNode( + x=P.slot_to_tid(x), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_slot), + out=P.slot_to_tid(out), + biases=biases_tid, + bias=P.slot_to_tid(b) if b else None, + group_size=self.group_size, + bits=self.bits, + mode="affine", + out_scalar_type=out_scalar_type, + scale_only=use_scale_only, + ) + ) + return out + + +@REGISTRY.register_pattern(name="QUANTIZED_EMBEDDING") +class QuantizedEmbeddingHandler(PatternHandler): + """ + Pattern for quantized embedding: dequantize_affine + embedding. + """ + + def __init__( + self, + head: Node, + body: List[Node], + qdata: Node, + scale: Node, + zero_point: Node, + group_size: int, + bits: int, + out_dtype: torch.dtype, + ): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.bits = bits + self.out_dtype = out_dtype + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["QuantizedEmbeddingHandler"]: + embedding_node = head + if not match_target(embedding_node, torch.ops.aten.embedding.default): + return None + + w, x = embedding_node.args[0:2] + + dequant_node = w + if not match_target(dequant_node, torch.ops.torchao.dequantize_affine.default): + return None + if not has_single_user(dequant_node): + return None + + parsed = parse_dequant_node(dequant_node) + if parsed is None: + return None + qdata, scale, zero_point, group_size, bits, out_dtype, _quantized_dim = parsed + out_dtype = scale.meta["val"].dtype if out_dtype is None else out_dtype + + head = embedding_node + body = [dequant_node] + return QuantizedEmbeddingHandler( + head, + body, + qdata=qdata, + scale=scale, + zero_point=zero_point, + group_size=group_size, + bits=bits, + out_dtype=out_dtype, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + w, x = n.args[0:2] + + qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) + zero_point_target, zero_point = P.get_placeholder_target_and_tensor( + self.zero_point + ) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + Q, B = to_mlx_qparams(qdata, scale, zero_point, self.bits) + out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) + + w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B) + + x, scale_slot = P.slot_map([x, self.scale]) + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) + + # Gather quantized weights by ids + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(w), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + + # Gather scales by ids + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scale_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + + # Gather biases by ids + _, b_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(biases), + index=ids_index, + out=P.slot_to_tid(b_sel), + axis=0, + ) + ) + + # Dequantize the gathered slices + out = P.make_or_get_slot(n) + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=P.slot_to_tid(b_sel), + group_size=self.group_size, + bits=self.bits, + mode="affine", + out_scalar_type=out_scalar_type, + ) + ) + return out diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index bfd593c162b..f6aabd9af8f 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -96,28 +96,1504 @@ inline std::vector infer_shape_with_minus_one( return resolved_shape; } +// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) +inline array gelu_tanh_impl(const array& x, StreamOrDevice s = {}) { + constexpr float sqrt_2_over_pi = 0.7978845608f; + auto dtype = x.dtype(); + + auto x3 = multiply(x, multiply(x, x, s), s); + auto term = multiply(array(0.044715f, dtype), x3, s); + auto inner = add(x, term, s); + inner = multiply(array(sqrt_2_over_pi, dtype), inner, s); + auto tanh_val = tanh(inner, s); + auto one_plus_tanh = add(array(1.0f, dtype), tanh_val, s); + auto out = multiply(x, one_plus_tanh, s); + out = multiply(array(0.5f, dtype), out, s); + return out; +} + +// Formula: 0.5 * x * (1 + erf(x / sqrt(2))) +inline array gelu_none_impl(const array& x, StreamOrDevice s = {}) { + constexpr float inv_sqrt_2 = 0.7071067812f; + auto dtype = x.dtype(); + + auto scaled = multiply(array(inv_sqrt_2, dtype), x, s); + auto erf_val = erf(scaled, s); + auto one_plus_erf = add(array(1.0f, dtype), erf_val, s); + auto out = multiply(x, one_plus_erf, s); + out = multiply(array(0.5f, dtype), out, s); + return out; +} + inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} inline void -exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { - st.set_tensor(n.out, st.const_tensor_ref(n.x)); +exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { + st.set_tensor(n.out, st.const_tensor_ref(n.x)); +} + +inline void +exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + + st.set_tensor(n.out, std::move(Y)); +} + +inline void +exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& X = st.const_tensor_ref(n.x); + auto W = st.const_tensor_ref(n.weight); + W = transpose(W, {1, 0}, s); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + X, + W, + /*alpha=*/1.0f, + /*beta=*/1.0f, + s) + : matmul(X, W, s); + + st.set_tensor(n.out, std::move(Y)); +} + +inline void +exec_item_int(const ItemIntNode& n, ExecutionState& st, StreamOrDevice) { + // Intentional sync: item() requires a concrete scalar value for SymInt + // shape computation, so we must force GPU evaluation here. + auto x = st.const_tensor_ref(n.x); + eval(x); + int item = x.item(); + st.set_value(n.out, item); +} + +inline void exec_expand_dims( + const ExpandDimsNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor(n.out, expand_dims(st.const_tensor_ref(n.x), n.axis, s)); +} + +inline void exec_tile(const TileNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto reps = resolve_ints(n.reps, st); + st.set_tensor(n.out, tile(x, reps, s)); +} + +inline void exec_take_along_axis( + const TakeAlongAxisNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor( + n.out, + take_along_axis( + st.const_tensor_ref(n.x), st.const_tensor_ref(n.indices), n.axis, s)); +} + +inline void exec_take(const TakeNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int axis = normalize_axis(n.axis, static_cast(x.ndim()), "Take"); + switch (n.index.kind) { + case 0: { // literal int + int index = normalize_axis( + clamp_to_int32(n.index.literal), x.shape(axis), "Take"); + st.set_tensor(n.out, take(x, index, axis, s)); + break; + } + case 1: { // Vid (dynamic int) + int index = normalize_axis( + st.const_value_ref(n.index.vid), x.shape(axis), "Take"); + st.set_tensor(n.out, take(x, index, axis, s)); + break; + } + case 2: { // Tid (tensor of indices) + const auto& indices = st.const_tensor_ref(n.index.tid); + st.set_tensor(n.out, take(x, indices, axis, s)); + break; + } + default: + throw std::runtime_error( + "TakeNode: invalid index kind: " + std::to_string(n.index.kind)); + } +} + +inline void +exec_rms_norm(const RMSNormNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.weight); + st.set_tensor(n.out, fast::rms_norm(x, w, n.eps, s)); +} + +inline void +exec_layer_norm(const LayerNormNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + + std::optional w = std::nullopt; + if (n.weight) { + w = st.const_tensor_ref(*n.weight); + } + std::optional bias = std::nullopt; + if (n.bias) { + bias = st.const_tensor_ref(*n.bias); + } + st.set_tensor(n.out, fast::layer_norm(x, w, bias, n.eps, s)); +} + +inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) { + const array& x = st.const_tensor_ref(n.x); + + std::optional freqs_arr = std::nullopt; + if (n.freqs) { + freqs_arr = st.const_tensor_ref(*n.freqs); + } + + // MLX has two overloads: rope(..., int offset, ...) and rope(..., const + // array& offset, ...) Call the appropriate one based on is_vid + if (n.offset.is_vid) { + // Scalar offset from Vid + int offset = st.const_value_ref(n.offset.vid); + st.set_tensor( + n.out, + fast::rope( + x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s)); + } else { + // Tensor offset from Tid + const array& offset = st.const_tensor_ref(n.offset.tid); + st.set_tensor( + n.out, + fast::rope( + x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s)); + } +} + +inline void exec_sdpa(const SdpaNode& n, ExecutionState& st, StreamOrDevice s) { + array Q = st.const_tensor_ref(n.q); + array K = st.const_tensor_ref(n.k); + array V = st.const_tensor_ref(n.v); + + std::string mask_mode = ""; + std::optional mask_arr = std::nullopt; + std::optional sinks = std::nullopt; + + if (n.mask) { + array M = st.const_tensor_ref(*n.mask); + // MLX's SDPA handles bool masks natively (True=attend, False=masked) + // For non-bool masks, ensure dtype matches Q + if (M.dtype() != bool_ && M.dtype() != Q.dtype()) { + M = astype(M, Q.dtype(), s); + } + mask_arr = std::move(M); + } + if (n.causal) { + mask_mode = "causal"; + } + + array out = fast::scaled_dot_product_attention( + Q, K, V, static_cast(n.scale), mask_mode, mask_arr, sinks, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_add(const AddNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, add(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_add_int(const AddIntNode& n, ExecutionState& st, StreamOrDevice) { + int64_t a = resolve_int(n.a, st); + int64_t b = resolve_int(n.b, st); + int64_t result = a + b; + if (result > std::numeric_limits::max() || + result < std::numeric_limits::min()) { + throw std::runtime_error("add_int: overflow"); + } + st.set_value(n.out, static_cast(result)); +} + +inline void exec_subtract_int( + const SubtractIntNode& n, + ExecutionState& st, + StreamOrDevice) { + int64_t a = resolve_int(n.a, st); + int64_t b = resolve_int(n.b, st); + int64_t result = a - b; + if (result > std::numeric_limits::max() || + result < std::numeric_limits::min()) { + throw std::runtime_error("subtract_int: overflow"); + } + st.set_value(n.out, static_cast(result)); +} + +inline void exec_multiply_int( + const MultiplyIntNode& n, + ExecutionState& st, + StreamOrDevice) { + int64_t a = resolve_int(n.a, st); + int64_t b = resolve_int(n.b, st); + int64_t result = a * b; + if (result > std::numeric_limits::max() || + result < std::numeric_limits::min()) { + throw std::runtime_error("multiply_int: overflow"); + } + st.set_value(n.out, static_cast(result)); +} + +inline void exec_floor_divide_int( + const FloorDivideIntNode& n, + ExecutionState& st, + StreamOrDevice) { + int32_t a = resolve_int(n.a, st); + int32_t b = resolve_int(n.b, st); + if (b == 0) { + throw std::runtime_error("floor_divide_int: division by zero"); + } + if (a == std::numeric_limits::min() && b == -1) { + throw std::runtime_error("floor_divide_int: overflow (INT32_MIN / -1)"); + } + // Floor division for integers (Python semantics: rounds towards negative + // infinity) + int32_t result = a / b; + // Adjust for floor division when signs differ and there's a remainder + if ((a % b != 0) && ((a < 0) != (b < 0))) { + result -= 1; + } + st.set_value(n.out, result); +} + +inline void +exec_mod_int(const ModIntNode& n, ExecutionState& st, StreamOrDevice) { + int32_t a = resolve_int(n.a, st); + int32_t b = resolve_int(n.b, st); + if (b == 0) { + throw std::runtime_error("mod_int: division by zero"); + } + // Python modulo semantics: result has same sign as divisor + int32_t result = a % b; + if ((result != 0) && ((result < 0) != (b < 0))) { + result += b; + } + st.set_value(n.out, result); +} + +inline void +exec_sym_size(const SymSizeNode& n, ExecutionState& st, StreamOrDevice) { + const array& a = st.const_tensor_ref(n.a); + int rank = static_cast(a.ndim()); + int dim = n.dim; + if (dim < 0) { + dim += rank; + } + if (dim < 0 || dim >= rank) { + throw std::out_of_range("SYM_SIZE: dim out of range"); + } + int32_t size = static_cast(a.shape()[static_cast(dim)]); + st.set_value(n.out, size); +} + +inline void +exec_multiply(const MultiplyNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, multiply(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_divide(const DivideNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, divide(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_subtract(const SubtractNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, subtract(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_conv1d(const Conv1DNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + auto out = conv1d(x, w, n.stride, n.padding, n.dilation, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void +exec_conv2d(const Conv2DNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::pair stride = {n.stride_h, n.stride_w}; + std::pair padding = {n.padding_h, n.padding_w}; + std::pair dilation = {n.dilation_h, n.dilation_w}; + + auto out = conv2d(x, w, stride, padding, dilation, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void +exec_conv3d(const Conv3DNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::tuple stride = {n.stride_d, n.stride_h, n.stride_w}; + std::tuple padding = {n.padding_d, n.padding_h, n.padding_w}; + std::tuple dilation = { + n.dilation_d, n.dilation_h, n.dilation_w}; + + auto out = conv3d(x, w, stride, padding, dilation, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_conv_transpose1d( + const ConvTranspose1DNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + auto out = conv_transpose1d( + x, w, n.stride, n.padding, n.dilation, n.output_padding, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_conv_transpose2d( + const ConvTranspose2DNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::pair stride = {n.stride_h, n.stride_w}; + std::pair padding = {n.padding_h, n.padding_w}; + std::pair dilation = {n.dilation_h, n.dilation_w}; + std::pair output_padding = {n.output_padding_h, n.output_padding_w}; + + auto out = conv_transpose2d( + x, w, stride, padding, dilation, output_padding, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_conv_transpose3d( + const ConvTranspose3DNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::tuple stride = {n.stride_d, n.stride_h, n.stride_w}; + std::tuple padding = {n.padding_d, n.padding_h, n.padding_w}; + std::tuple dilation = { + n.dilation_d, n.dilation_h, n.dilation_w}; + std::tuple output_padding = { + n.output_padding_d, n.output_padding_h, n.output_padding_w}; + + auto out = conv_transpose3d( + x, w, stride, padding, dilation, output_padding, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_gelu(const GeluNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + if (n.approximate == "tanh") { + st.set_tensor(n.out, gelu_tanh_impl(x, s)); + } else { + // "none" or any other value uses exact GELU + st.set_tensor(n.out, gelu_none_impl(x, s)); + } +} + +inline void +exec_arange(const ARangeNode& n, ExecutionState& st, StreamOrDevice s) { + // Get start, stop, step - may be literal int64 or dynamic Vid + int start_val = resolve_int(n.start, st); + int stop_val = resolve_int(n.stop, st); + int step_val = resolve_int(n.step, st); + + if (step_val == 0) { + throw std::runtime_error("arange: step must not be zero"); + } + + // Bound the output size: numel = ceil((stop - start) / step) + int64_t range = static_cast(stop_val) - start_val; + int64_t numel = 0; + if ((range > 0 && step_val > 0) || (range < 0 && step_val < 0)) { + numel = (range / step_val) + (range % step_val != 0 ? 1 : 0); + } + auto dtype = n.scalar_type.has_value() ? resolve_dtype(n.scalar_type.value()) + : ::mlx::core::int32; + check_allocation_bounded( + {static_cast(std::min( + numel, static_cast(std::numeric_limits::max())))}, + dtype, + "arange"); + + if (n.scalar_type.has_value()) { + st.set_tensor(n.out, arange(start_val, stop_val, step_val, dtype, s)); + } else { + // No dtype specified - use MLX's default (infers from inputs). + // The bounds check above conservatively assumes int32. + st.set_tensor(n.out, arange(start_val, stop_val, step_val, s)); + } +} + +inline void exec_silu(const SiluNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, multiply(x, sigmoid(x, s), s)); +} + +inline void +exec_sigmoid(const SigmoidNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, sigmoid(x, s)); +} + +inline void exec_tanh(const TanhNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, tanh(x, s)); +} + +inline void +exec_squeeze(const SqueezeNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& dims_fb = n.dims; + + if (dims_fb.size() > 0) { + // Squeeze specific dimensions, filtering out non-size-1 dims to match + // PyTorch semantics where squeeze on a non-size-1 dim is a no-op. + std::vector dims; + for (auto d : dims_fb) { + int axis = d < 0 ? d + static_cast(x.ndim()) : d; + if (axis >= 0 && axis < static_cast(x.ndim()) && + x.shape(axis) == 1) { + dims.push_back(d); + } + } + if (dims.size() > 0) { + st.set_tensor(n.out, squeeze(x, dims, s)); + } else { + st.set_tensor(n.out, x); + } + } else { + // Squeeze all dimensions of size 1 + st.set_tensor(n.out, squeeze(x, s)); + } +} + +inline void +exec_split(const SplitNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + + // Resolve dynamic sizes to std::vector + std::vector sizes_vec = resolve_ints(n.sizes, st); + + // Get results based on split mode + auto outs_fb = n.outs; + + if (sizes_vec.size() == 1) { + // Single value means split_size (chunk size) + // Compute actual sizes: e.g., dim_size=10, split_size=3 -> [3, 3, 3, 1] + int split_size = sizes_vec[0]; + if (split_size <= 0) { + throw std::runtime_error( + "split: split_size must be positive, got " + + std::to_string(split_size)); + } + int axis = n.axis < 0 ? n.axis + static_cast(x.ndim()) : n.axis; + int dim_size = x.shape(axis); + + std::vector indices; + for (int pos = split_size; pos < dim_size; pos += split_size) { + indices.push_back(pos); + } + + auto results = split(x, to_shape(indices), n.axis, s); + if (results.size() != outs_fb.size()) { + throw std::runtime_error("Split: output count mismatch"); + } + for (size_t i = 0; i < results.size(); ++i) { + st.set_tensor(outs_fb[i], std::move(results[i])); + } + } else { + // Multiple sizes: convert to cumulative indices for MLX + // sizes=[10, 20, 30] -> indices=[10, 30] (split at positions 10 and 30) + std::vector indices; + indices.reserve(sizes_vec.size() - 1); + int64_t cumsum = 0; + for (size_t i = 0; i < sizes_vec.size() - 1; ++i) { + cumsum += static_cast(sizes_vec[i]); + if (cumsum > std::numeric_limits::max() || + cumsum < std::numeric_limits::min()) { + throw std::runtime_error("split: cumulative size overflow"); + } + indices.push_back(static_cast(cumsum)); + } + auto results = split(x, to_shape(indices), n.axis, s); + if (results.size() != outs_fb.size()) { + throw std::runtime_error("Split: output count mismatch"); + } + for (size_t i = 0; i < results.size(); ++i) { + st.set_tensor(outs_fb[i], std::move(results[i])); + } + } +} + +inline void +exec_rsqrt(const RsqrtNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, rsqrt(x, s)); +} + +inline void +exec_maximum(const MaximumNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, maximum(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_minimum(const MinimumNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, minimum(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_log(const LogNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, log(x, s)); +} + +inline void +exec_softmax(const SoftmaxNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, softmax(x, n.axis, n.precise, s)); +} + +inline void exec_broadcast_to( + const BroadcastToNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto shape_vec = resolve_ints(n.shape, st); + + // Replace -1 with actual input dimensions (PyTorch expand semantics: + // -1 means "keep this dimension unchanged from input"). + // Dimensions are aligned from the RIGHT (broadcast semantics). + const auto& x_shape = x.shape(); + int offset = + static_cast(shape_vec.size()) - static_cast(x_shape.size()); + for (size_t i = 0; i < shape_vec.size(); i++) { + if (shape_vec[i] == -1) { + int input_dim = static_cast(i) - offset; + if (input_dim >= 0 && input_dim < static_cast(x_shape.size())) { + shape_vec[i] = + static_cast(x_shape[static_cast(input_dim)]); + } + } + } + + st.set_tensor( + n.out, + broadcast_to( + x, ::mlx::core::Shape(shape_vec.begin(), shape_vec.end()), s)); +} + +inline void exec_pad(const PadNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + + // Convert flat pad_width to vector of pairs + std::vector> pad_width_pairs; + auto pad_width_resolved = resolve_ints(n.pad_width, st); + if (pad_width_resolved.size() % 2 != 0) { + throw std::runtime_error( + "pad: pad_width must have even length, got " + + std::to_string(pad_width_resolved.size())); + } + for (size_t i = 0; i < pad_width_resolved.size(); i += 2) { + pad_width_pairs.push_back( + {pad_width_resolved[i], pad_width_resolved[i + 1]}); + } + + // MLX pad signature: pad(array, pad_width, pad_value, mode, stream) + if (n.mode == "constant") { + array pad_value(n.constant_value); + st.set_tensor(n.out, pad(x, pad_width_pairs, pad_value, "constant", s)); + } else if (n.mode == "edge") { + array pad_value(0.0f); + st.set_tensor(n.out, pad(x, pad_width_pairs, pad_value, "edge", s)); + } else { + throw std::runtime_error("Unsupported pad mode: " + n.mode); + } +} + +inline void +exec_where(const WhereNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& condition = st.const_tensor_ref(n.condition); + const auto& x = st.const_tensor_ref(n.x); + const auto& y = st.const_tensor_ref(n.y); + st.set_tensor(n.out, where(condition, x, y, s)); +} + +inline void +exec_reshape(const ReshapeNode& n, ExecutionState& st, StreamOrDevice s) { + auto new_shape = to_shape(n.shape, st); + st.set_tensor(n.out, reshape(st.const_tensor_ref(n.x), new_shape, s)); +} + +inline void +exec_transpose(const TransposeNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, transpose(st.const_tensor_ref(n.x), n.perm, s)); +} + +inline void +exec_as_strided(const AsStridedNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto shape = to_shape(n.shape, st); + auto resolved_strides = resolve_ints(n.strides, st); + Strides strides(resolved_strides.begin(), resolved_strides.end()); + st.set_tensor(n.out, as_strided(x, shape, strides, n.offset, s)); +} + +inline void +exec_contiguous(const ContiguousNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, contiguous(st.const_tensor_ref(n.x), false, s)); +} + +inline void +exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { + st.set_tensor(n.out, st.const_tensor_ref(n.x)); +} + +inline void +exec_gather(const GatherNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const int rank = static_cast(x.ndim()); + + if (n.indices.size() != n.axes.size()) { + throw std::runtime_error( + "GatherNode: indices count (" + std::to_string(n.indices.size()) + + ") must match axes count (" + std::to_string(n.axes.size()) + ")"); + } + + if (static_cast(n.slice_sizes.size()) != rank) { + throw std::runtime_error( + "GatherNode: slice_sizes length (" + + std::to_string(n.slice_sizes.size()) + ") must match input ndim (" + + std::to_string(rank) + ")"); + } + + for (auto axis : n.axes) { + if (axis < 0 || axis >= rank) { + throw std::runtime_error( + "GatherNode: axis " + std::to_string(axis) + + " out of range for input with ndim " + std::to_string(rank)); + } + } + + Shape slice_sizes(n.slice_sizes.begin(), n.slice_sizes.end()); + check_allocation_bounded(slice_sizes, x.dtype(), "gather"); + + std::vector indices; + indices.reserve(n.indices.size()); + for (auto tid : n.indices) { + indices.push_back(st.const_tensor_ref(tid)); + } + + st.set_tensor(n.out, gather(x, indices, n.axes, slice_sizes, s)); +} + +inline void +exec_slice(const SliceNode& n, ExecutionState& st, StreamOrDevice s) { + const array& x = st.const_tensor_ref(n.x); + const int rank = static_cast(x.ndim()); + + int axis = normalize_axis(resolve_int(n.axis, st), rank, "Slice"); + int start = resolve_int(n.start, st); + int stop = resolve_int(n.stop, st); + + std::vector vstart(static_cast(rank), 0); + std::vector vstop; + vstop.reserve(static_cast(rank)); + auto sh = x.shape(); + for (size_t i = 0; i < static_cast(rank); ++i) { + vstop.push_back(static_cast(sh[i])); + } + + if (n.step == 0) { + throw std::invalid_argument("Slice: step must not be 0"); + } + + vstart[static_cast(axis)] = start; + vstop[static_cast(axis)] = stop; + + std::vector vstrides(static_cast(rank), 1); + vstrides[static_cast(axis)] = n.step; + st.set_tensor( + n.out, + slice(x, to_shape(vstart), to_shape(vstop), to_shape(vstrides), s)); +} + +inline void +exec_astype(const AsTypeNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, astype(st.const_tensor_ref(n.x), resolve_dtype(n.scalar_type), s)); +} + +inline void exec_quantized_linear( + const QuantizedLinearNode& n, + ExecutionState& st, + StreamOrDevice s) { + // scale_only means biases should be computed, not provided + assert( + !(n.scale_only && n.biases) && + "scale_only=true but biases tensor also provided"); + + array X = st.const_tensor_ref(n.x); + array Wq = st.const_tensor_ref(n.w); + array Sc = st.const_tensor_ref(n.scales); + + if (n.bits <= 0 || n.bits > 8) { + throw std::runtime_error( + "exec_quantized_linear: bits must be in [1, 8], got " + + std::to_string(n.bits)); + } + + std::optional Qb = std::nullopt; + if (n.biases) { + Qb = st.const_tensor_ref(*n.biases); + } else if (n.scale_only) { + // Compute biases from scales: B = -scales * 2^(bits-1) + float offset = static_cast(1 << (n.bits - 1)); + Qb = multiply(Sc, array(-offset, Sc.dtype()), s); + } + + array Y = quantized_matmul( + X, + Wq, + Sc, + Qb, + /*transpose=*/true, + n.group_size, + n.bits, + n.mode, + s); + + if (n.bias) { + const auto& b = st.const_tensor_ref(*n.bias); + Y = add(Y, b, s); + } + + Dtype out_dtype = resolve_dtype(n.out_scalar_type); + if (out_dtype != Y.dtype()) { + Y = astype(Y, out_dtype, s); + } + + st.set_tensor(n.out, std::move(Y)); +} + +inline void exec_concatenate( + const ConcatenateNode& n, + ExecutionState& st, + StreamOrDevice s) { + auto tensors_fb = n.tensors; + std::vector tensors; + for (auto tid : tensors_fb) { + tensors.push_back(st.const_tensor_ref(tid)); + } + st.set_tensor(n.out, concatenate(tensors, n.axis, s)); +} + +inline void exec_full(const FullNode& n, ExecutionState& st, StreamOrDevice s) { + auto shape = to_shape(n.shape, st); + auto dtype = resolve_dtype(n.scalar_type); + check_allocation_bounded(shape, dtype, "full"); + st.set_tensor(n.out, full(shape, resolve_float(n.v, st), dtype, s)); +} + +inline void +exec_full_like(const FullLikeNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + // Use input dtype if not specified + auto dtype = n.scalar_type.has_value() ? resolve_dtype(n.scalar_type.value()) + : x.dtype(); + st.set_tensor(n.out, full_like(x, resolve_float(n.v, st), dtype, s)); +} + +inline void exec_slice_update( + const SliceUpdateNode& n, + ExecutionState& st, + StreamOrDevice s) { + // When out == dst, use direct assignment to preserve MLX buffer donation. + // TODO: I'm not sure if this is needed as a special case since the standard + // st.set_tensor does a std::move. Keeping for now, but should investigate and + // possibly remove in future. + const bool in_place = (n.out.idx == n.dst.idx); + array& dst = st.tensor_ref(n.dst); + const array& upd = st.const_tensor_ref(n.update); + + const int rank = static_cast(dst.ndim()); + + int axis = normalize_axis(resolve_int(n.axis, st), rank, "SliceUpdate"); + int start = resolve_int(n.start, st); + int stop = resolve_int(n.stop, st); + + std::vector vstart(static_cast(rank), 0); + std::vector vstop; + vstop.reserve(static_cast(rank)); + auto sh = dst.shape(); + for (size_t i = 0; i < static_cast(rank); ++i) { + vstop.push_back(static_cast(sh[i])); + } + + const int dst_dim = vstop[static_cast(axis)]; + + if (start < 0) + start += dst_dim; + start = std::max(0, std::min(start, dst_dim)); + if (stop < 0) + stop += dst_dim; + stop = std::max(0, std::min(stop, dst_dim)); + + vstart[static_cast(axis)] = start; + vstop[static_cast(axis)] = stop; + + std::vector vstrides(static_cast(rank), 1); + if (n.step < 1) { + throw std::invalid_argument( + "SliceUpdate: step must be >= 1, got " + std::to_string(n.step) + ""); + } + vstrides[static_cast(axis)] = n.step; + + if (in_place) { + if (start == stop) { + return; + } + if (n.step == 1) { + dst = slice_update(dst, upd, to_shape(vstart), to_shape(vstop), s); + } else { + dst = slice_update( + dst, upd, to_shape(vstart), to_shape(vstop), to_shape(vstrides), s); + } + + } else { + if (start == stop) { + st.set_tensor(n.out, dst); + return; + } + if (n.step == 1) { + st.set_tensor( + n.out, slice_update(dst, upd, to_shape(vstart), to_shape(vstop), s)); + } else { + st.set_tensor( + n.out, + slice_update( + dst, + upd, + to_shape(vstart), + to_shape(vstop), + to_shape(vstrides), + s)); + } + } +} + +// Helper: finds next contiguous run in indices starting at offset +// Returns (dst_start, dst_stop, upd_start, upd_stop) for the run +// Returns (0, 0, 0, 0) when no more runs +inline std::tuple next_contiguous_run( + const std::vector& indices, + size_t offset) { + if (offset >= indices.size()) + return {0, 0, 0, 0}; + + int dst_start = indices[offset]; + int upd_start = static_cast(offset); + size_t len = 1; + while (offset + len < indices.size() && + len < static_cast(std::numeric_limits::max()) && + indices[offset + len] == dst_start + static_cast(len)) { + ++len; + } + int dst_stop = dst_start + static_cast(len); + int upd_stop = upd_start + static_cast(len); + return {dst_start, dst_stop, upd_start, upd_stop}; +} + +// Copies update tensor into dst at positions specified by 1D indices along axis +// Optimizes into slice_update calls for contiguous runs +inline void +exec_index_copy(const IndexCopyNode& n, ExecutionState& st, StreamOrDevice s) { + array& dst = st.tensor_ref(n.dst); + const array& upd = st.const_tensor_ref(n.update); + const array& indices = st.const_tensor_ref(n.indices); + if (indices.ndim() != 1) { + throw std::invalid_argument("IndexCopyNode: indices must be 1D"); + } + const int rank = static_cast(dst.ndim()); + int axis = normalize_axis(n.axis, rank, "IndexCopyNode"); + const size_t uaxis = static_cast(axis); + const int dst_dim = static_cast(dst.shape()[uaxis]); + + // Get indices as a vector of ints, handling negative indices + // Note: PyTorch uses int64 for indices, so we read as int64_t + eval(indices); // Ensure indices are materialized before accessing data + if (indices.dtype() != ::mlx::core::int64) { + throw std::invalid_argument( + std::string("IndexCopyNode: expected int64 indices, got ") + + ExecutionState::dtype_str(indices.dtype())); + } + std::vector idx_vec(indices.size()); + auto idx_data = indices.data(); + for (size_t i = 0; i < indices.size(); ++i) { + int64_t idx = idx_data[i]; + if (idx < 0) { + idx += dst_dim; + } + if (idx < 0 || idx >= dst_dim) { + throw std::out_of_range( + "IndexCopyNode: index " + std::to_string(idx_data[i]) + + " out of range for axis " + std::to_string(axis) + " with size " + + std::to_string(dst_dim)); + } + if (idx > std::numeric_limits::max()) { + throw std::out_of_range( + "IndexCopyNode: index " + std::to_string(idx) + + " exceeds int32 range"); + } + idx_vec[i] = static_cast(idx); + } + + // When out == dst, use direct assignment to preserve MLX buffer donation. + // TODO: I'm not sure if this is needed as a special case since the standard + // st.set_tensor does a std::move. Keeping for now, but should investigate and + // possibly remove in future. + const bool in_place = (n.out.idx == n.dst.idx); + + if (idx_vec.empty()) { + if (!in_place) { + st.set_tensor(n.out, dst); + } + return; + } + + // Build base start/stop vectors for slice_update + const size_t urank = static_cast(rank); + std::vector dst_vstart(urank, 0); + std::vector dst_vstop; + dst_vstop.reserve(urank); + auto sh = dst.shape(); + for (size_t i = 0; i < urank; ++i) { + dst_vstop.push_back(static_cast(sh[i])); + } + + std::vector upd_vstart(urank, 0); + std::vector upd_vstop; + upd_vstop.reserve(urank); + auto upd_sh = upd.shape(); + for (size_t i = 0; i < urank; ++i) { + upd_vstop.push_back(static_cast(upd_sh[i])); + } + + array result = dst; // copy of dst to accumulate into + + // Process contiguous runs + size_t offset = 0; + while (offset < idx_vec.size()) { + auto [dst_start, dst_stop, upd_start, upd_stop] = + next_contiguous_run(idx_vec, offset); + + // Set axis range for dst + dst_vstart[uaxis] = dst_start; + dst_vstop[uaxis] = dst_stop; + + // Set axis range for upd slice + upd_vstart[uaxis] = upd_start; + upd_vstop[uaxis] = upd_stop; + + // Slice update - skip slicing if using entire update tensor + array upd_slice = + (upd_start == 0 && upd_stop == static_cast(upd_sh[uaxis])) + ? upd + : slice(upd, to_shape(upd_vstart), to_shape(upd_vstop), s); + + if (in_place) { + dst = slice_update( + dst, upd_slice, to_shape(dst_vstart), to_shape(dst_vstop), s); + } else { + result = slice_update( + result, upd_slice, to_shape(dst_vstart), to_shape(dst_vstop), s); + } + + offset = static_cast(upd_stop); + } + + if (!in_place) { + st.set_tensor(n.out, result); + } +} + +inline void +exec_dequantize(const DequantizeNode& n, ExecutionState& st, StreamOrDevice s) { + array Wq = st.const_tensor_ref(n.w); + array Sc = st.const_tensor_ref(n.scales); + + std::optional Qb = std::nullopt; + if (n.biases) { + Qb = st.const_tensor_ref(*n.biases); + } + + array Y = dequantize( + Wq, + Sc, + Qb, + n.group_size, + n.bits, + n.mode, + std::nullopt, // dtype - let MLX infer + s); + + Dtype out_dtype = resolve_dtype(n.out_scalar_type); + if (out_dtype != Y.dtype()) { + Y = astype(Y, out_dtype, s); + } + + st.set_tensor(n.out, std::move(Y)); +} + +inline void exec_less(const LessNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, less(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_less_equal(const LessEqualNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, less_equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_greater(const GreaterNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, greater(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_greater_equal( + const GreaterEqualNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor( + n.out, + greater_equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_equal(const EqualNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_not_equal(const NotEqualNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, not_equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_logical_not( + const LogicalNotNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s)); +} + +inline void exec_logical_and( + const LogicalAndNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor( + n.out, + logical_and(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_logical_or(const LogicalOrNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, logical_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) { + int rows = resolve_int(n.n, st); + int cols = resolve_int(n.m, st); + auto dtype = resolve_dtype(n.scalar_type); + check_allocation_bounded({rows, cols}, dtype, "tri"); + st.set_tensor(n.out, tri(rows, cols, n.k, dtype, s)); +} + +inline void exec_tril(const TrilNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, tril(x, n.k, s)); +} + +inline void exec_triu(const TriuNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, triu(x, n.k, s)); +} + +inline void +exec_floor(const FloorNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, floor(st.const_tensor_ref(n.x), s)); +} + +inline void exec_ceil(const CeilNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, ceil(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_square(const SquareNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, square(st.const_tensor_ref(n.x), s)); +} + +inline void exec_exp(const ExpNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, exp(st.const_tensor_ref(n.x), s)); +} + +inline void exec_sin(const SinNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sin(st.const_tensor_ref(n.x), s)); +} + +inline void exec_cos(const CosNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, cos(st.const_tensor_ref(n.x), s)); +} + +inline void exec_tan(const TanNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, tan(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arcsin(const ArcsinNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arcsin(st.const_tensor_ref(n.x), s)); } inline void -exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { - const auto& mat1 = st.const_tensor_ref(n.mat1); - const auto& mat2 = st.const_tensor_ref(n.mat2); +exec_arccos(const ArccosNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arccos(st.const_tensor_ref(n.x), s)); +} - array Y = n.bias ? addmm( - st.const_tensor_ref(*n.bias), - mat1, - mat2, - /*alpha=*/n.alpha, - /*beta=*/n.beta, - s) - : matmul(mat1, mat2, s); +inline void +exec_arctan(const ArctanNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arctan(st.const_tensor_ref(n.x), s)); +} - st.set_tensor(n.out, std::move(Y)); +inline void exec_sinh(const SinhNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sinh(st.const_tensor_ref(n.x), s)); +} + +inline void exec_cosh(const CoshNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, cosh(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arcsinh(const ArcsinhNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arcsinh(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arccosh(const ArccoshNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arccosh(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arctanh(const ArctanhNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arctanh(st.const_tensor_ref(n.x), s)); +} + +inline void exec_log2(const Log2Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, log2(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_log10(const Log10Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, log10(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_log1p(const Log1pNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, log1p(st.const_tensor_ref(n.x), s)); +} + +inline void exec_erf(const ErfNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, erf(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_expm1(const Expm1Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, expm1(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_round(const RoundNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, round(st.const_tensor_ref(n.x), n.decimals, s)); +} + +inline void +exec_reciprocal(const ReciprocalNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, reciprocal(st.const_tensor_ref(n.x), s)); +} + +inline void exec_sqrt(const SqrtNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sqrt(st.const_tensor_ref(n.x), s)); +} + +inline void exec_abs(const AbsNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, abs(st.const_tensor_ref(n.x), s)); +} + +inline void exec_neg(const NegNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, negative(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_atan2(const Atan2Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, arctan2(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_logaddexp(const LogAddExpNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, logaddexp(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_floor_divide( + const FloorDivideNode& n, + ExecutionState& st, + StreamOrDevice s) { + const array& a = st.const_tensor_ref(n.a); + const array& b = st.const_tensor_ref(n.b); + + if (!issubdtype(a.dtype(), inexact)) { + // mlx::floor_divide for integer types uses C++ truncation toward zero, + // but PyTorch floor_divide floors toward negative infinity. + // Adjust: floor_div(a, b) = trunc_div(a, b) - ((a % b != 0) & (sign(a) != + // sign(b))) + auto quot = divide(a, b, s); + auto rem = remainder(a, b, s); + auto zero = array(0, a.dtype()); + auto has_rem = not_equal(rem, zero, s); + auto a_neg = less(a, zero, s); + auto b_neg = less(b, zero, s); + auto signs_differ = not_equal(a_neg, b_neg, s); + auto adjust = logical_and(has_rem, signs_differ, s); + st.set_tensor(n.out, subtract(quot, astype(adjust, a.dtype(), s), s)); + } else { + st.set_tensor(n.out, floor_divide(a, b, s)); + } +} + +inline void +exec_remainder(const RemainderNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, remainder(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_power(const PowerNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, power(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_logsumexp(const LogSumExpNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + st.set_tensor(n.out, logsumexp(x, axes, n.keepdims, s)); +} + +inline void exec_sum(const SumNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, sum(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, sum(x, axes, n.keepdims, s)); + } +} + +inline void exec_mean(const MeanNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, mean(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, mean(x, axes, n.keepdims, s)); + } +} + +inline void exec_var(const VarNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, var(x, n.keepdims, n.ddof, s)); + } else { + st.set_tensor(n.out, var(x, axes, n.keepdims, n.ddof, s)); + } +} + +inline void exec_std(const StdNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, ::mlx::core::std(x, n.keepdims, n.ddof, s)); + } else { + st.set_tensor(n.out, ::mlx::core::std(x, axes, n.keepdims, n.ddof, s)); + } +} + +inline void exec_prod(const ProdNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, prod(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, prod(x, axes, n.keepdims, s)); + } +} + +inline void exec_max(const MaxNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, max(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, max(x, axes, n.keepdims, s)); + } +} + +inline void exec_min(const MinNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, min(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, min(x, axes, n.keepdims, s)); + } +} + +inline void +exec_argmax(const ArgmaxNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, argmax(x, n.axis, n.keepdims, s)); +} + +inline void +exec_argmin(const ArgminNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, argmin(x, n.axis, n.keepdims, s)); +} + +inline void +exec_median(const MedianNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, median(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, median(x, axes, n.keepdims, s)); + } +} + +inline void exec_clip(const ClipNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::optional a_min = n.a_min + ? std::optional(st.const_tensor_ref(*n.a_min)) + : std::nullopt; + std::optional a_max = n.a_max + ? std::optional(st.const_tensor_ref(*n.a_max)) + : std::nullopt; + st.set_tensor(n.out, clip(x, a_min, a_max, s)); +} + +inline void +exec_cumsum(const CumsumNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, cumsum(x, n.axis, n.reverse, n.inclusive, s)); +} + +inline void +exec_stack(const StackNode& n, ExecutionState& st, StreamOrDevice s) { + std::vector tensors; + for (auto tid : n.tensors) { + tensors.push_back(st.const_tensor_ref(tid)); + } + st.set_tensor(n.out, stack(tensors, n.axis, s)); +} + +inline void exec_sign(const SignNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sign(st.const_tensor_ref(n.x), s)); +} + +inline void exec_any(const AnyNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, any(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, any(x, axes, n.keepdims, s)); + } +} + +inline void exec_all(const AllNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, all(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, all(x, axes, n.keepdims, s)); + } +} + +inline void +exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int repeats = static_cast(resolve_int(n.repeats, st)); + if (repeats < 0) { + throw std::invalid_argument( + "repeat: repeats must be non-negative, got " + std::to_string(repeats)); + } + auto out_shape = x.shape(); + int axis = n.axis < 0 ? n.axis + static_cast(x.ndim()) : n.axis; + out_shape[static_cast(axis)] *= repeats; + check_allocation_bounded(out_shape, x.dtype(), "repeat"); + st.set_tensor(n.out, repeat(x, repeats, n.axis, s)); +} + +inline void exec_sort(const SortNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sort(st.const_tensor_ref(n.x), n.axis, s)); +} + +inline void +exec_argsort(const ArgsortNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, argsort(st.const_tensor_ref(n.x), n.axis, s)); +} + +inline void +exec_partition(const PartitionNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int kth = static_cast(resolve_int(n.kth, st)); + st.set_tensor(n.out, partition(x, kth, n.axis, s)); +} + +inline void exec_argpartition( + const ArgPartitionNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int kth = static_cast(resolve_int(n.kth, st)); + st.set_tensor(n.out, argpartition(x, kth, n.axis, s)); } } // namespace ops @@ -165,6 +1641,375 @@ class Interpreter { case OpCode::ADDMM: ops::exec_addmm(std::get(instr.node), st, s); break; + case OpCode::LINEAR: + ops::exec_linear(std::get(instr.node), st, s); + break; + case OpCode::ITEM_INT: + ops::exec_item_int(std::get(instr.node), st, s); + break; + case OpCode::EXPAND_DIMS: + ops::exec_expand_dims(std::get(instr.node), st, s); + break; + case OpCode::TILE: + ops::exec_tile(std::get(instr.node), st, s); + break; + case OpCode::TAKE_ALONG_AXIS: + ops::exec_take_along_axis( + std::get(instr.node), st, s); + break; + case OpCode::TAKE: + ops::exec_take(std::get(instr.node), st, s); + break; + case OpCode::RMS_NORM: + ops::exec_rms_norm(std::get(instr.node), st, s); + break; + case OpCode::LAYER_NORM: + ops::exec_layer_norm(std::get(instr.node), st, s); + break; + case OpCode::ROPE: + ops::exec_rope(std::get(instr.node), st, s); + break; + case OpCode::SDPA: + ops::exec_sdpa(std::get(instr.node), st, s); + break; + case OpCode::ADD: + ops::exec_add(std::get(instr.node), st, s); + break; + case OpCode::ADD_INT: + ops::exec_add_int(std::get(instr.node), st, s); + break; + case OpCode::SUBTRACT_INT: + ops::exec_subtract_int(std::get(instr.node), st, s); + break; + case OpCode::MULTIPLY_INT: + ops::exec_multiply_int(std::get(instr.node), st, s); + break; + case OpCode::FLOOR_DIVIDE_INT: + ops::exec_floor_divide_int( + std::get(instr.node), st, s); + break; + case OpCode::MOD_INT: + ops::exec_mod_int(std::get(instr.node), st, s); + break; + case OpCode::SYM_SIZE: + ops::exec_sym_size(std::get(instr.node), st, s); + break; + case OpCode::MULTIPLY: + ops::exec_multiply(std::get(instr.node), st, s); + break; + case OpCode::DIVIDE: + ops::exec_divide(std::get(instr.node), st, s); + break; + case OpCode::SUBTRACT: + ops::exec_subtract(std::get(instr.node), st, s); + break; + case OpCode::CONV1D: + ops::exec_conv1d(std::get(instr.node), st, s); + break; + case OpCode::CONV2D: + ops::exec_conv2d(std::get(instr.node), st, s); + break; + case OpCode::CONV3D: + ops::exec_conv3d(std::get(instr.node), st, s); + break; + case OpCode::GELU: + ops::exec_gelu(std::get(instr.node), st, s); + break; + case OpCode::ARANGE: + ops::exec_arange(std::get(instr.node), st, s); + break; + case OpCode::SILU: + ops::exec_silu(std::get(instr.node), st, s); + break; + case OpCode::SIGMOID: + ops::exec_sigmoid(std::get(instr.node), st, s); + break; + case OpCode::TANH: + ops::exec_tanh(std::get(instr.node), st, s); + break; + case OpCode::SQUEEZE: + ops::exec_squeeze(std::get(instr.node), st, s); + break; + case OpCode::SPLIT: + ops::exec_split(std::get(instr.node), st, s); + break; + case OpCode::RSQRT: + ops::exec_rsqrt(std::get(instr.node), st, s); + break; + case OpCode::MAXIMUM: + ops::exec_maximum(std::get(instr.node), st, s); + break; + case OpCode::MINIMUM: + ops::exec_minimum(std::get(instr.node), st, s); + break; + case OpCode::LOG: + ops::exec_log(std::get(instr.node), st, s); + break; + case OpCode::SOFTMAX: + ops::exec_softmax(std::get(instr.node), st, s); + break; + case OpCode::BROADCAST_TO: + ops::exec_broadcast_to(std::get(instr.node), st, s); + break; + case OpCode::PAD: + ops::exec_pad(std::get(instr.node), st, s); + break; + case OpCode::WHERE: + ops::exec_where(std::get(instr.node), st, s); + break; + case OpCode::RESHAPE: + ops::exec_reshape(std::get(instr.node), st, s); + break; + case OpCode::TRANSPOSE: + ops::exec_transpose(std::get(instr.node), st, s); + break; + case OpCode::AS_STRIDED: + ops::exec_as_strided(std::get(instr.node), st, s); + break; + case OpCode::CONTIGUOUS: + ops::exec_contiguous(std::get(instr.node), st, s); + break; + case OpCode::ID_COPY: + ops::exec_id_copy(std::get(instr.node), st, s); + break; + case OpCode::GATHER: + ops::exec_gather(std::get(instr.node), st, s); + break; + case OpCode::SLICE: + ops::exec_slice(std::get(instr.node), st, s); + break; + case OpCode::ASTYPE: + ops::exec_astype(std::get(instr.node), st, s); + break; + case OpCode::QUANTIZED_LINEAR: + ops::exec_quantized_linear( + std::get(instr.node), st, s); + break; + case OpCode::CONCATENATE: + ops::exec_concatenate(std::get(instr.node), st, s); + break; + case OpCode::FULL: + ops::exec_full(std::get(instr.node), st, s); + break; + case OpCode::FULL_LIKE: + ops::exec_full_like(std::get(instr.node), st, s); + break; + case OpCode::ARGMAX: + ops::exec_argmax(std::get(instr.node), st, s); + break; + case OpCode::SLICE_UPDATE: + ops::exec_slice_update(std::get(instr.node), st, s); + break; + case OpCode::INDEX_COPY: + ops::exec_index_copy(std::get(instr.node), st, s); + break; + case OpCode::DEQUANTIZE: + ops::exec_dequantize(std::get(instr.node), st, s); + break; + case OpCode::LESS: + ops::exec_less(std::get(instr.node), st, s); + break; + case OpCode::LESS_EQUAL: + ops::exec_less_equal(std::get(instr.node), st, s); + break; + case OpCode::GREATER: + ops::exec_greater(std::get(instr.node), st, s); + break; + case OpCode::GREATER_EQUAL: + ops::exec_greater_equal(std::get(instr.node), st, s); + break; + case OpCode::EQUAL: + ops::exec_equal(std::get(instr.node), st, s); + break; + case OpCode::NOT_EQUAL: + ops::exec_not_equal(std::get(instr.node), st, s); + break; + case OpCode::LOGICAL_NOT: + ops::exec_logical_not(std::get(instr.node), st, s); + break; + case OpCode::LOGICAL_AND: + ops::exec_logical_and(std::get(instr.node), st, s); + break; + case OpCode::LOGICAL_OR: + ops::exec_logical_or(std::get(instr.node), st, s); + break; + case OpCode::TRI: + ops::exec_tri(std::get(instr.node), st, s); + break; + case OpCode::TRIL: + ops::exec_tril(std::get(instr.node), st, s); + break; + case OpCode::TRIU: + ops::exec_triu(std::get(instr.node), st, s); + break; + // Math ops - Unary + case OpCode::FLOOR: + ops::exec_floor(std::get(instr.node), st, s); + break; + case OpCode::CEIL: + ops::exec_ceil(std::get(instr.node), st, s); + break; + case OpCode::SQUARE: + ops::exec_square(std::get(instr.node), st, s); + break; + case OpCode::EXP: + ops::exec_exp(std::get(instr.node), st, s); + break; + case OpCode::SIN: + ops::exec_sin(std::get(instr.node), st, s); + break; + case OpCode::COS: + ops::exec_cos(std::get(instr.node), st, s); + break; + case OpCode::TAN: + ops::exec_tan(std::get(instr.node), st, s); + break; + case OpCode::ARCSIN: + ops::exec_arcsin(std::get(instr.node), st, s); + break; + case OpCode::ARCCOS: + ops::exec_arccos(std::get(instr.node), st, s); + break; + case OpCode::ARCTAN: + ops::exec_arctan(std::get(instr.node), st, s); + break; + case OpCode::SINH: + ops::exec_sinh(std::get(instr.node), st, s); + break; + case OpCode::COSH: + ops::exec_cosh(std::get(instr.node), st, s); + break; + case OpCode::ARCSINH: + ops::exec_arcsinh(std::get(instr.node), st, s); + break; + case OpCode::ARCCOSH: + ops::exec_arccosh(std::get(instr.node), st, s); + break; + case OpCode::ARCTANH: + ops::exec_arctanh(std::get(instr.node), st, s); + break; + case OpCode::LOG2: + ops::exec_log2(std::get(instr.node), st, s); + break; + case OpCode::LOG10: + ops::exec_log10(std::get(instr.node), st, s); + break; + case OpCode::LOG1P: + ops::exec_log1p(std::get(instr.node), st, s); + break; + case OpCode::ERF: + ops::exec_erf(std::get(instr.node), st, s); + break; + case OpCode::EXPM1: + ops::exec_expm1(std::get(instr.node), st, s); + break; + case OpCode::ROUND: + ops::exec_round(std::get(instr.node), st, s); + break; + case OpCode::RECIPROCAL: + ops::exec_reciprocal(std::get(instr.node), st, s); + break; + case OpCode::SQRT: + ops::exec_sqrt(std::get(instr.node), st, s); + break; + case OpCode::ABS: + ops::exec_abs(std::get(instr.node), st, s); + break; + case OpCode::NEG: + ops::exec_neg(std::get(instr.node), st, s); + break; + // Math ops - Binary + case OpCode::ATAN2: + ops::exec_atan2(std::get(instr.node), st, s); + break; + case OpCode::LOG_ADD_EXP: + ops::exec_logaddexp(std::get(instr.node), st, s); + break; + case OpCode::FLOOR_DIVIDE: + ops::exec_floor_divide(std::get(instr.node), st, s); + break; + case OpCode::REMAINDER: + ops::exec_remainder(std::get(instr.node), st, s); + break; + case OpCode::POWER: + ops::exec_power(std::get(instr.node), st, s); + break; + // Math ops - Reduction + case OpCode::LOG_SUM_EXP: + ops::exec_logsumexp(std::get(instr.node), st, s); + break; + case OpCode::SUM: + ops::exec_sum(std::get(instr.node), st, s); + break; + case OpCode::MEAN: + ops::exec_mean(std::get(instr.node), st, s); + break; + case OpCode::VAR: + ops::exec_var(std::get(instr.node), st, s); + break; + case OpCode::STD: + ops::exec_std(std::get(instr.node), st, s); + break; + case OpCode::PROD: + ops::exec_prod(std::get(instr.node), st, s); + break; + case OpCode::MAX: + ops::exec_max(std::get(instr.node), st, s); + break; + case OpCode::MIN: + ops::exec_min(std::get(instr.node), st, s); + break; + case OpCode::ARGMIN: + ops::exec_argmin(std::get(instr.node), st, s); + break; + case OpCode::MEDIAN: + ops::exec_median(std::get(instr.node), st, s); + break; + case OpCode::CONV_TRANSPOSE1D: + ops::exec_conv_transpose1d( + std::get(instr.node), st, s); + break; + case OpCode::CONV_TRANSPOSE2D: + ops::exec_conv_transpose2d( + std::get(instr.node), st, s); + break; + case OpCode::CONV_TRANSPOSE3D: + ops::exec_conv_transpose3d( + std::get(instr.node), st, s); + break; + case OpCode::CLIP: + ops::exec_clip(std::get(instr.node), st, s); + break; + case OpCode::CUMSUM: + ops::exec_cumsum(std::get(instr.node), st, s); + break; + case OpCode::STACK: + ops::exec_stack(std::get(instr.node), st, s); + break; + case OpCode::SIGN: + ops::exec_sign(std::get(instr.node), st, s); + break; + case OpCode::ANY: + ops::exec_any(std::get(instr.node), st, s); + break; + case OpCode::ALL: + ops::exec_all(std::get(instr.node), st, s); + break; + case OpCode::REPEAT: + ops::exec_repeat(std::get(instr.node), st, s); + break; + case OpCode::SORT: + ops::exec_sort(std::get(instr.node), st, s); + break; + case OpCode::ARGSORT: + ops::exec_argsort(std::get(instr.node), st, s); + break; + case OpCode::PARTITION: + ops::exec_partition(std::get(instr.node), st, s); + break; + case OpCode::ARG_PARTITION: + ops::exec_argpartition(std::get(instr.node), st, s); + break; default: throw std::runtime_error( "Unknown opcode: " + std::to_string(static_cast(instr.op))); diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 8b159314760..f27e37056e4 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -86,6 +86,827 @@ table AddmmNode { beta: float = 1.0; // Scalar multiplier for bias } +table LinearNode { + x: Tid (required); + weight: Tid (required); + out: Tid (required); + bias: Tid; // optional +} + +table ItemIntNode { + x: Tid (required); + out: Vid (required); +} + +table ExpandDimsNode { + x: Tid (required); + out: Tid (required); + axis: int32; +} + +table TileNode { + x: Tid (required); + out: Tid (required); + reps: [IntOrVid] (required); +} + +table TakeAlongAxisNode { + x: Tid (required); + indices: Tid (required); + out: Tid (required); + axis: int32; +} + +table TakeNode { + x: Tid (required); + out: Tid (required); + index: IntOrVidOrTid (required); // Scalar int, dynamic Vid, or tensor of indices + axis: int32; // Axis along which to select +} + +table RMSNormNode { + x: Tid (required); + weight: Tid (required); + out: Tid (required); + eps: float; +} + +table LayerNormNode { + x: Tid (required); + out: Tid (required); + weight: Tid; // optional + bias: Tid; // optional + eps: float; +} + +table RopeNode { + x: Tid (required); + out: Tid (required); + dims: int32; + offset: VidOrTid (required); // Position offset: scalar (Vid) or tensor of positions (Tid) + freqs: Tid; // optional + traditional: bool = false; + base: float = 500000.0; // Llama 3 default + scale: float = 1.0; +} + +table SdpaNode { + q: Tid (required); + k: Tid (required); + v: Tid (required); + out: Tid (required); + scale: float; + mask: Tid; // optional + causal: bool = false; +} + +table AddNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table AddIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table SubtractIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table MultiplyIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table FloorDivideIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table ModIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table SymSizeNode { + a: Tid (required); + dim: int32; + out: Vid (required); +} + +table MultiplyNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table DivideNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table SubtractNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table Conv1DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride: int32 = 1; + padding: int32 = 0; + dilation: int32 = 1; + groups: int32 = 1; +} + +table Conv2DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + groups: int32 = 1; +} + +table Conv3DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_d: int32 = 1; + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_d: int32 = 0; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_d: int32 = 1; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + groups: int32 = 1; +} + +table ConvTranspose1DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride: int32 = 1; + padding: int32 = 0; + dilation: int32 = 1; + output_padding: int32 = 0; + groups: int32 = 1; +} + +table ConvTranspose2DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + output_padding_h: int32 = 0; + output_padding_w: int32 = 0; + groups: int32 = 1; +} + +table ConvTranspose3DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_d: int32 = 1; + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_d: int32 = 0; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_d: int32 = 1; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + output_padding_d: int32 = 0; + output_padding_h: int32 = 0; + output_padding_w: int32 = 0; + groups: int32 = 1; +} + +table GeluNode { + x: Tid (required); + out: Tid (required); + approximate: string (required); // "none" or "tanh" +} + +table ARangeNode { + out: Tid (required); + start: IntOrVid (required); // Can be literal or dynamic (from item()) + stop: IntOrVid (required); // Can be literal or dynamic (from item()) + step: IntOrVid (required); // Can be literal or dynamic + scalar_type: int8 = null; // ET ScalarType (optional - None means infer from context) +} + +table SiluNode { + x: Tid (required); + out: Tid (required); +} + +table SigmoidNode { + x: Tid (required); + out: Tid (required); +} + +table TanhNode { + x: Tid (required); + out: Tid (required); +} + +table SqueezeNode { + x: Tid (required); + out: Tid (required); + dims: [int32]; // Optional list of dimensions to squeeze. If empty, squeeze all dims of size 1 +} + +table SplitNode { + x: Tid (required); + outs: [Tid] (required); // Multiple output tensor IDs (one for each split chunk) + sizes: [IntOrVid] (required); // Split sizes (can be dynamic) + axis: int32; +} + +table RsqrtNode { + x: Tid (required); + out: Tid (required); +} + +table MaximumNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table MinimumNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LogNode { + x: Tid (required); + out: Tid (required); +} + +table SoftmaxNode { + x: Tid (required); + out: Tid (required); + axis: int32; // Dimension to compute softmax over + precise: bool = false; // Use precise (slow) implementation +} + +table BroadcastToNode { + x: Tid (required); + out: Tid (required); + shape: [IntOrVid] (required); // Target shape to broadcast to +} + +table PadNode { + x: Tid (required); + out: Tid (required); + pad_width: [IntOrVid] (required); // Padding pairs: [(before_0, after_0), (before_1, after_1), ...] + mode: string (required); // "constant" or "edge" + constant_value: float = 0.0; // Value to pad with (for constant mode) +} + +table WhereNode { + condition: Tid (required); + x: Tid (required); + y: Tid (required); + out: Tid (required); +} + +table ReshapeNode { + x: Tid (required); + out: Tid (required); + shape: [IntOrVid] (required); +} + +table TransposeNode { + x: Tid (required); + out: Tid (required); + perm: [int32] (required); +} + +table AsStridedNode { + x: Tid (required); + out: Tid (required); + shape: [IntOrVid] (required); // Output view shape (can be dynamic) + strides: [IntOrVid] (required); // Element strides per dimension (can be dynamic) + offset: uint64 = 0; // Element offset into source +} + +table ContiguousNode { + x: Tid (required); + out: Tid (required); +} + +table GatherNode { + x: Tid (required); + indices: [Tid] (required); // Index tensors (one per indexed axis) + out: Tid (required); + axes: [int32] (required); // Which axes to gather along + slice_sizes: [int32] (required); // Size of slice per dimension of x +} + +table SliceNode { + x: Tid (required); + out: Tid (required); + axis: IntOrVid (required); + start: IntOrVid (required); + stop: IntOrVid (required); + step: int32 = 1; +} + +table AsTypeNode { + x: Tid (required); + out: Tid (required); + scalar_type: int8; // ET ScalarType +} + +table QuantizedLinearNode { + x: Tid (required); + w: Tid (required); + scales: Tid (required); + out: Tid (required); + biases: Tid; // optional - quantization biases (required if scale_only=false) + bias: Tid; // optional - neural network bias + group_size: int32; + bits: int32; + mode: string (required); + out_scalar_type: int8; // ET ScalarType for output + scale_only: bool = false; // if true, compute biases = -scales * 2^(bits-1); if false, biases tensor required +} + +table ConcatenateNode { + tensors: [Tid] (required); // List of tensors to concatenate + out: Tid (required); + axis: int32; +} + +table FullNode { + out: Tid (required); + shape: [IntOrVid] (required); + v: FloatOrVid (required); // Fill value (can be dynamic from item()) + scalar_type: int8; // ET ScalarType +} + +table FullLikeNode { + x: Tid (required); // Input tensor to copy shape from + out: Tid (required); + v: FloatOrVid (required); // Fill value (can be dynamic from item()) + scalar_type: int8 = null; // ET ScalarType (optional - if null, use x's dtype) +} + +table ArgmaxNode { + x: Tid (required); + out: Tid (required); + axis: int32; + keepdims: bool = false; +} + +table SliceUpdateNode { + dst: Tid (required); + update: Tid (required); + out: Tid (required); // Can be same as dst + axis: IntOrVid (required); + start: IntOrVid (required); + stop: IntOrVid (required); + step: int32 = 1; +} + +// Index-based update: copies update tensor into dst at positions specified by 1D indices +// Runtime optimizes these into slice_update calls for contiguous runs +table IndexCopyNode { + dst: Tid (required); // destination tensor to update + update: Tid (required); // source tensor to copy from + indices: Tid (required); // 1D tensor of indices along axis + out: Tid (required); // output tensor (can be same as dst) + axis: int32; // dimension to index along +} + + +table DequantizeNode { + w: Tid (required); // Quantized matrix to dequantize + scales: Tid (required); // Scales per group_size elements + out: Tid (required); + biases: Tid; // optional - biases per group_size elements + group_size: int32; + bits: int32; + mode: string (required); // Quantization mode (e.g. "affine") + out_scalar_type: int8; // ET ScalarType for output dtype +} + +// Comparison ops (return bool arrays) +table LessNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LessEqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table GreaterNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table GreaterEqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table EqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table NotEqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +// Logical ops +table LogicalNotNode { + x: Tid (required); + out: Tid (required); +} + +table LogicalAndNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LogicalOrNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +// Triangular matrix ops +table TriNode { + out: Tid (required); + n: IntOrVid (required); // Number of rows + m: IntOrVid (required); // Number of columns + k: int32 = 0; // Diagonal offset: 0=main, +above, -below + scalar_type: int8; // ET ScalarType +} + +table TrilNode { + x: Tid (required); + out: Tid (required); + k: int32 = 0; // Diagonal offset: 0=main, +above, -below +} + +table TriuNode { + x: Tid (required); + out: Tid (required); + k: int32 = 0; // Diagonal offset: 0=main, +above, -below +} + +table ClipNode { + x: Tid (required); + out: Tid (required); + a_min: Tid; // optional lower bound + a_max: Tid; // optional upper bound +} + +table CumsumNode { + x: Tid (required); + out: Tid (required); + axis: int32; + reverse: bool = false; + inclusive: bool = true; +} + +table StackNode { + tensors: [Tid] (required); + out: Tid (required); + axis: int32 = 0; +} + +table SignNode { + x: Tid (required); + out: Tid (required); +} + +table AnyNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table AllNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table RepeatNode { + x: Tid (required); + out: Tid (required); + repeats: IntOrVid (required); // Number of times to repeat each element (can be dynamic) + axis: int32; // Axis along which to repeat +} + +table SortNode { + x: Tid (required); + out: Tid (required); + axis: int32; +} + +table ArgsortNode { + x: Tid (required); + out: Tid (required); + axis: int32; +} + +table PartitionNode { + x: Tid (required); + out: Tid (required); + kth: IntOrVid (required); // Partition index + axis: int32; +} + +table ArgPartitionNode { + x: Tid (required); + out: Tid (required); + kth: IntOrVid (required); // Partition index + axis: int32; +} + + +// ============================================================================= +// Math ops - Unary element-wise +// ============================================================================= + +table FloorNode { + x: Tid (required); + out: Tid (required); +} + +table CeilNode { + x: Tid (required); + out: Tid (required); +} + +table SquareNode { + x: Tid (required); + out: Tid (required); +} + +table ExpNode { + x: Tid (required); + out: Tid (required); +} + +table SinNode { + x: Tid (required); + out: Tid (required); +} + +table CosNode { + x: Tid (required); + out: Tid (required); +} + +table TanNode { + x: Tid (required); + out: Tid (required); +} + +table ArcsinNode { + x: Tid (required); + out: Tid (required); +} + +table ArccosNode { + x: Tid (required); + out: Tid (required); +} + +table ArctanNode { + x: Tid (required); + out: Tid (required); +} + +table SinhNode { + x: Tid (required); + out: Tid (required); +} + +table CoshNode { + x: Tid (required); + out: Tid (required); +} + +table ArcsinhNode { + x: Tid (required); + out: Tid (required); +} + +table ArccoshNode { + x: Tid (required); + out: Tid (required); +} + +table ArctanhNode { + x: Tid (required); + out: Tid (required); +} + +table Log2Node { + x: Tid (required); + out: Tid (required); +} + +table Log10Node { + x: Tid (required); + out: Tid (required); +} + +table Log1pNode { + x: Tid (required); + out: Tid (required); +} + +table ErfNode { + x: Tid (required); + out: Tid (required); +} + +table Expm1Node { + x: Tid (required); + out: Tid (required); +} + +table RoundNode { + x: Tid (required); + out: Tid (required); + decimals: int32 = 0; +} + +table ReciprocalNode { + x: Tid (required); + out: Tid (required); +} + +table SqrtNode { + x: Tid (required); + out: Tid (required); +} + +table AbsNode { + x: Tid (required); + out: Tid (required); +} + +table NegNode { + x: Tid (required); + out: Tid (required); +} + +// ============================================================================= +// Math ops - Binary element-wise +// ============================================================================= + +table Atan2Node { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LogAddExpNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table FloorDivideNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table RemainderNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table PowerNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +// ============================================================================= +// Math ops - Reduction +// ============================================================================= + +table LogSumExpNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; + keepdims: bool = false; +} + +table SumNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table MeanNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table VarNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; + ddof: int32 = 0; // Delta degrees of freedom (0=population var, 1=sample var) +} + +table StdNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; + ddof: int32 = 0; // Delta degrees of freedom +} + +table ProdNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table MaxNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table MinNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table ArgminNode { + x: Tid (required); + out: Tid (required); + axis: int32; + keepdims: bool = false; +} + +table MedianNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + // ============================================================================= // Union of all op types // ============================================================================= @@ -95,8 +916,128 @@ table AddmmNode { union OpNode { NoopNode, IdCopyNode, - AddmmNode - // BC: Add new op nodes here (append only) + AddmmNode, + LinearNode, + ItemIntNode, + ExpandDimsNode, + TileNode, + TakeAlongAxisNode, + TakeNode, + RMSNormNode, + LayerNormNode, + RopeNode, + SdpaNode, + AddNode, + AddIntNode, + SubtractIntNode, + MultiplyIntNode, + FloorDivideIntNode, + SymSizeNode, + MultiplyNode, + DivideNode, + SubtractNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + GeluNode, + ARangeNode, + SiluNode, + SigmoidNode, + TanhNode, + SqueezeNode, + SplitNode, + RsqrtNode, + MaximumNode, + MinimumNode, + LogNode, + SoftmaxNode, + BroadcastToNode, + PadNode, + WhereNode, + ReshapeNode, + TransposeNode, + AsStridedNode, + ContiguousNode, + GatherNode, + SliceNode, + AsTypeNode, + QuantizedLinearNode, + ConcatenateNode, + FullNode, + FullLikeNode, + ArgmaxNode, + SliceUpdateNode, + IndexCopyNode, + DequantizeNode, + LessNode, + LessEqualNode, + GreaterNode, + GreaterEqualNode, + EqualNode, + NotEqualNode, + LogicalNotNode, + LogicalAndNode, + LogicalOrNode, + TriNode, + TrilNode, + TriuNode, + FloorNode, + CeilNode, + SquareNode, + ExpNode, + SinNode, + CosNode, + TanNode, + ArcsinNode, + ArccosNode, + ArctanNode, + SinhNode, + CoshNode, + ArcsinhNode, + ArccoshNode, + ArctanhNode, + Log2Node, + Log10Node, + Log1pNode, + ErfNode, + Expm1Node, + RoundNode, + ReciprocalNode, + SqrtNode, + AbsNode, + NegNode, + Atan2Node, + LogAddExpNode, + FloorDivideNode, + PowerNode, + LogSumExpNode, + SumNode, + MeanNode, + VarNode, + StdNode, + ProdNode, + MaxNode, + MinNode, + ArgminNode, + MedianNode, + ModIntNode, + RemainderNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + ClipNode, + CumsumNode, + StackNode, + SignNode, + AnyNode, + AllNode, + RepeatNode, + SortNode, + ArgsortNode, + PartitionNode, + ArgPartitionNode + // BC: Add new op nodes here (append only) +>>>>>>> 7e54fa1e87 (up) } // ============================================================================= diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt index 2a709a63412..39024639d1d 100644 --- a/backends/mlx/test/CMakeLists.txt +++ b/backends/mlx/test/CMakeLists.txt @@ -49,3 +49,23 @@ target_link_libraries( strict_compile_test PRIVATE mlx_schema executorch_core mlx ) add_dependencies(op_test_runner strict_compile_test) + +# Multi-threaded inference test +include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + +et_cxx_test( + multi_thread_test_runner + SOURCES + ${CMAKE_CURRENT_LIST_DIR}/multi_thread_test_runner.cpp + EXTRA_LIBS + extension_module + extension_tensor + mlxdelegate +) + +# Add sanitizer link flags to multi_thread_test_runner if enabled +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options( + multi_thread_test_runner PRIVATE ${_mlx_sanitizer_link_options} + ) +endif() diff --git a/backends/mlx/test/export_multi_thread_test_model.py b/backends/mlx/test/export_multi_thread_test_model.py new file mode 100644 index 00000000000..3c6500cad78 --- /dev/null +++ b/backends/mlx/test/export_multi_thread_test_model.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export a test model for the multi-threaded inference test. + +The model exercises multiple ops and a mutable buffer (KV cache), +producing deterministic outputs that the C++ test can verify. + +Model behavior (accumulation via KV cache): + forward(x, input_pos): + x: [1, 1, 1, dim] (input tensor) + input_pos: [1] (cache write position, always 0) + + z = relu(x * 2 + 1) # always 3.0 with ones input + old_k = cache.k_cache[:, :, :1, :] # read old cache at pos 0 + new_val = z + old_k # accumulate: 3 + old + k_cache, v_cache = cache.update(input_pos, new_val, new_val) + return k_cache[:, :, :1, :] + v_cache[:, :, :1, :] + +With all-ones input and input_pos=[0], calling forward N times: + Call 1: old=0, new_val=3, cache=3. Output = 3 + 3 = 6.0 + Call 2: old=3, new_val=6, cache=6. Output = 6 + 6 = 12.0 + Call N: Output = 6.0 * N + +The C++ test can verify: output == 6.0 * call_number (all elements). + +Usage: + python export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte +""" + +import argparse + +import torch +import torch.nn as nn + +from executorch.backends.mlx.llm.cache import KVCache +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.exir import to_edge_transform_and_lower +from executorch.exir.capture._config import ExecutorchBackendConfig + + +class MultiOpCacheModel(nn.Module): + """ + A model with multiple ops and a mutable KV cache buffer that accumulates. + + Each forward() call: + 1. Computes z = relu(x * 2 + 1) — mul, add, relu (= 3.0 with ones) + 2. Reads old cache value at pos 0 — old_k + 3. Accumulates: new_val = z + old_k — add + 4. Writes new_val to both k and v caches — mutable buffer via kv_cache_update + 5. Returns k_cache + v_cache at pos 0 — sum of both cache slices + + With ones input, output = 6.0 * call_number (all elements). + """ + + def __init__(self, dim=4, max_len=8): + super().__init__() + self.cache = KVCache( + max_batch_size=1, + max_context_length=max_len, + n_heads=1, + head_dim=dim, + enable_dynamic_shape=True, + ) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + z = torch.relu(x * 2.0 + 1.0) + old_k = self.cache.k_cache[:, :, :1, :] + new_val = z + old_k + k_cache, v_cache = self.cache.update(input_pos, new_val, new_val) + return k_cache[:, :, :1, :] + v_cache[:, :, :1, :] + + +def export_model(output_path: str, dim=4, max_len=8): + model = MultiOpCacheModel(dim=dim, max_len=max_len) + example_inputs = ( + torch.randn(1, 1, 1, dim), # x: [B, H, S, D] + torch.tensor([0], dtype=torch.int64), # input_pos + ) + + with torch.no_grad(): + exported = torch.export.export(model, example_inputs) + exported = exported.run_decompositions({}) + + et_program = to_edge_transform_and_lower(exported, partitioner=[MLXPartitioner()]) + et_program = et_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + with open(output_path, "wb") as f: + f.write(et_program.buffer) + print(f"Exported model to {output_path}") + + # Verify accumulation pattern + model_ref = MultiOpCacheModel(dim=dim, max_len=max_len) + x = torch.ones(1, 1, 1, dim) + input_pos = torch.tensor([0], dtype=torch.int64) + print(f"Reference (ones input, dim={dim}, max_len={max_len}):") + for i in range(1, 4): + result = model_ref(x, input_pos) + expected = 6.0 * i + actual = result[0, 0, 0, 0].item() + status = "OK" if abs(actual - expected) < 1e-6 else "FAIL" + print(f" Call {i}: output={actual:.1f}, expected={expected:.1f} [{status}]") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "output", + nargs="?", + default="/tmp/multi_thread_test_model.pte", + help="Output .pte path (default: /tmp/multi_thread_test_model.pte)", + ) + args = parser.parse_args() + export_model(args.output) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/test/multi_thread_test_runner.cpp b/backends/mlx/test/multi_thread_test_runner.cpp new file mode 100644 index 00000000000..72c0917d81e --- /dev/null +++ b/backends/mlx/test/multi_thread_test_runner.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Multi-threaded inference stress test for the MLX delegate. + * + * Loads a .pte model on multiple threads (each with its own Module instance) + * and runs forward passes in parallel, verifying that all succeed and + * produce correct outputs. + * + * The model accumulates via KV cache: with all-ones input and input_pos=[0], + * call N produces output = 6.0 * N (all elements). Each thread has its own + * Module (and cache state), so correctness is verified independently. + * + * The test expects a model exported by export_multi_thread_test_model.py. + * + * Build: + * cmake --preset mlx + * cmake --build cmake-out --target multi_thread_test_runner + * + * Usage: + * ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \ + * ./cmake-out/backends/mlx/test/multi_thread_test_runner + * + * Environment variables: + * ET_TESTING_MODEL_PATH Path to .pte model file (required) + * ET_TESTING_NUM_THREADS Number of parallel threads (default: 4) + * ET_PREDICTIONS_PER_THREAD Inferences per thread (default: 10) + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::executorch::runtime; +using namespace ::executorch::extension; + +const std::string kTestPTEPath = [] { + if (const char* env_p = std::getenv("ET_TESTING_MODEL_PATH")) { + return std::string(env_p); + } + return std::string("model.pte"); +}(); + +const int kNumThreads = [] { + if (const char* env_p = std::getenv("ET_TESTING_NUM_THREADS")) { + try { + return std::stoi(env_p); + } catch (...) { + } + } + return 4; +}(); + +const int kPredictionsPerThread = [] { + if (const char* env_p = std::getenv("ET_PREDICTIONS_PER_THREAD")) { + try { + return std::stoi(env_p); + } catch (...) { + } + } + return 10; +}(); + +std::vector get_ones_inputs(Module& module) { + const auto method_meta = module.method_meta("forward"); + const auto num_inputs = method_meta->num_inputs(); + + std::vector tensors; + tensors.reserve(num_inputs); + + for (auto index = 0; index < num_inputs; ++index) { + const auto input_tag = method_meta->input_tag(index); + + switch (*input_tag) { + case Tag::Tensor: { + const auto tensor_meta = method_meta->input_tensor_meta(index); + const auto sizes = tensor_meta->sizes(); + if (tensor_meta->scalar_type() == exec_aten::ScalarType::Long) { + tensors.emplace_back( + zeros({sizes.begin(), sizes.end()}, tensor_meta->scalar_type())); + } else { + tensors.emplace_back( + ones({sizes.begin(), sizes.end()}, tensor_meta->scalar_type())); + } + } break; + default: + throw std::runtime_error( + "Unsupported input tag at index " + std::to_string(index)); + } + } + return tensors; +} + +struct ThreadResult { + size_t success_count{0}; + size_t correctness_failures{0}; + std::string error_message; +}; + +void run_predict( + int thread_id, + const std::string& model_path, + ThreadResult& result) { + Module module(model_path); + + for (int pred = 0; pred < kPredictionsPerThread; pred++) { + auto inputs = get_ones_inputs(module); + for (int i = 0; i < inputs.size(); i++) { + if (module.set_input(inputs[i], i) != Error::Ok) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": set_input(" << i << ") failed" << std::endl; + break; + } + } + + const auto forward_result = module.forward(); + + if (!forward_result.ok()) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": forward() failed with error " + << static_cast(forward_result.error()) << std::endl; + continue; + } + + const auto outputs = forward_result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": no tensor output" << std::endl; + continue; + } + + const auto& output_tensor = outputs[0].toTensor(); + const float* data = output_tensor.const_data_ptr(); + const float expected = 6.0f * (pred + 1); + bool correct = true; + for (ssize_t j = 0; j < output_tensor.numel(); j++) { + if (std::fabs(data[j] - expected) > 1e-4f) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": output[" << j << "] = " << data[j] << ", expected " + << expected << std::endl; + correct = false; + break; + } + } + if (!correct) { + result.correctness_failures++; + } + + result.success_count++; + } +} + +TEST(MLXMultiThreadTest, LoadAndRunParallel) { + ASSERT_FALSE(kTestPTEPath.empty()) << "ET_TESTING_MODEL_PATH must be set"; + ASSERT_GT(kNumThreads, 0) << "ET_TESTING_NUM_THREADS must be > 0"; + ASSERT_GT(kPredictionsPerThread, 0) + << "ET_PREDICTIONS_PER_THREAD must be > 0"; + + std::cout << "Running " << kNumThreads << " threads x " + << kPredictionsPerThread + << " predictions with model: " << kTestPTEPath << std::endl; + + std::vector threads(kNumThreads); + std::vector results(kNumThreads); + + for (int i = 0; i < kNumThreads; i++) { + threads[i] = + std::thread([&, i]() { run_predict(i, kTestPTEPath, results[i]); }); + } + for (int i = 0; i < kNumThreads; i++) { + threads[i].join(); + } + + size_t total_success = 0; + size_t total_correctness_failures = 0; + for (int i = 0; i < kNumThreads; i++) { + total_success += results[i].success_count; + total_correctness_failures += results[i].correctness_failures; + } + + const size_t total = kNumThreads * kPredictionsPerThread; + std::cout << "Success: " << total_success << "/" << total << std::endl; + std::cout << "Correctness failures: " << total_correctness_failures + << std::endl; + + ASSERT_EQ(total_success, total) << "Some forward() calls failed"; + ASSERT_EQ(total_correctness_failures, 0) << "Some outputs were incorrect"; +} diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 0ba98b532ad..164e94e3e78 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -24,7 +24,7 @@ See README.md in this directory for full documentation. """ -from typing import List, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -34,52 +34,6041 @@ custom_ops, ops, ) +from torch.export import Dim from .test_utils import OpTestCase, register_test +class AddTensorModel(nn.Module): + """Add two tensors, optionally with alpha.""" + + def __init__(self, alpha: Optional[float] = None): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if self.alpha is not None: + return torch.add(x, y, alpha=self.alpha) + return x + y + + +class AddScalarModel(nn.Module): + """Add tensor and scalar.""" + + def __init__(self, scalar: float = 1.0): + super().__init__() + self.scalar = scalar + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.scalar + + +@register_test +class AddTest(OpTestCase): + """Test case for add op.""" + + name = "add" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 16, 64), + scalar: Optional[float] = None, + alpha: Optional[float] = None, + ): + self.shape = shape + self.scalar = scalar + self.alpha = alpha + + if alpha is not None: + self.name = "add_alpha" + elif scalar is not None: + self.name = "add_scalar" + else: + self.name = "add" + + @classmethod + def get_test_configs(cls) -> List["AddTest"]: + return [ + cls(), # tensor + tensor + cls(scalar=2.5), # tensor + scalar + cls(alpha=2.0), # tensor + alpha * tensor + ] + + def create_model(self) -> nn.Module: + if self.scalar is not None: + return AddScalarModel(self.scalar) + else: + return AddTensorModel(self.alpha) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + if self.scalar is not None: + return (x,) + else: + y = torch.randn(self.shape) + return (x, y) + + +class SubModel(nn.Module): + """Model that performs element-wise subtraction, optionally with alpha.""" + + def __init__(self, alpha: Optional[float] = None): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if self.alpha is not None: + return torch.sub(x, y, alpha=self.alpha) + return torch.sub(x, y) + + +@register_test +class SubTest(OpTestCase): + name = "sub" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + scalar_sub: bool = False, + alpha: Optional[float] = None, + ): + self.shape = shape + self.scalar_sub = scalar_sub + self.alpha = alpha + shape_str = "x".join(str(s) for s in shape) + if alpha is not None: + self.name = f"sub_{shape_str}_alpha" + elif scalar_sub: + self.name = f"sub_{shape_str}_scalar" + else: + self.name = f"sub_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["SubTest"]: + return [ + cls(shape=(2, 3, 4)), + cls(shape=(10,)), + cls(shape=(4, 8)), + cls(shape=(2, 8, 16)), + cls(shape=(1, 128, 128)), + cls(shape=(2, 3, 4), scalar_sub=True), + cls(shape=(2, 3, 4), alpha=2.0), + ] + + def create_model(self) -> nn.Module: + return SubModel(self.alpha) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + if self.scalar_sub: + y = torch.randn(()) + else: + y = torch.randn(self.shape) + return (x, y) + + +class MulTensorModel(nn.Module): + """Multiply two tensors.""" + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x * y + + +class MulScalarModel(nn.Module): + """Multiply tensor and scalar.""" + + def __init__(self, scalar: float = 1.0): + super().__init__() + self.scalar = scalar + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.scalar + + +@register_test +class MulTest(OpTestCase): + """Test case for mul op.""" + + name = "mul" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 16, 64), + scalar: Optional[float] = None, + ): + self.shape = shape + self.scalar = scalar + + if scalar is not None: + self.name = "mul_scalar" + else: + self.name = "mul" + + @classmethod + def get_test_configs(cls) -> List["MulTest"]: + return [ + cls(), + cls(scalar=2.5), + ] + + def create_model(self) -> nn.Module: + if self.scalar is not None: + return MulScalarModel(self.scalar) + else: + return MulTensorModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + if self.scalar is not None: + return (x,) + else: + y = torch.randn(self.shape) + return (x, y) + + +class DivModel(nn.Module): + """Model that performs element-wise division.""" + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.div(x, y) + + +@register_test +class DivTest(OpTestCase): + name = "div" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + scalar_divisor: bool = False, + ): + self.shape = shape + self.scalar_divisor = scalar_divisor + shape_str = "x".join(str(s) for s in shape) + if scalar_divisor: + self.name = f"div_{shape_str}_scalar" + else: + self.name = f"div_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["DivTest"]: + return [ + cls(shape=(2, 3, 4)), + cls(shape=(10,)), + cls(shape=(4, 8)), + cls(shape=(2, 8, 16)), + cls(shape=(1, 128, 64)), + cls(shape=(2, 3, 4), scalar_divisor=True), + ] + + def create_model(self) -> nn.Module: + return DivModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + 2.0 + if self.scalar_divisor: + y = torch.randn(()) + 2.0 + else: + y = torch.randn(self.shape) + 2.0 + return (x, y) + + +class ClampModel(nn.Module): + """Model that applies clamp with min and max.""" + + def __init__(self, min_val: Optional[float], max_val: Optional[float]): + super().__init__() + self.min_val = min_val + self.max_val = max_val + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.clamp(x, min=self.min_val, max=self.max_val) + + +@register_test +class ClampTest(OpTestCase): + """Test case for clamp op with various min/max combinations.""" + + name = "clamp" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + min_val: Optional[float] = None, + max_val: Optional[float] = None, + ): + self.shape = shape + self.min_val = min_val + self.max_val = max_val + + # Build descriptive name + parts = ["clamp"] + if min_val is not None: + parts.append(f"min{min_val}") + if max_val is not None: + parts.append(f"max{max_val}") + if min_val is None and max_val is None: + parts.append("none") + shape_str = "x".join(str(s) for s in shape) + parts.append(shape_str) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ClampTest"]: + return [ + # Only min specified + cls(shape=(2, 3, 4), min_val=-0.5, max_val=None), + # Only max specified + cls(shape=(2, 3, 4), min_val=None, max_val=0.5), + # Both min and max specified + cls(shape=(2, 3, 4), min_val=-0.5, max_val=0.5), + # Different shapes + cls(shape=(10,), min_val=-1.0, max_val=1.0), + cls(shape=(4, 8), min_val=0.0, max_val=None), # ReLU-like + cls(shape=(2, 8, 16), min_val=-0.25, max_val=0.75), + ] + + def create_model(self) -> nn.Module: + return ClampModel(self.min_val, self.max_val) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Create inputs with values that span beyond typical clamp range + x = torch.randn(self.shape) * 2 # values roughly in [-4, 4] + return (x,) + + +class GELUModel(nn.Module): + """Simple model using GELU activation.""" + + def __init__(self, approximate: str = "none"): + super().__init__() + self.gelu = nn.GELU(approximate=approximate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.gelu(x) + + +@register_test +class GELUTest(OpTestCase): + """Test case for GELU activation.""" + + name = "gelu" + + def __init__(self, shape: Tuple[int, ...] = (2, 16, 64), approximate: str = "none"): + self.shape = shape + self.approximate = approximate + self.name = f"gelu_{approximate}" if approximate != "none" else "gelu" + + @classmethod + def get_test_configs(cls) -> List["GELUTest"]: + return [ + cls(), + cls(shape=(4, 32, 128)), + cls(approximate="tanh"), + cls(shape=(4, 32, 128), approximate="tanh"), + ] + + def create_model(self) -> nn.Module: + return GELUModel(approximate=self.approximate) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SoftmaxModel(nn.Module): + """Model that performs softmax along a specified dimension.""" + + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.softmax(x, dim=self.dim) + + +@register_test +class SoftmaxTest(OpTestCase): + """Test case for softmax op.""" + + name = "softmax" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + dim: int = -1, + ): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"softmax_{shape_str}_dim{dim}" + + @classmethod + def get_test_configs(cls) -> List["SoftmaxTest"]: + return [ + cls(shape=(2, 3, 4), dim=-1), + cls(shape=(2, 3, 4), dim=1), + cls(shape=(4, 8), dim=-1), + cls(shape=(2, 4, 8, 16), dim=-1), + ] + + def create_model(self) -> nn.Module: + return SoftmaxModel(dim=self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class LogSoftmaxModel(nn.Module): + """Model that applies log_softmax.""" + + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.log_softmax(x, dim=self.dim) + + +@register_test +class LogSoftmaxTest(OpTestCase): + name = "log_softmax" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (2, 3, 4), dim: int = -1): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"log_softmax_{shape_str}_dim{dim}" + + @classmethod + def get_test_configs(cls) -> List["LogSoftmaxTest"]: + return [ + cls(shape=(2, 3, 4), dim=-1), + cls(shape=(10,), dim=0), + cls(shape=(4, 8), dim=1), + cls(shape=(2, 8, 16), dim=1), + cls(shape=(1, 128, 512), dim=-1), + ] + + def create_model(self) -> nn.Module: + return LogSoftmaxModel(dim=self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SqueezeModel(nn.Module): + """Model that squeezes a tensor at specified dimensions.""" + + def __init__(self, dims: Optional[Tuple[int, ...]] = None): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims is None: + return torch.squeeze(x) + else: + return torch.squeeze(x, dim=self.dims) + + +@register_test +class SqueezeTest(OpTestCase): + """Test case for squeeze op.""" + + name = "squeeze" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (1, 3, 1, 4), + dims: Optional[Tuple[int, ...]] = (0, 2), + ): + self.shape = shape + self.dims = dims + shape_str = "x".join(str(s) for s in shape) + if dims is None: + dims_str = "all" + elif len(dims) == 0: + dims_str = "empty" + else: + dims_str = "_".join(str(d) for d in dims) + self.name = f"squeeze_{shape_str}_dims{dims_str}" + + @classmethod + def get_test_configs(cls) -> List["SqueezeTest"]: + return [ + cls(shape=(1, 3, 1, 4), dims=(0, 2)), + cls(shape=(1, 5, 1, 1), dims=(0,)), + cls(shape=(3, 1, 4), dims=(1,)), + cls(shape=(1, 1, 8), dims=(0, 1)), + cls(shape=(2, 1, 3, 1), dims=(1, 3)), + # Squeeze all singleton dims (no dims specified) + cls(shape=(1, 3, 1, 4), dims=None), + # Dims include non-size-1 axes (should be no-op for those axes) + cls(shape=(1, 1, 1, 8198), dims=(0, 1, 2, 3)), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + def create_model(self) -> nn.Module: + return SqueezeModel(self.dims) + + +class UnsqueezeModel(nn.Module): + """Model that unsqueezes a tensor at a given dimension.""" + + def __init__(self, dim: int = 0): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(self.dim) + + +@register_test +class UnsqueezeTest(OpTestCase): + """Test case for unsqueeze op.""" + + name = "unsqueeze" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 16, 64), + dim: int = 0, + ): + self.shape = shape + self.dim = dim + self.name = f"unsqueeze_dim{dim}" + + @classmethod + def get_test_configs(cls) -> List["UnsqueezeTest"]: + return [ + cls(dim=0), + cls(dim=1), + cls(dim=-1), + ] + + def create_model(self) -> nn.Module: + return UnsqueezeModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class PermuteModel(nn.Module): + """Model that permutes tensor dimensions.""" + + def __init__(self, dims: Tuple[int, ...] = (0, 2, 1, 3)): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(self.dims) + + +class TransposeModel(nn.Module): + """Model that transposes two dimensions.""" + + def __init__(self, dim0: int = 1, dim1: int = 2): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.transpose(self.dim0, self.dim1) + + +@register_test +class PermuteTest(OpTestCase): + """Test case for permute and transpose ops.""" + + name = "permute" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 8, 16, 64), + variant: str = "permute", + permute_dims: Tuple[int, ...] = (0, 2, 1, 3), + transpose_dims: Tuple[int, int] = (1, 2), + ): + self.shape = shape + self.variant = variant + self.permute_dims = permute_dims + self.transpose_dims = transpose_dims + + if variant == "transpose": + self.name = "transpose" + else: + self.name = "permute" + + @classmethod + def get_test_configs(cls) -> List["PermuteTest"]: + return [ + cls(variant="permute", permute_dims=(0, 2, 1, 3)), + cls(variant="transpose", transpose_dims=(1, 2)), + ] + + def create_model(self) -> nn.Module: + if self.variant == "transpose": + return TransposeModel(self.transpose_dims[0], self.transpose_dims[1]) + else: + return PermuteModel(self.permute_dims) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class NarrowModel(nn.Module): + """Model that narrows a tensor along a dimension.""" + + def __init__(self, dim: int, start: int, length: int): + super().__init__() + self.dim = dim + self.start = start + self.length = length + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.narrow(self.dim, self.start, self.length) + + +@register_test +class NarrowTest(OpTestCase): + """Test case for tensor.narrow().""" + + name = "narrow" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 16, 8), + dim: int = 1, + start: int = 2, + length: int = 8, + ): + self.shape = shape + self.dim = dim + self.start = start + self.length = length + self.name = f"narrow_dim{dim}_start{start}_len{length}" + + @classmethod + def get_test_configs(cls) -> List["NarrowTest"]: + return [ + cls(shape=(4, 16, 8), dim=1, start=2, length=8), + cls(shape=(8, 8), dim=0, start=1, length=4), + cls(shape=(2, 32, 4), dim=1, start=0, length=16), + ] + + def create_model(self) -> nn.Module: + return NarrowModel(self.dim, self.start, self.length) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SelectModel(nn.Module): + """Model that selects a single index along a dimension. + + torch.select(input, dim, index) returns input[..., index, ...] where + the indexing happens at dimension `dim`. The selected dimension is removed. + Maps to aten.select_copy.int -> MLX take(array, index, axis). + """ + + def __init__(self, dim: int, index: int): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.select(x, self.dim, self.index) + + +@register_test +class SelectTest(OpTestCase): + """Test case for torch.select (aten.select_copy.int -> TakeNode).""" + + name = "select" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 8, 16), + dim: int = 1, + index: int = 3, + ): + self.shape = shape + self.dim = dim + self.index = index + self.name = f"select_dim{dim}_idx{index}" + + @classmethod + def get_test_configs(cls) -> List["SelectTest"]: + return [ + cls(shape=(4, 8, 16), dim=0, index=2), + cls(shape=(4, 8, 16), dim=1, index=3), + cls(shape=(4, 8, 16), dim=2, index=0), + cls(shape=(4, 8, 16), dim=-1, index=5), + cls(shape=(2, 3), dim=0, index=1), + ] + + def create_model(self) -> nn.Module: + return SelectModel(self.dim, self.index) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SliceModel(nn.Module): + """Model that slices a tensor along dimension 1.""" + + def __init__(self, start: int, stop: int): + super().__init__() + self.start = start + self.stop = stop + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[:, self.start : self.stop] + + +class SliceDim0Model(nn.Module): + """Model that slices a tensor along dimension 0.""" + + def __init__(self, start: int, stop: int): + super().__init__() + self.start = start + self.stop = stop + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[self.start : self.stop] + + +@register_test +class SliceTest(OpTestCase): + """Test case for tensor slicing.""" + + name = "slice" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 16, 8), + dim: int = 1, + start: int = 2, + stop: int = 10, + ): + self.shape = shape + self.dim = dim + self.start = start + self.stop = stop + self.name = f"slice_dim{dim}_{start}to{stop}" + + @classmethod + def get_test_configs(cls) -> List["SliceTest"]: + return [ + cls(shape=(4, 16, 8), dim=1, start=2, stop=10), + cls(shape=(8, 8), dim=0, start=1, stop=5), + cls(shape=(2, 32, 4), dim=1, start=0, stop=16), + ] + + def create_model(self) -> nn.Module: + if self.dim == 0: + return SliceDim0Model(self.start, self.stop) + return SliceModel(self.start, self.stop) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class RepeatModel(nn.Module): + """Model that repeats a tensor along specified dimensions.""" + + def __init__(self, repeats: Tuple[int, ...]): + super().__init__() + self.repeats = repeats + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.repeat(*self.repeats) + + +@register_test +class RepeatTest(OpTestCase): + """Test case for tensor.repeat().""" + + name = "repeat" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + input_shape: Tuple[int, ...] = (2, 3, 4), + repeats: Tuple[int, ...] = (2, 1, 3), + ): + self.input_shape = input_shape + self.repeats = repeats + repeat_str = "x".join(str(r) for r in repeats) + self.name = f"repeat_{repeat_str}" + + @classmethod + def get_test_configs(cls) -> List["RepeatTest"]: + return [ + cls(input_shape=(2, 3), repeats=(2, 3)), + cls(input_shape=(2, 3, 4), repeats=(1, 2, 1)), + cls(input_shape=(4, 4), repeats=(3, 3)), + ] + + def create_model(self) -> nn.Module: + return RepeatModel(self.repeats) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + return (x,) + + +class CatNModel(nn.Module): + """Model that concatenates N tensors along a dimension.""" + + def __init__(self, dim: int = 0, n: int = 3): + super().__init__() + self.dim = dim + self.n = n + + def forward(self, *tensors: torch.Tensor) -> torch.Tensor: + return torch.cat(tensors[: self.n], dim=self.dim) + + +@register_test +class CatTest(OpTestCase): + name = "cat" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shapes: List[Tuple[int, ...]], dim: int = 0, tag: str = ""): + self.shapes = shapes + self.dim = dim + self.name = f"cat_{tag}" if tag else "cat" + + @classmethod + def get_test_configs(cls) -> List["CatTest"]: + return [ + cls(shapes=[(2, 3), (4, 3), (1, 3)], dim=0, tag="2d_dim0"), + cls(shapes=[(3, 2), (3, 4), (3, 1)], dim=1, tag="2d_dim1"), + cls(shapes=[(2, 3, 4), (5, 3, 4), (3, 3, 4)], dim=0, tag="3d_dim0"), + cls(shapes=[(3, 4), (2, 4)], dim=0, tag="two_tensors"), + cls(shapes=[(3, 2, 4), (3, 5, 4), (3, 1, 4)], dim=-2, tag="neg_dim"), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return tuple(torch.randn(s) for s in self.shapes) + + def create_model(self) -> nn.Module: + return CatNModel(dim=self.dim, n=len(self.shapes)) + + +class WhereModel(nn.Module): + """Model that conditionally selects from x or y based on condition.""" + + def forward( + self, condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + return torch.where(condition, x, y) + + +@register_test +class WhereTest(OpTestCase): + """Test case for where op.""" + + name = "where" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (2, 3, 4)): + self.shape = shape + shape_str = "x".join(str(s) for s in shape) + self.name = f"where_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["WhereTest"]: + return [ + cls(shape=(2, 3, 4)), + cls(shape=(4, 8)), + cls(shape=(2, 8, 16, 16)), + cls(shape=(1, 1, 128, 128)), + ] + + def create_model(self) -> nn.Module: + return WhereModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + condition = torch.rand(self.shape) > 0.5 + x = torch.randn(self.shape) + y = torch.randn(self.shape) + return (condition, x, y) + + +class PadModel(nn.Module): + """Model that pads a tensor with a constant value.""" + + def __init__(self, pad: Tuple[int, ...], value: float = 0.0): + super().__init__() + self.pad = pad + self.value = value + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.pad(x, self.pad, mode="constant", value=self.value) + + +@register_test +class PadTest(OpTestCase): + name = "pad" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + pad: Tuple[int, ...] = (1, 1, 1, 1), + value: float = 0.0, + ): + self.shape = shape + self.pad = pad + self.value = value + shape_str = "x".join(str(s) for s in shape) + pad_str = "_".join(str(p) for p in pad) + self.name = f"pad_{shape_str}_p{pad_str}_v{int(value)}" + + @classmethod + def get_test_configs(cls) -> List["PadTest"]: + return [ + cls(shape=(2, 3, 4), pad=(1, 1, 1, 1), value=0.0), + cls(shape=(10,), pad=(2, 3), value=0.0), + cls(shape=(4, 8), pad=(1, 2), value=0.0), + cls(shape=(2, 8, 16), pad=(1, 1, 2, 2), value=0.0), + cls(shape=(1, 3, 32, 32), pad=(1, 1, 1, 1), value=0.0), + cls(shape=(2, 3, 4), pad=(1, 1, 1, 1), value=1.0), + ] + + def create_model(self) -> nn.Module: + return PadModel(self.pad, self.value) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class LinearModel(nn.Module): + """Simple linear layer for testing.""" + + def __init__( + self, in_features: int = 64, out_features: int = 128, bias: bool = True + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class LinearTest(OpTestCase): + """Test case for nn.Linear.""" + + name = "linear" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + + if not bias: + self.name = "linear_no_bias" + else: + self.name = "linear" + + @classmethod + def get_test_configs(cls) -> List["LinearTest"]: + return [ + cls(), + cls(bias=False), + ] + + def create_model(self) -> nn.Module: + return LinearModel(self.in_features, self.out_features, bias=self.bias) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.seq_len, self.in_features) + return (x,) + + +class EmbeddingModel(nn.Module): + """Simple embedding layer for testing.""" + + def __init__(self, num_embeddings: int = 1000, embedding_dim: int = 64): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.embedding(x) + + +@register_test +class EmbeddingTest(OpTestCase): + """Test case for nn.Embedding.""" + + name = "embedding" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + num_embeddings: int = 1000, + embedding_dim: int = 64, + batch_size: int = 2, + seq_len: int = 16, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.name = "embedding" + + @classmethod + def get_test_configs(cls) -> List["EmbeddingTest"]: + return [ + cls(), + cls(num_embeddings=512, embedding_dim=128), + ] + + def create_model(self) -> nn.Module: + return EmbeddingModel(self.num_embeddings, self.embedding_dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) + return (x,) + + +class MaxPool1dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.MaxPool1d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class MaxPool1dTest(OpTestCase): + name = "max_pool1d" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + seq_len: int = 32, + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.seq_len = seq_len + self.batch_size = batch_size + + if tag: + self.name = f"max_pool1d_{tag}" + else: + parts = ["max_pool1d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MaxPool1dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Fast path: larger kernel + cls(kernel_size=4, stride=4, seq_len=64), + # stride=None (defaults to kernel_size) + cls(kernel_size=4, stride=None, seq_len=64, tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Global pooling: kernel == spatial size + cls(kernel_size=32, stride=32, tag="global"), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=4, tag="batch4"), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, seq_len=32, tag="stride_gt_kernel"), + ] + + def create_model(self) -> nn.Module: + return MaxPool1dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.batch_size, self.in_channels, self.seq_len),) + + +class MaxPool2dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class MaxPool2dTest(OpTestCase): + name = "max_pool2d" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 16, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"max_pool2d_{tag}" + else: + parts = ["max_pool2d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MaxPool2dTest"]: + return [ + # Fast path: kernel == stride, evenly divisible + cls(kernel_size=2, stride=2, input_size=(32, 32)), + # General path: overlapping windows + cls(kernel_size=3, stride=2, padding=1, input_size=(32, 32)), + # Fast path: 4x4 pooling + cls(kernel_size=4, stride=4, input_size=(64, 64)), + # General path: stride != kernel, no padding + cls(kernel_size=3, stride=1, input_size=(16, 16)), + # Batch > 1 + cls(kernel_size=2, stride=2, input_size=(32, 32), batch_size=4), + # stride=None (defaults to kernel_size) + cls(kernel_size=2, stride=None, input_size=(32, 32), tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, input_size=(16, 16), tag="c1"), + # Global pooling + cls(kernel_size=8, stride=8, input_size=(8, 8), tag="global"), + # Non-square kernel/stride + cls( + kernel_size=(2, 3), + stride=(2, 3), + input_size=(16, 18), + tag="nonsquare_fast", + ), + cls( + kernel_size=(3, 2), + stride=(1, 2), + padding=(1, 0), + input_size=(16, 16), + tag="nonsquare_general", + ), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, input_size=(16, 16), tag="stride_gt_kernel"), + # Non-square input with square kernel + cls(kernel_size=2, stride=2, input_size=(16, 32), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return MaxPool2dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + ), + ) + + +class MaxPool3dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.MaxPool3d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class MaxPool3dTest(OpTestCase): + name = "max_pool3d" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"max_pool3d_{tag}" + else: + parts = ["max_pool3d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MaxPool3dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=2, tag="batch2"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Non-cubic kernel/stride + cls( + kernel_size=(2, 2, 4), + stride=(2, 2, 4), + input_size=(8, 16, 16), + tag="noncubic_fast", + ), + # Stride > kernel + cls( + kernel_size=2, stride=3, input_size=(8, 16, 16), tag="stride_gt_kernel" + ), + # stride=None (defaults to kernel_size) + cls(kernel_size=2, stride=None, tag="stride_none"), + # Global pooling: kernel == spatial + cls( + kernel_size=(8, 16, 16), + stride=(8, 16, 16), + input_size=(8, 16, 16), + tag="global", + ), + # Non-cubic general path (stride != kernel) + cls( + kernel_size=(3, 2, 2), + stride=(1, 2, 2), + padding=(1, 0, 0), + input_size=(8, 16, 16), + tag="noncubic_general", + ), + # Non-cubic input with cubic kernel + cls(kernel_size=2, stride=2, input_size=(4, 8, 16), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return MaxPool3dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + *self.input_size, + ), + ) + + +class AvgPool1dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.AvgPool1d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class AvgPool1dTest(OpTestCase): + name = "avg_pool1d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + seq_len: int = 32, + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.seq_len = seq_len + self.batch_size = batch_size + + if tag: + self.name = f"avg_pool1d_{tag}" + else: + parts = ["avg_pool1d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["AvgPool1dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Fast path: larger kernel + cls(kernel_size=4, stride=4, seq_len=64), + # stride=None (defaults to kernel_size) + cls(kernel_size=4, stride=None, seq_len=64, tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Global pooling + cls(kernel_size=32, stride=32, tag="global"), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=4, tag="batch4"), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, seq_len=32, tag="stride_gt_kernel"), + ] + + def create_model(self) -> nn.Module: + return AvgPool1dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.batch_size, self.in_channels, self.seq_len),) + + +class AvgPool2dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class AvgPool2dTest(OpTestCase): + name = "avg_pool2d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 16, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"avg_pool2d_{tag}" + else: + parts = ["avg_pool2d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["AvgPool2dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2, input_size=(32, 32)), + # General path: overlapping windows + cls(kernel_size=3, stride=2, padding=1, input_size=(32, 32)), + # Fast path: 4x4 pooling + cls(kernel_size=4, stride=4, input_size=(64, 64)), + # General path: stride != kernel + cls(kernel_size=3, stride=1, input_size=(16, 16)), + # Batch > 1 + cls(kernel_size=2, stride=2, input_size=(32, 32), batch_size=4), + # stride=None + cls(kernel_size=2, stride=None, input_size=(32, 32), tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, input_size=(16, 16), tag="c1"), + # Global pooling + cls(kernel_size=8, stride=8, input_size=(8, 8), tag="global"), + # Non-square kernel/stride + cls( + kernel_size=(2, 3), + stride=(2, 3), + input_size=(16, 18), + tag="nonsquare_fast", + ), + cls( + kernel_size=(3, 2), + stride=(1, 2), + padding=(1, 0), + input_size=(16, 16), + tag="nonsquare_general", + ), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, input_size=(16, 16), tag="stride_gt_kernel"), + # Non-square input with square kernel + cls(kernel_size=2, stride=2, input_size=(16, 32), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return AvgPool2dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + ), + ) + + +class AvgPool3dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.AvgPool3d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class AvgPool3dTest(OpTestCase): + name = "avg_pool3d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"avg_pool3d_{tag}" + else: + parts = ["avg_pool3d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["AvgPool3dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=2, tag="batch2"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Non-cubic kernel/stride + cls( + kernel_size=(2, 2, 4), + stride=(2, 2, 4), + input_size=(8, 16, 16), + tag="noncubic_fast", + ), + # Stride > kernel + cls( + kernel_size=2, stride=3, input_size=(8, 16, 16), tag="stride_gt_kernel" + ), + # stride=None (defaults to kernel_size) + cls(kernel_size=2, stride=None, tag="stride_none"), + # Global pooling: kernel == spatial + cls( + kernel_size=(8, 16, 16), + stride=(8, 16, 16), + input_size=(8, 16, 16), + tag="global", + ), + # Non-cubic general path (stride != kernel) + cls( + kernel_size=(3, 2, 2), + stride=(1, 2, 2), + padding=(1, 0, 0), + input_size=(8, 16, 16), + tag="noncubic_general", + ), + # Non-cubic input with cubic kernel + cls(kernel_size=2, stride=2, input_size=(4, 8, 16), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return AvgPool3dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + *self.input_size, + ), + ) + + +class RMSNormModel(nn.Module): + """Model using torch.nn.functional.rms_norm.""" + + def __init__(self, hidden_dim: int = 64, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_dim)) + self.hidden_dim = hidden_dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.rms_norm( + x, (self.hidden_dim,), self.weight, self.eps + ) + + +@register_test +class RMSNormTest(OpTestCase): + """Test case for torch.nn.functional.rms_norm (aten.rms_norm).""" + + name = "aten_rms_norm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + hidden_dim: int = 64, + batch_size: int = 2, + seq_len: int = 16, + eps: float = 1e-5, + ): + self.hidden_dim = hidden_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.eps = eps + self.name = "aten_rms_norm" + + @classmethod + def get_test_configs(cls) -> List["RMSNormTest"]: + return [ + cls(), + cls(hidden_dim=128, eps=1e-6), + ] + + def create_model(self) -> nn.Module: + return RMSNormModel(self.hidden_dim, self.eps) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.seq_len, self.hidden_dim) + return (x,) + + +class RopeModel(nn.Module): + """Model that applies RoPE with dynamic position.""" + + def __init__( + self, + dims: int = 64, + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + ): + super().__init__() + self.dims = dims + self.traditional = traditional + self.base = base + self.scale = scale + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + pos_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = pos_tensor.item() + q_rot = torch.ops.mlx.rope( + q, self.dims, pos, self.traditional, self.base, self.scale, None + ) + k_rot = torch.ops.mlx.rope( + k, self.dims, pos, self.traditional, self.base, self.scale, None + ) + return q_rot, k_rot + + +@register_test +class RopeTest(OpTestCase): + """Test case for RoPE.""" + + name = "rope" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 1, + num_heads: int = 8, + seq_len: int = 16, + head_dim: int = 64, + dims: Optional[int] = None, + pos: int = 0, + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + ): + self.batch_size = batch_size + self.num_heads = num_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.dims = dims if dims is not None else head_dim + self.pos = pos + self.traditional = traditional + self.base = base + self.scale = scale + self.name = "rope" + + @classmethod + def get_test_configs(cls) -> List["RopeTest"]: + configs = [ + cls(), + cls(traditional=True), + cls(head_dim=64, dims=32), + cls(head_dim=64, dims=32, traditional=True), + ] + for cfg in configs: + parts = ["rope"] + if cfg.traditional: + parts.append("traditional") + if cfg.dims != cfg.head_dim: + parts.append(f"dims{cfg.dims}") + cfg.name = "_".join(parts) + return configs + + def create_model(self) -> nn.Module: + return RopeModel( + dims=self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + k = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + pos_tensor = torch.tensor(self.pos, dtype=torch.int64) + return (q, k, pos_tensor) + + +from executorch.backends.mlx.llm.cache import KVCache + + +class KVCacheModel(nn.Module): + """ + Test model wrapping KVCache from cache.py. + + This tests the ExecutorTorch llama KVCache-compatible interface that uses + the mlx::kv_cache_update op internally. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = KVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position tensor + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache and return full cache tensors.""" + k_cache, v_cache = self.cache.update(input_pos, k_val, v_val) + return k_cache, v_cache + + +@register_test +class KVCacheTest(OpTestCase): + """ + Test case for MLX KVCache with ExecutorTorch llama KVCache interface. + + This verifies that KVCache: + 1. Accepts the ET llama KVCache update interface + 2. Correctly delegates to mlx::kv_cache_update custom op + 3. Produces correct outputs for both export and test inputs + """ + + name = "kv_cache" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["KVCacheTest"]: + return [ + cls(), # default config + cls(n_heads=8, head_dim=32), # different head config + cls(enable_dynamic_shape=False), # static shape mode + ] + + def create_model(self) -> nn.Module: + return KVCacheModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Note: KVCache.update() takes (input_pos, k_val, v_val) - position first + # Test with different position and different seq_step + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + # seq_step (dim 2) is dynamic for k_val and v_val + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class KVCacheIntModel(nn.Module): + """ + Test model that passes int/SymInt (not tensor) to KVCache.update(). + + This tests the "int route" where the caller extracts the start position + from the tensor before calling update, which is the preferred pattern + in multi-layer models to avoid redundant SymInt extraction. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = KVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position tensor + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract int from tensor, then pass to update (the int route).""" + start_pos = input_pos[0].item() + return self.cache.update(start_pos, k_val, v_val) + + +@register_test +class KVCacheIntTest(OpTestCase): + """ + Test case for MLX KVCache with int/SymInt input_pos. + + This verifies the "int route" where the caller extracts the start position + before calling update, matching the recommended pattern for multi-layer models. + """ + + name = "kv_cache_int" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["KVCacheIntTest"]: + return [ + cls(), # default config + cls(n_heads=8, head_dim=32), # different head config + ] + + def create_model(self) -> nn.Module: + return KVCacheIntModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class KVCacheSliceModel(nn.Module): + """ + Test model that updates KVCache then slices the result. + + This tests that operations on the returned cache work correctly, + matching the pattern used in attention implementations. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.max_context_length = max_context_length + self.cache = KVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position tensor + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache and return sliced result.""" + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + end_pos = start_pos + seq_len + + torch._check(start_pos >= 0) + torch._check(end_pos <= self.max_context_length) + torch._check(end_pos >= 0) + + k_cache, v_cache = self.cache.update(input_pos, k_val, v_val) + + k_valid = k_cache[:, :, :end_pos, :] + v_valid = v_cache[:, :, :end_pos, :] + return k_valid, v_valid + + +@register_test +class KVCacheSliceTest(OpTestCase): + """ + Test case for MLX KVCache update followed by slicing. + + This verifies that: + 1. The ET llama KVCache-compatible interface works correctly + 2. Subsequent slice operations on the returned cache work correctly + """ + + name = "kv_cache_slice" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["KVCacheSliceTest"]: + return [ + cls(), + cls(n_heads=8, head_dim=32), + ] + + def create_model(self) -> nn.Module: + return KVCacheSliceModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class RingBufferKVCacheModel(nn.Module): + """ + Test model wrapping RingBufferKVCache from cache.py. + + Updates the ring buffer cache and returns the full cache contents. + Uses kv_cache_update with ring_size > 0, which should emit + ModIntNode + SubtractIntNode + two SliceUpdateNodes. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import RingBufferKVCache + + self.cache = RingBufferKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + torch._check(start_pos >= 0) + + k_cache, v_cache = self.cache.update(start_pos, k_val, v_val) + return k_cache, v_cache + + +@register_test +class RingBufferKVCacheTest(OpTestCase): + """ + Test case for RingBufferKVCache with ring_size > 0. + + Verifies that kv_cache_update with ring_size emits the ring buffer + SliceUpdate pattern (ModInt + SubtractInt + 2x Slice + 2x SliceUpdate) + and produces correct results. + """ + + name = "ring_buffer_kv_cache" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 64, + seq_step: int = 4, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + + @classmethod + def get_test_configs(cls) -> List["RingBufferKVCacheTest"]: + return [ + cls(), + cls(n_heads=8, head_dim=32, max_context_length=32, seq_step=2), + ] + + def create_model(self) -> nn.Module: + return RingBufferKVCacheModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 2 + input_pos = torch.tensor([8], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + def get_expected_node_counts(self) -> Optional[Dict[str, int]]: + return { + "ItemIntNode": 1, + "ModIntNode": 2, + "SymSizeNode": 4, + "SubtractIntNode": 4, + "AddIntNode": 2, + "SliceNode": 4, + "SliceUpdateNode": 4, + "IdCopyNode": 2, + } + + +class MockModelConfig: + """ + Mock HuggingFace model config for testing HFStaticCache. + + This simulates the config structure expected by HFStaticCache. + """ + + def __init__( + self, + num_hidden_layers: int = 2, + num_attention_heads: int = 4, + num_key_value_heads: int | None = None, + hidden_size: int = 256, + head_dim: int | None = None, + max_position_embeddings: int = 128, + ): + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.hidden_size = hidden_size + self.head_dim = head_dim or (hidden_size // num_attention_heads) + self.max_position_embeddings = max_position_embeddings + + def get_text_config(self, **kwargs): + """Return self for HF StaticCache compatibility.""" + return self + + +class HFStaticCacheModel(nn.Module): + """ + Test model wrapping HFStaticCache from cache.py. + + This tests the HuggingFace-compatible StaticCache interface. + """ + + def __init__( + self, + config: MockModelConfig, + layer_idx: int = 0, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import HFStaticCache + + self.cache = HFStaticCache(config) + self.layer_idx = layer_idx + + # Register buffers explicitly so torch.export treats them as mutable + # buffers rather than constants. This mirrors what replace_hf_cache_with_mlx() does. + for i, layer_cache in enumerate(self.cache.kv_cache): + self.register_buffer( + f"key_cache_{i}", layer_cache.k_cache, persistent=False + ) + self.register_buffer( + f"value_cache_{i}", layer_cache.v_cache, persistent=False + ) + + def forward( + self, + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + cache_position: torch.Tensor, # 1D tensor with start position + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache using HuggingFace-style interface.""" + return self.cache.update( + k_val, + v_val, + self.layer_idx, + cache_kwargs={"cache_position": cache_position}, + ) + + +@register_test +class HFStaticCacheTest(OpTestCase): + """Test case for HFStaticCache with HuggingFace-compatible interface.""" + + name = "hf_static_cache" + rtol = 1e-5 + atol = 1e-5 + expected_node_counts = { + "ItemIntNode": 1, + "SymSizeNode": 2, + "AddIntNode": 2, + "SliceUpdateNode": 2, + "IdCopyNode": 2, + } + + def __init__( + self, + num_heads: int = 4, + head_dim: int = 64, + num_layers: int = 2, + max_seq_len: int = 128, + seq_step: int = 8, + layer_idx: int = 0, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.num_layers = num_layers + self.max_seq_len = max_seq_len + self.seq_step = seq_step + self.layer_idx = layer_idx + + @classmethod + def get_test_configs(cls) -> List["HFStaticCacheTest"]: + return [ + cls(), # default config, layer 0 + cls(num_heads=8, head_dim=32, layer_idx=1), # different config, layer 1 + ] + + def create_model(self) -> nn.Module: + config = MockModelConfig( + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + hidden_size=self.num_heads * self.head_dim, + head_dim=self.head_dim, + max_position_embeddings=self.max_seq_len, + ) + return HFStaticCacheModel(config, layer_idx=self.layer_idx) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # BHSD layout [B, H, S, D] + k_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + cache_position = torch.tensor([0], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Test with different position and different seq_step + test_seq_step = self.seq_step + 4 # Different from export seq_step + k_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + cache_position = torch.tensor([16], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + # seq_step (dim 2) is dynamic for k_val and v_val + seq_dim = Dim("seq_step", min=1, max=self.max_seq_len) + return { + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + "cache_position": None, + } + + +class HFStaticCacheSliceModel(nn.Module): + """ + Test model that updates HFStaticCache then slices the result. + + This tests that operations on the returned cache work correctly + with the HuggingFace-compatible interface. + """ + + def __init__( + self, + config: MockModelConfig, + layer_idx: int = 0, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import HFStaticCache + + self.max_seq_len = config.max_position_embeddings + self.cache = HFStaticCache(config) + self.layer_idx = layer_idx + + # Register buffers explicitly so torch.export treats them as mutable + # buffers rather than constants. This mirrors what replace_hf_cache_with_mlx() does. + for i, layer_cache in enumerate(self.cache.kv_cache): + self.register_buffer( + f"key_cache_{i}", layer_cache.k_cache, persistent=False + ) + self.register_buffer( + f"value_cache_{i}", layer_cache.v_cache, persistent=False + ) + + def forward( + self, + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + cache_position: torch.Tensor, # 1D tensor with start position + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache and return sliced cache (only the valid portion).""" + pos = cache_position[0].item() + seq_len = k_val.size(2) + end_pos = pos + seq_len + + # Add constraints for dynamic shapes + torch._check(pos >= 0) + torch._check(end_pos <= self.max_seq_len) + torch._check(end_pos >= 0) + + # Update cache using HuggingFace-style interface + k_cache, v_cache = self.cache.update( + k_val, + v_val, + self.layer_idx, + cache_kwargs={"cache_position": cache_position}, + ) + + # Slice to get only the valid portion [0:end_pos] + k_valid = k_cache[:, :, :end_pos, :] + v_valid = v_cache[:, :, :end_pos, :] + + return k_valid, v_valid + + +@register_test +class HFStaticCacheSliceTest(OpTestCase): + """ + Test case for HFStaticCache update followed by slicing. + + This verifies that: + 1. The HuggingFace-compatible interface works correctly + 2. Subsequent slice operations on the returned cache work correctly + """ + + name = "hf_static_cache_slice" + rtol = 1e-5 + atol = 1e-5 + expected_node_counts = { + "ItemIntNode": 2, + "SymSizeNode": 3, + "AddIntNode": 3, + "SliceUpdateNode": 2, + "IdCopyNode": 2, + "SliceNode": 2, + } + + def __init__( + self, + num_heads: int = 4, + head_dim: int = 64, + num_layers: int = 2, + max_seq_len: int = 128, + seq_step: int = 8, + layer_idx: int = 0, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.num_layers = num_layers + self.max_seq_len = max_seq_len + self.seq_step = seq_step + self.layer_idx = layer_idx + + @classmethod + def get_test_configs(cls) -> List["HFStaticCacheSliceTest"]: + return [ + cls(), # default config, layer 0 + cls(num_heads=8, head_dim=32, layer_idx=1), # different config, layer 1 + ] + + def create_model(self) -> nn.Module: + config = MockModelConfig( + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + hidden_size=self.num_heads * self.head_dim, + head_dim=self.head_dim, + max_position_embeddings=self.max_seq_len, + ) + return HFStaticCacheSliceModel(config, layer_idx=self.layer_idx) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # BHSD layout [B, H, S, D] + k_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + cache_position = torch.tensor([0], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Test with different position and different seq_step + test_seq_step = self.seq_step + 4 # Different from export seq_step + k_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + cache_position = torch.tensor([16], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + # seq_step (dim 2) is dynamic for k_val and v_val + seq_dim = Dim("seq_step", min=1, max=self.max_seq_len) + return { + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + "cache_position": None, + } + + +class DynamicArangeModel(nn.Module): + """Model that uses arange with dynamic start/stop from tensor.item().""" + + def __init__(self, length: int, vocab_size: int = 32): + super().__init__() + self.length = length + self.embed = nn.Embedding(vocab_size, 16) + + def forward(self, pos: torch.Tensor) -> torch.Tensor: + torch._check(pos.numel() == 1) + pos_int = pos.item() + torch._check(pos_int >= 0) + positions = torch.arange( + pos_int, pos_int + self.length, device=pos.device, dtype=torch.long + ) + return self.embed(positions) + + +@register_test +class DynamicArangeTest(OpTestCase): + """Test case for torch.arange() with dynamic start/stop.""" + + name = "arange_dynamic" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + position: int = 4, + length: int = 4, + vocab_size: int = 32, + ): + self.position = position + self.length = length + self.vocab_size = vocab_size + self.name = f"arange_dynamic_pos{position}_len{length}" + + @classmethod + def get_test_configs(cls) -> List["DynamicArangeTest"]: + return [ + cls(position=0, length=4), + cls(position=4, length=4), + cls(position=10, length=8), + ] + + def create_model(self) -> nn.Module: + return DynamicArangeModel(length=self.length, vocab_size=self.vocab_size) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + pos = torch.tensor([self.position], dtype=torch.long) + return (pos,) + + +class LayerNormModel(nn.Module): + """Simple model using LayerNorm.""" + + def __init__(self, normalized_shape: int = 64, eps: float = 1e-5): + super().__init__() + self.layer_norm = nn.LayerNorm(normalized_shape, eps=eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer_norm(x) + + +@register_test +class LayerNormTest(OpTestCase): + """Test case for nn.LayerNorm.""" + + name = "layer_norm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + normalized_shape: int = 64, + batch_size: int = 2, + seq_len: int = 16, + eps: float = 1e-5, + ): + self.normalized_shape = normalized_shape + self.batch_size = batch_size + self.seq_len = seq_len + self.eps = eps + self.name = "layer_norm" + + @classmethod + def get_test_configs(cls) -> List["LayerNormTest"]: + return [ + cls(), + cls(normalized_shape=128, eps=1e-6), + ] + + def create_model(self) -> nn.Module: + return LayerNormModel(self.normalized_shape, self.eps) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.seq_len, self.normalized_shape) + return (x,) + + +class Conv1dModel(nn.Module): + """Simple model using Conv1d.""" + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class Conv1dTest(OpTestCase): + """Test case for nn.Conv1d.""" + + name = "conv1d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + bias: bool = True, + batch_size: int = 2, + seq_len: int = 64, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias = bias + self.batch_size = batch_size + self.seq_len = seq_len + + parts = ["conv1d"] + if not bias: + parts.append("no_bias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Conv1dTest"]: + return [ + cls(), + cls(bias=False), + ] + + def create_model(self) -> nn.Module: + return Conv1dModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.bias, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_channels, self.seq_len) + return (x,) + + +class Conv2DModel(nn.Module): + """Model that performs 2D convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class Conv2DTest(OpTestCase): + """Test case for conv2d op.""" + + name = "conv2d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + bias: bool = True, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + + parts = [ + "conv2d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + if batch_size != 1: + parts.append(f"b{batch_size}") + if not bias: + parts.append("nobias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Conv2DTest"]: + return [ + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(32, 32), + ), + cls( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + input_size=(64, 64), + ), + cls(in_channels=64, out_channels=128, kernel_size=1, input_size=(16, 16)), + # 5x5 conv + cls( + in_channels=3, + out_channels=8, + kernel_size=5, + padding=2, + input_size=(28, 28), + ), + # Batch size > 1 + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(32, 32), + batch_size=4, + ), + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(32, 32), + bias=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.in_channels, self.input_size[0], self.input_size[1] + ) + return (x,) + + def create_model(self) -> nn.Module: + return Conv2DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.bias, + ) + + +class Conv3DModel(nn.Module): + """Model that performs 3D convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class Conv3DTest(OpTestCase): + """Test case for conv3d op.""" + + name = "conv3d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + bias: bool = True, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + + parts = [ + "conv3d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}x{input_size[2]}") + if batch_size != 1: + parts.append(f"b{batch_size}") + if not bias: + parts.append("nobias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Conv3DTest"]: + return [ + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(8, 16, 16), + ), + cls( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + input_size=(8, 16, 16), + ), + cls(in_channels=64, out_channels=128, kernel_size=1, input_size=(4, 8, 8)), + # Batch size > 1 + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(8, 16, 16), + batch_size=2, + ), + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(8, 16, 16), + bias=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + self.input_size[2], + ) + return (x,) + + def create_model(self) -> nn.Module: + return Conv3DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.bias, + ) + + +class ConvTranspose1dModel(nn.Module): + """Simple model using ConvTranspose1d.""" + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + groups: int = 1, + ): + super().__init__() + self.conv_transpose = nn.ConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + groups=groups, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_transpose(x) + + +@register_test +class ConvTranspose1dTest(OpTestCase): + """Test case for nn.ConvTranspose1d.""" + + name = "conv_transpose1d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + groups: int = 1, + batch_size: int = 2, + seq_len: int = 64, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.bias = bias + self.groups = groups + self.batch_size = batch_size + self.seq_len = seq_len + + parts = ["conv_transpose1d"] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + if output_padding != 0: + parts.append(f"op{output_padding}") + if not bias: + parts.append("no_bias") + if groups != 1: + parts.append(f"g{groups}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ConvTranspose1dTest"]: + return [ + cls(), + cls(bias=False), + cls(stride=2), + cls(stride=2, output_padding=1), + cls(padding=1), + cls(in_channels=8, out_channels=8, groups=8), # depthwise + cls(in_channels=6, out_channels=6, groups=3), # grouped + ] + + def create_model(self) -> nn.Module: + return ConvTranspose1dModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.bias, + self.groups, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_channels, self.seq_len) + return (x,) + + +class ConvTranspose2DModel(nn.Module): + """Model that performs 2D transposed convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + groups: int = 1, + ): + super().__init__() + self.conv_transpose = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + groups=groups, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_transpose(x) + + +@register_test +class ConvTranspose2DTest(OpTestCase): + """Test case for nn.ConvTranspose2d.""" + + name = "conv_transpose2d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + bias: bool = True, + groups: int = 1, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + self.groups = groups + + parts = [ + "conv_transpose2d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + if output_padding != 0: + parts.append(f"op{output_padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + if not bias: + parts.append("nobias") + if groups != 1: + parts.append(f"g{groups}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ConvTranspose2DTest"]: + return [ + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1), + cls(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + cls( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + cls(in_channels=64, out_channels=128, kernel_size=1), + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False), + cls( + in_channels=8, out_channels=8, kernel_size=3, padding=1, groups=8 + ), # depthwise + cls( + in_channels=6, out_channels=6, kernel_size=3, padding=1, groups=3 + ), # grouped + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + ) + return (x,) + + def create_model(self) -> nn.Module: + return ConvTranspose2DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.bias, + self.groups, + ) + + +class ConvTranspose3DModel(nn.Module): + """Model that performs 3D transposed convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv_transpose = nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_transpose(x) + + +@register_test +class ConvTranspose3DTest(OpTestCase): + """Test case for nn.ConvTranspose3d.""" + + name = "conv_transpose3d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + bias: bool = True, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + + parts = [ + "conv_transpose3d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + if output_padding != 0: + parts.append(f"op{output_padding}") + parts.append(f"{input_size[0]}x{input_size[1]}x{input_size[2]}") + if not bias: + parts.append("nobias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ConvTranspose3DTest"]: + return [ + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1), + cls(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + cls(in_channels=64, out_channels=128, kernel_size=1), + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + self.input_size[2], + ) + return (x,) + + def create_model(self) -> nn.Module: + return ConvTranspose3DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.bias, + ) + + +class SliceScatterModel(nn.Module): + """Model that performs slice_scatter.""" + + def __init__(self, dim: int = 0, start: int = 0, end: int = 2, step: int = 1): + super().__init__() + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x: torch.Tensor, src: torch.Tensor) -> torch.Tensor: + return x.slice_scatter( + src, dim=self.dim, start=self.start, end=self.end, step=self.step + ) + + +@register_test +class SliceScatterTest(OpTestCase): + """Test case for aten.slice_scatter.""" + + name = "slice_scatter" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + input_shape: Tuple[int, ...] = (4, 8), + dim: int = 0, + start: int = 0, + end: int = 2, + step: int = 1, + ): + self.input_shape = input_shape + self.dim = dim + self.start = start + self.end = end + self.step = step + + parts = ["slice_scatter", f"d{dim}", f"s{start}", f"e{end}"] + if step != 1: + parts.append(f"step{step}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["SliceScatterTest"]: + return [ + # Basic: replace first 2 rows + cls(input_shape=(4, 8), dim=0, start=0, end=2), + # Replace middle rows + cls(input_shape=(4, 8), dim=0, start=1, end=3), + # Along dim=1 + cls(input_shape=(4, 8), dim=1, start=2, end=6), + # With step=2 + cls(input_shape=(4, 8), dim=0, start=0, end=4, step=2), + # 3D tensor + cls(input_shape=(2, 4, 8), dim=1, start=0, end=2), + ] + + def create_model(self) -> nn.Module: + return SliceScatterModel( + dim=self.dim, + start=self.start, + end=self.end, + step=self.step, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + # Compute the src shape: same as x but with the slice size along dim + src_shape = list(self.input_shape) + slice_len = len(range(self.start, self.end, self.step)) + src_shape[self.dim] = slice_len + src = torch.randn(src_shape) + return (x, src) + + class BmmModel(nn.Module): """Model that performs batch matrix multiplication.""" - def __init__(self, batch_size: int, n: int, m: int, p: int): + def __init__(self, batch_size: int, n: int, m: int, p: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bmm(x, self.weight) + + +@register_test +class BmmTest(OpTestCase): + """Test case for bmm (batch matrix multiplication).""" + + name = "bmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 4, + n: int = 8, + m: int = 16, + p: int = 32, + ): + self.batch_size = batch_size + self.n = n + self.m = m + self.p = p + self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + @classmethod + def get_test_configs(cls) -> List["BmmTest"]: + return [ + cls(batch_size=4, n=8, m=16, p=32), + cls(batch_size=2, n=64, m=64, p=32), + ] + + def create_model(self) -> nn.Module: + return BmmModel(self.batch_size, self.n, self.m, self.p) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.n, self.m) + return (x,) + + +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) + + +@register_test +class AddmmTest(OpTestCase): + """Test case for addmm.""" + + name = "addmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 2, + in_features: int = 64, + out_features: int = 32, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + self.batch_size = batch_size + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.alpha = alpha + self.beta = beta + + # Build unique test name + if not bias: + name = f"addmm_{in_features}x{out_features}_no_bias" + elif alpha != 1.0 or beta != 1.0: + name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" + else: + name = f"addmm_{in_features}x{out_features}" + self.name = name + + @classmethod + def get_test_configs(cls) -> List["AddmmTest"]: + return [ + cls( + batch_size=2, in_features=64, out_features=32 + ), # with bias, default alpha/beta + cls( + batch_size=2, in_features=64, out_features=32, bias=False + ), # without bias + cls(batch_size=4, in_features=128, out_features=64), # larger size + cls( + batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 + ), # custom alpha/beta + ] + + def create_model(self) -> nn.Module: + return AddmmModel( + self.in_features, + self.out_features, + bias=self.bias, + alpha=self.alpha, + beta=self.beta, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features) + return (x,) + + +class ExpandModel(nn.Module): + """Model that expands a tensor to a larger shape.""" + + def __init__(self, target_shape: Tuple[int, ...]): + super().__init__() + self.target_shape = target_shape + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.expand(self.target_shape) + + +@register_test +class ExpandTest(OpTestCase): + """Test case for expand (expand_copy) op.""" + + name = "expand" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + input_shape: Tuple[int, ...] = (1, 3, 1), + target_shape: Tuple[int, ...] = (2, 3, 4), + ): + self.input_shape = input_shape + self.target_shape = target_shape + + input_str = "x".join(str(s) for s in input_shape) + target_str = "x".join(str(s) for s in target_shape) + self.name = f"expand_{input_str}_to_{target_str}" + + @classmethod + def get_test_configs(cls) -> List["ExpandTest"]: + return [ + cls(input_shape=(2, 3, 1), target_shape=(2, 3, 4)), + cls(input_shape=(1, 3, 4), target_shape=(2, 3, 4)), + cls(input_shape=(1, 1, 4), target_shape=(2, 3, 4)), + cls(input_shape=(1, 1, 1), target_shape=(2, 3, 4)), + cls(input_shape=(1, 8), target_shape=(4, 8)), + cls(input_shape=(1, 1, 1, 64), target_shape=(2, 8, 16, 64)), + # Expand with -1 (keep dimension unchanged from input) + cls(input_shape=(93,), target_shape=(1, -1)), + # Multiple -1 dimensions (keep all but first) + cls(input_shape=(1, 1, 5, 8), target_shape=(1, -1, -1, -1)), + # Multiple -1 with actual expansion on first dim + cls(input_shape=(1, 3, 5, 8), target_shape=(2, -1, -1, -1)), + # Two -1 dimensions at start + cls(input_shape=(2, 3, 4), target_shape=(-1, -1, 4)), + ] + + def create_model(self) -> nn.Module: + return ExpandModel(self.target_shape) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + return (x,) + + +class IndexModel(nn.Module): + """Model that indexes a tensor using another tensor.""" + + def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return x[indices] + + +@register_test +class IndexTest(OpTestCase): + """Test case for tensor indexing.""" + + name = "index" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + table_size: int = 100, + num_indices: int = 10, + ): + self.table_size = table_size + self.num_indices = num_indices + self.name = f"index_{table_size}_idx{num_indices}" + + @classmethod + def get_test_configs(cls) -> List["IndexTest"]: + return [ + cls(table_size=100, num_indices=10), + cls(table_size=50, num_indices=5), + ] + + def create_model(self) -> nn.Module: + return IndexModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.table_size) + indices = torch.randint(0, self.table_size, (self.num_indices,)) + return (x, indices) + + +class AdvancedIndexModel(nn.Module): + """Model that performs advanced (multi-index) tensor indexing. + + Implements x[i0, i1, ...] with multiple index tensors, which maps to + aten.index.Tensor with multiple non-None indices. + """ + + def __init__(self, num_indexed_dims: int): + super().__init__() + self.num_indexed_dims = num_indexed_dims + + def forward(self, x: torch.Tensor, *indices: torch.Tensor) -> torch.Tensor: + idx_list = list(indices) + return x[tuple(idx_list)] + + +@register_test +class AdvancedIndexTest(OpTestCase): + """Test case for multi-index tensor indexing (advanced/fancy indexing).""" + + name = "advanced_index" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + input_shape: Tuple[int, ...] = (4, 5, 6), + num_indexed_dims: int = 2, + num_indices: int = 3, + ): + self.input_shape = input_shape + self.num_indexed_dims = num_indexed_dims + self.num_indices = num_indices + self.name = ( + f"advanced_index_{'x'.join(str(s) for s in input_shape)}" + f"_dims{num_indexed_dims}_idx{num_indices}" + ) + + @classmethod + def get_test_configs(cls) -> List["AdvancedIndexTest"]: + return [ + # 2D input, index both dims + cls(input_shape=(8, 6), num_indexed_dims=2, num_indices=4), + # 3D input, index all 3 dims + cls(input_shape=(4, 5, 6), num_indexed_dims=3, num_indices=3), + # 4D input, index all 4 dims (the original failing case) + cls(input_shape=(2, 3, 4, 5), num_indexed_dims=4, num_indices=2), + # 3D input, index 2 of 3 dims + cls(input_shape=(4, 5, 6), num_indexed_dims=2, num_indices=5), + ] + + def create_model(self) -> nn.Module: + return AdvancedIndexModel(self.num_indexed_dims) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + indices = [] + for dim in range(self.num_indexed_dims): + idx = torch.randint(0, self.input_shape[dim], (self.num_indices,)) + indices.append(idx) + return (x, *indices) + + +class IndexUpdateModel(nn.Module): + """Model that performs index_copy on a mutable buffer. + + This triggers the INDEX_UPDATE pattern which matches aten.index_copy.default + on a mutable buffer and lowers it to IndexUpdateNode. + """ + + def __init__( + self, + buffer_size: int = 128, + feature_dim: int = 64, + axis: int = 0, + ): + super().__init__() + self.axis = axis + if axis == 0: + self.register_buffer("data", torch.zeros(buffer_size, feature_dim)) + else: + # axis == 1 + self.register_buffer("data", torch.zeros(feature_dim, buffer_size)) + + def forward(self, indices: torch.Tensor, update: torch.Tensor) -> torch.Tensor: + """Update buffer at indices along axis using index_copy.""" + self.data.index_copy_(self.axis, indices, update) + return self.data.clone() + + +@register_test +class IndexUpdateTest(OpTestCase): + """Test case for index_update pattern (index_copy on mutable buffer). + + This tests the INDEX_UPDATE pattern handler which recognizes + aten.index_copy.default on a mutable buffer and lowers it to IndexUpdateNode. + The buffer is managed internally by the MLX backend. + """ + + name = "index_update" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + buffer_size: int = 128, + feature_dim: int = 64, + num_indices: int = 8, + axis: int = 0, + ): + self.buffer_size = buffer_size + self.feature_dim = feature_dim + self.num_indices = num_indices + self.axis = axis + self.name = ( + f"index_update_axis{axis}_{buffer_size}x{feature_dim}_idx{num_indices}" + ) + + @classmethod + def get_test_configs(cls) -> List["IndexUpdateTest"]: + return [ + # Basic case: update along axis 0 + cls(buffer_size=128, feature_dim=64, num_indices=8, axis=0), + # Smaller buffer + cls(buffer_size=32, feature_dim=16, num_indices=4, axis=0), + # Update along axis 1 + cls(buffer_size=64, feature_dim=32, num_indices=8, axis=1), + ] + + def create_model(self) -> nn.Module: + return IndexUpdateModel( + buffer_size=self.buffer_size, + feature_dim=self.feature_dim, + axis=self.axis, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Create unique indices (no duplicates) for index_copy + # PyTorch requires int64 (long) for indices + indices = torch.randperm(self.buffer_size)[: self.num_indices].to(torch.int64) + + # Create update tensor with shape matching the indexed dimension + if self.axis == 0: + update = torch.randn(self.num_indices, self.feature_dim) + else: + update = torch.randn(self.feature_dim, self.num_indices) + + return (indices, update) + + +class SplitWithSizesModel(nn.Module): + """Model that splits a tensor into chunks with specified sizes.""" + + def __init__(self, sizes, dim=0): + super().__init__() + self.sizes = sizes + self.dim = dim + + def forward(self, x): + chunks = torch.ops.aten.split_with_sizes_copy.default(x, self.sizes, self.dim) + return chunks[0] + + +class SplitWithSizesMultiOutputModel(nn.Module): + """Model that splits with specified sizes and uses multiple outputs.""" + + def __init__(self, sizes, dim=0): + super().__init__() + self.sizes = sizes + self.dim = dim + + def forward(self, x): + chunks = torch.ops.aten.split_with_sizes_copy.default(x, self.sizes, self.dim) + return chunks[0] + chunks[-1] + + +class SplitUniformModel(nn.Module): + """Model that splits a tensor into chunks of uniform size using torch.split.""" + + def __init__(self, split_size, dim=0): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + chunks = torch.split(x, self.split_size, dim=self.dim) + return chunks[0] + + +class SplitUniformMultiOutputModel(nn.Module): + """Model that splits uniformly and uses multiple outputs.""" + + def __init__(self, split_size, dim=0): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + chunks = torch.split(x, self.split_size, dim=self.dim) + return torch.cat([chunks[0], chunks[-1]], dim=self.dim) + + +@register_test +class SplitTest(OpTestCase): + name = "split" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape, model_cls, model_kwargs, tag=""): + self.shape = shape + self.model_cls = model_cls + self.model_kwargs = model_kwargs + self.name = f"split_{tag}" if tag else "split" + + @classmethod + def get_test_configs(cls) -> List["SplitTest"]: + return [ + # split_with_sizes_copy tests + cls( + shape=(9, 4), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [2, 3, 4], "dim": 0}, + tag="sizes_dim0", + ), + cls( + shape=(3, 10), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [2, 3, 5], "dim": 1}, + tag="sizes_dim1", + ), + cls( + shape=(2, 12, 4), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [3, 4, 5], "dim": 1}, + tag="sizes_3d", + ), + cls( + shape=(8, 4), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [3, 5], "dim": 0}, + tag="sizes_two", + ), + cls( + shape=(10, 3), + model_cls=SplitWithSizesMultiOutputModel, + model_kwargs={"sizes": [5, 5], "dim": 0}, + tag="sizes_multi", + ), + # torch.split (uniform) tests + cls( + shape=(10, 4), + model_cls=SplitUniformModel, + model_kwargs={"split_size": 3, "dim": 0}, + tag="uniform_dim0", + ), + cls( + shape=(3, 7), + model_cls=SplitUniformModel, + model_kwargs={"split_size": 4, "dim": 1}, + tag="uniform_dim1", + ), + cls( + shape=(11, 5), + model_cls=SplitUniformMultiOutputModel, + model_kwargs={"split_size": 3, "dim": 0}, + tag="uniform_multi", + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + def create_model(self) -> nn.Module: + return self.model_cls(**self.model_kwargs) + + +class ArangeModel(nn.Module): + """Model that creates a tensor using arange and multiplies with input.""" + + def __init__(self, stop: int, use_dtype: bool = True): + super().__init__() + self.stop = stop + self.use_dtype = use_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_dtype: + indices = torch.arange(self.stop, dtype=x.dtype, device=x.device) + else: + # No dtype - let MLX infer (defaults to int64 for integer inputs) + indices = torch.arange(self.stop, device=x.device) + indices = indices.to(x.dtype) # Cast for multiplication + return x * indices + + +@register_test +class ArangeTest(OpTestCase): + """Test case for torch.arange().""" + + name = "arange" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + stop: int = 10, + dtype: torch.dtype = torch.float32, + use_dtype: bool = True, + ): + self.stop = stop + self.dtype = dtype + self.use_dtype = use_dtype + dtype_name = str(dtype).split(".")[-1] + if use_dtype: + self.name = f"arange_{stop}_{dtype_name}" + else: + self.name = f"arange_{stop}_no_dtype" + + @classmethod + def get_test_configs(cls) -> List["ArangeTest"]: + return [ + # With explicit dtype + cls(stop=10, dtype=torch.float32, use_dtype=True), + cls(stop=32, dtype=torch.float32, use_dtype=True), + cls(stop=100, dtype=torch.float32, use_dtype=True), + cls(stop=16, dtype=torch.int32, use_dtype=True), + cls(stop=16, dtype=torch.int64, use_dtype=True), + # Without dtype (let MLX infer) + cls(stop=10, dtype=torch.float32, use_dtype=False), + cls(stop=32, dtype=torch.float32, use_dtype=False), + ] + + def create_model(self) -> nn.Module: + return ArangeModel(self.stop, use_dtype=self.use_dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.dtype in (torch.int32, torch.int64): + x = torch.randint(1, 10, (self.stop,), dtype=self.dtype) + else: + x = torch.randn(self.stop, dtype=self.dtype) + return (x,) + + +class UnaryOpModel(nn.Module): + """Generic model that applies a single unary torch op.""" + + def __init__(self, op_fn: Callable): + super().__init__() + self.op_fn = op_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.op_fn(x) + + +def _input_fn( + uniform: bool = False, scale: float = 1.0, offset: float = 0.0, abs: bool = False +): + """Return a callable(shape, dtype) that generates a single-element input tuple. + + Args: + uniform: Use torch.rand (uniform [0,1]) instead of torch.randn (normal). + scale: Multiply the base tensor by this value. + offset: Add this value after scaling. + abs: Apply .abs() to the base tensor before scale/offset. + """ + + def fn(shape, dtype): + base = ( + torch.rand(shape, dtype=dtype) + if uniform + else torch.randn(shape, dtype=dtype) + ) + if abs: + base = base.abs() + return (base * scale + offset,) + + return fn + + +def _bool_input_fn(): + """Return a callable(shape, dtype) that generates a single-element bool tensor tuple.""" + + def fn(shape, _dtype): + return (torch.randint(0, 2, shape, dtype=torch.bool),) + + return fn + + +def _int_input_fn(low: int = -100, high: int = 100): + """Return a callable(shape, dtype) that generates a single-element integer tensor tuple.""" + + def fn(shape, dtype): + return (torch.randint(low, high, shape, dtype=dtype),) + + return fn + + +# Standard shape and dtype configs used by unary tests. +_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)] +_SHAPES_2 = [(16,), (4, 4)] +_UNARY_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + +def _make_unary_op_test( + op_name: str, + op_fn: Callable, + shapes: List[Tuple[int, ...]] = None, + dtypes: List[torch.dtype] = None, + input_fn: Callable = None, +) -> type: + """Generate a registered OpTestCase subclass for a unary math op. + + Args: + op_name: Name used for test registration and output directories. + op_fn: The torch function to test (e.g. torch.floor). + shapes: List of input shapes. Defaults to _SHAPES_2. + dtypes: List of dtypes to test. Defaults to _UNARY_DTYPES. + input_fn: Callable(shape, dtype) -> Tuple[Tensor, ...] that creates inputs. + Defaults to _input_fn() (standard randn). + """ + if shapes is None: + shapes = _SHAPES_2 + if dtypes is None: + dtypes = _UNARY_DTYPES + if input_fn is None: + input_fn = _input_fn() + + class _Test(OpTestCase): + name = op_name + + def __init__( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"{op_name}_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(shape=s, dtype=d) for s in shapes for d in dtypes] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return input_fn(self.shape, self.dtype) + + def create_model(self) -> nn.Module: + return UnaryOpModel(op_fn) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +# fmt: off +# Each entry is a dict with required keys "op_name" and "op_fn". +# Optional keys: "shapes" (default _SHAPES_2), "dtypes" (default _UNARY_DTYPES), +# "input_fn" (default _input_fn()). +# _input_fn(uniform, scale, offset) — uniform=True uses rand, False uses randn. +_UNARY_OP_TESTS = [ + {"op_name": "floor", "op_fn": torch.floor, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)}, + {"op_name": "ceil", "op_fn": torch.ceil, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)}, + {"op_name": "square", "op_fn": torch.square, "shapes": _SHAPES_3}, + {"op_name": "exp", "op_fn": torch.exp, "shapes": _SHAPES_3}, + {"op_name": "sin", "op_fn": torch.sin, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=3.14159)}, + {"op_name": "cos", "op_fn": torch.cos, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=3.14159)}, + {"op_name": "tan", "op_fn": torch.tan, "input_fn": _input_fn(scale=0.5)}, + {"op_name": "asin", "op_fn": torch.asin, "input_fn": _input_fn(uniform=True, scale=2, offset=-1)}, + {"op_name": "acos", "op_fn": torch.acos, "input_fn": _input_fn(uniform=True, scale=2, offset=-1)}, + {"op_name": "atan", "op_fn": torch.atan}, + {"op_name": "sinh", "op_fn": torch.sinh}, + {"op_name": "cosh", "op_fn": torch.cosh}, + {"op_name": "asinh", "op_fn": torch.asinh}, + {"op_name": "acosh", "op_fn": torch.acosh, "input_fn": _input_fn(uniform=True, offset=1.0)}, + {"op_name": "atanh", "op_fn": torch.atanh, "input_fn": _input_fn(uniform=True, scale=1.8, offset=-0.9)}, + {"op_name": "log2", "op_fn": torch.log2, "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "log10", "op_fn": torch.log10, "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "log1p", "op_fn": torch.log1p, "input_fn": _input_fn(uniform=True)}, + {"op_name": "erf", "op_fn": torch.erf}, + {"op_name": "expm1", "op_fn": torch.expm1}, + {"op_name": "round", "op_fn": torch.round, "input_fn": _input_fn(scale=10)}, + {"op_name": "reciprocal", "op_fn": torch.reciprocal, "input_fn": _input_fn(offset=1.0)}, + {"op_name": "sqrt", "op_fn": torch.sqrt, "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "abs", "op_fn": torch.abs}, + {"op_name": "neg", "op_fn": torch.neg}, + {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, + # activations + {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, + {"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)}, + {"op_name": "tanh", "op_fn": torch.tanh, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=3)}, + {"op_name": "silu", "op_fn": nn.SiLU(), "shapes": [(2, 16, 64), (4, 32, 128)], "dtypes": [torch.float32]}, + # math + {"op_name": "rsqrt", "op_fn": torch.rsqrt, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "clone", "op_fn": torch.clone, "shapes": [(2, 3, 4), (8, 8), (16,)], "dtypes": [torch.float32]}, +] +# fmt: on + +# Generate and register all unary math op test classes. +for _entry in _UNARY_OP_TESTS: + _cls = _make_unary_op_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +class BinaryOpModel(nn.Module): + def __init__(self, op_fn: Callable): + super().__init__() + self.op_fn = op_fn + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return self.op_fn(a, b) + + +class PowerScalarModel(nn.Module): + def __init__(self, exponent: float): + super().__init__() + self.exponent = exponent + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return torch.pow(a, self.exponent) + + +_BINARY_DTYPES = [torch.float32] + + +def _make_binary_op_test( + op_name: str, + op_fn: Callable, + shapes: List[Tuple[int, ...]] = None, + dtypes: List[torch.dtype] = None, + input_fn_a: Callable = None, + input_fn_b: Callable = None, +) -> type: + """Generate a registered OpTestCase subclass for a binary math op.""" + if shapes is None: + shapes = _SHAPES_3 + if dtypes is None: + dtypes = _BINARY_DTYPES + if input_fn_a is None: + input_fn_a = _input_fn() + if input_fn_b is None: + input_fn_b = _input_fn() + + class _Test(OpTestCase): + name = op_name + + def __init__( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"{op_name}_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(shape=s, dtype=d) for s in shapes for d in dtypes] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return input_fn_a(self.shape, self.dtype) + input_fn_b( + self.shape, self.dtype + ) + + def create_model(self) -> nn.Module: + return BinaryOpModel(op_fn) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +# fmt: off +_BINARY_OP_TESTS = [ + # math + {"op_name": "maximum", "op_fn": torch.maximum}, + {"op_name": "minimum", "op_fn": torch.minimum}, + {"op_name": "atan2", "op_fn": torch.atan2}, + {"op_name": "logaddexp", "op_fn": torch.logaddexp}, + {"op_name": "floor_divide", "op_fn": torch.floor_divide, "input_fn_a": _input_fn(scale=10), "input_fn_b": _input_fn(abs=True, offset=1)}, + {"op_name": "floor_divide_int", "op_fn": torch.floor_divide, "dtypes": [torch.int32], "input_fn_a": _int_input_fn(-100, 100), "input_fn_b": _int_input_fn(1, 10)}, + {"op_name": "remainder", "op_fn": torch.remainder, "input_fn_a": _input_fn(scale=10), "input_fn_b": _input_fn(abs=True, offset=1)}, + {"op_name": "remainder_int", "op_fn": torch.remainder, "dtypes": [torch.int32], "input_fn_a": _int_input_fn(-100, 100), "input_fn_b": _int_input_fn(1, 10)}, + {"op_name": "power", "op_fn": torch.pow, "input_fn_a": _input_fn(uniform=True, offset=0.5), "input_fn_b": _input_fn(uniform=True, scale=2)}, + # comparison + {"op_name": "less", "op_fn": torch.lt, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.float32, torch.bfloat16]}, + {"op_name": "less_equal", "op_fn": torch.le, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "greater", "op_fn": torch.gt, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "greater_equal", "op_fn": torch.ge, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "equal", "op_fn": torch.eq, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "not_equal", "op_fn": torch.ne, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + # logical + {"op_name": "logical_and", "op_fn": torch.logical_and, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, + {"op_name": "logical_or", "op_fn": torch.logical_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, +] +# fmt: on + + +for _entry in _BINARY_OP_TESTS: + _cls = _make_binary_op_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +@register_test +class PowerScalarTest(OpTestCase): + """Test case for aten.pow op (Tensor_Scalar variant).""" + + name = "power_scalar" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + exponent: float = 2.0, + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.exponent = exponent + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"power_scalar_{shape_str}_exp{exponent}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["PowerScalarTest"]: + return [ + cls(shape=(16,), exponent=2.0, dtype=torch.float32), + cls(shape=(4, 4), exponent=0.5, dtype=torch.float32), + cls(shape=(4, 4), exponent=3.0, dtype=torch.float32), + cls(shape=(2, 3, 4), exponent=-1.0, dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.rand(self.shape, dtype=self.dtype) + 0.5,) + + def create_model(self) -> nn.Module: + return PowerScalarModel(self.exponent) + + +class CompareScalarModel(nn.Module): + def __init__(self, op_fn: Callable, scalar: float): + super().__init__() + self.op_fn = op_fn + self.scalar = scalar + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.op_fn(a, self.scalar) + + +def _make_compare_scalar_test( + op_name: str, + op_fn: Callable, +) -> type: + """Generate a registered OpTestCase subclass for a comparison Scalar op.""" + + class _Test(OpTestCase): + name = op_name + + def __init__( + self, + shape: Tuple[int, ...], + scalar: float, + dtype: torch.dtype, + ): + self.shape = shape + self.scalar = scalar + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"{op_name}_{shape_str}_s{scalar}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [ + cls(shape=(16,), scalar=0.0, dtype=torch.float32), + cls(shape=(4, 4), scalar=0.5, dtype=torch.float32), + cls(shape=(2, 3, 4), scalar=-1.0, dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return CompareScalarModel(op_fn, self.scalar) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +_COMPARE_SCALAR_TESTS = [ + {"op_name": "less_scalar", "op_fn": torch.lt}, + {"op_name": "less_equal_scalar", "op_fn": torch.le}, + {"op_name": "greater_scalar", "op_fn": torch.gt}, + {"op_name": "greater_equal_scalar", "op_fn": torch.ge}, + {"op_name": "equal_scalar", "op_fn": torch.eq}, + {"op_name": "not_equal_scalar", "op_fn": torch.ne}, +] + +for _entry in _COMPARE_SCALAR_TESTS: + _cls = _make_compare_scalar_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +class ReductionOpModel(nn.Module): + def __init__(self, op_fn: Callable, dim=None, keepdim: bool = False): + super().__init__() + self.op_fn = op_fn + self.dim = dim + self.keepdim = keepdim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dim is None: + return self.op_fn(x) + return self.op_fn(x, dim=self.dim, keepdim=self.keepdim) + + +class CorrectionReductionOpModel(nn.Module): + def __init__( + self, op_fn: Callable, dim=None, keepdim: bool = False, correction: int = 1 + ): + super().__init__() + self.op_fn = op_fn + self.dim = dim + self.keepdim = keepdim + self.correction = correction + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dim is None: + return self.op_fn(x, correction=self.correction) + return self.op_fn( + x, dim=self.dim, keepdim=self.keepdim, correction=self.correction + ) + + +def _make_reduction_op_test( + op_name: str, + op_fn: Callable, + configs: List[dict], + input_fn: Callable = None, + has_correction: bool = False, +) -> type: + """Generate a registered OpTestCase subclass for a reduction op. + + Args: + op_name: Name used for test registration. + op_fn: The torch function (e.g. torch.sum). + configs: List of dicts with keys: shape, dim, keepdim, dtype, and + optionally correction (for var/std). + input_fn: Callable(shape, dtype) -> Tuple[Tensor, ...]. + has_correction: If True, use CorrectionReductionOpModel. + """ + if input_fn is None: + input_fn = _input_fn() + + class _Test(OpTestCase): + name = op_name + + def __init__(self, shape, dim, keepdim, dtype, correction=1): + self.shape = shape + self.dim = dim + self.keepdim = keepdim + self.dtype = dtype + self.correction = correction + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + dim_str = f"_dim{dim}" if dim is not None else "_all" + kd_str = "_kd" if keepdim else "" + corr_str = f"_corr{correction}" if has_correction else "" + self.name = f"{op_name}_{shape_str}{dim_str}{kd_str}{corr_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(**c) for c in configs] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return input_fn(self.shape, self.dtype) + + def create_model(self) -> nn.Module: + if has_correction: + return CorrectionReductionOpModel( + op_fn, + dim=self.dim, + keepdim=self.keepdim, + correction=self.correction, + ) + return ReductionOpModel(op_fn, dim=self.dim, keepdim=self.keepdim) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +_REDUCTION_CONFIGS_6 = [ + {"shape": (16,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": None, "keepdim": False, "dtype": torch.float32}, +] + +_REDUCTION_CONFIGS_5 = [ + {"shape": (16,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + { + "shape": (4, 4), + "dim": -1, + "keepdim": False, + "dtype": torch.float32, + "correction": 0, + }, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, +] + +_PROD_CONFIGS = [ + {"shape": (8,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, +] + +_LOGSUMEXP_CONFIGS = [ + {"shape": (16,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, +] + +# fmt: off +_REDUCTION_OP_TESTS = [ + {"op_name": "sum", "op_fn": torch.sum, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "mean", "op_fn": torch.mean, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "amax", "op_fn": torch.amax, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "amin", "op_fn": torch.amin, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "argmax", "op_fn": torch.argmax, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "argmin", "op_fn": torch.argmin, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "prod", "op_fn": torch.prod, "configs": _PROD_CONFIGS, "input_fn": _input_fn(scale=0.5, offset=1.0)}, + {"op_name": "var", "op_fn": torch.var, "configs": _REDUCTION_CONFIGS_5, "has_correction": True}, + {"op_name": "std", "op_fn": torch.std, "configs": _REDUCTION_CONFIGS_5, "has_correction": True}, + {"op_name": "logsumexp", "op_fn": torch.logsumexp, "configs": _LOGSUMEXP_CONFIGS}, +] +# fmt: on + +for _entry in _REDUCTION_OP_TESTS: + _cls = _make_reduction_op_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +# --- Global max (aten.max.default) - no dim argument --- + + +class MaxGlobalModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.max(x) + + +@register_test +class MaxGlobalTest(OpTestCase): + name = "max_global" + + def __init__(self, shape=(3, 4), dtype=torch.float32): + self.shape = shape + self.dtype = dtype + + @classmethod + def get_test_configs(cls): + return [ + cls(shape=(16,)), + cls(shape=(3, 4)), + cls(shape=(2, 3, 4)), + cls(shape=(3, 4), dtype=torch.bfloat16), + ] + + def create_model(self): + return MaxGlobalModel() + + def create_inputs(self): + return (torch.randn(self.shape, dtype=self.dtype),) + + +class MinGlobalModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.min(x) + + +@register_test +class MinGlobalTest(OpTestCase): + name = "min_global" + + def __init__(self, shape=(3, 4), dtype=torch.float32): + self.shape = shape + self.dtype = dtype + + @classmethod + def get_test_configs(cls): + return [ + cls(shape=(16,)), + cls(shape=(3, 4)), + cls(shape=(2, 3, 4)), + cls(shape=(3, 4), dtype=torch.bfloat16), + ] + + def create_model(self): + return MinGlobalModel() + + def create_inputs(self): + return (torch.randn(self.shape, dtype=self.dtype),) + + +class TriangularModel(nn.Module): + def __init__(self, mode: str = "tril", diagonal: int = 0): + super().__init__() + self.op_fn = torch.tril if mode == "tril" else torch.triu + self.diagonal = diagonal + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.op_fn(x, diagonal=self.diagonal) + + +_TRIANGULAR_CONFIGS = [ + {"shape": (4, 4), "diagonal": 0, "dtype": torch.float32}, + {"shape": (8, 8), "diagonal": 0, "dtype": torch.float32}, + {"shape": (4, 6), "diagonal": 0, "dtype": torch.float32}, + {"shape": (6, 4), "diagonal": 0, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": 1, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": -1, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": 2, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": 0, "dtype": torch.bfloat16}, + {"shape": (2, 4, 4), "diagonal": 0, "dtype": torch.float32}, + {"shape": (2, 3, 4, 4), "diagonal": 0, "dtype": torch.float32}, +] + + +def _make_triangular_test(mode: str) -> type: + """Generate a registered OpTestCase subclass for tril or triu.""" + + class _Test(OpTestCase): + name = mode + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + diagonal: int = 0, + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.diagonal = diagonal + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + diag_str = f"d{diagonal}" if diagonal != 0 else "" + self.name = f"{mode}_{shape_str}_{dtype_str}{diag_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(**c) for c in _TRIANGULAR_CONFIGS] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return TriangularModel(mode=mode, diagonal=self.diagonal) + + _Test.__name__ = f"{mode.title()}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +TrilTest = _make_triangular_test("tril") +TriuTest = _make_triangular_test("triu") +register_test(TrilTest) +register_test(TriuTest) + + +class ZerosLikeModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.zeros_like(x) + + +class OnesLikeModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ones_like(x) + + +class FullLikeModel(nn.Module): + def __init__(self, fill_value: float, dtype: Optional[torch.dtype] = None): + super().__init__() + self.fill_value = fill_value + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = torch.full_like(x, self.fill_value, dtype=self.dtype) + if self.dtype is not None and self.dtype != x.dtype: + return x * t.to(x.dtype) + return t + + +@register_test +class ZerosLikeTest(OpTestCase): + """Test case for aten.zeros_like op.""" + + name = "zeros_like" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"zeros_like_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["ZerosLikeTest"]: + return [ + cls(shape=(16,), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.bfloat16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return ZerosLikeModel() + + +@register_test +class OnesLikeTest(OpTestCase): + """Test case for aten.ones_like op.""" + + name = "ones_like" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"ones_like_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["OnesLikeTest"]: + return [ + cls(shape=(16,), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.bfloat16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return OnesLikeModel() + + +@register_test +class FullLikeTest(OpTestCase): + """Test case for aten.full_like op.""" + + name = "full_like" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + fill_value: float = 3.14, + dtype: torch.dtype = torch.float32, + fill_dtype: Optional[torch.dtype] = None, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ): + self.shape = shape + self.fill_value = fill_value + self.dtype = dtype + self.fill_dtype = fill_dtype + if rtol is not None: + self.rtol = rtol + if atol is not None: + self.atol = atol + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + fill_dtype_str = ( + f"_as_{str(fill_dtype).replace('torch.', '')}" if fill_dtype else "" + ) + self.name = f"full_like_{shape_str}_v{fill_value}_{dtype_str}{fill_dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["FullLikeTest"]: + return [ + cls(shape=(16,), fill_value=3.14, dtype=torch.float32), + cls(shape=(4, 4), fill_value=2.71, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=-1.0, dtype=torch.float32), + cls(shape=(4, 4), fill_value=0.5, dtype=torch.bfloat16), + # Explicit fill_dtype exercises scalar_type serialization (optional_int). + # 1.005859375 rounds differently in bf16 vs f32, so the model multiplies + # the bf16 mask back into the f32 input to make the precision loss observable. + cls( + shape=(4, 4), + fill_value=1.005859375, + fill_dtype=torch.bfloat16, + rtol=0.0, + atol=0.0, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.fill_dtype is not None: + torch.manual_seed(42) + return (torch.randn(self.shape, dtype=self.dtype) * 100,) + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return FullLikeModel(fill_value=self.fill_value, dtype=self.fill_dtype) + + +class FullModel(nn.Module): + def __init__(self, shape: Tuple[int, ...], fill_value: float, dtype: torch.dtype): + super().__init__() + self.shape = shape + self.fill_value = fill_value + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.full(self.shape, self.fill_value, dtype=self.dtype) + + +class ZerosModel(nn.Module): + def __init__(self, shape: Tuple[int, ...], dtype: torch.dtype): + super().__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.zeros(self.shape, dtype=self.dtype) + + +class OnesModel(nn.Module): + def __init__(self, shape: Tuple[int, ...], dtype: torch.dtype): + super().__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ones(self.shape, dtype=self.dtype) + + +@register_test +class FullTest(OpTestCase): + """Test case for aten.full op.""" + + name = "full" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + fill_value: float = 1.5, + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.fill_value = fill_value + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"full_{shape_str}_{fill_value}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["FullTest"]: + return [ + cls(shape=(2, 3, 4), fill_value=1.5, dtype=torch.float32), + cls(shape=(10,), fill_value=0.0, dtype=torch.float32), + cls(shape=(1, 128), fill_value=-2.5, dtype=torch.float32), + cls(shape=(4, 8, 16), fill_value=3.14159, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=1.0, dtype=torch.bfloat16), + cls(shape=(8, 16), fill_value=-1.0, dtype=torch.bfloat16), + cls(shape=(2, 3, 4), fill_value=2.0, dtype=torch.float16), + # Integer fill values (matching individual test file) + cls(shape=(2, 3, 4), fill_value=0.0, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=1.0, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=-1.0, dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(1, dtype=torch.float32) + return (x,) + + def create_model(self) -> nn.Module: + return FullModel(self.shape, self.fill_value, self.dtype) + + +@register_test +class ZerosTest(OpTestCase): + """Test case for aten.zeros op.""" + + name = "zeros" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"zeros_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["ZerosTest"]: + return [ + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(10,), dtype=torch.float32), + cls(shape=(1, 128), dtype=torch.float32), + cls(shape=(4, 8, 16), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.bfloat16), + cls(shape=(8, 16), dtype=torch.bfloat16), + cls(shape=(2, 3, 4), dtype=torch.float16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(1, dtype=torch.float32) + return (x,) + + def create_model(self) -> nn.Module: + return ZerosModel(self.shape, self.dtype) + + +@register_test +class OnesTest(OpTestCase): + """Test case for aten.ones op.""" + + name = "ones" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"ones_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["OnesTest"]: + return [ + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(10,), dtype=torch.float32), + cls(shape=(1, 128), dtype=torch.float32), + cls(shape=(4, 8, 16), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.bfloat16), + cls(shape=(8, 16), dtype=torch.bfloat16), + cls(shape=(2, 3, 4), dtype=torch.float16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(1, dtype=torch.float32) + return (x,) + + def create_model(self) -> nn.Module: + return OnesModel(self.shape, self.dtype) + + +class ToDtypeModel(nn.Module): + def __init__(self, target_dtype: torch.dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.to(self.target_dtype) + + +@register_test +class ToDtypeTest(OpTestCase): + """Test case for to.dtype op.""" + + name = "to_dtype" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + source_dtype: torch.dtype = torch.float32, + target_dtype: torch.dtype = torch.bfloat16, + ): + self.shape = shape + self.source_dtype = source_dtype + self.target_dtype = target_dtype + shape_str = "x".join(str(s) for s in shape) + src_str = str(source_dtype).replace("torch.", "") + tgt_str = str(target_dtype).replace("torch.", "") + self.name = f"to_dtype_{shape_str}_{src_str}_to_{tgt_str}" + + @classmethod + def get_test_configs(cls) -> List["ToDtypeTest"]: + return [ + cls( + shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.bfloat16 + ), + cls(shape=(10,), source_dtype=torch.float32, target_dtype=torch.bfloat16), + cls( + shape=(1, 128), source_dtype=torch.float32, target_dtype=torch.bfloat16 + ), + cls( + shape=(2, 3, 4), source_dtype=torch.bfloat16, target_dtype=torch.float32 + ), + cls( + shape=(4, 8, 16), + source_dtype=torch.bfloat16, + target_dtype=torch.float32, + ), + cls( + shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.float16 + ), + cls( + shape=(2, 3, 4), source_dtype=torch.float16, target_dtype=torch.float32 + ), + cls(shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.int32), + cls(shape=(2, 3, 4), source_dtype=torch.int32, target_dtype=torch.float32), + cls(shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.int64), + cls(shape=(2, 3, 4), source_dtype=torch.int64, target_dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.source_dtype in (torch.int32, torch.int64): + x = torch.randint(-100, 100, self.shape, dtype=self.source_dtype) + else: + x = torch.randn(self.shape, dtype=self.source_dtype) + return (x,) + + def create_model(self) -> nn.Module: + return ToDtypeModel(self.target_dtype) + + +class BatchNormModel(nn.Module): + def __init__(self, num_features: int, dtype: torch.dtype, affine: bool = True): + super().__init__() + self.bn = nn.BatchNorm2d(num_features, affine=affine, dtype=dtype) + self.bn.eval() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(x) + + +class BatchNorm1dModel(nn.Module): + def __init__(self, num_features: int, dtype: torch.dtype, affine: bool = True): + super().__init__() + self.bn = nn.BatchNorm1d(num_features, affine=affine, dtype=dtype) + self.bn.eval() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(x) + + +@register_test +class BatchNorm2dTest(OpTestCase): + """Test case for aten._native_batch_norm_legit_no_training op with 2D input.""" + + name = "batch_norm_2d" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + batch_size: int = 2, + num_features: int = 16, + height: int = 8, + width: int = 8, + dtype: torch.dtype = torch.float32, + affine: bool = True, + ): + self.batch_size = batch_size + self.num_features = num_features + self.height = height + self.width = width + self.dtype = dtype + self.affine = affine + dtype_str = str(dtype).replace("torch.", "") + prefix = "batch_norm_2d_no_affine" if not affine else "batch_norm_2d" + self.name = f"{prefix}_{batch_size}x{num_features}x{height}x{width}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["BatchNorm2dTest"]: + return [ + cls(batch_size=1, num_features=16, height=8, width=8, dtype=torch.float32), + cls( + batch_size=2, num_features=32, height=16, width=16, dtype=torch.float32 + ), + cls(batch_size=4, num_features=64, height=4, width=4, dtype=torch.float32), + cls(batch_size=2, num_features=16, height=8, width=8, dtype=torch.bfloat16), + cls(batch_size=1, num_features=32, height=4, width=4, dtype=torch.bfloat16), + cls(batch_size=2, num_features=16, height=8, width=8, dtype=torch.float16), + # No-affine variants (no weight/bias) + cls( + batch_size=1, + num_features=16, + height=8, + width=8, + dtype=torch.float32, + affine=False, + ), + cls( + batch_size=2, + num_features=32, + height=4, + width=4, + dtype=torch.bfloat16, + affine=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.num_features, + self.height, + self.width, + dtype=self.dtype, + ) + return (x,) + + def create_model(self) -> nn.Module: + return BatchNormModel(self.num_features, self.dtype, affine=self.affine) + + +@register_test +class BatchNorm1dTest(OpTestCase): + """Test case for aten._native_batch_norm_legit_no_training op with 1D input.""" + + name = "batch_norm_1d" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + batch_size: int = 2, + num_features: int = 16, + seq_len: int = 32, + dtype: torch.dtype = torch.float32, + affine: bool = True, + ): + self.batch_size = batch_size + self.num_features = num_features + self.seq_len = seq_len + self.dtype = dtype + self.affine = affine + dtype_str = str(dtype).replace("torch.", "") + prefix = "batch_norm_1d_no_affine" if not affine else "batch_norm_1d" + self.name = f"{prefix}_{batch_size}x{num_features}x{seq_len}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["BatchNorm1dTest"]: + return [ + cls(batch_size=1, num_features=16, seq_len=32, dtype=torch.float32), + cls(batch_size=2, num_features=32, seq_len=64, dtype=torch.float32), + cls(batch_size=2, num_features=16, seq_len=32, dtype=torch.bfloat16), + cls(batch_size=2, num_features=16, seq_len=32, dtype=torch.float16), + # No-affine variants (no weight/bias) + cls( + batch_size=1, + num_features=16, + seq_len=32, + dtype=torch.float32, + affine=False, + ), + cls( + batch_size=2, + num_features=32, + seq_len=64, + dtype=torch.bfloat16, + affine=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.num_features, self.seq_len, dtype=self.dtype + ) + return (x,) + + def create_model(self) -> nn.Module: + return BatchNorm1dModel(self.num_features, self.dtype, affine=self.affine) + + +class SDPAModel(nn.Module): + """Basic scaled dot product attention.""" + + def __init__(self, is_causal: bool = False): + super().__init__() + self.is_causal = is_causal + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=self.is_causal + ) + + +class SDPAWithMaskModel(nn.Module): + """SDPA with explicit attention mask (additive float format).""" + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + + +class SDPAWithBoolMaskModel(nn.Module): + """SDPA with boolean attention mask. + + This tests the case where a boolean mask is passed to SDPA. + PyTorch expects: True = attend, False = masked out. + """ + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + + +class GQAModel(nn.Module): + """Grouped Query Attention - fewer KV heads than Q heads.""" + + def __init__(self, num_heads: int, num_kv_heads: int, is_causal: bool = False): + super().__init__() + self.num_groups = num_heads // num_kv_heads + self.is_causal = is_causal + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + k = k.repeat_interleave(self.num_groups, dim=1) + v = v.repeat_interleave(self.num_groups, dim=1) + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=self.is_causal + ) + + +@register_test +class SDPATest(OpTestCase): + """Test case for SDPA.""" + + name = "sdpa" + rtol = 1e-3 + atol = 1e-3 + expected_node_counts = {"SdpaNode": 1, "ExpandDimsNode": 0} + + def __init__( + self, + batch_size: int = 2, + num_heads: int = 8, + seq_len: int = 32, + head_dim: int = 64, + num_kv_heads: Optional[int] = None, + is_causal: bool = False, + use_mask: bool = False, + use_bool_mask: bool = False, + ): + self.batch_size = batch_size + self.num_heads = num_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.num_kv_heads = num_kv_heads + self.is_causal = is_causal + self.use_mask = use_mask + self.use_bool_mask = use_bool_mask + + parts = ["sdpa"] + if num_kv_heads is not None: + parts.append(f"gqa{num_kv_heads}") + if is_causal: + parts.append("causal") + if use_mask: + parts.append("mask") + if use_bool_mask: + parts.append("bool_mask") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["SDPATest"]: + return [ + cls(), + cls(is_causal=True), + cls(num_kv_heads=4), + cls(use_mask=True), + cls(use_bool_mask=True), # Test boolean mask conversion + ] + + def create_model(self) -> nn.Module: + if self.use_mask: + return SDPAWithMaskModel() + elif self.use_bool_mask: + return SDPAWithBoolMaskModel() + elif self.num_kv_heads is not None: + return GQAModel(self.num_heads, self.num_kv_heads, self.is_causal) + else: + return SDPAModel(self.is_causal) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + kv_heads = self.num_kv_heads if self.num_kv_heads else self.num_heads + k = torch.randn(self.batch_size, kv_heads, self.seq_len, self.head_dim) + v = torch.randn(self.batch_size, kv_heads, self.seq_len, self.head_dim) + + if self.use_mask: + # Additive float mask: 0 = attend, -inf = masked + mask = torch.zeros(self.batch_size, 1, self.seq_len, self.seq_len) + mask[:, :, :, : self.seq_len // 4] = float("-inf") + return (q, k, v, mask) + elif self.use_bool_mask: + # Boolean mask: True = attend, False = masked + # This tests that the backend correctly converts bool -> additive format + mask = torch.ones( + self.batch_size, 1, self.seq_len, self.seq_len, dtype=torch.bool + ) + mask[:, :, :, : self.seq_len // 4] = False # Mask out first quarter + return (q, k, v, mask) + return (q, k, v) + + +class CustomSDPAModel(nn.Module): + """ + Test model for mlx::custom_sdpa with KVCache. + + Simulates a single attention layer: updates the KV cache, then calls + mlx::custom_sdpa which slices K/V to [0:start_pos+seq_len] and runs SDPA. + """ + + def __init__( + self, + max_context_length: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.cache = KVCache( + max_batch_size=1, + max_context_length=max_context_length, + n_heads=n_kv_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position indices + q: torch.Tensor, # [B, n_heads, S, D] + k_val: torch.Tensor, # [B, n_kv_heads, S, D] + v_val: torch.Tensor, # [B, n_kv_heads, S, D] + ) -> torch.Tensor: + # Update KV cache and get full cache tensors + k_cache, v_cache = self.cache.update(input_pos, k_val, v_val) + + start_pos = input_pos[0].item() + + output = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=start_pos, + is_causal=True, + scale=self.head_dim**-0.5, + ) + return output + + +@register_test +class CustomSDPATest(OpTestCase): + """ + Test case for mlx::custom_sdpa with KV cache slicing. + + Verifies that custom_sdpa: + 1. Correctly slices K/V cache to [0:start_pos+seq_len] + 2. Produces numerically correct attention output + 3. Handles GQA (fewer KV heads than Q heads) + 4. Works with dynamic shapes (varying seq_len and start_pos) + """ + + name = "custom_sdpa" + rtol = 1e-3 + atol = 1e-3 + expected_node_counts = { + "SdpaNode": 1, + "SliceUpdateNode": 2, + "SliceNode": 2, + "IdCopyNode": 2, + "ExpandDimsNode": 0, + } + + def __init__( + self, + n_heads: int = 8, + n_kv_heads: int = 8, + head_dim: int = 64, + max_context_length: int = 128, + seq_len: int = 8, + ): + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_len = seq_len + + parts = ["custom_sdpa"] + if n_kv_heads != n_heads: + parts.append(f"gqa{n_kv_heads}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["CustomSDPATest"]: + return [ + cls(), # MHA + cls(n_kv_heads=4), # GQA (8 Q heads, 4 KV heads) + cls(n_kv_heads=1), # MQA (8 Q heads, 1 KV head) + ] + + def create_model(self) -> nn.Module: + return CustomSDPAModel( + max_context_length=self.max_context_length, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + q = torch.randn(1, self.n_heads, self.seq_len, self.head_dim) + k = torch.randn(1, self.n_kv_heads, self.seq_len, self.head_dim) + v = torch.randn(1, self.n_kv_heads, self.seq_len, self.head_dim) + return (input_pos, q, k, v) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_len = self.seq_len + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + q = torch.randn(1, self.n_heads, test_seq_len, self.head_dim) + k = torch.randn(1, self.n_kv_heads, test_seq_len, self.head_dim) + v = torch.randn(1, self.n_kv_heads, test_seq_len, self.head_dim) + return (input_pos, q, k, v) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_len", min=1, max=self.max_context_length) + return { + "input_pos": None, + "q": {2: seq_dim}, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class QuantizedLinearModel(nn.Module): + """Simple linear layer that will be quantized.""" + + def __init__( + self, in_features: int = 64, out_features: int = 128, bias: bool = True + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class QuantizedLinearTest(OpTestCase): + """Test case for TorchAO int4 quantized nn.Linear.""" + + name = "quantized_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + self.group_size = group_size + self.dtype = dtype + + parts = ["quantized_linear", f"g{group_size}"] + if not bias: + parts.append("no_bias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["QuantizedLinearTest"]: + return [ + cls(), + ] + + def create_model(self) -> nn.Module: + model = QuantizedLinearModel( + self.in_features, self.out_features, bias=self.bias + ) + model = model.to(self.dtype) + + try: + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(self.group_size) + ), + ) + except ImportError: + raise RuntimeError("TorchAO not installed. Run: pip install torchao") + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.seq_len, self.in_features, dtype=self.dtype + ) + return (x,) + + +class QuantizedEmbeddingModel(nn.Module): + """Simple embedding layer that will be quantized.""" + + def __init__( + self, + num_embeddings: int = 1000, + embedding_dim: int = 64, + ): super().__init__() - self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + self.embedding = nn.Embedding(num_embeddings, embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.bmm(x, self.weight) + return self.embedding(x) @register_test -class BmmTest(OpTestCase): - """Test case for bmm (batch matrix multiplication).""" +class QuantizedEmbeddingTest(OpTestCase): + """Test case for TorchAO int4 quantized nn.Embedding.""" - name = "bmm" - rtol = 1e-4 - atol = 1e-4 + name = "quantized_embedding" + rtol = 0.1 + atol = 0.1 def __init__( self, - batch_size: int = 4, - n: int = 8, - m: int = 16, - p: int = 32, + num_embeddings: int = 1000, + embedding_dim: int = 64, + batch_size: int = 2, + seq_len: int = 16, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.group_size = group_size + self.dtype = dtype + + parts = ["quantized_embedding", f"g{group_size}"] + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["QuantizedEmbeddingTest"]: + return [ + cls(), + ] + + def create_model(self) -> nn.Module: + model = QuantizedEmbeddingModel(self.num_embeddings, self.embedding_dim) + model = model.to(self.dtype) + + try: + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + + def embedding_filter(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Embedding) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(self.group_size) + ), + embedding_filter, + ) + except ImportError: + raise RuntimeError("TorchAO not installed. Run: pip install torchao") + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) + return (x,) + + +class DequantizeConv2dModel(nn.Module): + """Conv2d layer whose weight will be quantized. + + The pattern matcher only fuses dequantize_affine with linear and embedding. + A quantized Conv2d produces a standalone dequantize_affine node in the graph, + exercising the DequantizeNode path. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 64, + kernel_size: int = 3, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, padding=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class DequantizeTest(OpTestCase): + """Test case for standalone TorchAO dequantize_affine (DequantizeNode). + + Uses a quantized Conv2d to produce a standalone dequantize_affine node, + since the pattern matcher only fuses dequantize with linear/embedding. + """ + + name = "dequantize" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 64, + kernel_size: int = 3, + height: int = 8, + width: int = 8, + batch_size: int = 1, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.height = height + self.width = width self.batch_size = batch_size + self.group_size = group_size + self.dtype = dtype + + parts = ["dequantize", f"g{group_size}"] + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["DequantizeTest"]: + return [ + cls(), + ] + + def create_model(self) -> nn.Module: + model = DequantizeConv2dModel( + self.in_channels, self.out_channels, self.kernel_size + ) + model = model.to(self.dtype) + + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + + def conv2d_filter(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Conv2d) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(self.group_size) + ), + conv2d_filter, + ) + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.height, + self.width, + dtype=self.dtype, + ) + return (x,) + + +class CumsumModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cumsum(x, dim=self.dim) + + +@register_test +class CumsumTest(OpTestCase): + name = "cumsum" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"cumsum_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["CumsumTest"]: + return [ + cls(shape=(8,), dim=0), + cls(shape=(3, 4), dim=0), + cls(shape=(3, 4), dim=1), + cls(shape=(2, 3, 4), dim=-1), + ] + + def create_model(self) -> nn.Module: + return CumsumModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class StackModel(nn.Module): + def __init__(self, dim: int = 0, n: int = 3): + super().__init__() + self.dim = dim self.n = n - self.m = m - self.p = p - self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + def forward(self, *tensors: torch.Tensor) -> torch.Tensor: + return torch.stack(tensors[: self.n], dim=self.dim) + + +@register_test +class StackTest(OpTestCase): + name = "stack" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0, n: int = 3): + self.shape = shape + self.dim = dim + self.n = n + shape_str = "x".join(str(s) for s in shape) + self.name = f"stack_dim{dim}_n{n}_{shape_str}" @classmethod - def get_test_configs(cls) -> List["BmmTest"]: + def get_test_configs(cls) -> List["StackTest"]: return [ - cls(batch_size=4, n=8, m=16, p=32), - cls(batch_size=2, n=64, m=64, p=32), + cls(shape=(3, 4), dim=0, n=3), + cls(shape=(3, 4), dim=1, n=2), + cls(shape=(2, 3), dim=-1, n=4), ] def create_model(self) -> nn.Module: - return BmmModel(self.batch_size, self.n, self.m, self.p) + return StackModel(dim=self.dim, n=self.n) def create_inputs(self) -> Tuple[torch.Tensor, ...]: - x = torch.randn(self.batch_size, self.n, self.m) + return tuple(torch.randn(self.shape) for _ in range(self.n)) + + +class SignModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) + + +@register_test +class SignTest(OpTestCase): + name = "sign" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4)): + self.shape = shape + shape_str = "x".join(str(s) for s in shape) + self.name = f"sign_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["SignTest"]: + return [ + cls(shape=(8,)), + cls(shape=(3, 4)), + cls(shape=(2, 3, 4)), + ] + + def create_model(self) -> nn.Module: + return SignModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class AnyModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.any(x, dim=self.dim) + + +@register_test +class AnyTest(OpTestCase): + name = "any" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"any_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["AnyTest"]: + return [ + cls(shape=(4, 6), dim=0), + cls(shape=(4, 6), dim=1), + cls(shape=(2, 3, 4), dim=-1), + ] + + def create_model(self) -> nn.Module: + return AnyModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Mix of True/False values + return (torch.randint(0, 2, self.shape).bool(),) + + +class AllModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.all(x, dim=self.dim) + + +@register_test +class AllTest(OpTestCase): + name = "all" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"all_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["AllTest"]: + return [ + cls(shape=(4, 6), dim=0), + cls(shape=(4, 6), dim=1), + cls(shape=(2, 3, 4), dim=-1), + ] + + def create_model(self) -> nn.Module: + return AllModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Mostly True with some False + x = torch.ones(self.shape, dtype=torch.bool) + x[0] = False return (x,) + + +class RepeatInterleaveModel(nn.Module): + def __init__(self, repeats: int, dim: int): + super().__init__() + self.repeats = repeats + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.repeat_interleave(self.repeats, dim=self.dim) + + +@register_test +class RepeatInterleaveTest(OpTestCase): + name = "repeat_interleave" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + repeats: int = 2, + dim: int = 0, + ): + self.shape = shape + self.repeats = repeats + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"repeat_interleave_r{repeats}_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["RepeatInterleaveTest"]: + return [ + cls(shape=(2, 4), repeats=3, dim=0), + cls(shape=(2, 4), repeats=2, dim=1), + cls(shape=(1, 8, 4, 16), repeats=4, dim=1), # GQA-like pattern + ] + + def create_model(self) -> nn.Module: + return RepeatInterleaveModel(self.repeats, self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class SortModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Only return sorted values + return torch.sort(x, dim=self.dim)[0] + + +class SortIndicesModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Only return sort indices + return torch.sort(x, dim=self.dim)[1] + + +class SortBothModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + values, indices = torch.sort(x, dim=self.dim) + return values, indices + + +@register_test +class SortTest(OpTestCase): + name = "sort" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (3, 4), + dim: int = -1, + output: str = "values", + ): + self.shape = shape + self.dim = dim + self.output = output + shape_str = "x".join(str(s) for s in shape) + self.name = f"sort_{output}_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["SortTest"]: + return [ + cls(shape=(8,), dim=0, output="values"), + cls(shape=(3, 4), dim=-1, output="values"), + cls(shape=(3, 4), dim=0, output="indices"), + cls(shape=(2, 3, 4), dim=1, output="both"), + ] + + def create_model(self) -> nn.Module: + if self.output == "values": + return SortModel(self.dim) + elif self.output == "indices": + return SortIndicesModel(self.dim) + else: + return SortBothModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + def get_expected_node_counts(self) -> Optional[Dict[str, int]]: + if self.output == "values": + return {"SortNode": 1, "ArgsortNode": 0} + elif self.output == "indices": + return {"SortNode": 0, "ArgsortNode": 1} + else: + return {"SortNode": 1, "ArgsortNode": 1} + + +class ArgsortModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.argsort(x, dim=self.dim) + + +@register_test +class ArgsortTest(OpTestCase): + name = "argsort" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = -1): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"argsort_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["ArgsortTest"]: + return [ + cls(shape=(8,), dim=0), + cls(shape=(3, 4), dim=-1), + cls(shape=(3, 4), dim=0), + ] + + def create_model(self) -> nn.Module: + return ArgsortModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class TopKValuesModel(nn.Module): + def __init__(self, k: int, dim: int): + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.topk(x, self.k, dim=self.dim)[0] + + +class TopKIndicesModel(nn.Module): + def __init__(self, k: int, dim: int): + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.topk(x, self.k, dim=self.dim)[1] + + +class TopKBothModel(nn.Module): + def __init__(self, k: int, dim: int): + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + values, indices = torch.topk(x, self.k, dim=self.dim) + return values, indices + + +class TopKDynamicKModel(nn.Module): + """TopK with k derived from a dynamic tensor shape (exercises dynamic k path).""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor, k_source: torch.Tensor) -> torch.Tensor: + k = k_source.shape[0] + return torch.topk(x, k, dim=self.dim)[0] + + +@register_test +class TopKTest(OpTestCase): + name = "topk" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (3, 8), + k: int = 3, + dim: int = -1, + output: str = "values", + ): + self.shape = shape + self.k = k + self.dim = dim + self.output = output + shape_str = "x".join(str(s) for s in shape) + self.name = f"topk_k{k}_dim{dim}_{output}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["TopKTest"]: + return [ + # Values only + cls(shape=(16,), k=5, dim=0, output="values"), + cls(shape=(4, 8), k=3, dim=-1, output="values"), + cls(shape=(2, 4, 16), k=4, dim=-1, output="values"), + # Indices only + cls(shape=(4, 8), k=3, dim=-1, output="indices"), + # Both values and indices + cls(shape=(4, 8), k=3, dim=-1, output="both"), + # Dynamic k + cls(shape=(4, 8), k=3, dim=-1, output="dynamic_k"), + ] + + def create_model(self) -> nn.Module: + if self.output == "values": + return TopKValuesModel(self.k, self.dim) + elif self.output == "indices": + return TopKIndicesModel(self.k, self.dim) + elif self.output == "dynamic_k": + return TopKDynamicKModel(self.dim) + else: + return TopKBothModel(self.k, self.dim) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if self.output == "dynamic_k": + k_dim = Dim("k", min=1, max=self.shape[self.dim]) + return {"x": None, "k_source": {0: k_dim}} + return None + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.output == "dynamic_k": + return (torch.randn(self.shape), torch.randn(self.k)) + return (torch.randn(self.shape),) From 51462b366da43efa66c93e63f00e68e11f1b5b70 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:24:57 -0800 Subject: [PATCH 12/34] up --- backends/mlx/runtime/MLXInterpreter.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index f6aabd9af8f..1924f9faa35 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -764,11 +764,6 @@ exec_contiguous(const ContiguousNode& n, ExecutionState& st, StreamOrDevice s) { st.set_tensor(n.out, contiguous(st.const_tensor_ref(n.x), false, s)); } -inline void -exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { - st.set_tensor(n.out, st.const_tensor_ref(n.x)); -} - inline void exec_gather(const GatherNode& n, ExecutionState& st, StreamOrDevice s) { const auto& x = st.const_tensor_ref(n.x); @@ -1769,9 +1764,6 @@ class Interpreter { case OpCode::CONTIGUOUS: ops::exec_contiguous(std::get(instr.node), st, s); break; - case OpCode::ID_COPY: - ops::exec_id_copy(std::get(instr.node), st, s); - break; case OpCode::GATHER: ops::exec_gather(std::get(instr.node), st, s); break; From 1add04d1e88ff363b7980ecec4803b540e23dac0 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:30:07 -0800 Subject: [PATCH 13/34] up --- backends/mlx/builder/op_helpers.py | 77 +++++++ backends/mlx/ops.py | 5 +- backends/mlx/patterns.py | 302 ++++++++++++++++++++++---- backends/mlx/runtime/MLXInterpreter.h | 71 ++---- backends/mlx/serialization/schema.fbs | 16 +- backends/mlx/test/test_ops.py | 90 ++++++++ extension/llm/export/nvfp4.py | 190 ++++++++++++++++ 7 files changed, 648 insertions(+), 103 deletions(-) create mode 100644 extension/llm/export/nvfp4.py diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index 5e082cdf386..790fd63ebdc 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -18,6 +18,11 @@ if TYPE_CHECKING: from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +# When True, always serialize the biases tensor for quantized ops. +# When False, use init-time computation when zero_point is all zeros, +# computing biases = -scales * 2^(bits-1) during the init chain. +QUANTIZED_SERIALIZE_BIASES = False + def get_aten_target(target): """ @@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S return slot +def emit_quantized_biases( + P: "MLXProgramBuilder", + zero_point_key: str, + scale: torch.Tensor, + zero_point: torch.Tensor, + bits: int, + B: torch.Tensor, + scale_slot: "Slot", +) -> "Slot": + """Emit biases for quantized ops, computing at init time when possible. + + When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False, + avoids serializing the biases tensor by computing biases = scales * -offset + during the init chain instead. + + Returns the biases Slot. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import MultiplyNode + from torch._subclasses.fake_tensor import FakeTensor + + is_scale_only = False + if not isinstance(zero_point, FakeTensor): + if torch.sum(torch.abs(zero_point)).item() == 0: + is_scale_only = True + + if QUANTIZED_SERIALIZE_BIASES or not is_scale_only: + return P.make_or_get_constant(f"{zero_point_key}_to_biases", B) + + scale_dtype = scale.dtype + offset = 1 << (bits - 1) + neg_offset = emit_lifted_constant(P, -offset, scale_dtype) + biases = P.make_or_get_constant( + f"{zero_point_key}_to_biases_dummy", torch.tensor(0.0, dtype=B.dtype) + ) + P.emit_init( + MultiplyNode( + a=P.slot_to_tid(scale_slot), + b=P.slot_to_tid(neg_offset), + out=P.slot_to_tid(biases), + ) + ) + return biases + + def to_mlx_qparams( qdata: torch.Tensor, scale: torch.Tensor, @@ -217,6 +266,34 @@ def to_mlx_qparams( return Q, None +def parse_dequant_nvfp4_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, torch.dtype]]: + """Parse a torchao.dequantize_nvfp4 node. + + Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a + dequantize_nvfp4 node or the custom op is not registered. + """ + target = get_aten_target(node.target) + try: + import executorch.extension.llm.export.nvfp4 # noqa: F401 + except ImportError: + return None + + if target is not torch.ops.torchao.dequantize_nvfp4.default: + return None + + qdata, scale, per_tensor_scale = node.args[0:3] + + output_dtype = torch.float32 + if len(node.args) > 4: + output_dtype = node.args[4] + elif "output_dtype" in node.kwargs: + output_dtype = node.kwargs["output_dtype"] + + return qdata, scale, per_tensor_scale, output_dtype + + def parse_dequant_node( node: Node, ) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 0d5f21ebd7d..a181bf03422 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -21,6 +21,7 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( emit_lifted_constant, + emit_quantized_biases, parse_dequant_node, to_mlx_qparams, torch_dtype_to_scalar_type, @@ -3646,8 +3647,10 @@ def _dequantize_affine_handler(P: MLXProgramBuilder, n: Node) -> Slot: B = B.reshape(*leading_dims, B.shape[-1]) w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) - biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B) scale_const = P.make_or_get_constant(f"{scale_target}_scale", scale_nd) + biases = emit_quantized_biases( + P, zero_point_target, scale, zero_point, bits, B, scale_const + ) if needs_permute: _, dequant_tmp = P.make_tmp_slot() diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py index 908fa52b448..29e5e326c69 100644 --- a/backends/mlx/patterns.py +++ b/backends/mlx/patterns.py @@ -19,8 +19,10 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( + emit_quantized_biases, emit_stop_position, parse_dequant_node, + parse_dequant_nvfp4_node, to_mlx_qparams, torch_dtype_to_scalar_type, ) @@ -35,12 +37,15 @@ ) from executorch.backends.mlx.serialization.mlx_graph_schema import ( AddIntNode, + AddNode, + AsTypeNode, DequantizeNode, IndexCopyNode, IntOrVid, IntOrVidOrTid, ModIntNode, - QuantizedLinearNode, + MultiplyNode, + QuantizedMatmulNode, SdpaNode, SliceNode, SliceUpdateNode, @@ -51,11 +56,6 @@ from torch.export.exported_program import ExportedProgram from torch.fx.node import Node -# When True, always serialize the biases tensor for quantized ops (existing behavior). -# When False, use scale_only=True optimization when zero_point is all zeros, -# which avoids serializing the biases tensor (C++ runtime computes: biases = -scales * 2^(bits-1)). -QUANTIZED_SERIALIZE_BIASES = True - @REGISTRY.register_pattern(name="INDEX_COPY") class IndexCopyHandler(PatternHandler): @@ -526,6 +526,7 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: q, k, v, attn_mask = P.slot_map([q, k, v, attn_mask]) out = P.make_or_get_slot(n) + P.emit( SdpaNode( q=P.slot_to_tid(q), @@ -540,6 +541,122 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register_pattern(name="NVFP4_QUANTIZED_EMBEDDING") +class NVFP4QuantizedEmbeddingHandler(PatternHandler): + """Fuse dequantize_nvfp4 + embedding into gather + DequantizeNode(mode="nvfp4"). + + Matches: + embedding(dequantize_nvfp4(qdata, scale, per_tensor_scale, ...), indices) + + Emits: + TakeNode(qdata) → TakeNode(scales) → DequantizeNode(mode="nvfp4") + [→ MultiplyNode(per_tensor_scale)] [→ AsTypeNode] + """ + + def __init__(self, head, body, qdata, scale, per_tensor_scale, output_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.per_tensor_scale = per_tensor_scale + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.embedding.default): + return None + + w, x = head.args[0:2] + if not isinstance(w, Node): + return None + if not has_single_user(w): + return None + parsed = parse_dequant_nvfp4_node(w) + if parsed is None: + return None + qdata, scale, per_tensor_scale, output_dtype = parsed + return cls(head, [w], qdata, scale, per_tensor_scale, output_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + w_node, x_node = n.args[0:2] + + has_per_tensor_scale = True + _, per_tensor_scale_value = P.get_placeholder_target_and_tensor( + self.per_tensor_scale + ) + from torch._subclasses.fake_tensor import FakeTensor + + if not isinstance(per_tensor_scale_value, FakeTensor): + if per_tensor_scale_value.item() == 1.0: + has_per_tensor_scale = False + + x_dtype = x_node.meta["val"].dtype + needs_cast = self.output_dtype != x_dtype + + x, scales_slot, per_tensor_scale, qdata_slot = P.slot_map( + [x_node, self.scale, self.per_tensor_scale, self.qdata] + ) + + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) + + # Gather quantized weights by indices + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(qdata_slot), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + + # Gather scales by indices + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scales_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + + # Dequantize the gathered slices + out = P.make_or_get_slot(n) + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=None, + group_size=16, + bits=4, + mode="nvfp4", + dtype=torch_dtype_to_scalar_type(self.output_dtype), + ) + ) + + if has_per_tensor_scale: + P.emit( + MultiplyNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(per_tensor_scale), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(self.output_dtype), + ) + ) + + return out + + @REGISTRY.register_pattern(name="MLX_CUSTOM_SDPA") class MLXCustomSdpaHandler(PatternHandler): """ @@ -769,8 +886,8 @@ def maybe_create( def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: assert n == self.head - x, w = n.args[0:2] - b = n.args[2] if len(n.args) > 2 else None + x_node, w_node = n.args[0:2] + b_node = n.args[2] if len(n.args) > 2 else None qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) zero_point_target, zero_point = P.get_placeholder_target_and_tensor( @@ -778,49 +895,51 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: ) _, scale = P.get_placeholder_target_and_tensor(self.scale) - out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) - - # Check if we can use scale_only optimization: - # When zero_point is all zeros, biases = -scales * 2^(bits-1) - # which can be computed at runtime instead of serialized. - # Note: During partitioning, tensors are FakeTensors so we skip the check. - # The optimization is only applied during preprocess when we have real tensors. - use_scale_only = False - if not QUANTIZED_SERIALIZE_BIASES: - from torch._subclasses.fake_tensor import FakeTensor - - if not isinstance(zero_point, FakeTensor): - if torch.sum(torch.abs(zero_point)).item() == 0: - use_scale_only = True + x_slot, scale_slot, b_slot = P.slot_map([x_node, self.scale, b_node]) - Q, B = to_mlx_qparams( - qdata, scale, zero_point, self.bits, compute_biases=not use_scale_only - ) + Q, B = to_mlx_qparams(qdata, scale, zero_point, self.bits) w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + biases = emit_quantized_biases( + P, zero_point_target, scale, zero_point, self.bits, B, scale_slot + ) - if use_scale_only: - biases_tid = None - else: - biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B) - biases_tid = P.slot_to_tid(biases) - - x, scale_slot, b = P.slot_map([x, self.scale, b]) out = P.make_or_get_slot(n) + has_bias = b_node is not None + x_dtype = x_node.meta["val"].dtype + needs_cast = self.out_dtype != x_dtype + P.emit( - QuantizedLinearNode( - x=P.slot_to_tid(x), + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), w=P.slot_to_tid(w), scales=P.slot_to_tid(scale_slot), out=P.slot_to_tid(out), - biases=biases_tid, - bias=P.slot_to_tid(b) if b else None, + biases=P.slot_to_tid(biases), group_size=self.group_size, bits=self.bits, mode="affine", - out_scalar_type=out_scalar_type, - scale_only=use_scale_only, + transpose=True, ) ) + + if has_bias: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(b_slot), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(self.out_dtype), + ) + ) + return out @@ -898,9 +1017,11 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) - biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B) x, scale_slot = P.slot_map([x, self.scale]) + biases = emit_quantized_biases( + P, zero_point_target, scale, zero_point, self.bits, B, scale_slot + ) ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) # Gather quantized weights by ids @@ -947,7 +1068,108 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: group_size=self.group_size, bits=self.bits, mode="affine", - out_scalar_type=out_scalar_type, + dtype=out_scalar_type, + ) + ) + return out + + +@REGISTRY.register_pattern(name="NVFP4_QUANTIZED_LINEAR") +class NVFP4QuantizedLinearHandler(PatternHandler): + """Fuse dequantize_nvfp4 + linear into QuantizedMatmulNode(mode="nvfp4"). + + Matches: + linear(x, dequantize_nvfp4(qdata, scale, block_size, [per_tensor_scale]), bias) + + Emits: + QuantizedMatmulNode [→ MultiplyNode(per_tensor_scale)] [→ AddNode(bias)] + """ + + def __init__(self, head, body, qdata, scale, per_tensor_scale, output_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.per_tensor_scale = per_tensor_scale + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.linear.default): + return None + x, dequant = head.args[0:2] + if not isinstance(dequant, Node): + return None + if not has_single_user(dequant): + return None + parsed = parse_dequant_nvfp4_node(dequant) + if parsed is None: + return None + qdata, scale, per_tensor_scale, output_dtype = parsed + return cls(head, [dequant], qdata, scale, per_tensor_scale, output_dtype) + + def __call__(self, P, n): + assert n == self.head + + x_node, w_node = n.args[0:2] + b_node = n.args[2] if len(n.args) > 2 else None + + needs_cast = x_node.meta["val"].dtype != self.output_dtype + has_bias = b_node is not None + has_per_tensor_scale = True + + _, per_tensor_scale_value = P.get_placeholder_target_and_tensor( + self.per_tensor_scale + ) + from torch._subclasses.fake_tensor import FakeTensor + + if not isinstance(per_tensor_scale_value, FakeTensor): + if per_tensor_scale_value.item() == 1.0: + has_per_tensor_scale = False + + x, w, scales, bias, per_tensor_scale = P.slot_map( + [x_node, self.qdata, self.scale, b_node, self.per_tensor_scale] + ) + + out = P.make_or_get_slot(n) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scales), + out=P.slot_to_tid(out), + biases=None, + group_size=16, + bits=4, + mode="nvfp4", + transpose=True, ) ) + + if has_per_tensor_scale: + P.emit( + MultiplyNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(per_tensor_scale), + out=P.slot_to_tid(out), + ) + ) + + if has_bias: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(bias), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(self.output_dtype), + ) + ) + return out diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 1924f9faa35..5acce31b47d 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -839,54 +839,21 @@ exec_astype(const AsTypeNode& n, ExecutionState& st, StreamOrDevice s) { n.out, astype(st.const_tensor_ref(n.x), resolve_dtype(n.scalar_type), s)); } -inline void exec_quantized_linear( - const QuantizedLinearNode& n, +inline void exec_quantized_matmul( + const QuantizedMatmulNode& n, ExecutionState& st, StreamOrDevice s) { - // scale_only means biases should be computed, not provided - assert( - !(n.scale_only && n.biases) && - "scale_only=true but biases tensor also provided"); - array X = st.const_tensor_ref(n.x); array Wq = st.const_tensor_ref(n.w); array Sc = st.const_tensor_ref(n.scales); - if (n.bits <= 0 || n.bits > 8) { - throw std::runtime_error( - "exec_quantized_linear: bits must be in [1, 8], got " + - std::to_string(n.bits)); - } - std::optional Qb = std::nullopt; - if (n.biases) { + if (n.biases.has_value()) { Qb = st.const_tensor_ref(*n.biases); - } else if (n.scale_only) { - // Compute biases from scales: B = -scales * 2^(bits-1) - float offset = static_cast(1 << (n.bits - 1)); - Qb = multiply(Sc, array(-offset, Sc.dtype()), s); } array Y = quantized_matmul( - X, - Wq, - Sc, - Qb, - /*transpose=*/true, - n.group_size, - n.bits, - n.mode, - s); - - if (n.bias) { - const auto& b = st.const_tensor_ref(*n.bias); - Y = add(Y, b, s); - } - - Dtype out_dtype = resolve_dtype(n.out_scalar_type); - if (out_dtype != Y.dtype()) { - Y = astype(Y, out_dtype, s); - } + X, Wq, Sc, Qb, n.transpose, n.group_size, n.bits, n.mode, s); st.set_tensor(n.out, std::move(Y)); } @@ -1142,21 +1109,19 @@ exec_dequantize(const DequantizeNode& n, ExecutionState& st, StreamOrDevice s) { Qb = st.const_tensor_ref(*n.biases); } - array Y = dequantize( - Wq, - Sc, - Qb, - n.group_size, - n.bits, - n.mode, - std::nullopt, // dtype - let MLX infer - s); + std::optional global_scale = std::nullopt; + if (n.global_scale) { + global_scale = st.const_tensor_ref(*n.global_scale); + } - Dtype out_dtype = resolve_dtype(n.out_scalar_type); - if (out_dtype != Y.dtype()) { - Y = astype(Y, out_dtype, s); + std::optional dtype = std::nullopt; + if (n.dtype) { + dtype = resolve_dtype(*n.dtype); } + array Y = dequantize( + Wq, Sc, Qb, n.group_size, n.bits, n.mode, global_scale, dtype, s); + st.set_tensor(n.out, std::move(Y)); } @@ -1773,10 +1738,6 @@ class Interpreter { case OpCode::ASTYPE: ops::exec_astype(std::get(instr.node), st, s); break; - case OpCode::QUANTIZED_LINEAR: - ops::exec_quantized_linear( - std::get(instr.node), st, s); - break; case OpCode::CONCATENATE: ops::exec_concatenate(std::get(instr.node), st, s); break; @@ -2002,6 +1963,10 @@ class Interpreter { case OpCode::ARG_PARTITION: ops::exec_argpartition(std::get(instr.node), st, s); break; + case OpCode::QUANTIZED_MATMUL: + ops::exec_quantized_matmul( + std::get(instr.node), st, s); + break; default: throw std::runtime_error( "Unknown opcode: " + std::to_string(static_cast(instr.op))); diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index f27e37056e4..d8d72d4fbd9 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -444,18 +444,16 @@ table AsTypeNode { scalar_type: int8; // ET ScalarType } -table QuantizedLinearNode { +table QuantizedMatmulNode { x: Tid (required); w: Tid (required); scales: Tid (required); out: Tid (required); - biases: Tid; // optional - quantization biases (required if scale_only=false) - bias: Tid; // optional - neural network bias + biases: Tid; // optional - required for affine mode, null for nvfp4 group_size: int32; bits: int32; mode: string (required); - out_scalar_type: int8; // ET ScalarType for output - scale_only: bool = false; // if true, compute biases = -scales * 2^(bits-1); if false, biases tensor required + transpose: bool = true; } table ConcatenateNode { @@ -514,7 +512,8 @@ table DequantizeNode { group_size: int32; bits: int32; mode: string (required); // Quantization mode (e.g. "affine") - out_scalar_type: int8; // ET ScalarType for output dtype + global_scale: Tid; // optional - global scale for nvfp4 + dtype: int8 = null; // ET ScalarType for output dtype } // Comparison ops (return bool arrays) @@ -961,7 +960,6 @@ union OpNode { GatherNode, SliceNode, AsTypeNode, - QuantizedLinearNode, ConcatenateNode, FullNode, FullLikeNode, @@ -1035,9 +1033,9 @@ union OpNode { SortNode, ArgsortNode, PartitionNode, - ArgPartitionNode + ArgPartitionNode, + QuantizedMatmulNode // BC: Add new op nodes here (append only) ->>>>>>> 7e54fa1e87 (up) } // ============================================================================= diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 164e94e3e78..bc7e01b3c18 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -5433,6 +5433,7 @@ def __init__( def get_test_configs(cls) -> List["QuantizedLinearTest"]: return [ cls(), + cls(bias=False), ] def create_model(self) -> nn.Module: @@ -6072,3 +6073,92 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: if self.output == "dynamic_k": return (torch.randn(self.shape), torch.randn(self.k)) return (torch.randn(self.shape),) + + +class NVFP4QuantizedLinearModel(nn.Module): + """Simple linear layer that will be quantized with NVFP4.""" + + def __init__( + self, in_features: int = 64, out_features: int = 128, bias: bool = True + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class NVFP4QuantizedLinearTest(OpTestCase): + """Test case for NVFP4 quantized nn.Linear.""" + + name = "nvfp4_quantized_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + use_per_tensor_scale: bool = True, + dtype: torch.dtype = torch.float32, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + self.use_per_tensor_scale = use_per_tensor_scale + self.dtype = dtype + + parts = ["nvfp4_quantized_linear"] + if not bias: + parts.append("no_bias") + if not use_per_tensor_scale: + parts.append("no_pts") + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["NVFP4QuantizedLinearTest"]: + return [ + cls(), + cls(bias=False), + cls(use_per_tensor_scale=False), + cls(bias=False, use_per_tensor_scale=False), + cls(dtype=torch.bfloat16), + cls(bias=False, dtype=torch.bfloat16), + cls(use_per_tensor_scale=False, dtype=torch.bfloat16), + cls(bias=False, use_per_tensor_scale=False, dtype=torch.bfloat16), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = NVFP4QuantizedLinearModel( + self.in_features, self.out_features, bias=self.bias + ) + model = model.to(self.dtype) + + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + from torchao.quantization import quantize_ + + quantize_( + model, + ExportableNVFP4Config(use_per_tensor_scale=self.use_per_tensor_scale), + ) + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.seq_len, self.in_features, dtype=self.dtype + ) + return (x,) diff --git a/extension/llm/export/nvfp4.py b/extension/llm/export/nvfp4.py new file mode 100644 index 00000000000..40ffd4c1bb2 --- /dev/null +++ b/extension/llm/export/nvfp4.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +NVFP4 export-compatible quantization. + +Upstream NVFP4Tensor's dequantize() uses raw Python ops that don't survive +run_decompositions. This module registers a torch.library custom op +(torchao::dequantize_nvfp4) so the dequant node persists through export, +similar to how dequantize_affine works for int4. + +Usage: + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + from torchao.quantization import quantize_ + + quantize_(model, ExportableNVFP4Config()) +""" + +import types +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch import Tensor +from torchao.core.config import AOBaseConfig +from torchao.prototype.mx_formats.kernels import f4_unpacked_to_f32, unpack_uint4 +from torchao.prototype.mx_formats.nvfp4_tensor import ( + nvfp4_quantize, + per_tensor_amax_to_scale, +) +from torchao.quantization.quant_api import _quantization_type +from torchao.quantization.transform_module import register_quantize_module_handler +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +from typing import Optional + + +@torch.library.custom_op("torchao::dequantize_nvfp4", mutates_args=()) +def nvfp4_dequantize( + qdata: Tensor, + scale: Tensor, + per_tensor_scale: Tensor, + block_size: int, + output_dtype: torch.dtype = torch.float32, +) -> Tensor: + """Dequantize NVFP4 packed data.""" + data_unpacked = unpack_uint4(qdata.view(torch.uint8).contiguous()) + data_f32 = f4_unpacked_to_f32(data_unpacked) + + M = data_f32.shape[0] + K = data_f32.shape[1] + + data_f32 = data_f32.view(M, K // block_size, block_size) + scale_fp8 = scale.view(torch.float8_e4m3fn) + scale_f32 = scale_fp8.to(torch.float32).view(M, K // block_size, 1) + scale_f32 = per_tensor_scale * scale_f32 + result = (data_f32 * scale_f32).view(M, K) + return result.to(output_dtype) + + +@nvfp4_dequantize.register_fake +def _(qdata, scale, per_tensor_scale, block_size, output_dtype=torch.float32): + M = qdata.shape[0] + K = qdata.shape[1] * 8 # 8 FP4 values per uint32 + return torch.empty(M, K, dtype=output_dtype, device=qdata.device) + + +class ExportableNVFP4Tensor(TorchAOBaseTensor): + """NVFP4 tensor subclass that dequantizes via a registered custom op.""" + + tensor_data_names = ["qdata", "scale", "per_tensor_scale"] + tensor_attribute_names = ["block_size", "orig_dtype"] + + def __new__(cls, qdata, scale, per_tensor_scale, block_size, orig_dtype): + K = qdata.shape[-1] * 8 # 8 FP4 values per uint32 + shape = (qdata.shape[0], K) + self = torch.Tensor._make_wrapper_subclass( + cls, shape, dtype=orig_dtype, device=qdata.device, requires_grad=False + ) + self.qdata = qdata + self.scale = scale + self.per_tensor_scale = per_tensor_scale + self.block_size = block_size + self.orig_dtype = orig_dtype + return self + + def dequantize(self, output_dtype=None): + dtype = output_dtype or self.orig_dtype + return torch.ops.torchao.dequantize_nvfp4( + self.qdata, + self.scale, + self.per_tensor_scale, + self.block_size, + output_dtype=dtype, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = ExportableNVFP4Tensor.implements + + +@implements([aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor = args[0] + weight_tensor = args[1] + bias = args[2] if len(args) > 2 else None + weight_dequant = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_dequant, bias) + + +@implements([aten.embedding.default]) +def _(func, types, args, kwargs): + weight_tensor = args[0] + indices = args[1] + weight_dequant = weight_tensor.dequantize() + return torch.nn.functional.embedding(indices, weight_dequant) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + return args[0].dequantize().t() + + +@implements([aten.detach.default]) +def _(func, types, args, kwargs): + return args[0] + + +@implements([aten._to_copy.default]) +def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].orig_dtype) + return args[0].dequantize(output_dtype=dtype) + + +@dataclass +class ExportableNVFP4Config(AOBaseConfig): + """NVFP4 weight-only quantization config for torch.export.""" + + use_per_tensor_scale: bool = True + + +def _linear_extra_repr(self): + return ( + f"in_features={self.weight.shape[1]}, " + f"out_features={self.weight.shape[0]}, " + f"weight={_quantization_type(self.weight)}" + ) + + +@register_quantize_module_handler(ExportableNVFP4Config) +def _exportable_nvfp4_transform(module: nn.Module, config: ExportableNVFP4Config): + weight = module.weight + + if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: + raise RuntimeError( + f"NVFP4 requires weight dims divisible by 16, got {weight.shape}" + ) + + per_tensor_scale = 1.0 + if config.use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(weight)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + + scales_fp8, qdata_packed = nvfp4_quantize( + weight, block_size=16, per_tensor_scale=per_tensor_scale + ) + + qdata_u32 = qdata_packed.view(torch.uint32) + scales_u8 = scales_fp8.view(torch.uint8) + + pts = torch.tensor(per_tensor_scale, dtype=torch.float32) + quantized_weight = ExportableNVFP4Tensor( + qdata_u32, + scales_u8, + pts, + block_size=16, + orig_dtype=weight.dtype, + ) + module.weight = nn.Parameter(quantized_weight, requires_grad=False) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module From d15ee3cfb1a186c1cdbdd1bef2cffb75779bc58c Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 22:10:31 -0800 Subject: [PATCH 14/34] up --- backends/mlx/ops.py | 6 ------ extension/llm/export/nvfp4.py | 3 --- 2 files changed, 9 deletions(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index a181bf03422..439d4569313 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -779,12 +779,6 @@ def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out -@REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) -def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: - """No-op handler for nodes that don't emit any MLX instructions.""" - return None - - @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). diff --git a/extension/llm/export/nvfp4.py b/extension/llm/export/nvfp4.py index 40ffd4c1bb2..feeb95f50a6 100644 --- a/extension/llm/export/nvfp4.py +++ b/extension/llm/export/nvfp4.py @@ -39,9 +39,6 @@ aten = torch.ops.aten -from typing import Optional - - @torch.library.custom_op("torchao::dequantize_nvfp4", mutates_args=()) def nvfp4_dequantize( qdata: Tensor, From b7da263c437ef6bd0991525c69111121d74fc6ef Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 11:09:26 -0800 Subject: [PATCH 15/34] up --- backends/mlx/builder/op_helpers.py | 10 +-- backends/mlx/runtime/MLXInterpreter.h | 21 ------ backends/mlx/serialization/schema.fbs | 8 --- backends/mlx/test/test_ops.py | 98 ++++++++++----------------- 4 files changed, 42 insertions(+), 95 deletions(-) diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index 790fd63ebdc..6ed6c9412ce 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -321,11 +321,11 @@ def parse_dequant_node( quantized_dim, group_size = non_one[0] if group_size not in [32, 64, 128]: return None - if qmin == -8 and qmax == 7: - bits = 4 - elif qmin == -128 and qmax == 127: - bits = 8 - else: + + # TODO: MLX supports 3, 5, and 7, but we need to figure out the + # packing story in to_mlx_qparams to use them + bits = (qmax - qmin + 1).bit_length() - 1 + if bits not in [2, 4, 8]: return None return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 5acce31b47d..ddb8931b282 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -149,24 +149,6 @@ exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { st.set_tensor(n.out, std::move(Y)); } -inline void -exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { - const auto& X = st.const_tensor_ref(n.x); - auto W = st.const_tensor_ref(n.weight); - W = transpose(W, {1, 0}, s); - - array Y = n.bias ? addmm( - st.const_tensor_ref(*n.bias), - X, - W, - /*alpha=*/1.0f, - /*beta=*/1.0f, - s) - : matmul(X, W, s); - - st.set_tensor(n.out, std::move(Y)); -} - inline void exec_item_int(const ItemIntNode& n, ExecutionState& st, StreamOrDevice) { // Intentional sync: item() requires a concrete scalar value for SymInt @@ -1601,9 +1583,6 @@ class Interpreter { case OpCode::ADDMM: ops::exec_addmm(std::get(instr.node), st, s); break; - case OpCode::LINEAR: - ops::exec_linear(std::get(instr.node), st, s); - break; case OpCode::ITEM_INT: ops::exec_item_int(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index d8d72d4fbd9..6b5132d6273 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -86,13 +86,6 @@ table AddmmNode { beta: float = 1.0; // Scalar multiplier for bias } -table LinearNode { - x: Tid (required); - weight: Tid (required); - out: Tid (required); - bias: Tid; // optional -} - table ItemIntNode { x: Tid (required); out: Vid (required); @@ -916,7 +909,6 @@ union OpNode { NoopNode, IdCopyNode, AddmmNode, - LinearNode, ItemIntNode, ExpandDimsNode, TileNode, diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index bc7e01b3c18..35514f4df04 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -5385,19 +5385,6 @@ def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: } -class QuantizedLinearModel(nn.Module): - """Simple linear layer that will be quantized.""" - - def __init__( - self, in_features: int = 64, out_features: int = 128, bias: bool = True - ): - super().__init__() - self.linear = nn.Linear(in_features, out_features, bias=bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - @register_test class QuantizedLinearTest(OpTestCase): """Test case for TorchAO int4 quantized nn.Linear.""" @@ -5408,13 +5395,14 @@ class QuantizedLinearTest(OpTestCase): def __init__( self, - in_features: int = 64, + in_features: int = 128, out_features: int = 128, batch_size: int = 2, seq_len: int = 16, bias: bool = True, group_size: int = 32, dtype: torch.dtype = torch.bfloat16, + qdtype: torch.dtype = torch.int4, ): self.in_features = in_features self.out_features = out_features @@ -5423,8 +5411,9 @@ def __init__( self.bias = bias self.group_size = group_size self.dtype = dtype + self.qdtype = qdtype - parts = ["quantized_linear", f"g{group_size}"] + parts = ["quantized_linear", f"{qdtype}", f"g{group_size}"] if not bias: parts.append("no_bias") self.name = "_".join(parts) @@ -5434,26 +5423,25 @@ def get_test_configs(cls) -> List["QuantizedLinearTest"]: return [ cls(), cls(bias=False), + cls(group_size=64), + cls(group_size=128), + cls(qdtype=torch.int2), + cls(qdtype=torch.int8), ] def create_model(self) -> nn.Module: - model = QuantizedLinearModel( - self.in_features, self.out_features, bias=self.bias - ) + model = LinearModel(self.in_features, self.out_features, bias=self.bias) model = model.to(self.dtype) - try: - from torchao.quantization.granularity import PerGroup - from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, granularity=PerGroup(self.group_size) - ), - ) - except ImportError: - raise RuntimeError("TorchAO not installed. Run: pip install torchao") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=self.qdtype, granularity=PerGroup(self.group_size) + ), + ) return model @@ -5464,21 +5452,6 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x,) -class QuantizedEmbeddingModel(nn.Module): - """Simple embedding layer that will be quantized.""" - - def __init__( - self, - num_embeddings: int = 1000, - embedding_dim: int = 64, - ): - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.embedding(x) - - @register_test class QuantizedEmbeddingTest(OpTestCase): """Test case for TorchAO int4 quantized nn.Embedding.""" @@ -5490,11 +5463,12 @@ class QuantizedEmbeddingTest(OpTestCase): def __init__( self, num_embeddings: int = 1000, - embedding_dim: int = 64, + embedding_dim: int = 128, batch_size: int = 2, seq_len: int = 16, group_size: int = 32, dtype: torch.dtype = torch.bfloat16, + qdtype: torch.dtype = torch.int4, ): self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -5502,36 +5476,38 @@ def __init__( self.seq_len = seq_len self.group_size = group_size self.dtype = dtype + self.qdtype = qdtype - parts = ["quantized_embedding", f"g{group_size}"] + parts = ["quantized_embedding", f"{qdtype}", f"g{group_size}"] self.name = "_".join(parts) @classmethod def get_test_configs(cls) -> List["QuantizedEmbeddingTest"]: return [ cls(), + cls(group_size=64), + cls(group_size=128), + cls(qdtype=torch.int2), + cls(qdtype=torch.int8), ] def create_model(self) -> nn.Module: - model = QuantizedEmbeddingModel(self.num_embeddings, self.embedding_dim) + model = EmbeddingModel(self.num_embeddings, self.embedding_dim) model = model.to(self.dtype) - try: - from torchao.quantization.granularity import PerGroup - from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ - def embedding_filter(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Embedding) + def embedding_filter(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Embedding) - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, granularity=PerGroup(self.group_size) - ), - embedding_filter, - ) - except ImportError: - raise RuntimeError("TorchAO not installed. Run: pip install torchao") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(self.group_size) + ), + embedding_filter, + ) return model From 681ae8c85fb7c0d74a6c9c7a92a8a0cae7c18852 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 3 Mar 2026 11:23:22 -0800 Subject: [PATCH 16/34] up --- .github/workflows/mlx.yml | 364 ++++++++++ Makefile | 38 +- backends/mlx/examples/__init__.py | 6 + backends/mlx/examples/llm/README.md | 100 +++ backends/mlx/examples/llm/export_llm_hf.py | 464 +++++++++++++ backends/mlx/examples/llm/run_llm_hf.py | 186 +++++ backends/mlx/examples/voxtral/__init__.py | 5 + .../mlx/examples/voxtral/export_voxtral_hf.py | 264 ++++++++ backends/mlx/examples/whisper/README.md | 82 +++ backends/mlx/examples/whisper/__init__.py | 7 + backends/mlx/examples/whisper/args.py | 151 +++++ .../mlx/examples/whisper/export_whisper.py | 639 ++++++++++++++++++ backends/mlx/examples/whisper/run_whisper.py | 276 ++++++++ backends/mlx/llm/et_attention.py | 252 +++++++ backends/mlx/llm/hf_attention.py | 224 ++++++ backends/mlx/llm/quantization.py | 151 +++++ backends/mlx/llm/source_transformation.py | 294 ++++++++ examples/models/llama/CMakeLists.txt | 20 +- examples/models/llama/export_llama_lib.py | 57 ++ examples/models/parakeet/CMakeLists.txt | 23 +- examples/models/parakeet/CMakePresets.json | 33 + examples/models/parakeet/README.md | 28 +- .../models/parakeet/export_parakeet_tdt.py | 18 +- examples/models/voxtral/CMakeLists.txt | 22 +- examples/models/voxtral/CMakePresets.json | 33 + examples/models/voxtral/README.md | 45 ++ .../models/voxtral_realtime/CMakeLists.txt | 13 +- .../models/voxtral_realtime/CMakePresets.json | 34 + examples/models/voxtral_realtime/README.md | 54 +- .../voxtral_realtime/export_voxtral_rt.py | 98 ++- examples/models/voxtral_realtime/model.py | 268 +++++++- extension/llm/export/config/llm_config.py | 19 + 32 files changed, 4214 insertions(+), 54 deletions(-) create mode 100644 backends/mlx/examples/__init__.py create mode 100644 backends/mlx/examples/llm/README.md create mode 100644 backends/mlx/examples/llm/export_llm_hf.py create mode 100644 backends/mlx/examples/llm/run_llm_hf.py create mode 100644 backends/mlx/examples/voxtral/__init__.py create mode 100644 backends/mlx/examples/voxtral/export_voxtral_hf.py create mode 100644 backends/mlx/examples/whisper/README.md create mode 100644 backends/mlx/examples/whisper/__init__.py create mode 100644 backends/mlx/examples/whisper/args.py create mode 100644 backends/mlx/examples/whisper/export_whisper.py create mode 100644 backends/mlx/examples/whisper/run_whisper.py create mode 100644 backends/mlx/llm/et_attention.py create mode 100644 backends/mlx/llm/hf_attention.py create mode 100644 backends/mlx/llm/quantization.py create mode 100644 backends/mlx/llm/source_transformation.py diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 53c7b9360cd..e35242b9191 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -10,6 +10,9 @@ on: - .github/workflows/mlx.yml - backends/mlx/** - extension/llm/export/** + - examples/models/parakeet/** + - examples/models/voxtral/** + - examples/models/voxtral_realtime/** workflow_dispatch: permissions: {} @@ -104,3 +107,364 @@ jobs: echo "::error::Too many test failures: $FAILED > $MAX_FAILURES" exit 1 fi + + test-mlx-parakeet: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-parakeet + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Parakeet requirements" + ${CONDA_RUN} pip install -r examples/models/parakeet/install_requirements.txt + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Parakeet" + ${CONDA_RUN} python -m executorch.examples.models.parakeet.export_parakeet_tdt \ + --backend mlx \ + --dtype bf16 \ + --qlinear_encoder 4w \ + --qlinear_encoder_group_size 128 \ + --qlinear 4w \ + --qlinear_group_size 128 \ + --output-dir /tmp/parakeet_mlx + echo "::endgroup::" + + echo "::group::Build Parakeet MLX runner" + ${CONDA_RUN} make parakeet-mlx + echo "::endgroup::" + + echo "::group::Run Parakeet MLX runner" + curl -L https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav -o /tmp/test_audio.wav + OUTPUT=$(./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path /tmp/parakeet_mlx/model.pte \ + --audio_path /tmp/test_audio.wav \ + --tokenizer_path /tmp/parakeet_mlx/tokenizer.model 2>&1) + echo "Runner output:" + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Phoebe"; then + echo "Success: 'Phoebe' found in output" + else + echo "Failed: Expected 'Phoebe' not found in output" + exit 1 + fi + echo "::endgroup::" + + test-mlx-voxtral: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-voxtral + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Voxtral requirements" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} pip install mistral_common librosa soundfile datasets + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + ${CONDA_RUN} pip install "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Voxtral" + ${CONDA_RUN} python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --output-dir /tmp/voxtral_mlx \ + --dtype bf16 \ + --quantize-linear int4 + echo "::endgroup::" + + echo "::group::Build Voxtral MLX runner" + ${CONDA_RUN} make voxtral-mlx + echo "::endgroup::" + + echo "::group::Run Voxtral MLX runner" + curl -L https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main/tekken.json -o /tmp/tekken.json + curl -L https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav -o /tmp/test_audio.wav + OUTPUT=$(./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path /tmp/voxtral_mlx/model.pte \ + --tokenizer_path /tmp/tekken.json \ + --audio_path /tmp/test_audio.wav \ + --processor_path /tmp/voxtral_mlx/preprocessor.pte \ + --prompt "What is happening in this audio?" \ + --temperature 0 2>&1) + echo "Runner output:" + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "poem"; then + echo "Success: 'poem' found in output" + else + echo "Failed: Expected 'poem' not found in output" + exit 1 + fi + echo "::endgroup::" + + test-mlx-voxtral-realtime: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-voxtral-realtime + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Voxtral Realtime requirements" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]" safetensors + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Download model" + ${CONDA_RUN} huggingface-cli download mistralai/Voxtral-Mini-4B-Realtime-2602 + MODEL_PATH=$(${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))") + echo "Model path: ${MODEL_PATH}" + echo "::endgroup::" + + echo "::group::Export Voxtral Realtime (streaming)" + ${CONDA_RUN} python -m executorch.examples.models.voxtral_realtime.export_voxtral_rt \ + --model-path "${MODEL_PATH}" \ + --backend mlx \ + --streaming \ + --output-dir /tmp/voxtral_rt_mlx \ + --qlinear-encoder 4w \ + --qlinear 4w \ + --qembedding 8w \ + --qembedding-group-size 128 \ + --export-preprocessor + echo "::endgroup::" + + echo "::group::Build Voxtral Realtime MLX runner" + ${CONDA_RUN} make voxtral_realtime-mlx + echo "::endgroup::" + + echo "::group::Run Voxtral Realtime MLX runner" + curl -L https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav -o /tmp/test_audio.wav + OUTPUT=$(./cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ + --model_path /tmp/voxtral_rt_mlx/model.pte \ + --tokenizer_path "${MODEL_PATH}/tekken.json" \ + --preprocessor_path /tmp/voxtral_rt_mlx/preprocessor.pte \ + --audio_path /tmp/test_audio.wav \ + --streaming 2>&1) + echo "Runner output:" + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Phoebe"; then + echo "Success: 'Phoebe' found in output" + else + echo "Failed: Expected 'Phoebe' not found in output" + exit 1 + fi + echo "::endgroup::" + + test-mlx-whisper: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-whisper + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch and configure MLX build" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Whisper requirements" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} pip install transformers soundfile datasets librosa + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Whisper" + ${CONDA_RUN} python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id "openai/whisper-tiny" \ + --output-dir /tmp/whisper_mlx \ + --dtype bf16 \ + --quantize-linear int4 + echo "::endgroup::" + + echo "::group::Run Whisper inference" + OUTPUT=$( ${CONDA_RUN} python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio 2>&1) + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Mr. Quilter"; then + echo "Success: 'Mr. Quilter' found in transcription" + else + echo "Failed: Expected 'Mr. Quilter' not found in transcription" + exit 1 + fi + echo "::endgroup::" + + + test-mlx-stories110m: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-stories110m + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Llama requirements" + ${CONDA_RUN} sh examples/models/llama/install_requirements.sh + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Build ExecuTorch with MLX delegate" + ${CONDA_RUN} cmake --workflow --preset mlx-release + echo "::endgroup::" + + echo "::group::Build Llama runner with MLX" + pushd examples/models/llama + ${CONDA_RUN} cmake --workflow --preset llama-release + popd + echo "::endgroup::" + + echo "::group::Download stories110M artifacts" + curl -Ls "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" --output stories110M.pt + curl -Ls "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" --output tokenizer.model + echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json + echo "::endgroup::" + + echo "::group::Create tokenizer.bin" + ${CONDA_RUN} python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin + echo "::endgroup::" + + echo "::group::Export stories110M with MLX backend via export_llama_lib" + ${CONDA_RUN} python -m extension.llm.export.export_llm \ + base.checkpoint=stories110M.pt \ + base.params=params.json \ + model.use_kv_cache=true \ + model.dtype_override=fp32 \ + backend.mlx.enabled=true \ + quantization.qmode=4w \ + quantization.group_size=32 \ + export.output_name=/tmp/stories110m_mlx.pte + echo "::endgroup::" + + echo "::group::Run inference with C++ llama runner" + ./cmake-out/examples/models/llama/llama_main \ + --model_path=/tmp/stories110m_mlx.pte \ + --tokenizer_path=tokenizer.bin \ + --prompt="Once upon a time," \ + --temperature=0 \ + --seq_len=10 + echo "::endgroup::" + + test-mlx-llm: + strategy: + fail-fast: false + matrix: + model: + - id: "unsloth/Llama-3.2-1B-Instruct" + name: "llama-1b" + - id: "unsloth/Qwen3-0.6B" + name: "qwen3-0.6b" + - id: "unsloth/gemma-3-1b-it" + name: "gemma3-1b" + use-custom: [false, true] + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }} + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + MODEL_ID="${{ matrix.model.id }}" + MODEL_NAME="${{ matrix.model.name }}" + USE_CUSTOM="${{ matrix.use-custom }}" + + CUSTOM_ARGS="" + if [ "${USE_CUSTOM}" = "true" ]; then + CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache" + fi + + echo "::group::Install ExecuTorch and configure MLX build" + ${CONDA_RUN} python install_executorch.py > /dev/null + ${CONDA_RUN} cmake --preset mlx-release + echo "::endgroup::" + + echo "::group::Install LLM requirements" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export ${MODEL_NAME}" + ${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "${MODEL_ID}" \ + --output /tmp/${MODEL_NAME}.pte \ + --quantize-linear int4 \ + --quantize-embeddings int4 \ + ${CUSTOM_ARGS} + echo "::endgroup::" + + echo "::group::Run ${MODEL_NAME} inference" + OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte /tmp/${MODEL_NAME}.pte \ + --model-id "${MODEL_ID}" \ + --prompt "What is the capital of France?" \ + --max-new-tokens 50 2>&1) + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Paris"; then + echo "Success: 'Paris' found in output" + else + echo "Failed: Expected 'Paris' not found in output" + exit 1 + fi + echo "::endgroup::" diff --git a/Makefile b/Makefile index ad8544210f7..ab3cacf2659 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,10 @@ # # SUPPORTED MODELS: # ----------------- -# - voxtral: Multimodal voice + text model (CPU, CUDA, Metal) -# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal) +# - voxtral: Multimodal voice + text model (CPU, CUDA, Metal, MLX) +# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal, MLX) # - whisper: Speech recognition model (CPU, CUDA, Metal) -# - parakeet: Speech recognition model (CPU, CUDA, Metal) +# - parakeet: Speech recognition model (CPU, CUDA, Metal, MLX) # - sortformer: Speaker diarization model (CPU) # - silero_vad: Voice activity detection model (CPU) # - llama: Text generation model (CPU) @@ -91,16 +91,18 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" @echo " voxtral-cpu - Build Voxtral runner with CPU backend" @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " voxtral-mlx - Build Voxtral runner with MLX backend" @echo " voxtral_realtime-cuda - Build Voxtral Realtime runner with CUDA backend" @echo " voxtral_realtime-cpu - Build Voxtral Realtime runner with CPU backend" @echo " voxtral_realtime-metal - Build Voxtral Realtime runner with Metal backend (macOS only)" + @echo " voxtral_realtime-mlx - Build Voxtral Realtime runner with MLX backend" @echo " whisper-cuda - Build Whisper runner with CUDA backend" @echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)" @echo " whisper-cpu - Build Whisper runner with CPU backend" @@ -109,6 +111,7 @@ help: @echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)" @echo " parakeet-cpu - Build Parakeet runner with CPU backend" @echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)" + @echo " parakeet-mlx - Build Parakeet runner with MLX backend" @echo " sortformer-cpu - Build Sortformer runner with CPU backend" @echo " silero-vad-cpu - Build Silero VAD runner with CPU backend" @echo " llama-cuda - Build Llama runner with CUDA backend" @@ -146,6 +149,15 @@ voxtral-metal: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" +voxtral-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Voxtral runner with MLX..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + whisper-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." cmake --workflow --preset llm-release-cuda @@ -218,6 +230,15 @@ parakeet-metal: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner" +parakeet-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Parakeet runner with MLX..." + cd examples/models/parakeet && cmake --workflow --preset parakeet-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner" + sortformer-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release @@ -254,6 +275,15 @@ voxtral_realtime-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" +voxtral_realtime-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Voxtral Realtime runner with MLX..." + cd examples/models/voxtral_realtime && cmake --workflow --preset voxtral-realtime-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" + silero-vad-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release diff --git a/backends/mlx/examples/__init__.py b/backends/mlx/examples/__init__.py new file mode 100644 index 00000000000..f557ef26c5b --- /dev/null +++ b/backends/mlx/examples/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md new file mode 100644 index 00000000000..7346efcef69 --- /dev/null +++ b/backends/mlx/examples/llm/README.md @@ -0,0 +1,100 @@ +# LLM MLX Example + +This example demonstrates how to export and run LLMs using the MLX delegate for Apple Silicon. + +## Features + +- **Export**: Convert HuggingFace LLMs to ExecuTorch format with MLX delegate +- **Quantization**: Optional INT4/INT8 weight quantization via TorchAO +- **KV Cache**: Efficient KV cache implementation for autoregressive generation +- **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX +- **Pybindings**: Run inference using ExecuTorch Python bindings + +## Requirements + +```bash +pip install transformers optimum-executorch +``` + +## Scripts Overview + +| Script | Description | +|--------|-------------| +| `export_llm_hf` | Export LLMs using optimum-executorch pipeline, with optional custom MLX SDPA/KV cache | +| `run_llm_hf` | Run exported models with token-by-token generation | + +For exporting via the ExecuTorch LLM pipeline (e.g. `examples/models/llama`), use `--mlx` to enable the MLX delegate. + +--- + +## `export_llm_hf` + +Uses optimum-executorch's `CausalLMExportableModule` by default. Optional flags enable custom MLX-optimized components (custom SDPA and/or KV cache). + +```bash +# Baseline export using optimum-executorch +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "unsloth/Llama-3.2-1B-Instruct" \ + --output llama_hf.pte + +# With custom MLX components +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "unsloth/Llama-3.2-1B-Instruct" \ + --output llama_hf_mlx.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache + +# With INT4 quantization +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "unsloth/Llama-3.2-1B-Instruct" \ + --output llama_hf_int4.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache \ + --quantize-linear int4 \ + --quantize-embeddings int4 +``` + +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID | +| `--output` | *(required)* | Output .pte file path | +| `--max-seq-len` | `1024` | Maximum sequence length for KV cache | +| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--quantize-linear` | None | Quantization for linear layers (`int4`, `int8`) | +| `--quantize-embeddings` | None | Quantization for embedding layers (`int4`, `int8`) | +| `--no-tie-word-embeddings` | `False` | Disable re-tying lm_head to embedding after quantization | +| `--use-custom-sdpa` | `False` | Use MLX custom SDPA (`mlx::custom_sdpa`) | +| `--use-custom-kv-cache` | `False` | Use MLX custom KV cache (`mlx::kv_cache_update`) | + +--- + +## `run_llm_hf` + +Run models exported with `export_llm_hf`. Supports both full-prompt prefill (dynamic seq len exports) and token-by-token prefill (fixed seq len exports). + +```bash +python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte llama_hf.pte \ + --model-id unsloth/Llama-3.2-1B-Instruct \ + --prompt "Explain quantum computing in simple terms" +``` + +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--pte` | `llama_hf.pte` | Path to .pte file | +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) | +| `--prompt` | `The quick brown fox` | Input prompt | +| `--max-new-tokens` | `50` | Maximum tokens to generate | + +--- + +## Architecture + +The `export_llm_hf` script uses optimum-executorch's `CausalLMExportableModule` by default. When custom flags are enabled, it uses `TorchExportableModuleWithStaticCache` from HuggingFace transformers, with optional MLX-specific replacements: + +- `--use-custom-sdpa`: Registers `mlx::custom_sdpa` attention implementation +- `--use-custom-kv-cache`: Replaces HF's `StaticCache` with `HFStaticCache` using `mlx::kv_cache_update` diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py new file mode 100644 index 00000000000..f00880ac9cb --- /dev/null +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export LLM model from HuggingFace to MLX backend. + +By default, uses optimum-executorch's CausalLMExportableModule which provides +a proven export pipeline. Optional flags enable custom MLX-optimized components: + + --use-custom-sdpa Register MLX attention (mlx::custom_sdpa) which handles + K/V slicing and causal masking internally. + --use-custom-kv-cache Replace HF's StaticCache with HFStaticCache that uses + mlx::kv_cache_update for optimized cache updates. + +When neither flag is set, the script behaves identically to the original +optimum-executorch export pipeline. + +Usage: + # Baseline (optimum-executorch pipeline): + python -m executorch.backends.mlx.examples.llm.export_llm_hf \\ + --model-id "unsloth/Llama-3.2-1B-Instruct" \\ + --output llama_hf.pte + + # With custom MLX components: + python -m executorch.backends.mlx.examples.llm.export_llm_hf \\ + --model-id "unsloth/Llama-3.2-1B-Instruct" \\ + --output llama_hf_mlx.pte \\ + --use-custom-sdpa \\ + --use-custom-kv-cache + +Requirements: + pip install transformers torch optimum-executorch +""" + +import argparse +import logging +import os +from typing import Optional + +import torch + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def _export_with_optimum( + model_id: str, + output_path: str, + max_seq_len: int, + dtype: str, + quantize_linear: Optional[str], + quantize_embeddings: Optional[str], + no_tie_word_embeddings: bool = False, + linear_group_size: Optional[int] = None, + embeddings_group_size: Optional[int] = None, +) -> None: + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + from optimum.exporters.executorch.tasks.causal_lm import load_causal_lm_model + + dtype_map = {"fp32": "float32", "fp16": "float16", "bf16": "bfloat16"} + dtype_str = dtype_map.get(dtype, "bfloat16") + + logger.info(f"Loading model using optimum-executorch: {model_id}") + exportable = load_causal_lm_model( + model_id, + dtype=dtype_str, + max_seq_len=max_seq_len, + ) + + from executorch.backends.mlx.llm.quantization import apply_quantization + + apply_quantization( + exportable.model, + quantize_linear, + quantize_embeddings, + tie_word_embeddings=getattr( + exportable.model.config, "tie_word_embeddings", False + ) + and not no_tie_word_embeddings, + linear_group_size=linear_group_size, + embeddings_group_size=embeddings_group_size, + ) + + logger.info("Exporting model with torch.export...") + exported_progs = exportable.export() + + logger.info("Delegating to MLX backend...") + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + if len(exported_progs) == 1: + exported_progs = {"forward": next(iter(exported_progs.values()))} + + edge_program = exir.to_edge_transform_and_lower( + exported_progs, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + constant_methods=exportable.metadata, + ) + + logger.info("Exporting to ExecuTorch...") + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + _save_program(executorch_program, output_path) + + +def _export_with_custom_components( + model_id: str, + output_path: str, + max_seq_len: int, + dtype: str, + quantize_linear: Optional[str], + quantize_embeddings: Optional[str], + use_custom_sdpa: bool, + use_custom_kv_cache: bool, + no_tie_word_embeddings: bool = False, + linear_group_size: Optional[int] = None, + embeddings_group_size: Optional[int] = None, +) -> None: + """ + Export using direct HF model with custom MLX components. + + Used when --use-custom-sdpa and/or --use-custom-kv-cache are set. + """ + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + from transformers import AutoModelForCausalLM + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + + torch_dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + torch_dtype = torch_dtype_map.get(dtype, torch.bfloat16) + + if use_custom_sdpa: + from executorch.backends.mlx.llm.hf_attention import register_mlx_attention + + register_mlx_attention() + logger.info("Registered MLX custom SDPA attention") + + attn_implementation = "mlx" if use_custom_sdpa else None + + # Detect sliding window models (e.g., gemma) + sliding_window = None + + logger.info(f"Loading HuggingFace model: {model_id}") + load_kwargs = { + "torch_dtype": torch_dtype, + "low_cpu_mem_usage": True, + } + if attn_implementation: + load_kwargs["attn_implementation"] = attn_implementation + model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) + + # Check if model uses sliding window attention + sliding_window = getattr(model.config, "sliding_window", None) + if sliding_window is not None: + logger.info(f"Model has sliding_window={sliding_window}") + # Cap max_seq_len to sliding window size for cache allocation + effective_cache_len = min(max_seq_len, sliding_window) + logger.info(f" Capping cache length to sliding window: {effective_cache_len}") + else: + effective_cache_len = max_seq_len + + model.generation_config.cache_implementation = "static" + model.generation_config.cache_config = { + "batch_size": 1, + "max_cache_len": effective_cache_len, + } + model.eval() + + # Use HybridCache wrapper for sliding window models (stores cache as .cache), + # StaticCache wrapper for non-sliding-window models (stores cache as .static_cache). + # This matters because the sliding window SDPA closure looks up the cache via + # exportable_module.cache, matching the optimum-executorch convention. + if sliding_window is not None: + from transformers.integrations.executorch import ( + TorchExportableModuleWithHybridCache, + ) + + logger.info("Creating TorchExportableModuleWithHybridCache wrapper...") + exportable = TorchExportableModuleWithHybridCache( + model=model, + batch_size=1, + max_cache_len=effective_cache_len, + ) + else: + logger.info("Creating TorchExportableModuleWithStaticCache wrapper...") + exportable = TorchExportableModuleWithStaticCache( + model=model, + batch_size=1, + max_cache_len=effective_cache_len, + ) + + if use_custom_kv_cache: + if sliding_window is not None: + # Use ring buffer cache for sliding window models + from executorch.backends.mlx.llm.source_transformation import ( + replace_hf_cache_with_mlx_ring_buffer, + ) + + logger.info( + f"Replacing StaticCache with RingBuffer KV cache " + f"(window_size={effective_cache_len})..." + ) + replace_hf_cache_with_mlx_ring_buffer( + exportable, + model.config, + max_batch_size=1, + window_size=effective_cache_len, + dtype=torch_dtype, + ) + + if use_custom_sdpa: + # Re-register attention with sliding window closure + from executorch.backends.mlx.llm.hf_attention import ( + register_mlx_sliding_window_attention, + ) + + register_mlx_sliding_window_attention(exportable) + model.config._attn_implementation = "mlx_sliding_window" + logger.info( + " Registered sliding window attention (mlx_sliding_window)" + ) + + logger.info(" RingBuffer KV cache installed successfully") + else: + # Use standard linear cache for non-sliding-window models + from executorch.backends.mlx.llm.source_transformation import ( + replace_hf_cache_with_mlx, + ) + + logger.info("Replacing HuggingFace StaticCache with HFStaticCache...") + replace_hf_cache_with_mlx( + exportable, + model.config, + max_batch_size=1, + max_cache_len=effective_cache_len, + dtype=torch_dtype, + ) + logger.info(" HFStaticCache installed successfully") + + from executorch.backends.mlx.llm.quantization import apply_quantization + + apply_quantization( + exportable.model, + quantize_linear, + quantize_embeddings, + tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False) + and not no_tie_word_embeddings, + linear_group_size=linear_group_size, + embeddings_group_size=embeddings_group_size, + ) + + logger.info("Exporting model with torch.export...") + seq_length = 3 + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + + seq_len_dim = torch.export.Dim("seq_length_dim", max=effective_cache_len - 1) + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + with torch.no_grad(): + exported_program = torch.export.export( + exportable, + args=(), + kwargs={ + "input_ids": example_input_ids, + "cache_position": example_cache_position, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + logger.info("Export completed successfully") + for sym, constraint in exported_program.range_constraints.items(): + logger.info(f" Range constraint: {sym}: {constraint}") + + logger.info("Delegating to MLX backend...") + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + edge_program = exir.to_edge_transform_and_lower( + {"forward": exported_program}, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + ) + + logger.info("Exporting to ExecuTorch...") + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ) + ) + + _save_program(executorch_program, output_path) + + +def _save_program(executorch_program, output_path: str) -> None: + """Save the ExecuTorch program to disk.""" + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info(f"Saved model to: {output_path}") + logger.info(f"Program size: {len(executorch_program.buffer) / 1024 / 1024:.2f} MB") + + +def export_llama_hf( + model_id: str, + output_path: str, + max_seq_len: int = 1024, + dtype: str = "bf16", + quantize_linear: Optional[str] = None, + quantize_embeddings: Optional[str] = None, + use_custom_sdpa: bool = False, + use_custom_kv_cache: bool = False, + no_tie_word_embeddings: bool = False, + linear_group_size: Optional[int] = None, + embeddings_group_size: Optional[int] = None, +) -> None: + """ + Export a HuggingFace Llama model to ExecuTorch with MLX backend. + + Args: + model_id: HuggingFace model ID + output_path: Path to save the .pte file + max_seq_len: Maximum sequence length for KV cache + dtype: Model dtype ("fp32", "fp16", "bf16") + quantize_linear: Quantization for linear layers ("int4", "int8", or None) + quantize_embeddings: Quantization for embeddings ("int4", "int8", or None) + use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa) + use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update) + """ + if use_custom_sdpa or use_custom_kv_cache: + logger.info( + f"Using custom components: sdpa={use_custom_sdpa}, " + f"kv_cache={use_custom_kv_cache}" + ) + _export_with_custom_components( + model_id=model_id, + output_path=output_path, + max_seq_len=max_seq_len, + dtype=dtype, + quantize_linear=quantize_linear, + quantize_embeddings=quantize_embeddings, + use_custom_sdpa=use_custom_sdpa, + use_custom_kv_cache=use_custom_kv_cache, + no_tie_word_embeddings=no_tie_word_embeddings, + linear_group_size=linear_group_size, + embeddings_group_size=embeddings_group_size, + ) + else: + logger.info("Using optimum-executorch pipeline (no custom components)") + _export_with_optimum( + model_id=model_id, + output_path=output_path, + max_seq_len=max_seq_len, + dtype=dtype, + quantize_linear=quantize_linear, + quantize_embeddings=quantize_embeddings, + no_tie_word_embeddings=no_tie_word_embeddings, + linear_group_size=linear_group_size, + embeddings_group_size=embeddings_group_size, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Export HuggingFace Llama model to MLX backend" + ) + parser.add_argument( + "--model-id", + type=str, + default="unsloth/Llama-3.2-1B-Instruct", + help="HuggingFace model ID", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output .pte file path", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=1024, + help="Maximum sequence length for KV cache", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Model dtype", + ) + from executorch.backends.mlx.llm.quantization import add_quantization_args + + add_quantization_args(parser) + parser.add_argument( + "--use-custom-sdpa", + action="store_true", + default=False, + help="Use MLX custom SDPA (mlx::custom_sdpa) for attention", + ) + parser.add_argument( + "--use-custom-kv-cache", + action="store_true", + default=False, + help="Use MLX custom KV cache (mlx::kv_cache_update)", + ) + + args = parser.parse_args() + + export_llama_hf( + model_id=args.model_id, + output_path=args.output, + max_seq_len=args.max_seq_len, + dtype=args.dtype, + quantize_linear=args.quantize_linear, + quantize_embeddings=args.quantize_embeddings, + use_custom_sdpa=args.use_custom_sdpa, + use_custom_kv_cache=args.use_custom_kv_cache, + no_tie_word_embeddings=args.no_tie_word_embeddings, + linear_group_size=args.linear_group_size, + embeddings_group_size=args.embeddings_group_size, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py new file mode 100644 index 00000000000..ca3d0468114 --- /dev/null +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run exported Llama model (from HuggingFace) using ExecuTorch pybindings. + +This script runs models exported using export_llm_hf.py. It loads the tokenizer +directly from HuggingFace using the same model ID used during export. + +Usage: + python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte llama_hf.pte \ + --model-id unsloth/Llama-3.2-1B-Instruct \ + --prompt "Hello, world!" +""" + +import argparse +import logging +import time + +import torch +from executorch.runtime import Runtime, Verification +from transformers import AutoTokenizer + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def _get_max_input_seq_len(program) -> int: + """Inspect the .pte program metadata to determine the max input_ids seq len. + + Returns the static seq-len dimension of the first input tensor (input_ids). + For models exported with dynamic shapes this will be the upper-bound; for + models exported with a fixed (1,1) shape it will be 1. + """ + meta = program.metadata("forward") + input_ids_info = meta.input_tensor_meta(0) + sizes = input_ids_info.sizes() + # sizes is (batch, seq_len) + return sizes[1] if len(sizes) >= 2 else 1 + + +def run_inference( + pte_path: str, + model_id: str, + prompt: str, + max_new_tokens: int = 50, +) -> str: + """Run inference on the exported HuggingFace model.""" + logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + logger.info(f"Loading model from {pte_path}...") + et_runtime = Runtime.get() + program = et_runtime.load_program(pte_path, verification=Verification.Minimal) + + max_seq_len = _get_max_input_seq_len(program) + logger.info(f"Model input_ids max seq len: {max_seq_len}") + + forward = program.load_method("forward") + + logger.info(f"Encoding prompt: {prompt!r}") + messages = [{"role": "user", "content": prompt}] + formatted_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") + logger.info(f"Input shape: {input_ids.shape}") + + generated_tokens = input_ids[0].tolist() + seq_len = input_ids.shape[1] + + start_time = time.time() + + if max_seq_len == 1: + # Model was exported with fixed (1,1) input — token-by-token prefill + logger.info(f"Running token-by-token prefill ({seq_len} tokens)...") + for i in range(seq_len): + token_input = input_ids[:, i : i + 1] + cache_position = torch.tensor([i], dtype=torch.long) + outputs = forward.execute([token_input, cache_position]) + logits = outputs[0] + else: + # Model was exported with dynamic seq len — full-prompt prefill + logger.info(f"Running full-prompt prefill ({seq_len} tokens)...") + cache_position = torch.arange(seq_len, dtype=torch.long) + outputs = forward.execute([input_ids, cache_position]) + logits = outputs[0] + + prefill_time = time.time() - start_time + logger.info( + f"Prefill time: {prefill_time:.3f}s " + f"({seq_len / prefill_time:.1f} tokens/sec)" + ) + + # Get the next token from the last position + next_token_logits = logits[0, -1, :] + next_token = torch.argmax(next_token_logits).item() + generated_tokens.append(next_token) + + # Decode: generate tokens one at a time + logger.info(f"Generating up to {max_new_tokens} tokens...") + decode_start = time.time() + + for i in range(max_new_tokens - 1): + pos = len(generated_tokens) - 1 + cache_position = torch.tensor([pos], dtype=torch.long) + token_input = torch.tensor([[next_token]], dtype=torch.long) + + outputs = forward.execute([token_input, cache_position]) + logits = outputs[0] + + next_token_logits = logits[0, -1, :] + next_token = torch.argmax(next_token_logits).item() + generated_tokens.append(next_token) + + if next_token == tokenizer.eos_token_id: + logger.info(f"EOS token reached at position {i + 1}") + break + + decode_time = time.time() - decode_start + num_generated = len(generated_tokens) - seq_len + tokens_per_sec = num_generated / decode_time if decode_time > 0 else 0 + + print(f"\nPrefill time: {prefill_time:.3f}s ({seq_len / prefill_time:.1f} tok/s)") + print( + f"Decode time: {decode_time:.3f}s ({num_generated} tokens, {tokens_per_sec:.1f} tok/s)" + ) + + # Decode only the newly generated tokens (not the input prompt) + new_tokens = generated_tokens[seq_len:] + generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + return generated_text + + +def main(): + parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model") + parser.add_argument( + "--pte", + type=str, + default="llama_hf.pte", + help="Path to the .pte file", + ) + parser.add_argument( + "--model-id", + type=str, + default="unsloth/Llama-3.2-1B-Instruct", + help="HuggingFace model ID (used to load tokenizer)", + ) + parser.add_argument( + "--prompt", + type=str, + default="The quick brown fox", + help="Input prompt", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=50, + help="Maximum number of new tokens to generate", + ) + + args = parser.parse_args() + + generated_text = run_inference( + pte_path=args.pte, + model_id=args.model_id, + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + ) + + print("\n" + "=" * 60) + print("Generated text:") + print("=" * 60) + print(generated_text) + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/voxtral/__init__.py b/backends/mlx/examples/voxtral/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/examples/voxtral/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/examples/voxtral/export_voxtral_hf.py b/backends/mlx/examples/voxtral/export_voxtral_hf.py new file mode 100644 index 00000000000..d2ae68f0d30 --- /dev/null +++ b/backends/mlx/examples/voxtral/export_voxtral_hf.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export Voxtral model from HuggingFace using optimum-executorch, delegated to +the MLX backend. + +Voxtral is a multimodal audio-language model (mistralai/Voxtral-Mini-3B-2507). +The exported .pte contains three methods: + - audio_encoder : mel-spectrogram features → audio embeddings + - token_embedding : token ids → text embeddings + - text_decoder : embeddings + cache_position → next-token logits + +Usage: + python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --model-id "mistralai/Voxtral-Mini-3B-2507" \ + --output voxtral_mlx.pte + +Requirements: + pip install transformers torch optimum-executorch mistral-common librosa +""" + +import argparse +import logging +import os +from typing import Optional + +import torch + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def export_preprocessor( + output_path: str, + feature_size: int = 128, + max_audio_len: int = 300, +) -> None: + """ + Export the Voxtral audio preprocessor (mel spectrogram) to MLX. + + Args: + output_path: Path to save the preprocessor .pte file + feature_size: Mel spectrogram feature dimension (128 for Voxtral) + max_audio_len: Maximum audio length in seconds + """ + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + from executorch.extension.audio.mel_spectrogram import WhisperAudioProcessor + from torch.export import Dim + + logger.info("Exporting audio preprocessor with MLX backend...") + + model = WhisperAudioProcessor( + feature_size=feature_size, + max_audio_len=max_audio_len, + stack_output=True, + ) + + audio_tensor = torch.randn(93680) + shapes_collection = torch.export.ShapesCollection() + max_n_chunks = int(model.max_audio_len * model.n_samples) + shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)} + + with torch.no_grad(), torch.fx.experimental._config.patch( + backed_size_oblivious=True + ): + ep = torch.export.export( + model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True + ) + + edge_program = exir.to_edge_transform_and_lower( + ep, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info(f"Saved preprocessor to: {output_path}") + logger.info( + f"Preprocessor size: {len(executorch_program.buffer) / 1024 / 1024:.2f} MB" + ) + + +def export_voxtral_hf( + model_id: str, + output_dir: str, + max_seq_len: int = 1024, + dtype: str = "bf16", + quantize_linear: Optional[str] = None, + quantize_embeddings: Optional[str] = None, + linear_group_size: Optional[int] = None, + embeddings_group_size: Optional[int] = None, + max_audio_len: int = 300, +) -> None: + """ + Export a HuggingFace Voxtral model using optimum-executorch, delegated to + the MLX backend. Outputs two files: + - model.pte: the main model (audio_encoder, token_embedding, text_decoder) + - preprocessor.pte: mel spectrogram preprocessor for raw audio + + Args: + model_id: HuggingFace model ID (e.g., "mistralai/Voxtral-Mini-3B-2507") + output_dir: Directory to save the .pte files + max_seq_len: Maximum sequence length for KV cache + dtype: Model dtype ("fp32", "fp16", "bf16") + quantize_linear: Quantization for linear layers ("int4", "int8", or None) + quantize_embeddings: Quantization for embedding layers ("int4", "int8", or None) + linear_group_size: Group size for linear quantization (default: 32 for int4, 128 for int8) + embeddings_group_size: Group size for embedding quantization (default: 32 for int4, 128 for int8) + max_audio_len: Maximum audio length in seconds for preprocessor + """ + from optimum.exporters.executorch.tasks.multimodal_text_to_text import ( + load_multimodal_text_to_text_model, + ) + + os.makedirs(output_dir, exist_ok=True) + + # --- Export preprocessor --- + export_preprocessor( + output_path=os.path.join(output_dir, "preprocessor.pte"), + max_audio_len=max_audio_len, + ) + + # --- Export model --- + logger.info(f"Loading model using optimum-executorch: {model_id}") + + dtype_map = {"fp32": "float32", "fp16": "float16", "bf16": "bfloat16"} + dtype_str = dtype_map.get(dtype, "bfloat16") + + exportable = load_multimodal_text_to_text_model( + model_id, + dtype=dtype_str, + max_seq_len=max_seq_len, + ) + + # Apply quantization if requested + from executorch.backends.mlx.llm.quantization import apply_quantization + + apply_quantization( + exportable.model, + quantize_linear, + quantize_embeddings, + linear_group_size=linear_group_size, + embeddings_group_size=embeddings_group_size, + ) + + logger.info("Exporting model with torch.export...") + exported_progs = exportable.export() + logger.info(f"Exported methods: {list(exported_progs.keys())}") + + logger.info("Delegating to MLX backend...") + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + edge_program = exir.to_edge_transform_and_lower( + exported_progs, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + constant_methods=exportable.metadata, + ) + + logger.info("Exporting to ExecuTorch...") + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + model_path = os.path.join(output_dir, "model.pte") + with open(model_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info(f"Saved model to: {model_path}") + logger.info(f"Model size: {len(executorch_program.buffer) / 1024 / 1024:.2f} MB") + + +def main(): + parser = argparse.ArgumentParser( + description="Export HuggingFace Voxtral model using optimum-executorch to MLX" + ) + parser.add_argument( + "--model-id", + type=str, + default="mistralai/Voxtral-Mini-3B-2507", + help="HuggingFace model ID", + ) + parser.add_argument( + "--output-dir", + type=str, + default="voxtral_mlx", + help="Output directory for model.pte and preprocessor.pte", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=1024, + help="Maximum sequence length for KV cache", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Model dtype", + ) + from executorch.backends.mlx.llm.quantization import add_quantization_args + + add_quantization_args(parser) + parser.add_argument( + "--max-audio-len", + type=int, + default=300, + help="Maximum audio length in seconds for preprocessor", + ) + + args = parser.parse_args() + + export_voxtral_hf( + model_id=args.model_id, + output_dir=args.output_dir, + max_seq_len=args.max_seq_len, + dtype=args.dtype, + quantize_linear=args.quantize_linear, + quantize_embeddings=args.quantize_embeddings, + linear_group_size=args.linear_group_size, + embeddings_group_size=args.embeddings_group_size, + max_audio_len=args.max_audio_len, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/whisper/README.md b/backends/mlx/examples/whisper/README.md new file mode 100644 index 00000000000..ed7333d881d --- /dev/null +++ b/backends/mlx/examples/whisper/README.md @@ -0,0 +1,82 @@ +# Whisper MLX Examples + +Export and run [OpenAI Whisper](https://huggingface.co/openai/whisper-tiny) speech-to-text models on the MLX backend. + +## Scripts + +| Script | Description | +|---|---| +| `export_whisper.py` | Export with custom KV cache wrapper (3 separate `.pte` files) | +| `run_whisper.py` | Run models exported with `export_whisper` | + +## Quick start + +```bash +# Export +python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id openai/whisper-tiny \ + --output-dir /tmp/whisper_mlx + +# Run +python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio +``` + + +## export_whisper.py + +Custom export that splits the model into three programs: + +- **encoder.pte** — audio features → encoder hidden states +- **cross_kv.pte** — encoder hidden states → per-layer cross-attention K/V +- **decoder.pte** — token-by-token generation with self-attention KV cache + +```bash +python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id openai/whisper-tiny \ + --output-dir /tmp/whisper_mlx \ + --quantize-linear int4 +``` + +| Option | Default | Description | +|---|---|---| +| `--model-id` | `openai/whisper-tiny` | HuggingFace model ID | +| `--output-dir` | `whisper_mlx` | Output directory for `.pte` files | +| `--max-decoder-seq-len` | `256` | Maximum decoder sequence length | +| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--quantize-linear` | `None` | Quantize linear layers (`int4`, `int8`) | +| `--quantize-embeddings` | `None` | Quantize embedding layers (`int4`, `int8`) | +| `--linear-group-size` | `None` | Group size for linear quantization (32, 64, 128; default: 32 for int4, 128 for int8) | +| `--embeddings-group-size` | `None` | Group size for embedding quantization (32, 64, 128; default: 32 for int4, 128 for int8) | + +## run_whisper.py + +Run models exported with `export_whisper.py`. Loads encoder, cross_kv, and +decoder programs from a directory. + +```bash +python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio +``` + +| Option | Default | Description | +|---|---|---| +| `--model-dir` | `/tmp/whisper_mlx` | Directory containing exported `.pte` files | +| `--model-id` | `openai/whisper-tiny` | HuggingFace model ID (used to load processor) | +| `--audio-file` | `None` | Path to audio file (WAV, MP3, etc.) | +| `--use-sample-audio` | `False` | Use sample audio from HuggingFace datasets | +| `--max-new-tokens` | `256` | Maximum tokens to generate | +| `--language` | `en` | Language code | +| `--task` | `transcribe` | `transcribe` or `translate` | +| `--dtype` | `bf16` | Input dtype (must match export dtype) | + +## Requirements + +After installing ExecuTorch, install optimum-executorch: + +```bash +pip install optimum-executorch +pip install transformers torchao soundfile datasets +``` diff --git a/backends/mlx/examples/whisper/__init__.py b/backends/mlx/examples/whisper/__init__.py new file mode 100644 index 00000000000..0adc14c3f18 --- /dev/null +++ b/backends/mlx/examples/whisper/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# diff --git a/backends/mlx/examples/whisper/args.py b/backends/mlx/examples/whisper/args.py new file mode 100644 index 00000000000..82ed7371926 --- /dev/null +++ b/backends/mlx/examples/whisper/args.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared argument definitions for Whisper export and run scripts. +""" + +import argparse +import logging +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def add_export_args(parser: argparse.ArgumentParser) -> None: + """Add common export arguments for Whisper scripts.""" + parser.add_argument( + "--model-id", + type=str, + default="openai/whisper-tiny", + help="HuggingFace model ID", + ) + parser.add_argument( + "--max-decoder-seq-len", + type=int, + default=256, + help="Maximum decoder sequence length", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Model dtype", + ) + from executorch.backends.mlx.llm.quantization import add_quantization_args + + add_quantization_args(parser) + + +def add_run_args(parser: argparse.ArgumentParser) -> None: + """Add common runtime arguments for Whisper scripts.""" + parser.add_argument( + "--model-id", + type=str, + default="openai/whisper-tiny", + help="HuggingFace model ID (used to load processor)", + ) + parser.add_argument( + "--audio-file", + type=str, + default=None, + help="Path to audio file (WAV, MP3, etc.)", + ) + parser.add_argument( + "--use-sample-audio", + action="store_true", + help="Use sample audio from HuggingFace datasets", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=256, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--language", + type=str, + default="en", + help="Language code for transcription", + ) + parser.add_argument( + "--task", + type=str, + choices=["transcribe", "translate"], + default="transcribe", + help="Task: transcribe or translate", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Input dtype (must match the dtype used during export)", + ) + + +def load_audio( + audio_path: Optional[str], + use_sample_audio: bool, + processor, +) -> torch.Tensor: + """Load and preprocess audio input. + + Returns: + input_features: [1, n_mels, n_frames] tensor + """ + if use_sample_audio: + logger.info("Loading sample audio from HuggingFace datasets...") + try: + from datasets import load_dataset + except ImportError: + logger.error("datasets not installed. Run: pip install datasets") + raise + + dataset = load_dataset( + "distil-whisper/librispeech_long", + "clean", + split="validation", + ) + sample = dataset[0]["audio"] + audio_array = sample["array"] + sampling_rate = sample["sampling_rate"] + else: + if audio_path is None: + raise ValueError( + "Either --audio-file or --use-sample-audio must be provided" + ) + + logger.info(f"Loading audio from: {audio_path}") + try: + import soundfile as sf + except ImportError: + logger.error("soundfile not installed. Run: pip install soundfile") + raise + + audio_array, sampling_rate = sf.read(audio_path) + + input_features = processor( + audio_array, + return_tensors="pt", + truncation=False, + sampling_rate=sampling_rate, + ).input_features + + # Truncate to 30 seconds (3000 frames at 100 frames/sec) + max_frames = 3000 + if input_features.shape[2] > max_frames: + logger.info( + f"Truncating audio from {input_features.shape[2]} to {max_frames} frames" + ) + input_features = input_features[:, :, :max_frames].contiguous() + + return input_features diff --git a/backends/mlx/examples/whisper/export_whisper.py b/backends/mlx/examples/whisper/export_whisper.py new file mode 100644 index 00000000000..03123c08935 --- /dev/null +++ b/backends/mlx/examples/whisper/export_whisper.py @@ -0,0 +1,639 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export Whisper model to MLX delegate using ExecuTorch. + +Exports three separate programs: +- encoder.pte: Audio features → encoder hidden states +- cross_kv.pte: Encoder hidden states → per-layer cross-attention K/V +- decoder.pte: Token-by-token generation with self-attention KV cache + +The decoder uses: +- llama.update_cache for self-attention KV cache updates +- Pre-computed cross-attention K/V passed as inputs + +Usage: + python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id "openai/whisper-tiny" \ + --output-dir /tmp/whisper_mlx \ + --quantize-linear int4 + +Requirements: + pip install transformers torchao +""" + +import argparse +import logging +import os +from typing import Optional, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import WhisperForConditionalGeneration + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Import shared KV cache module +from executorch.backends.mlx.llm.cache import KVCache +from executorch.backends.mlx.passes import get_default_passes + +# Import custom ops +from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +class WhisperEncoderExportable(nn.Module): + """ + Wrapper around Whisper's encoder for export. + + forward(input_features) -> encoder_hidden_states + """ + + def __init__(self, encoder: nn.Module): + super().__init__() + self.encoder = encoder + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + return self.encoder(input_features=input_features).last_hidden_state + + +class WhisperSelfAttentionWithCache(nn.Module): + """ + Whisper self-attention layer with static KV cache. + + Uses llama.update_cache pattern for cache updates. + """ + + def __init__( + self, + attn_module: nn.Module, + max_cache_len: int, + dtype: torch.dtype, + ): + super().__init__() + self.q_proj = attn_module.q_proj + self.k_proj = attn_module.k_proj + self.v_proj = attn_module.v_proj + self.out_proj = attn_module.out_proj + + self.num_heads = attn_module.num_heads + self.head_dim = attn_module.head_dim + self.scale = self.head_dim**-0.5 + self.max_cache_len = max_cache_len + + # Initialize KV cache module + self.kv_cache = KVCache( + max_batch_size=1, + max_context_length=max_cache_len, + n_heads=self.num_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, # [B, T, H*D] + pos_int: int, # Position as SymInt + ) -> torch.Tensor: + B, T, _ = hidden_states.shape + H, D = self.num_heads, self.head_dim + + # Linear projections + q = self.q_proj(hidden_states) # [B, T, H*D] + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to [B, H, T, D] + q = q.view(B, T, H, D).transpose(1, 2) + k = k.view(B, T, H, D).transpose(1, 2) + v = v.view(B, T, H, D).transpose(1, 2) + + # Update KV cache + k_cache, v_cache = self.kv_cache.update(pos_int, k, v) + + # Explicit windowing: slice cache to valid positions + end_pos = pos_int + T + k_win = k_cache[:, :, :end_pos, :] + v_win = v_cache[:, :, :end_pos, :] + + # SDPA with causal mask + attn_out = F.scaled_dot_product_attention( + q, k_win, v_win, attn_mask=None, is_causal=True, scale=self.scale + ) + + # Reshape back + attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, H * D) + return self.out_proj(attn_out) + + +class WhisperCrossAttention(nn.Module): + """ + Whisper cross-attention layer. + + K/V are pre-computed from encoder output and passed as inputs. + No cache update needed - just uses the pre-computed K/V directly. + """ + + def __init__(self, attn_module: nn.Module): + super().__init__() + self.q_proj = attn_module.q_proj + self.out_proj = attn_module.out_proj + + self.num_heads = attn_module.num_heads + self.head_dim = attn_module.head_dim + self.scale = self.head_dim**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, # [B, T_dec, H*D] + cross_k: torch.Tensor, # [B, H, T_enc, D] - pre-computed + cross_v: torch.Tensor, # [B, H, T_enc, D] - pre-computed + ) -> torch.Tensor: + B, T, _ = hidden_states.shape + H, D = self.num_heads, self.head_dim + + # Query projection + q = self.q_proj(hidden_states) + q = q.view(B, T, H, D).transpose(1, 2) # [B, H, T_dec, D] + + # SDPA with pre-computed K/V (no causal mask for cross-attention) + attn_out = F.scaled_dot_product_attention( + q, cross_k, cross_v, attn_mask=None, is_causal=False, scale=self.scale + ) + + # Reshape back + attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, H * D) + return self.out_proj(attn_out) + + +class WhisperDecoderLayerWithCache(nn.Module): + """ + Wrapper for a single Whisper decoder layer with KV cache. + """ + + def __init__( + self, + layer: nn.Module, + max_cache_len: int, + dtype: torch.dtype, + ): + super().__init__() + # Self-attention with cache + self.self_attn = WhisperSelfAttentionWithCache( + layer.self_attn, max_cache_len, dtype + ) + self.self_attn_layer_norm = layer.self_attn_layer_norm + + # Cross-attention (K/V passed as inputs) + self.encoder_attn = WhisperCrossAttention(layer.encoder_attn) + self.encoder_attn_layer_norm = layer.encoder_attn_layer_norm + + # FFN + self.fc1 = layer.fc1 + self.fc2 = layer.fc2 + self.final_layer_norm = layer.final_layer_norm + self.activation_fn = layer.activation_fn + + def forward( + self, + hidden_states: torch.Tensor, + pos_int: int, + cross_k: torch.Tensor, + cross_v: torch.Tensor, + ) -> torch.Tensor: + # Self-attention + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states, pos_int) + hidden_states = residual + hidden_states + + # Cross-attention + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn(hidden_states, cross_k, cross_v) + hidden_states = residual + hidden_states + + # FFN + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperDecoderWithCache(nn.Module): + """ + Whisper decoder wrapper with static KV cache. + + Takes: + - decoder_input_ids: [B, T_dec] token IDs + - cache_position: [1] tensor with start position + - cross_k_tuple: tuple of num_layers tensors [B, H, T_enc, D] - pre-computed cross K + - cross_v_tuple: tuple of num_layers tensors [B, H, T_enc, D] - pre-computed cross V + + Returns: + - logits: [B, T_dec, vocab_size] + """ + + def __init__( + self, + model: "WhisperForConditionalGeneration", + max_decoder_seq_len: int, + ): + super().__init__() + + decoder = model.get_decoder() + dtype = decoder.embed_tokens.weight.dtype + + self.embed_tokens = decoder.embed_tokens + self.embed_positions = decoder.embed_positions + self.layer_norm = decoder.layer_norm + self.proj_out = model.proj_out + + # Wrap decoder layers with cache + self.layers = nn.ModuleList( + [ + WhisperDecoderLayerWithCache(layer, max_decoder_seq_len, dtype) + for layer in decoder.layers + ] + ) + + self.num_layers = len(self.layers) + self.max_decoder_seq_len = max_decoder_seq_len + + def forward( + self, + decoder_input_ids: torch.Tensor, # [B, T_dec] + cache_position: torch.Tensor, # [1] tensor + cross_k_tuple: Tuple[torch.Tensor, ...], # num_layers x [B, H, T_enc, D] + cross_v_tuple: Tuple[torch.Tensor, ...], # num_layers x [B, H, T_enc, D] + ) -> torch.Tensor: + B, T = decoder_input_ids.shape + + # Get position as SymInt + torch._check(cache_position.numel() == 1) + pos_int = cache_position.item() + torch._check(pos_int >= 0) + torch._check(pos_int + T <= self.max_decoder_seq_len) + + # Token + positional embeddings + # Whisper uses absolute positions [pos_int, pos_int + T) + # Use F.embedding to ensure proper lowering (not aten.index.Tensor) + positions = torch.arange( + pos_int, pos_int + T, device=decoder_input_ids.device, dtype=torch.long + ) + hidden_states = self.embed_tokens(decoder_input_ids) + pos_embed = F.embedding(positions, self.embed_positions.weight) + hidden_states = hidden_states + pos_embed + + # Decoder layers + for i, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, pos_int, cross_k_tuple[i], cross_v_tuple[i] + ) + + hidden_states = self.layer_norm(hidden_states) + logits = self.proj_out(hidden_states) + return logits + + +class WhisperCrossKVProjection(nn.Module): + """ + Compute cross-attention K/V projections from encoder hidden states. + + forward(encoder_hidden_states) -> (k_tuple, v_tuple) + """ + + def __init__(self, model: "WhisperForConditionalGeneration"): + super().__init__() + decoder = model.get_decoder() + + # Store K/V projections for each layer + self.k_projs = nn.ModuleList() + self.v_projs = nn.ModuleList() + self.num_heads_list = [] + self.head_dim_list = [] + + for layer in decoder.layers: + self.k_projs.append(layer.encoder_attn.k_proj) + self.v_projs.append(layer.encoder_attn.v_proj) + self.num_heads_list.append(layer.encoder_attn.num_heads) + self.head_dim_list.append(layer.encoder_attn.head_dim) + + def forward( + self, encoder_hidden_states: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: + """ + Returns: + (k_tuple, v_tuple) where each is a tuple of num_layers tensors, + each with shape [B, H, T_enc, D] + """ + B, T_enc, _ = encoder_hidden_states.shape + + k_list = [] + v_list = [] + + for i, (k_proj, v_proj) in enumerate(zip(self.k_projs, self.v_projs)): + H = self.num_heads_list[i] + D = self.head_dim_list[i] + + k = k_proj(encoder_hidden_states) # [B, T_enc, H*D] + v = v_proj(encoder_hidden_states) + + # Reshape to [B, H, T_enc, D] + k = k.view(B, T_enc, H, D).transpose(1, 2) + v = v.view(B, T_enc, H, D).transpose(1, 2) + + k_list.append(k) + v_list.append(v) + + return tuple(k_list), tuple(v_list) + + +def export_whisper_to_mlx( + model_id: str, + output_dir: str, + max_decoder_seq_len: int = 256, + dtype: str = "bf16", + quantize_linear: Optional[str] = None, + quantize_embeddings: Optional[str] = None, + linear_group_size: Optional[int] = None, + embeddings_group_size: Optional[int] = None, +) -> None: + """ + Export Whisper model components to MLX delegate. + + Exports: + - encoder.pte: Audio encoder + - cross_kv.pte: Cross-attention K/V projection + - decoder.pte: Decoder with self-attention KV cache + + Args: + model_id: HuggingFace model ID + output_dir: Directory to save .pte files + max_decoder_seq_len: Maximum decoder sequence length + dtype: Model dtype ("fp32", "fp16", "bf16") + quantize_linear: Quantization method for linear layers ("int4", "int8", or None) + quantize_embeddings: Quantization method for embedding layers ("int4", "int8", or None) + linear_group_size: Group size for linear quantization. Defaults to 32 for int4, 128 for int8. + embeddings_group_size: Group size for embedding quantization. Defaults to 32 for int4, 128 for int8. + """ + from transformers import AutoProcessor, WhisperForConditionalGeneration + + # Map dtype string to torch dtype + dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + torch_dtype = dtype_map.get(dtype, torch.float32) + + logger.info(f"Loading model: {model_id} (dtype={dtype})") + processor = AutoProcessor.from_pretrained(model_id) + model = WhisperForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype + ) + model.eval() + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Get feature extractor info + fe = processor.feature_extractor + batch_size = 1 + + # Create example encoder input + encoder_input = torch.zeros( + (batch_size, fe.feature_size, fe.nb_max_frames), dtype=torch_dtype + ) + + # Create wrappers + logger.info("Creating model wrappers...") + encoder_wrapper = WhisperEncoderExportable(model.get_encoder()).eval() + cross_kv_wrapper = WhisperCrossKVProjection(model).eval() + + # Get encoder output shape for decoder + with torch.no_grad(): + encoder_hidden_states = encoder_wrapper(encoder_input) + encoder_seq_len = encoder_hidden_states.shape[1] + + decoder_wrapper = WhisperDecoderWithCache(model, max_decoder_seq_len).eval() + + # Apply quantization if requested + # Whisper has 3 separate wrappers to quantize, and embed_positions must be + # excluded from embedding quantization (accessed via indexing). + if quantize_linear or quantize_embeddings: + from executorch.backends.mlx.llm.quantization import _default_group_size + + try: + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + from torchao.quantization.quantize_.workflows import ( + IntxChooseQParamsAlgorithm, + ) + + qparams_algorithm = IntxChooseQParamsAlgorithm.HQQ_SCALE_ONLY + + if quantize_embeddings: + embed_dtype = ( + torch.int4 if quantize_embeddings == "int4" else torch.int8 + ) + embed_gs = embeddings_group_size or _default_group_size( + quantize_embeddings + ) + logger.info( + f"Quantizing embedding layers with {quantize_embeddings} " + f"(group size {embed_gs})..." + ) + quantize_( + decoder_wrapper, + IntxWeightOnlyConfig( + weight_dtype=embed_dtype, + granularity=PerGroup(embed_gs), + intx_choose_qparams_algorithm=qparams_algorithm, + ), + lambda m, fqn: isinstance(m, nn.Embedding) + and "embed_tokens" in fqn, + ) + + if quantize_linear: + linear_dtype = torch.int4 if quantize_linear == "int4" else torch.int8 + linear_gs = linear_group_size or _default_group_size(quantize_linear) + config = IntxWeightOnlyConfig( + weight_dtype=linear_dtype, + granularity=PerGroup(linear_gs), + intx_choose_qparams_algorithm=qparams_algorithm, + ) + logger.info( + f"Quantizing linear layers with {quantize_linear} " + f"(group size {linear_gs})..." + ) + for module in [encoder_wrapper, cross_kv_wrapper, decoder_wrapper]: + quantize_( + module, + config, + filter_fn=lambda m, fqn: isinstance(m, nn.Linear), + ) + + logger.info("Applied quantization successfully") + except ImportError: + logger.error("TorchAO not installed. Run: pip install torchao") + raise + + logger.info("Exporting encoder...") + + with torch.no_grad(): + encoder_ep = torch.export.export( + encoder_wrapper, (encoder_input,), dynamic_shapes=None, strict=True + ) + encoder_ep = encoder_ep.run_decompositions({}) + + _save_to_pte(encoder_ep, os.path.join(output_dir, "encoder.pte"), "encoder") + + logger.info("Exporting cross-KV projection...") + + with torch.no_grad(): + example_cross_k, example_cross_v = cross_kv_wrapper(encoder_hidden_states) + example_cross_k = tuple(k.contiguous() for k in example_cross_k) + example_cross_v = tuple(v.contiguous() for v in example_cross_v) + + cross_kv_ep = torch.export.export( + cross_kv_wrapper, + (encoder_hidden_states,), + dynamic_shapes=None, + strict=True, + ) + cross_kv_ep = cross_kv_ep.run_decompositions({}) + + _save_to_pte(cross_kv_ep, os.path.join(output_dir, "cross_kv.pte"), "cross_kv") + + logger.info("Exporting decoder...") + + # Example inputs for decoder + start_id = getattr(model.config, "decoder_start_token_id", 0) + decoder_input_ids = torch.tensor([[start_id]], dtype=torch.long) + cache_position = torch.tensor([0], dtype=torch.long) + + with torch.no_grad(): + # Build dynamic shapes for all inputs + # decoder_input_ids: [B, T_dec] - T_dec is dynamic + # cache_position: [1] - static + # cross_k_tuple: tuple of num_layers tensors - static + # cross_v_tuple: tuple of num_layers tensors - static + seq_dim = torch.export.Dim.AUTO(min=1, max=max_decoder_seq_len) + num_layers = decoder_wrapper.num_layers + dynamic_shapes = ( + {1: seq_dim}, # decoder_input_ids + None, # cache_position + tuple(None for _ in range(num_layers)), # cross_k_tuple + tuple(None for _ in range(num_layers)), # cross_v_tuple + ) + + decoder_ep = torch.export.export( + decoder_wrapper, + (decoder_input_ids, cache_position, example_cross_k, example_cross_v), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + decoder_ep = decoder_ep.run_decompositions({}) + + _save_to_pte(decoder_ep, os.path.join(output_dir, "decoder.pte"), "decoder") + + # Save processor for inference + processor_path = os.path.join(output_dir, "processor") + processor.save_pretrained(processor_path) + logger.info(f"Saved processor to: {processor_path}") + + # Save metadata + metadata = { + "model_id": model_id, + "dtype": dtype, + "quantize_linear": quantize_linear, + "quantize_embeddings": quantize_embeddings, + "max_decoder_seq_len": max_decoder_seq_len, + "encoder_seq_len": encoder_seq_len, + "num_decoder_layers": decoder_wrapper.num_layers, + } + import json + + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2) + logger.info(f"Saved metadata to: {os.path.join(output_dir, 'metadata.json')}") + + +def _save_to_pte(ep, output_path: str, name: str) -> None: + """Lower and save an ExportedProgram to a .pte file.""" + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + + # Allow repeat_interleave and sdpa ops + edge_config = EdgeCompileConfig( + _core_aten_ops_exception_list=[ + torch.ops.aten.repeat_interleave.self_int, + torch.ops.aten.scaled_dot_product_attention.default, + ] + ) + + edge_program = exir.to_edge_transform_and_lower( + ep, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + ) + + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + + with open(output_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info( + f"Saved {name} to: {output_path} " + f"({len(executorch_program.buffer) / 1024 / 1024:.2f} MB)" + ) + + +def main(): + parser = argparse.ArgumentParser(description="Export Whisper model to MLX delegate") + from executorch.backends.mlx.examples.whisper.args import add_export_args + + add_export_args(parser) + parser.add_argument( + "--output-dir", + type=str, + default="whisper_mlx", + help="Output directory for .pte files", + ) + + args = parser.parse_args() + + export_whisper_to_mlx( + model_id=args.model_id, + output_dir=args.output_dir, + max_decoder_seq_len=args.max_decoder_seq_len, + dtype=args.dtype, + quantize_linear=args.quantize_linear, + quantize_embeddings=args.quantize_embeddings, + linear_group_size=args.linear_group_size, + embeddings_group_size=args.embeddings_group_size, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/whisper/run_whisper.py b/backends/mlx/examples/whisper/run_whisper.py new file mode 100644 index 00000000000..e20e7db6e2b --- /dev/null +++ b/backends/mlx/examples/whisper/run_whisper.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run exported Whisper model using ExecuTorch pybindings. + +This script loads the three exported programs (encoder, cross_kv, decoder) +and performs speech-to-text transcription. + +Usage: + python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --audio-file /path/to/audio.wav + + # Or use sample audio from HuggingFace: + python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio + +Requirements: + pip install transformers soundfile datasets +""" + +import argparse +import json +import logging +import os +import time +from typing import List, Optional + +import torch + +from executorch.backends.mlx.examples.whisper.args import load_audio + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def run_whisper_inference( # noqa: C901 + model_dir: str, + audio_path: Optional[str] = None, + use_sample_audio: bool = False, + max_new_tokens: int = 256, + language: str = "en", + task: str = "transcribe", + dtype: str = "bf16", +) -> str: + """ + Run Whisper inference using exported ExecuTorch models. + + Args: + model_dir: Directory containing encoder.pte, cross_kv.pte, decoder.pte + audio_path: Path to audio file (WAV, MP3, etc.) + use_sample_audio: If True, use sample audio from HuggingFace + max_new_tokens: Maximum number of tokens to generate + language: Language code for transcription + task: "transcribe" or "translate" + dtype: Input dtype (must match the dtype used during export) + + Returns: + Transcribed text + """ + from executorch.runtime import Runtime, Verification + from transformers import AutoProcessor + + # Load metadata (for structural info like num_decoder_layers) + metadata_path = os.path.join(model_dir, "metadata.json") + with open(metadata_path, "r") as f: + metadata = json.load(f) + + num_layers = metadata["num_decoder_layers"] + + # Load processor + processor_path = os.path.join(model_dir, "processor") + logger.info(f"Loading processor from: {processor_path}") + processor = AutoProcessor.from_pretrained(processor_path) + + # Load audio + input_features = load_audio(audio_path, use_sample_audio, processor) + logger.info(f"Input features shape: {input_features.shape}") + + # Cast to model dtype + dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + model_dtype = dtype_map.get(dtype, torch.float32) + input_features = input_features.to(model_dtype) + logger.info(f"Input dtype: {input_features.dtype}") + + # Load ExecuTorch programs + et_runtime = Runtime.get() + + logger.info("Loading encoder...") + encoder_path = os.path.join(model_dir, "encoder.pte") + encoder_program = et_runtime.load_program( + encoder_path, verification=Verification.Minimal + ) + encoder_forward = encoder_program.load_method("forward") + + logger.info("Loading cross_kv...") + cross_kv_path = os.path.join(model_dir, "cross_kv.pte") + cross_kv_program = et_runtime.load_program( + cross_kv_path, verification=Verification.Minimal + ) + cross_kv_forward = cross_kv_program.load_method("forward") + + logger.info("Loading decoder...") + decoder_path = os.path.join(model_dir, "decoder.pte") + decoder_program = et_runtime.load_program( + decoder_path, verification=Verification.Minimal + ) + decoder_forward = decoder_program.load_method("forward") + + logger.info("Running encoder...") + overall_start = time.time() + start_time = time.time() + + encoder_outputs = encoder_forward.execute([input_features]) + encoder_hidden_states = encoder_outputs[0] + + encoder_time = time.time() - start_time + logger.info(f"Encoder time: {encoder_time:.3f}s") + logger.info(f"Encoder output shape: {encoder_hidden_states.shape}") + + logger.info("Computing cross-attention K/V...") + start_time = time.time() + + cross_kv_outputs = cross_kv_forward.execute([encoder_hidden_states]) + # Output is (k_tuple, v_tuple) flattened: [k0, k1, ..., v0, v1, ...] + # Each k_i, v_i has shape [B, H, T_enc, D] + cross_k_tuple = tuple(cross_kv_outputs[:num_layers]) + cross_v_tuple = tuple(cross_kv_outputs[num_layers:]) + + cross_kv_time = time.time() - start_time + logger.info(f"Cross-KV time: {cross_kv_time:.3f}s") + logger.info(f"Cross K/V: {num_layers} layers, each shape {cross_k_tuple[0].shape}") + + # Get forced decoder IDs for language/task + forced_decoder_ids = processor.get_decoder_prompt_ids( + language=language, + task=task, + ) + # Build forced tokens dict: position -> token_id + forced_tokens_dict = {} + if forced_decoder_ids is not None: + for item in forced_decoder_ids: + if isinstance(item, (list, tuple)) and len(item) == 2: + pos, tok_id = item + if tok_id is not None: + forced_tokens_dict[pos] = int(tok_id) + + # Start with decoder_start_token_id (start-of-transcript) + # Get from processor.tokenizer if available, otherwise use common ID + try: + sot_id = processor.tokenizer.convert_tokens_to_ids("<|startoftranscript|>") + except Exception: + sot_id = 50258 # Common Whisper SOT token ID + + # Also get EOS token ID + try: + eos_id = processor.tokenizer.convert_tokens_to_ids("<|endoftext|>") + except Exception: + eos_id = 50257 # Common Whisper EOS token ID + + generated_tokens: List[int] = [sot_id] + + logger.info(f"Generating up to {max_new_tokens} tokens...") + decode_start = time.time() + + # Initial decoder input + decoder_input_ids = torch.tensor([[sot_id]], dtype=torch.long) + cache_position = torch.tensor([0], dtype=torch.long) + + # Prefill with initial token + decoder_inputs = ( + [decoder_input_ids, cache_position] + list(cross_k_tuple) + list(cross_v_tuple) + ) + decoder_outputs = decoder_forward.execute(decoder_inputs) + logits = decoder_outputs[0] + + # Update cache position + cache_position = cache_position + decoder_input_ids.shape[1] + + # Generation loop + for _step in range(max_new_tokens): + current_pos = cache_position.item() + + # Check for forced token at this position + if current_pos in forced_tokens_dict: + next_token_id = forced_tokens_dict[current_pos] + else: + next_token_id = torch.argmax(logits[0, -1, :]).item() + + generated_tokens.append(next_token_id) + + # Check for EOS + if next_token_id == eos_id: + break + + # Prepare next decoder input + decoder_input_ids = torch.tensor([[next_token_id]], dtype=torch.long) + + # Run decoder + decoder_inputs = ( + [decoder_input_ids, cache_position] + + list(cross_k_tuple) + + list(cross_v_tuple) + ) + decoder_outputs = decoder_forward.execute(decoder_inputs) + logits = decoder_outputs[0] + + # Update cache position + cache_position = cache_position + 1 + + decode_time = time.time() - decode_start + total_time = time.time() - overall_start + tokens_generated = len(generated_tokens) - 1 # Exclude initial SOT + tokens_per_sec = tokens_generated / decode_time if decode_time > 0 else 0 + + print(f"\nEncoder time: {encoder_time:.3f}s") + print(f"Cross-KV time: {cross_kv_time:.3f}s") + print( + f"Decode time: {decode_time:.3f}s ({tokens_generated} tokens, {tokens_per_sec:.1f} tok/s)" + ) + print(f"Total time: {total_time:.3f}s") + + # Decode to text + transcript = processor.tokenizer.decode( + generated_tokens, + skip_special_tokens=True, + ) + + return transcript + + +def main(): + parser = argparse.ArgumentParser(description="Run exported Whisper model") + from executorch.backends.mlx.examples.whisper.args import add_run_args + + add_run_args(parser) + parser.add_argument( + "--model-dir", + type=str, + default="/tmp/whisper_mlx", + help="Directory containing exported .pte files", + ) + + args = parser.parse_args() + + if not args.audio_file and not args.use_sample_audio: + logger.warning("No audio specified. Using --use-sample-audio") + args.use_sample_audio = True + + transcript = run_whisper_inference( + model_dir=args.model_dir, + audio_path=args.audio_file, + use_sample_audio=args.use_sample_audio, + max_new_tokens=args.max_new_tokens, + language=args.language, + task=args.task, + dtype=args.dtype, + ) + + print("\n" + "=" * 60) + print("Transcript:") + print("=" * 60) + print(transcript) + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/llm/et_attention.py b/backends/mlx/llm/et_attention.py new file mode 100644 index 00000000000..10c758f94fe --- /dev/null +++ b/backends/mlx/llm/et_attention.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MLX-optimized attention for ExecutorTorch's Llama attention registry. + +Registers an "mlx" attention type that uses mlx::kv_cache_update and +mlx::custom_sdpa for efficient execution on Apple Silicon. + +Usage: + import executorch.backends.mlx.llm.et_attention # noqa: F401 + + model_args = ModelArgs(attention_type="mlx", ...) + transformer = construct_transformer(model_args) +""" + +from typing import Any, Optional, Tuple, TYPE_CHECKING + +import executorch.backends.mlx.custom_ops as _mlx_custom_ops # noqa: F401 + +import torch +import torch.nn as nn +from executorch.backends.mlx.llm.cache import KVCache +from executorch.examples.models.llama.attention import ( + Attention, + ForwardOptions, + register_attention, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.rope import Rope + +if TYPE_CHECKING: + from executorch.examples.models.llama.attention import AttentionMHA + + +@register_attention("mlx") +class MLXAttentionMHA(Attention): + """ + MLX-optimized attention using mlx::kv_cache_update and mlx::custom_sdpa. + + Supports MHA, GQA, KV caching, and optional QK normalization. + Follows the same interface as AttentionMHA. + """ + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + if not args.use_kv_cache: + raise ValueError("MLXAttention requires use_kv_cache=True") + + self.use_kv_cache = True + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_batch_size = args.max_batch_size + self.max_context_len = args.max_context_len + self.dim = args.dim + self.attention_qkv_bias = args.attention_qkv_bias + self.use_qk_norm = args.use_qk_norm + self.qk_norm_before_rope = args.qk_norm_before_rope + self.enable_dynamic_shape = args.enable_dynamic_shape + + if self.use_qk_norm: + self.q_norm_fn = RMSNorm(self.head_dim, eps=args.norm_eps) + self.k_norm_fn = RMSNorm(self.head_dim, eps=args.norm_eps) + + self.wq = nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wk = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wv = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.layer_id = layer_id + self.rope = rope + self.rope_base = rope.params.rope_freq_base + self.use_fused_rope = self._can_use_fused_rope(rope.params) + self.rope_traditional = not rope.params.use_hf_rope + self.rope_dims = int(self.head_dim * rope.params.partial_rotary_factor) + + self.kv_cache = KVCache( + max_batch_size=args.max_batch_size, + max_context_length=args.max_context_len, + n_heads=self.n_kv_heads, + head_dim=self.head_dim, + enable_dynamic_shape=args.enable_dynamic_shape, + ) + + @staticmethod + def _can_use_fused_rope(params: ModelArgs) -> bool: + if params.no_rope_layer_interval is not None: + return False + return True + + @classmethod + def from_attention_mha( + cls, other: "AttentionMHA", dtype: Optional[torch.dtype] = None + ) -> "MLXAttentionMHA": + """ + Create an MLXAttentionMHA from an existing AttentionMHA. + + Shares weight references (wq, wk, wv, wo, rope, norm) and creates + a fresh KVCache. + """ + from executorch.examples.models.llama.attention import AttentionMHA + + assert isinstance(other, AttentionMHA) + + instance = cls.__new__(cls) + Attention.__init__(instance) + + # Copy all config attributes + instance.use_kv_cache = True + instance.n_heads = other.n_heads + instance.n_kv_heads = other.n_kv_heads + instance.n_local_heads = other.n_local_heads + instance.n_local_kv_heads = other.n_local_kv_heads + instance.n_rep = other.n_rep + instance.head_dim = other.head_dim + instance.max_batch_size = other.max_batch_size + instance.max_context_len = other.max_context_len + instance.dim = other.dim + instance.attention_qkv_bias = other.attention_qkv_bias + instance.use_qk_norm = other.use_qk_norm + instance.qk_norm_before_rope = other.qk_norm_before_rope + instance.enable_dynamic_shape = other.enable_dynamic_shape + + # Share weight references + instance.wq = other.wq + instance.wk = other.wk + instance.wv = other.wv + instance.wo = other.wo + instance.layer_id = other.layer_id + instance.rope = other.rope + instance.rope_base = other.rope.params.rope_freq_base + instance.use_fused_rope = cls._can_use_fused_rope(other.rope.params) + instance.rope_traditional = not other.rope.params.use_hf_rope + instance.rope_dims = int( + instance.head_dim * other.rope.params.partial_rotary_factor + ) + + if other.use_qk_norm: + instance.q_norm_fn = other.q_norm_fn + instance.k_norm_fn = other.k_norm_fn + + # Create fresh MLX KV cache + cache_dtype = dtype if dtype is not None else torch.float32 + if hasattr(other, "kv_cache") and hasattr(other.kv_cache, "k_cache"): + cache_dtype = dtype if dtype is not None else other.kv_cache.k_cache.dtype + instance.kv_cache = KVCache( + max_batch_size=other.max_batch_size, + max_context_length=other.max_context_len, + n_heads=instance.n_kv_heads, + head_dim=instance.head_dim, + enable_dynamic_shape=other.enable_dynamic_shape, + dtype=cache_dtype, + ) + + return instance + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + input_pos = kwargs.get("input_pos") + assert input_pos is not None + bsz, seqlen, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + if self.use_qk_norm and self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + + if "start_pos" in kwargs: + start_pos = kwargs["start_pos"] + else: + start_pos = input_pos[0].item() + + if self.use_fused_rope: + # Transpose to BHSD first (mlx::rope expects BHSD) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + q = torch.ops.mlx.rope( + q, + self.rope_dims, + start_pos, + self.rope_traditional, + self.rope_base, + 1.0, + None, + ) + k = torch.ops.mlx.rope( + k, + self.rope_dims, + start_pos, + self.rope_traditional, + self.rope_base, + 1.0, + None, + ) + else: + # Fallback: upstream rope (handles scaled rope, partial rotary, etc.) + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + k, v = self.kv_cache.update(start_pos, k, v) + + output = torch.ops.mlx.custom_sdpa( + q, + k, + v, + start_pos=start_pos, + is_causal=True, + scale=self.head_dim**-0.5, + ) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output), None diff --git a/backends/mlx/llm/hf_attention.py b/backends/mlx/llm/hf_attention.py new file mode 100644 index 00000000000..9e3c864dce6 --- /dev/null +++ b/backends/mlx/llm/hf_attention.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MLX-optimized attention for HuggingFace models. + +Registers a custom attention implementation ("mlx") with HuggingFace's +attention interface, following the same pattern as optimum-executorch's +custom_sdpa: + +1. Mask function returns None (custom op handles causal masking internally) +2. Attention function extracts start_pos from position_ids[0][0] +3. mlx::custom_sdpa receives full K/V cache + start_pos, slices K/V internally +4. MLX pattern handler serializes custom_sdpa as SliceNode(K), SliceNode(V), SdpaNode + +Usage: + from executorch.backends.mlx.llm.hf_attention import register_mlx_attention + + register_mlx_attention() + + model = AutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="mlx", + ) +""" + +from typing import Callable, Optional, Tuple, Union + +import executorch.backends.mlx.custom_ops as _mlx_custom_ops # noqa: F401 + +import torch + + +def mlx_sdpa_with_start_pos_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD + key: torch.Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (full cache) + value: torch.Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (full cache) + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa: F821 + position_ids: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + """ + MLX-optimized SDPA following optimum-executorch's custom_sdpa pattern. + + Extracts start_pos from position_ids, then delegates to mlx::custom_sdpa + which handles K/V cache slicing, GQA expansion, and causal masking. + + Returns (output, None) where output is [B, seq_len, num_heads, head_dim] (BSHD). + """ + kwargs.pop("is_causal", None) + is_causal = getattr(module, "is_causal", True) + + if is_causal: + assert ( + position_ids is not None + ), "position_ids must be provided to find start position for causal attention" + start_pos = position_ids[0][0].item() + seq_len = query.shape[2] + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= key.shape[2]) + attn_mask = None + else: + start_pos = 0 + attn_mask = attention_mask + + output = torch.ops.mlx.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scaling, + ) + + # Transpose BHSD → BSHD for HF + return output.transpose(1, 2).contiguous(), None + + +def sdpa_mask_passthrough( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Optional[Callable] = None, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + allow_torch_fix: bool = True, + **kwargs, +) -> Optional[torch.Tensor]: + """Returns None — custom SDPA handles causal masking, avoiding bounded mask tensors.""" + return None + + +def register_mlx_attention(name: str = "mlx") -> None: + """ + Register MLX attention with HuggingFace's attention interfaces. + + After registration, models can use MLX attention via: + model = AutoModelForCausalLM.from_pretrained(..., attn_implementation="mlx") + """ + try: + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + ALL_ATTENTION_FUNCTIONS.register(name, mlx_sdpa_with_start_pos_forward) + ALL_MASK_ATTENTION_FUNCTIONS.register(name, sdpa_mask_passthrough) + + except ImportError: + raise ImportError( + "transformers is not installed. Please install it: pip install transformers" + ) + + +def get_mlx_sliding_window_sdpa(exportable_module) -> Callable: + """ + Create a closure-based SDPA function for sliding window attention. + + Following optimum-executorch's pattern, the returned function captures + the model reference so it can access ring buffer caches at runtime to + create attention masks lazily — avoiding torch.export tracing issues. + + Args: + exportable_module: The model module containing .cache (HFStaticCache + or similar) with ring buffer layers accessible via .kv_cache[layer_idx]. + + Returns: + Attention function compatible with HuggingFace's attention interface. + """ + + def _sliding_window_sdpa_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD + key: torch.Tensor, # [B, num_kv_heads, window_size, head_dim] - BHSD + value: torch.Tensor, # [B, num_kv_heads, window_size, head_dim] - BHSD + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa: F821 + position_ids: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + """ + MLX sliding window SDPA using ring buffer KV cache. + + Creates the attention mask lazily by reaching into the ring buffer + cache via the captured model reference. This keeps mask creation + in Python (not in the traced graph). + + Uses is_causal=False since the mask handles both causality and windowing. + """ + from executorch.backends.mlx.llm.cache import RingBufferKVCache + + layer_idx = getattr(module, "layer_idx", None) + seq_len = query.shape[2] + attn_mask = None + start_pos = 0 + + if layer_idx is not None and position_ids is not None: + start_pos = position_ids[0][0].item() + + # Reach into the model's cache to find the ring buffer for this layer. + # TorchExportableModuleWithHybridCache stores .cache (standard path). + cache = getattr(exportable_module, "cache", None) + + if cache is not None: + layer_cache = cache.kv_cache[layer_idx] + if isinstance(layer_cache, RingBufferKVCache): + attn_mask = layer_cache.create_sliding_window_mask( + start_pos, seq_len + ) + # Override start_pos so custom_sdpa slices the full buffer: + # stop_pos = start_pos + seq_len = buffer_size + start_pos = layer_cache.buffer_size - seq_len + + if attn_mask is None: + raise RuntimeError( + f"Sliding window attention at layer {layer_idx} requires a " + f"RingBufferKVCache, but none was found. Ensure the model's " + f"cache is set up with RingBufferKVCache for sliding window layers." + ) + + output = torch.ops.mlx.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + scale=scaling, + ) + + # Transpose BHSD → BSHD for HF + return output.transpose(1, 2).contiguous(), None + + return _sliding_window_sdpa_forward + + +def register_mlx_sliding_window_attention( + exportable_module, name: str = "mlx_sliding_window" +) -> None: + """Register MLX sliding window attention with HuggingFace's attention interfaces.""" + try: + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + sdpa_fn = get_mlx_sliding_window_sdpa(exportable_module) + ALL_ATTENTION_FUNCTIONS.register(name, sdpa_fn) + ALL_MASK_ATTENTION_FUNCTIONS.register(name, sdpa_mask_passthrough) + + except ImportError: + raise ImportError( + "transformers is not installed. Please install it: pip install transformers" + ) diff --git a/backends/mlx/llm/quantization.py b/backends/mlx/llm/quantization.py new file mode 100644 index 00000000000..0fdb988f9fa --- /dev/null +++ b/backends/mlx/llm/quantization.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared quantization utilities for MLX LLM export scripts. +""" + +import argparse +import logging +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def add_quantization_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--quantize-linear", + type=str, + choices=["int4", "int8"], + default=None, + help="Quantization method for linear layers", + ) + parser.add_argument( + "--quantize-embeddings", + type=str, + choices=["int4", "int8"], + default=None, + help="Quantization method for embedding layers", + ) + parser.add_argument( + "--linear-group-size", + type=int, + choices=[32, 64, 128], + default=None, + help="Group size for linear layer quantization (default: 32 for int4, 128 for int8)", + ) + parser.add_argument( + "--embeddings-group-size", + type=int, + choices=[32, 64, 128], + default=None, + help="Group size for embedding layer quantization (default: 32 for int4, 128 for int8)", + ) + parser.add_argument( + "--no-tie-word-embeddings", + action="store_true", + default=False, + help="Disable tying lm_head weights to embedding after quantization, " + "even if the model config has tie_word_embeddings=True", + ) + + +def _default_group_size(dtype_str: str) -> int: + return 32 if dtype_str == "int4" else 128 + + +def apply_quantization( + model: torch.nn.Module, + quantize_linear: Optional[str], + quantize_embeddings: Optional[str], + tie_word_embeddings: bool = False, + linear_group_size: Optional[int] = None, + embeddings_group_size: Optional[int] = None, +) -> None: + """Apply TorchAO quantization to the model. + + Uses the HQQ (Half-Quadratic Quantization) scale-only algorithm for + choosing quantization parameters. + + Args: + model: The model to quantize. Expected to have model.model.embed_tokens + and model.lm_head attributes for weight tying. + quantize_linear: Quantization method for linear layers ("int4", "int8", or None) + quantize_embeddings: Quantization method for embedding layers ("int4", "int8", or None) + tie_word_embeddings: If True, re-tie lm_head.weight to embed_tokens.weight + after quantization. Should be set from the HF model config's + tie_word_embeddings field, and can be overridden with --no-tie-word-embeddings. + linear_group_size: Group size for linear quantization. Defaults to 32 for int4, 128 for int8. + embeddings_group_size: Group size for embedding quantization. Defaults to 32 for int4, 128 for int8. + """ + if not quantize_linear and not quantize_embeddings: + return + + logger.info("Applying quantization with TorchAO...") + try: + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + from torchao.quantization.quantize_.workflows import IntxChooseQParamsAlgorithm + + qparams_algorithm = IntxChooseQParamsAlgorithm.HQQ_SCALE_ONLY + + if quantize_embeddings: + embed_dtype = torch.int4 if quantize_embeddings == "int4" else torch.int8 + embed_group_size = embeddings_group_size or _default_group_size( + quantize_embeddings + ) + logger.info( + f"Quantizing embedding layers with {quantize_embeddings} " + f"(group size {embed_group_size})..." + ) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=embed_dtype, + granularity=PerGroup(embed_group_size), + intx_choose_qparams_algorithm=qparams_algorithm, + ), + filter_fn=lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if quantize_linear: + linear_dtype = torch.int4 if quantize_linear == "int4" else torch.int8 + linear_group_size = linear_group_size or _default_group_size( + quantize_linear + ) + logger.info( + f"Quantizing linear layers with {quantize_linear} " + f"(group size {linear_group_size})..." + ) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=linear_dtype, + granularity=PerGroup(linear_group_size), + intx_choose_qparams_algorithm=qparams_algorithm, + ), + filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear), + ) + + if ( + tie_word_embeddings + and hasattr(model, "lm_head") + and hasattr(model, "model") + ): + embed = getattr(model.model, "embed_tokens", None) + if embed is not None: + model.lm_head.weight = embed.weight + logger.info( + "Re-tied lm_head weights to embedding (tie_word_embeddings=True)" + ) + + logger.info("Applied quantization successfully") + except ImportError: + logger.error("TorchAO not installed. Run: pip install torchao") + raise diff --git a/backends/mlx/llm/source_transformation.py b/backends/mlx/llm/source_transformation.py new file mode 100644 index 00000000000..d90073c633e --- /dev/null +++ b/backends/mlx/llm/source_transformation.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Source transformations for MLX backend export. + +Provides transforms that replace standard model components with MLX-optimized +versions +""" + +import logging +from typing import Callable + +import torch +import torch.nn as nn + +from executorch.backends.mlx.llm.cache import HFStaticCache, KVCache, RingBufferKVCache + +logger = logging.getLogger(__name__) + + +def _replace_modules( + module: nn.Module, + target_type: type, + factory: Callable[[nn.Module], nn.Module], + label: str, +) -> nn.Module: + """Recursively replace all instances of target_type using factory.""" + + def _recurse(parent: nn.Module) -> int: + count = 0 + for name, child in list(parent.named_children()): + if isinstance(child, target_type): + setattr(parent, name, factory(child)) + count += 1 + else: + count += _recurse(child) + return count + + count = _recurse(module) + if count > 0: + logger.info(f"Replaced {count} {label}") + return module + + +def replace_et_kv_cache_with_mlx( + module: nn.Module, dtype: torch.dtype = None +) -> nn.Module: + """ + Replace ET's KVCache with MLX-optimized KVCache. + + Recursively finds all KVCache instances (from examples/models/llama/attention.py) + and replaces them with KVCache, which uses mlx::kv_cache_update instead of + unsupported index_put operations. + + Args: + module: Model to modify (in place) + dtype: Optional dtype for cache tensors. If None, uses original cache dtype. + """ + try: + from executorch.examples.models.llama.attention import ( + KVCache as ETKVCache_Original, + ) + except ImportError: + return module + + def _make_mlx_cache(child): + cache_dtype = dtype if dtype is not None else child.k_cache.dtype + return KVCache( + max_batch_size=child.max_batch_size, + max_context_length=child.max_context_length, + n_heads=child.n_heads, + head_dim=child.head_dim, + enable_dynamic_shape=child.enable_dynamic_shape, + dtype=cache_dtype, + ) + + return _replace_modules( + module, + ETKVCache_Original, + _make_mlx_cache, + f"KVCache → KVCache (dtype={dtype})", + ) + + +def replace_hf_cache_with_mlx( + module: nn.Module, + config, + max_batch_size: int = 1, + max_cache_len: int | None = None, + dtype: torch.dtype = torch.float32, +) -> nn.Module: + """ + Replace HuggingFace's StaticCache with MLX-optimized HFStaticCache. + + Should be called on TorchExportableModuleWithStaticCache (from + transformers.integrations.executorch), NOT on CausalLMExportableModule + (from optimum-executorch). + + Args: + module: HF exportable module with static_cache or cache attribute + config: HF model config + max_batch_size: Maximum batch size (default: 1) + max_cache_len: Maximum cache length. If None, uses config.max_position_embeddings + dtype: Cache tensor dtype (default: torch.float32) + + Raises: + ValueError: If module has no recognized cache attribute + """ + from transformers.cache_utils import StaticCache + + mlx_cache = HFStaticCache( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + dtype=dtype, + ) + + def _install_cache(attr_name): + setattr(module, attr_name, mlx_cache) + for i, layer_cache in enumerate(mlx_cache.kv_cache): + setattr(module, f"key_cache_{i}", layer_cache.k_cache) + setattr(module, f"value_cache_{i}", layer_cache.v_cache) + + if hasattr(module, "static_cache"): + assert isinstance( + module.static_cache, StaticCache + ), f"Expected StaticCache, got {type(module.static_cache)}" + _install_cache("static_cache") + elif hasattr(module, "cache"): + if isinstance(module.cache, StaticCache): + _install_cache("cache") + else: + raise ValueError( + f"module.cache is not a StaticCache, got {type(module.cache)}" + ) + else: + raise ValueError("Module must have 'static_cache' or 'cache' attribute") + + return module + + +def replace_hf_cache_with_mlx_ring_buffer( + module: nn.Module, + config, + max_batch_size: int = 1, + window_size: int = 512, + dtype: torch.dtype = torch.float32, +) -> nn.Module: + """ + Replace HuggingFace's StaticCache with RingBufferKVCache for sliding window models. + + Creates a HFStaticCache-like structure where each layer uses a RingBufferKVCache + instead of a linear KVCache. This enables infinite-length generation for models + with sliding window attention (e.g., gemma). + + Args: + module: HF exportable module with static_cache or cache attribute + config: HF model config + max_batch_size: Maximum batch size (default: 1) + window_size: Sliding window size (cache capacity per layer) + dtype: Cache tensor dtype + + Raises: + ValueError: If module has no recognized cache attribute + """ + from transformers.cache_utils import StaticCache + + num_layers = config.num_hidden_layers + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + + # Create HFStaticCache with ring buffer layers + mlx_cache = HFStaticCache( + config=config, + max_batch_size=max_batch_size, + max_cache_len=window_size, + dtype=dtype, + ) + + # Replace each layer's KVCache with RingBufferKVCache + for i in range(num_layers): + ring_cache = RingBufferKVCache( + max_batch_size=max_batch_size, + max_context_length=window_size, + n_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) + mlx_cache.kv_cache[i] = ring_cache + + def _install_cache(attr_name): + setattr(module, attr_name, mlx_cache) + for i, layer_cache in enumerate(mlx_cache.kv_cache): + setattr(module, f"key_cache_{i}", layer_cache.k_cache) + setattr(module, f"value_cache_{i}", layer_cache.v_cache) + + if hasattr(module, "static_cache"): + assert isinstance( + module.static_cache, StaticCache + ), f"Expected StaticCache, got {type(module.static_cache)}" + _install_cache("static_cache") + elif hasattr(module, "cache"): + if isinstance(module.cache, StaticCache): + _install_cache("cache") + else: + raise ValueError( + f"module.cache is not a StaticCache, got {type(module.cache)}" + ) + else: + raise ValueError("Module must have 'static_cache' or 'cache' attribute") + + logger.info( + f"Installed RingBufferKVCache: {num_layers} layers, " + f"window_size={window_size}, heads={num_kv_heads}, head_dim={head_dim}" + ) + + return module + + +class MLXRope(nn.Module): + """ + MLX-optimized Rotary Position Embedding. + + Wraps ET's Rope, currently delegating to the original implementation. + Can be extended to use torch.ops.mlx.rope. + """ + + def __init__(self, original_rope: nn.Module): + super().__init__() + self.params = original_rope.params + self.precompute_freqs_cis = original_rope.precompute_freqs_cis + self.apply_rotary_emb = original_rope.apply_rotary_emb + self.register_buffer("freqs_cos", original_rope.freqs_cos, persistent=False) + self.register_buffer("freqs_sin", original_rope.freqs_sin, persistent=False) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos, seq_len: int): + if self.params.use_kv_cache: + assert input_pos is not None + if self.params.enable_dynamic_shape: + input_pos_item = input_pos[-1].item() + torch._check(input_pos_item >= 0) + torch._check(input_pos_item < self.params.max_context_len) + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + else: + assert input_pos is None + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + return freqs_cos, freqs_sin + + +def transform_attention_mha_to_mlx( + module: nn.Module, dtype: torch.dtype = None +) -> nn.Module: + """ + Replace AttentionMHA with MLXAttentionMHA throughout the model. + + Shares weight references (wq, wk, wv, wo, rope, norm) from the original + and creates a fresh KVCache for each attention layer. + + Args: + module: Model to modify (in place) + dtype: Optional dtype for KV cache. If None, inferred from original. + """ + from executorch.backends.mlx.llm.et_attention import MLXAttentionMHA + from executorch.examples.models.llama.attention import AttentionMHA + + _replace_modules( + module, + AttentionMHA, + lambda child: MLXAttentionMHA.from_attention_mha(child, dtype=dtype), + f"AttentionMHA → MLXAttentionMHA (cache dtype={dtype})", + ) + return module diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index a61d43f626e..6d5b5cc2566 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -107,8 +107,13 @@ else() endif() # quantized_ops_lib: Register quantized op kernels into the runtime -executorch_target_link_options_shared_lib(quantized_ops_lib) -list(APPEND link_libraries quantized_kernels quantized_ops_lib) +if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_kernels quantized_ops_lib) + get_target_property(_quantized_imported quantized_ops_lib IMPORTED) + if(NOT _quantized_imported) + executorch_target_link_options_shared_lib(quantized_ops_lib) + endif() +endif() if(TARGET custom_ops) executorch_target_link_options_shared_lib(custom_ops) @@ -198,6 +203,12 @@ if(TARGET mpsdelegate) executorch_target_link_options_shared_lib(mpsdelegate) endif() +# MLX backend +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + # Openvino backend if(TARGET openvino_backend) find_package(OpenVINO REQUIRED) @@ -226,6 +237,11 @@ endif() add_executable(llama_main ${_srcs}) +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(llama_main) +endif() + # Only strip symbols for Release and MinSizeRel builds. if(CMAKE_BUILD_TYPE STREQUAL "Release" OR CMAKE_BUILD_TYPE STREQUAL "MinSizeRel" diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index dd7aafdd024..3bc2dfc3952 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -496,6 +496,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="Specify the device for Openvino (CPU, GPU or NPU).", ) + parser.add_argument( + "--mlx", + action="store_true", + help="Delegate to MLX backend (Apple Silicon). Use with --use_kv_cache=True.", + ) + parser.add_argument( "--expand_rope_table", default=False, @@ -766,6 +772,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: coreml=llm_config.backend.coreml.enabled, coreml_ios=llm_config.backend.coreml.ios, vulkan=llm_config.backend.vulkan.enabled, + mlx=llm_config.backend.mlx.enabled, use_qat=llm_config.quantization.use_qat, use_lora=llm_config.base.use_lora, preq_mode=( @@ -1044,6 +1051,34 @@ def _to_edge_and_lower_llama_arm( return builder.to_executorch(passes=additional_passes) +def _to_edge_and_lower_llama_mlx( + builder_exported, + modelname, + quantizers, + additional_passes, + verbose: bool = False, +) -> LLMEdgeManager: + """ + Lower Llama model to MLX backend using to_edge_transform_and_lower. + """ + logging.info("Lowering model using MLX partitioner") + + from executorch.backends.mlx.partitioner import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + + partitioners = [MLXPartitioner()] + + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners, + transform_passes=get_default_passes(), + ) + + if verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + return builder.to_executorch(passes=additional_passes) + + def _to_edge_and_lower_llama( # noqa: C901 builder_exported, modelname, @@ -1420,6 +1455,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config, verbose=llm_config.debug.verbose, ) + elif llm_config.backend.mlx.enabled: + builder = _to_edge_and_lower_llama_mlx( + builder_exported, + modelname, + quantizers, + additional_passes, + verbose=llm_config.debug.verbose, + ) else: builder = _to_edge_and_lower_llama( builder_exported, @@ -1597,6 +1640,7 @@ def _get_source_transforms( # noqa coreml: bool = False, coreml_ios: int = 15, vulkan: bool = False, + mlx: bool = False, use_qat: bool = False, use_lora: int = 0, preq_mode: Optional[str] = None, @@ -1774,6 +1818,19 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) + elif mlx: + from executorch.backends.mlx.llm.source_transformation import ( + replace_et_kv_cache_with_mlx, + transform_attention_mha_to_mlx, + ) + from executorch.examples.models.llama.source_transformation.rms_norm import ( + replace_rms_norm_with_native_rms_norm, + ) + + transforms.append(transform_attention_mha_to_mlx) + transforms.append(replace_et_kv_cache_with_mlx) + transforms.append(replace_rms_norm_with_native_rms_norm) + if local_global_attention: transforms.append( partial( diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index ec52a596af2..218b77d087a 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -23,6 +23,7 @@ find_package(gflags REQUIRED) # Find executorch libraries list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) + find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) get_target_property(_executorch_imported executorch IMPORTED) if(NOT _executorch_imported) @@ -42,9 +43,14 @@ endif() # CPU-only builds need quantized and custom ops if(NOT EXECUTORCH_BUILD_CUDA) - list(APPEND link_libraries quantized_ops_lib custom_ops) - executorch_target_link_options_shared_lib(quantized_ops_lib) - executorch_target_link_options_shared_lib(custom_ops) + if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + endif() + if(TARGET custom_ops) + list(APPEND link_libraries custom_ops) + executorch_target_link_options_shared_lib(custom_ops) + endif() endif() # XNNPACK @@ -91,6 +97,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() +# Link MLX delegate +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + add_executable(parakeet_runner main.cpp timestamp_utils.cpp tokenizer_utils.cpp) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(parakeet_runner) @@ -99,6 +111,11 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") endif() endif() +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(parakeet_runner) +endif() + target_include_directories( parakeet_runner PUBLIC ${_common_include_directories} ) diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json index ccb4f4fcdd2..9644378ed73 100644 --- a/examples/models/parakeet/CMakePresets.json +++ b/examples/models/parakeet/CMakePresets.json @@ -55,6 +55,19 @@ "type": "equals", "rhs": "Darwin" } + }, + { + "name": "parakeet-mlx", + "displayName": "Parakeet runner (MLX)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -85,6 +98,12 @@ "configurePreset": "parakeet-metal", "configuration": "Release", "targets": ["parakeet_runner"] + }, + { + "name": "parakeet-mlx", + "displayName": "Build Parakeet runner (MLX)", + "configurePreset": "parakeet-mlx", + "targets": ["parakeet_runner"] } ], "workflowPresets": [ @@ -143,6 +162,20 @@ "name": "parakeet-metal" } ] + }, + { + "name": "parakeet-mlx", + "displayName": "Configure and build Parakeet runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "parakeet-mlx" + }, + { + "type": "build", + "name": "parakeet-mlx" + } + ] } ] } diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 75611b3bd4f..b8f4bc0cd78 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -25,7 +25,7 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | -| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `cuda`, `cuda-windows` (default: `xnnpack`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `mlx`, `cuda`, `cuda-windows` (default: `xnnpack`) | | `--dtype` | Data type: `fp32`, `bf16`, `fp16` (default: `fp32`). Metal backend supports `fp32` and `bf16` only (no `fp16`). | | `--audio` | Path to audio file for transcription test | @@ -171,6 +171,23 @@ python export_parakeet_tdt.py --backend cuda-windows --output-dir ./parakeet_cud This generates: - `model.pte` - The compiled Parakeet TDT model - `aoti_cuda_blob.ptd` - CUDA kernel blob required at runtime + +### MLX Export (macOS) + +Export with MLX backend (bf16, int4 quantized, group size 128): +```bash +python export_parakeet_tdt.py \ + --backend mlx \ + --dtype bf16 \ + --qlinear_encoder 4w \ + --qlinear_encoder_group_size 128 \ + --qlinear 4w \ + --qlinear_group_size 128 \ + --output-dir ./parakeet_mlx_4w +``` + +This generates: +- `model.pte` - The compiled model with MLX delegate (~470 MB) - `tokenizer.model` - SentencePiece tokenizer ## C++ Runner @@ -188,6 +205,9 @@ make parakeet-metal # CUDA build (Linux) make parakeet-cuda + +# MLX build (macOS) +make parakeet-mlx ``` On Windows (PowerShell), use CMake workflow presets directly: @@ -222,6 +242,12 @@ DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner --data_path examples/models/parakeet/parakeet_cuda/aoti_cuda_blob.ptd \ --audio_path /path/to/audio.wav \ --tokenizer_path examples/models/parakeet/parakeet_cuda/tokenizer.model + +# MLX +./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path examples/models/parakeet/parakeet_mlx_4w/model.pte \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/parakeet_mlx_4w/tokenizer.model ``` Windows (PowerShell): diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 6747880cd9e..340ac02d833 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -389,7 +389,6 @@ def export_all( qlinear_group_size=qlinear_encoder_group_size, qlinear_packing_format=qlinear_encoder_packing_format, ) - programs["encoder"] = export( encoder_with_proj, (), @@ -560,11 +559,26 @@ def _create_cuda_partitioners(programs, is_windows=False): return partitioner, updated_programs +def _create_mlx_partitioners(programs): + """Create MLX partitioners for all programs.""" + from executorch.backends.mlx.partitioner import MLXPartitioner + + print("\nLowering to ExecuTorch with MLX...") + + partitioner = {} + for key in programs.keys(): + partitioner[key] = [MLXPartitioner()] + + return partitioner, programs + + def lower_to_executorch(programs, metadata=None, backend="portable"): if backend == "xnnpack": partitioner, programs = _create_xnnpack_partitioners(programs) elif backend == "metal": partitioner, programs = _create_metal_partitioners(programs) + elif backend == "mlx": + partitioner, programs = _create_mlx_partitioners(programs) elif backend in ("cuda", "cuda-windows"): partitioner, programs = _create_cuda_partitioners( programs, is_windows=(backend == "cuda-windows") @@ -607,7 +621,7 @@ def main(): "--backend", type=str, default="xnnpack", - choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows"], + choices=["portable", "xnnpack", "metal", "mlx", "cuda", "cuda-windows"], help="Backend for acceleration (default: xnnpack)", ) parser.add_argument( diff --git a/examples/models/voxtral/CMakeLists.txt b/examples/models/voxtral/CMakeLists.txt index 036e6454efe..2e8ebb5c5e9 100644 --- a/examples/models/voxtral/CMakeLists.txt +++ b/examples/models/voxtral/CMakeLists.txt @@ -45,9 +45,14 @@ executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) # CPU-only builds need quantized and custom ops if(NOT EXECUTORCH_BUILD_CUDA) - list(APPEND link_libraries quantized_ops_lib custom_ops) - executorch_target_link_options_shared_lib(quantized_ops_lib) - executorch_target_link_options_shared_lib(custom_ops) + if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + endif() + if(TARGET custom_ops) + list(APPEND link_libraries custom_ops) + executorch_target_link_options_shared_lib(custom_ops) + endif() endif() # XNNPACK @@ -99,6 +104,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() +# Link MLX delegate +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + # Add tokenizers list(APPEND link_libraries tokenizers::tokenizers) @@ -120,6 +131,11 @@ if(WIN32) target_link_options(voxtral_runner PRIVATE "/STACK:8388608") endif() +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(voxtral_runner) +endif() + # On Windows, copy required DLLs to the executable directory if(MSVC AND EXECUTORCH_BUILD_CUDA) add_custom_command( diff --git a/examples/models/voxtral/CMakePresets.json b/examples/models/voxtral/CMakePresets.json index d9e0ba6af19..e853604c1a1 100644 --- a/examples/models/voxtral/CMakePresets.json +++ b/examples/models/voxtral/CMakePresets.json @@ -41,6 +41,19 @@ "type": "equals", "rhs": "Darwin" } + }, + { + "name": "voxtral-mlx", + "displayName": "Voxtral runner (MLX)", + "inherits": ["voxtral-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -61,6 +74,12 @@ "displayName": "Build Voxtral runner (Metal)", "configurePreset": "voxtral-metal", "targets": ["voxtral_runner"] + }, + { + "name": "voxtral-mlx", + "displayName": "Build Voxtral runner (MLX)", + "configurePreset": "voxtral-mlx", + "targets": ["voxtral_runner"] } ], "workflowPresets": [ @@ -105,6 +124,20 @@ "name": "voxtral-metal" } ] + }, + { + "name": "voxtral-mlx", + "displayName": "Configure and build Voxtral runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "voxtral-mlx" + }, + { + "type": "build", + "name": "voxtral-mlx" + } + ] } ] } diff --git a/examples/models/voxtral/README.md b/examples/models/voxtral/README.md index 72d21425648..1ffae876a99 100644 --- a/examples/models/voxtral/README.md +++ b/examples/models/voxtral/README.md @@ -140,6 +140,51 @@ This will generate: See the "Building the multimodal runner" section below for instructions on building with Metal support, and the "Running the model" section for runtime instructions. +## MLX Support (macOS) +On Apple Silicon, you can export and run Voxtral using the [MLX backend](../../../backends/mlx), which provides accelerated inference via Apple's MLX framework. + +### Exporting with MLX +The MLX export script produces two `.pte` files — the model and the audio preprocessor — both delegated to MLX: +``` +python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --output-dir mlx_voxtral_int4_bf16 \ + --dtype bf16 \ + --quantize-linear int4 +``` + +This will generate: +- `model.pte` - The exported model with MLX delegate (audio_encoder, token_embedding, text_decoder) +- `preprocessor.pte` - The mel spectrogram audio preprocessor with MLX delegate + +#### Export arguments + +| Argument | Description | +|----------|-------------| +| `--model-id` | HuggingFace model ID (default: `mistralai/Voxtral-Mini-3B-2507`) | +| `--output-dir` | Output directory for `.pte` files (default: `voxtral_mlx`) | +| `--dtype` | Model dtype: `fp32`, `fp16`, `bf16` (default: `bf16`) | +| `--max-seq-len` | Maximum sequence length for KV cache (default: `1024`) | +| `--quantize-linear` | Quantization for linear layers: `int4`, `int8` (default: none) | +| `--quantize-linear-group-size` | Group size for linear quantization (default: `32`) | +| `--max-audio-len` | Maximum audio length in seconds for preprocessor (default: `300`) | + +### Building for MLX +From the ExecuTorch root directory: +``` +make voxtral-mlx +``` + +### Running with MLX +``` +./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path mlx_voxtral_int4_bf16/model.pte \ + --tokenizer_path path/to/tekken.json \ + --prompt "What is happening in this audio?" \ + --audio_path path/to/audio.wav \ + --processor_path mlx_voxtral_int4_bf16/preprocessor.pte \ + --temperature 0 +``` + # Running the model To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's MultiModal runner API. The Voxtral runner will do the following things: diff --git a/examples/models/voxtral_realtime/CMakeLists.txt b/examples/models/voxtral_realtime/CMakeLists.txt index 5d047df51c0..28545f407ca 100644 --- a/examples/models/voxtral_realtime/CMakeLists.txt +++ b/examples/models/voxtral_realtime/CMakeLists.txt @@ -33,7 +33,7 @@ list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) # CPU-only builds need quantized and custom ops -if(NOT EXECUTORCH_BUILD_CUDA) +if(NOT EXECUTORCH_BUILD_CUDA AND NOT EXECUTORCH_BUILD_MLX) list(APPEND link_libraries quantized_ops_lib custom_ops) executorch_target_link_options_shared_lib(quantized_ops_lib) executorch_target_link_options_shared_lib(custom_ops) @@ -87,6 +87,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() +# Link MLX delegate +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + # Tokenizer list(APPEND link_libraries tokenizers::tokenizers) @@ -106,6 +112,11 @@ target_compile_options( voxtral_realtime_runner PUBLIC ${_common_compile_options} ) +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(voxtral_realtime_runner) +endif() + # On Windows, copy required DLLs to the executable directory if(MSVC AND EXECUTORCH_BUILD_CUDA) add_custom_command( diff --git a/examples/models/voxtral_realtime/CMakePresets.json b/examples/models/voxtral_realtime/CMakePresets.json index 707e94b0169..94f8411fb2d 100644 --- a/examples/models/voxtral_realtime/CMakePresets.json +++ b/examples/models/voxtral_realtime/CMakePresets.json @@ -41,6 +41,19 @@ "string": "${hostSystemName}", "list": ["Linux", "Windows"] } + }, + { + "name": "voxtral-realtime-mlx", + "displayName": "Voxtral Realtime runner (MLX)", + "inherits": ["voxtral-realtime-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -62,6 +75,13 @@ "displayName": "Build Voxtral Realtime runner (CUDA)", "configurePreset": "voxtral-realtime-cuda", "targets": ["voxtral_realtime_runner"] + }, + { + "name": "voxtral-realtime-mlx", + "displayName": "Build Voxtral Realtime runner (MLX)", + "configurePreset": "voxtral-realtime-mlx", + "configuration": "Release", + "targets": ["voxtral_realtime_runner"] } ], "workflowPresets": [ @@ -106,6 +126,20 @@ "name": "voxtral-realtime-cuda" } ] + }, + { + "name": "voxtral-realtime-mlx", + "displayName": "Configure and build Voxtral Realtime runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "voxtral-realtime-mlx" + }, + { + "type": "build", + "name": "voxtral-realtime-mlx" + } + ] } ] } diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index 7d29ba8c11b..6915fba3580 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -88,10 +88,11 @@ python export_voxtral_rt.py \ |---------|---------|-----------|--------------| | `xnnpack` | ✓ | ✓ | `4w`, `8w`, `8da4w`, `8da8w` | | `metal` | ✓ | ✓ | none (fp32) or `fpa4w` (Metal-specific 4-bit) | +| `mlx` | ✓ | ✓ | `4w`, `8w`, `nvfp4` (NVIDIA FP4 dtype) | | `cuda` | ✓ | ✓ | `4w`, `8w` | -Metal backend provides Apple GPU acceleration. CUDA backend provides NVIDIA GPU -acceleration via AOTInductor. + +MLX and Metal backends provide Apple GPU acceleration. CUDA backend provides NVIDIA GPU acceleration via AOTInductor. #### CUDA export examples @@ -163,12 +164,48 @@ Alternatively, you can build torchao with Metal support while installing ExecuTo EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh ``` +#### MLX export examples + +MLX backend uses the MLX delegate for Apple Silicon GPU acceleration. + +Offline: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear 4w \ + --qembedding 8w \ + --qembedding-group-size 128 \ + --export-preprocessor +``` + +Streaming: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ + --streaming \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear 4w \ + --qembedding 8w \ + --qembedding-group-size 128 \ + --export-preprocessor +``` + +`--export-preprocessor` bundles the mel preprocessor into the output directory +using the MLX partitioner, so no separate preprocessor export step is needed. + ### Options | Flag | Default | Description | |------|---------|-------------| | `--model-path` | (required) | Directory with `params.json` + `consolidated.safetensors` | -| `--backend` | `xnnpack` | `xnnpack`, `metal`, `cuda`, or `portable` | +| `--backend` | `xnnpack` | `xnnpack`, `mlx`, `metal`, `cuda`, or `portable` | | `--dtype` | `fp32` | Model dtype: `fp32` or `bf16` | | `--output-dir` | `./voxtral_rt_exports` | Output directory | | `--max-seq-len` | `4096` | KV cache length | @@ -180,6 +217,8 @@ EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_ex | `--qlinear-encoder-group-size` | `32` | Group size for encoder linear quantization | | `--qlinear-encoder-packing-format` | (none) | Packing format for encoder 4w quantization (`tile_packed_to_4d` for CUDA) | | `--qembedding` | (none) | Embedding layer quantization (`8w`) | +| `--qembedding-group-size` | `0` | Group size for embedding quantization (0 = per-channel) | +| `--export-preprocessor` | off | Export `preprocessor.pte` alongside the model | | `--streaming` | off | Export streaming encoder with KV cache | | `--max-enc-len` | `750` | Encoder sliding window size (streaming only) | @@ -220,6 +259,15 @@ make voxtral_realtime-metal This builds ExecuTorch with Metal backend support. The runner binary is at the same path as above. Metal exports can only run on macOS with Apple Silicon. +### MLX (Apple GPU) + +```bash +make voxtral_realtime-mlx +``` + +This builds ExecuTorch with MLX backend support. MLX provides GPU acceleration +on Apple Silicon via the MLX delegate. + ## Run The runner requires: diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index 31f792232a3..8255aa9861c 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -39,9 +39,7 @@ import torch import torch.nn as nn - from executorch.examples.models.voxtral_realtime.model import load_model - from executorch.exir import ( EdgeCompileConfig, ExecutorchBackendConfig, @@ -112,6 +110,7 @@ def _export_decoder_and_embedding( qlinear_group_size, qlinear_packing_format, qembedding, + qembedding_group_size, device="cpu", ): """Export text_decoder and token_embedding into programs dict.""" @@ -157,6 +156,7 @@ def _export_decoder_and_embedding( quantize_model_( tok_emb, qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, ) tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len) @@ -174,12 +174,13 @@ def export_all( model, max_seq_len, qlinear_encoder=None, - qlinear_encoder_group_size=32, + qlinear_encoder_group_size=None, qlinear_encoder_packing_format=None, qlinear=None, - qlinear_group_size=32, + qlinear_group_size=None, qlinear_packing_format=None, qembedding=None, + qembedding_group_size=None, backend="xnnpack", ): """Export all three model components with per-component quantization.""" @@ -236,6 +237,7 @@ def export_all( qlinear_group_size, qlinear_packing_format, qembedding, + qembedding_group_size, device, ) @@ -258,12 +260,13 @@ def export_streaming( max_seq_len, max_enc_len=750, qlinear_encoder=None, - qlinear_encoder_group_size=32, + qlinear_encoder_group_size=None, qlinear_encoder_packing_format=None, qlinear=None, - qlinear_group_size=32, + qlinear_group_size=None, qlinear_packing_format=None, qembedding=None, + qembedding_group_size=None, backend="xnnpack", ): """Export streaming model components with per-component quantization.""" @@ -316,6 +319,7 @@ def export_streaming( qlinear_group_size, qlinear_packing_format, qembedding, + qembedding_group_size, device, ) @@ -373,6 +377,60 @@ def _linear_bias_decomposition(input, weight, bias=None): return out +def export_preprocessor(output_dir, backend="xnnpack", streaming=False): + """Export mel spectrogram preprocessor. + + Uses XNNPACK for all backends except MLX, which uses MLX partitioner. + """ + from executorch.extension.audio.mel_spectrogram import WhisperAudioProcessor + + # Use MLX partitioner for mlx backend, XNNPACK for everything else + pp_backend = "mlx" if backend == "mlx" else "xnnpack" + print(f" Using {pp_backend.upper()} partitioner for preprocessor...") + + model = WhisperAudioProcessor( + feature_size=128, + max_audio_len=300, + streaming=streaming, + ) + + audio_tensor = torch.randn(93680) + shapes_collection = torch.export.ShapesCollection() + max_n_chunks = int(model.max_audio_len * model.n_samples) + shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)} + + with torch.no_grad(), torch.fx.experimental._config.patch( + backed_size_oblivious=True + ): + ep = export( + model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True + ) + + if pp_backend == "mlx": + from executorch.backends.mlx.partitioner import MLXPartitioner + + partitioner = [MLXPartitioner()] + else: + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) + + partitioner = [XnnpackPartitioner()] + + edge = to_edge_transform_and_lower( + ep, + partitioner=partitioner, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + exec_prog = edge.to_executorch() + + pp_path = os.path.join(output_dir, "preprocessor.pte") + with open(pp_path, "wb") as f: + exec_prog.write_to_file(f) + size_mb = os.path.getsize(pp_path) / (1024 * 1024) + print(f" Saved preprocessor to {pp_path} ({size_mb:.1f} MB)") + + def lower_to_executorch(programs, metadata, backend="xnnpack"): """Lower exported programs to ExecuTorch.""" if backend == "xnnpack": @@ -428,6 +486,11 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"): if backend == "cuda-windows": compile_specs.append(CompileSpec("platform", b"windows")) partitioner[key] = [CudaPartitioner(compile_specs)] + elif backend == "mlx": + from executorch.backends.mlx.partitioner import MLXPartitioner + + print("\nLowering to ExecuTorch with MLX...") + partitioner = {key: [MLXPartitioner()] for key in programs} else: print("\nLowering to ExecuTorch (portable)...") partitioner = [] @@ -468,7 +531,7 @@ def main(): parser.add_argument( "--backend", default="xnnpack", - choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows"], + choices=["portable", "xnnpack", "mlx", "metal", "cuda", "cuda-windows"], help="Backend for acceleration (default: xnnpack)", ) parser.add_argument( @@ -497,7 +560,7 @@ def main(): parser.add_argument( "--qlinear-group-size", type=int, - default=32, + default=None, help="Group size for decoder linear quantization (default: 32).", ) parser.add_argument( @@ -515,7 +578,7 @@ def main(): parser.add_argument( "--qlinear-encoder-group-size", type=int, - default=32, + default=None, help="Group size for encoder linear quantization (default: 32).", ) parser.add_argument( @@ -530,6 +593,12 @@ def main(): choices=["8w"], help="Quantize embedding layers (8-bit weight-only).", ) + parser.add_argument( + "--qembedding-group-size", + type=int, + default=None, + help="Group size for embedding quantization (default: 0 = per-channel).", + ) parser.add_argument( "--streaming", action="store_true", @@ -547,6 +616,11 @@ def main(): choices=["fp32", "bf16"], help="Model dtype (default: fp32).", ) + parser.add_argument( + "--export-preprocessor", + action="store_true", + help="Also export preprocessor.pte (uses XNNPACK, or MLX for --backend mlx).", + ) args = parser.parse_args() backend_for_export = "cuda" if args.backend == "cuda-windows" else args.backend @@ -591,6 +665,7 @@ def main(): "qlinear_group_size": args.qlinear_group_size, "qlinear_packing_format": args.qlinear_packing_format, "qembedding": args.qembedding, + "qembedding_group_size": args.qembedding_group_size, "backend": backend_for_export, } if args.streaming: @@ -616,6 +691,11 @@ def main(): et.write_tensor_data_to_file(args.output_dir) print(f"Saved tensor data to {args.output_dir}/") + # Export preprocessor if requested + if args.export_preprocessor: + print("\nExporting preprocessor...") + export_preprocessor(args.output_dir, args.backend, args.streaming) + print("\nDone!") diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 26778413834..9640e09983d 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from executorch.extension.llm.custom_ops import custom_ops as _custom_ops # noqa: F401 @@ -50,7 +49,7 @@ class VoxtralRealtimeConfig: downsample_factor: int = 4 # Runtime max_seq_len: int = 4096 - backend: str = "xnnpack" # "xnnpack", "metal", "cuda", or "portable" + backend: str = "xnnpack" # "xnnpack", "mlx", "metal", "cuda", or "portable" @staticmethod def from_params_json(path: str) -> "VoxtralRealtimeConfig": @@ -153,12 +152,14 @@ class EncoderAttention(nn.Module): """Multi-head attention with RoPE for the causal whisper encoder. Biases: wq yes, wk no, wv yes, wo yes. + Supports MLX backend for Apple Silicon GPU acceleration. """ - def __init__(self, dim: int, n_heads: int, head_dim: int): + def __init__(self, dim: int, n_heads: int, head_dim: int, backend: str = "xnnpack"): super().__init__() self.n_heads = n_heads self.head_dim = head_dim + self.backend = backend attn_dim = n_heads * head_dim self.wq = nn.Linear(dim, attn_dim, bias=True) self.wk = nn.Linear(dim, attn_dim, bias=False) @@ -177,7 +178,15 @@ def forward( v = self.wv(x).view(B, T, self.n_heads, self.head_dim) q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) q, k, v = (t.transpose(1, 2) for t in (q, k, v)) - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + if self.backend == "mlx": + # Use MLX custom SDPA for Apple Silicon GPU + start_pos = 0 # Offline encoder always starts at 0 + scale = self.head_dim**-0.5 + y = torch.ops.mlx.custom_sdpa( + q, k, v, start_pos=start_pos, is_causal=True, scale=scale + ) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.wo(y.transpose(1, 2).contiguous().view(B, T, -1)) @@ -199,7 +208,10 @@ def __init__(self, config: VoxtralRealtimeConfig): super().__init__() self.attention_norm = RMSNorm(config.enc_dim, config.enc_norm_eps) self.attention = EncoderAttention( - config.enc_dim, config.enc_n_heads, config.enc_head_dim + config.enc_dim, + config.enc_n_heads, + config.enc_head_dim, + backend=config.backend, ) self.ffn_norm = RMSNorm(config.enc_dim, config.enc_norm_eps) self.feed_forward = EncoderSwiGLU(config.enc_dim, config.enc_hidden_dim) @@ -617,10 +629,146 @@ def forward( return y.view(bsz, seqlen, self.dim) +class MLXKVCache(nn.Module): + """Wrapper that adapts MLX BHSD KV cache for model's BSHD convention. + + The model's QKV projections produce [B, S, H, D] tensors, but MLX's + KVCache expects [B, H, S, D]. This wrapper transposes on the way in. + """ + + def __init__( + self, max_seq_len: int, n_kv_heads: int, head_dim: int, dtype: torch.dtype + ): + super().__init__() + from executorch.backends.mlx.llm.cache import KVCache as MLXKVCacheImpl + + self.cache = MLXKVCacheImpl( + max_batch_size=1, + max_context_length=max_seq_len, + n_heads=n_kv_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # Transpose BSHD -> BHSD for MLX cache + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) + return self.cache.update(input_pos, k_val, v_val) + + +class MLXEncoderRingKVCache(nn.Module): + """Wrapper that adapts MLX RingBufferKVCache for the encoder's BSHD convention. + + The encoder's QKV projections produce [B, S, H, D] tensors, but MLX's + RingBufferKVCache expects [B, H, S, D]. This wrapper transposes on the + way in and delegates ring buffer semantics to the MLX implementation. + """ + + def __init__( + self, + window_size: int, + n_heads: int, + head_dim: int, + dtype: torch.dtype, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import RingBufferKVCache + + self.ring_cache = RingBufferKVCache( + max_batch_size=1, + max_context_length=window_size, + n_heads=n_heads, + head_dim=head_dim, + dtype=dtype, + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # Transpose BSHD -> BHSD for MLX ring buffer + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) + return self.ring_cache.update(input_pos, k_val, v_val) + + def create_causal_mask(self, start_pos: int, seq_len: int) -> torch.Tensor: + return self.ring_cache.create_sliding_window_mask(start_pos, seq_len) + + +class MLXSDPA(nn.Module): + """SDPA using MLX custom op for Apple Silicon GPU acceleration. + + Uses torch.ops.mlx.custom_sdpa which handles GQA expansion and causal + masking internally. KV cache is in BHSD layout, queries are in BSHD. + """ + + def __init__(self, n_heads: int, head_dim: int): + super().__init__() + self.dim = n_heads * head_dim + self.scale = head_dim**-0.5 + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + start_pos = input_pos[0].item() + q = q.transpose(1, 2) # BSHD -> BHSD + y = torch.ops.mlx.custom_sdpa( + q, k, v, start_pos=start_pos, is_causal=True, scale=self.scale + ) + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + +class MLXEncoderSDPA(nn.Module): + """SDPA for streaming encoder with MLX ring buffer KV cache. + + Uses F.scaled_dot_product_attention with explicit attn_mask from the + ring buffer. KV cache is in BHSD layout, queries are in BSHD. + """ + + def __init__(self, n_heads: int, head_dim: int): + super().__init__() + self.dim = n_heads * head_dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + input_pos: (seq_len,) position indices (unused, kept for interface). + q: (B, seq_len, n_heads, head_dim) in BSHD layout. + k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXEncoderRingKVCache. + bsz, seqlen: batch size and query length. + mask: (1, 1, seq_len, buf_size) additive attention mask from ring buffer. + """ + q = q.transpose(1, 2) # BSHD -> BHSD + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False) + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + class LMAttention(nn.Module): """GQA with RoPE, KV cache, and SDPA. No biases. - Supports both custom ops (for XNNPACK) and standard PyTorch ops (for Metal/AOTI). + Supports custom ops (for XNNPACK), standard PyTorch ops (for Metal/AOTI), + and MLX backend ops (for Apple Silicon GPU acceleration via MLX delegate). """ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int): @@ -630,6 +778,7 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int): self.head_dim = config.head_dim self.dim = config.dim self.backend = config.backend + self.rope_theta = config.rope_theta self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False) @@ -637,7 +786,13 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int): self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False) # Choose KV cache and SDPA based on backend - if self.backend == "metal": + if self.backend == "mlx": + cache_dtype = self.wq.weight.dtype + self.kv_cache = MLXKVCache( + max_seq_len, self.n_kv_heads, self.head_dim, dtype=cache_dtype + ) + self.sdpa = MLXSDPA(self.n_heads, self.head_dim) + elif self.backend == "metal": self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim) self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim) elif self.backend == "cuda": @@ -660,7 +815,24 @@ def forward( k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) - q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + if self.backend == "mlx": + start_pos = input_pos[0].item() + q = torch.ops.mlx.rope( + q.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.rope_theta, + ).transpose(1, 2) + k = torch.ops.mlx.rope( + k.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.rope_theta, + ).transpose(1, 2) + else: + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) k, v = self.kv_cache.update(input_pos, k, v) @@ -978,6 +1150,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): self.n_heads = config.enc_n_heads self.head_dim = config.enc_head_dim self.bool_mask = config.backend == "cuda" + self.enc_rope_theta = config.enc_rope_theta # Register conv states as buffers (mutable state for streaming) self.register_buffer("conv1_state", torch.zeros(1, config.num_mel_bins, 2)) @@ -986,23 +1159,43 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): # Ring buffer KV caches for unlimited streaming. # Window size = max_enc_len (encoder sliding window from params.json). # Buffer is 2x internally for safe wraparound. - # Choose cache implementation based on backend - cache_class = ( - StandardEncoderRingKVCache - if config.backend in ("metal", "cuda") - else EncoderRingKVCache - ) - self.kv_caches = nn.ModuleList( - [ - cache_class(max_enc_len, config.enc_n_heads, config.enc_head_dim) - for _ in range(config.enc_n_layers) - ] - ) - - # Choose SDPA based on backend - if config.backend in ("metal", "cuda"): + # Choose cache and SDPA implementation based on backend + self.backend = config.backend + if config.backend == "mlx": + # Use the encoder layer weight dtype for cache buffers so they + # match Q/K/V projection outputs (avoids dtype mismatch in SDPA). + cache_dtype = self.layers[0].attention.wq.weight.dtype + self.kv_caches = nn.ModuleList( + [ + MLXEncoderRingKVCache( + max_enc_len, + config.enc_n_heads, + config.enc_head_dim, + dtype=cache_dtype, + ) + for _ in range(config.enc_n_layers) + ] + ) + self.sdpa = MLXEncoderSDPA(config.enc_n_heads, config.enc_head_dim) + elif config.backend in ("metal", "cuda"): + self.kv_caches = nn.ModuleList( + [ + StandardEncoderRingKVCache( + max_enc_len, config.enc_n_heads, config.enc_head_dim + ) + for _ in range(config.enc_n_layers) + ] + ) self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim) else: + self.kv_caches = nn.ModuleList( + [ + EncoderRingKVCache( + max_enc_len, config.enc_n_heads, config.enc_head_dim + ) + for _ in range(config.enc_n_layers) + ] + ) self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim) # RoPE inverse frequencies for on-the-fly computation. @@ -1032,7 +1225,25 @@ def _streaming_encoder_layer( k = attn.wk(h).view(B, T, self.n_heads, self.head_dim) v = attn.wv(h).view(B, T, self.n_heads, self.head_dim) - q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + if self.backend == "mlx": + start_pos = input_pos[0].item() + q = torch.ops.mlx.rope( + q.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.enc_rope_theta, + ).transpose(1, 2) + k = torch.ops.mlx.rope( + k.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.enc_rope_theta, + ).transpose(1, 2) + else: + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + k, v = self.kv_caches[layer_idx].update(input_pos, k, v) y = self.sdpa(input_pos, q, k, v, B, T, mask) @@ -1162,14 +1373,19 @@ def load_model( max_seq_len: Maximum sequence length for KV cache. n_delay_tokens: Transcription delay in tokens (default 6 = 480ms). dtype: Weight dtype (default: float32). - backend: Backend for acceleration ("xnnpack", "metal", "cuda", or "portable"). + backend: Backend for acceleration ("xnnpack", "mlx", "metal", "cuda", or "portable"). """ - _VALID_BACKENDS = ("xnnpack", "metal", "cuda", "portable") + _VALID_BACKENDS = ("xnnpack", "mlx", "metal", "cuda", "portable") + if backend not in _VALID_BACKENDS: raise ValueError( f"Unknown backend '{backend}'. Must be one of {_VALID_BACKENDS}." ) + # Import MLX custom ops for mlx backend + if backend == "mlx": + import executorch.backends.mlx.custom_ops as _mlx_custom_ops # noqa: F401 + from safetensors import safe_open model_dir = Path(model_path) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 8398c4c5306..4da706bb889 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -574,6 +574,15 @@ class VgfConfig: compiler_flags: List[str] = field(default_factory=list) +@dataclass +class MLXConfig: + """ + Configures the MLX backend for Apple Silicon. + """ + + enabled: bool = False + + @dataclass class BackendConfig: """ @@ -590,7 +599,11 @@ class BackendConfig: torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) tosa: TosaConfig = field(default_factory=TosaConfig) ethosu: EthosUConfig = field(default_factory=EthosUConfig) +<<<<<<< HEAD vgf: VgfConfig = field(default_factory=VgfConfig) +======= + mlx: MLXConfig = field(default_factory=MLXConfig) +>>>>>>> d44535c8a5 (up) ################################################################################ @@ -763,6 +776,12 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "mps"): llm_config.backend.mps.enabled = args.mps + # MLX - auto-enable use_kv_cache when MLX is enabled + if hasattr(args, "mlx"): + llm_config.backend.mlx.enabled = args.mlx + if args.mlx: + llm_config.model.use_kv_cache = True + # Openvino if hasattr(args, "openvino"): llm_config.backend.openvino.enabled = args.openvino From 848cbd3cb571ed56e82ce108c67d105cfd126f4b Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 3 Mar 2026 17:24:33 -0800 Subject: [PATCH 17/34] up --- backends/mlx/passes.py | 494 ++++++++++++++++++++++- backends/mlx/test/test_passes.py | 659 +++++++++++++++++++++++++++++++ 2 files changed, 1150 insertions(+), 3 deletions(-) diff --git a/backends/mlx/passes.py b/backends/mlx/passes.py index c7efdf561de..ef4c768a2f8 100644 --- a/backends/mlx/passes.py +++ b/backends/mlx/passes.py @@ -8,13 +8,501 @@ Graph transformation passes for the MLX backend. """ -from typing import List +from dataclasses import dataclass +from typing import List, Optional -from executorch.exir.pass_base import ExportPass +import torch +from executorch.backends.mlx.pattern_utils import ( + extract_lifted_tensor_constant, + match_target, + OpStep, + PatternMatch, + walk_back, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes.cse_pass import CSEPass +from torch.fx import GraphModule, Node def get_default_passes() -> List[ExportPass]: """ Returns a list of passes that are enabled by default for the MLX backend. """ - return [] + return [ + FuseRMSNormPass(), + CanonicalizePermutePass(), + CollapseViewCopyPass(), + CollapsePermutePass(), + CollapseDtypeConversionPass(), + RemoveNoOpsPass(), + CSEPass(), + ] + + +@dataclass +class RMSNormMatch(PatternMatch): + """ + Matched RMSNorm pattern. + + HuggingFace Llama's RMSNorm decomposes into: + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + return weight * hidden_states.to(input_dtype) + + Graph pattern: + _to_copy (to f32) [optional] + pow(x, 2) + mean_dim(pow_out, [-1], keepdim=True) + add(mean_out, eps_tensor) + rsqrt(add_out) + mul(to_copy_out, rsqrt_out) + _to_copy (back to original dtype) [optional] + mul(weight, to_copy_out) + """ + + input_node: Node = None # type: ignore[assignment] + weight_node: Node = None # type: ignore[assignment] + eps: float = 0.0 + + @classmethod + def maybe_create(cls, head: Node, **context) -> Optional["RMSNormMatch"]: + """Match RMSNorm pattern starting from final mul(weight, normalized).""" + # Head must be mul + if not match_target(head, torch.ops.aten.mul.Tensor): + return None + + if len(head.args) < 2: + return None + + # Try both orderings: mul(weight, normalized) or mul(normalized, weight) + for weight_idx, norm_idx in [(0, 1), (1, 0)]: + weight_node = head.args[weight_idx] + norm_node = head.args[norm_idx] + + if not isinstance(norm_node, Node): + continue + + # Match entire chain with single walk_back: + # [_to_copy] -> mul(input, rsqrt) -> rsqrt -> add -> mean -> pow -> [_to_copy] + # The mul follows arg_index=1 to get rsqrt (not input) + result = walk_back( + norm_node, + [ + OpStep( + op=torch.ops.aten._to_copy.default, + optional=True, + kwargs={ + "dtype", + "layout", + "device", + "pin_memory", + "non_blocking", + "memory_format", + }, + ), + OpStep(op=torch.ops.aten.mul.Tensor, nargs=2, arg_index=1), + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.add.Tensor, nargs=2), + OpStep(op=torch.ops.aten.mean.dim, nargs=(2, 3), kwargs={"dtype"}), + OpStep(op=torch.ops.aten.pow.Tensor_Scalar, nargs=2), + OpStep( + op=torch.ops.aten._to_copy.default, + optional=True, + require_single_user=False, # _to_copy output used by both pow and mul + kwargs={ + "dtype", + "layout", + "device", + "pin_memory", + "non_blocking", + "memory_format", + }, + ), + ], + ) + if result is None: + continue + + original_input, entries = result + to_copy_out, mul, rsqrt, add, mean, pow, to_copy_in = entries + + # If input _to_copy matched, verify it has exactly 2 users: pow and mul + if to_copy_in is not None: + users = set(to_copy_in.users.keys()) + expected_users = {pow, mul} + if users != expected_users: + continue + + # Validate pow exponent is 2 + if pow.args[1] != 2: + continue + + # Extract epsilon from add node (it's a lifted tensor constant) + eps_value = None + for arg in add.args: + eps_value = extract_lifted_tensor_constant(arg) + if eps_value is not None: + break + + if eps_value is None: + continue + + # Build body from non-None entries + body = [n for n in entries if n is not None] + + return cls( + head=head, + body=body, + input_node=original_input, + weight_node=weight_node, + eps=eps_value, + ) + + return None + + +class FuseRMSNormPass(ExportPass): + """ + Fuses decomposed RMSNorm operations into aten.rms_norm. + + This reduces ~7 ops to 1 fused op per RMSNorm layer. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + + for node in list(graph.nodes): + match = RMSNormMatch.maybe_create(node) + if match is None: + continue + + # Get input shape for normalized_shape + input_meta = match.input_node.meta.get("val") + if input_meta is None: + continue + + # Create fused rms_norm node + with graph.inserting_before(node): + normalized_shape = [input_meta.shape[-1]] + rms_norm_node = graph.call_function( + torch.ops.aten.rms_norm.default, + args=( + match.input_node, + normalized_shape, + match.weight_node, + match.eps, + ), + ) + rms_norm_node.meta = node.meta.copy() + + node.replace_all_uses_with(rms_norm_node) + match.remove_body_nodes(graph) + graph.erase_node(node) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +class CanonicalizePermutePass(ExportPass): + """ + Converts transpose_copy to permute_copy in the edge dialect graph. + + transpose_copy(x, dim0, dim1) is equivalent to permute_copy(x, perm) + where perm is the identity permutation with dim0 and dim1 swapped. + This lets the backend handle a single permute op instead of both + transpose and permute. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + + for node in list(graph.nodes): + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.transpose_copy.int + ): + continue + + input_node = node.args[0] + input_val = ( + input_node.meta.get("val") if isinstance(input_node, Node) else None + ) + if input_val is None: + continue + + ndim = input_val.dim() + dim0 = node.args[1] + dim1 = node.args[2] + + # Normalize negative dims + if dim0 < 0: + dim0 += ndim + if dim1 < 0: + dim1 += ndim + + # Build permutation: identity with dim0 and dim1 swapped + perm = list(range(ndim)) + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + + node.target = exir_ops.edge.aten.permute_copy.default + node.args = (input_node, perm) + modified = True + + if modified: + graph.lint() + + return PassResult(graph_module, modified) + + +class CollapseViewCopyPass(ExportPass): + """ + Collapses consecutive view_copy nodes into a single view_copy. + + view_copy(view_copy(x, shape1), shape2) → view_copy(x, shape2) + + Only the final shape matters, so intermediate view_copys can be removed. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + view_copy_target = exir_ops.edge.aten.view_copy.default + + for node in list(graph.nodes): + if node.op != "call_function" or node.target != view_copy_target: + continue + + parent = node.args[0] + if ( + isinstance(parent, Node) + and parent.op == "call_function" + and parent.target == view_copy_target + and len(parent.users) == 1 + ): + original_input = parent.args[0] + target_shape = node.args[1] + + # Check if final shape matches original input shape (identity). + # Compare meta shapes (not args) so SymInt dims are handled. + # Use try/except because shapes may contain unbacked SymInts + # (e.g. from .item() calls) that can't be guarded on. + original_val = ( + original_input.meta.get("val") + if isinstance(original_input, Node) + else None + ) + output_val = node.meta.get("val") + is_identity = False + if original_val is not None and output_val is not None: + try: + is_identity = original_val.shape == output_val.shape + except Exception: + is_identity = False + if is_identity: + # Identity — remove both view_copys + node.replace_all_uses_with(original_input) + graph.erase_node(node) + graph.erase_node(parent) + else: + # Collapse: view_copy(view_copy(x, s1), s2) → view_copy(x, s2) + node.args = (original_input, target_shape) + graph.erase_node(parent) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +class CollapsePermutePass(ExportPass): + """ + Collapses consecutive permute_copy nodes into a single permute_copy. + + permute(permute(x, p1), p2) → permute(x, composed) + where composed[i] = p1[p2[i]]. + + If the composed permutation is the identity, the permute is removed entirely. + Must run after CanonicalizePermutePass so all transpose_copy nodes are permute_copy. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + permute_target = exir_ops.edge.aten.permute_copy.default + + for node in list(graph.nodes): + if node.op != "call_function" or node.target != permute_target: + continue + + parent = node.args[0] + if ( + isinstance(parent, Node) + and parent.op == "call_function" + and parent.target == permute_target + and len(parent.users) == 1 + ): + p1 = parent.args[1] + p2 = node.args[1] + composed = [p1[p2[i]] for i in range(len(p2))] + + if composed == list(range(len(composed))): + # Identity permutation — remove both permutes + node.replace_all_uses_with(parent.args[0]) + graph.erase_node(node) + graph.erase_node(parent) + else: + node.args = (parent.args[0], composed) + graph.erase_node(parent) + + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +def _is_pure_dtype_cast(kwargs: dict) -> bool: + """Check that _to_copy kwargs only specify dtype (no device/layout/memory_format).""" + for k, v in kwargs.items(): + if k == "dtype": + continue + if v is not None: + return False + return "dtype" in kwargs + + +class CollapseDtypeConversionPass(ExportPass): + """ + Collapses consecutive _to_copy (dtype conversion) nodes into a single one. + + _to_copy(dtype=bf16)(_to_copy(dtype=f32)(x)) → _to_copy(dtype=bf16)(x) + + Only the final dtype matters. Only collapses when both nodes are pure dtype + conversions (no device/layout/memory_format changes). + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + to_copy_target = exir_ops.edge.aten._to_copy.default + + for node in list(graph.nodes): + if node.op != "call_function" or node.target != to_copy_target: + continue + + parent = node.args[0] + if not ( + isinstance(parent, Node) + and parent.op == "call_function" + and parent.target == to_copy_target + and len(parent.users) == 1 + ): + continue + + # Only collapse pure dtype conversions + node_kw = node.kwargs + parent_kw = parent.kwargs + if not _is_pure_dtype_cast(node_kw) or not _is_pure_dtype_cast(parent_kw): + continue + + # Rewrite: to_copy(to_copy(x, dtype=d1), dtype=d2) → to_copy(x, dtype=d2) + node.args = (parent.args[0],) + graph.erase_node(parent) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +class RemoveNoOpsPass(ExportPass): + """ + Removes ops that are no-ops in the MLX backend. + + - alias_copy(x): always a no-op + - clone(x): only when memory_format is contiguous or absent + - _to_copy(x, dtype=d): when x already has dtype d + - view_copy(x, shape): when shape matches input shape + - permute_copy(x, [0,1,...,n-1]): identity permutation + - slice_copy(x, ...): when output shape matches input shape (full slice) + """ + + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 + graph = graph_module.graph + modified = False + + for node in list(graph.nodes): + if node.op != "call_function": + continue + + input_node = ( + node.args[0] if node.args and isinstance(node.args[0], Node) else None + ) + if input_node is None: + continue + + remove = False + + if node.target == exir_ops.edge.aten.alias_copy.default: + remove = True + + elif node.target == exir_ops.edge.aten.clone.default: + mem_fmt = node.kwargs.get("memory_format") + if mem_fmt is None or mem_fmt == torch.contiguous_format: + remove = True + + elif node.target == exir_ops.edge.aten._to_copy.default: + if _is_pure_dtype_cast(node.kwargs): + input_val = input_node.meta.get("val") + target_dtype = node.kwargs.get("dtype") + if input_val is not None and input_val.dtype == target_dtype: + remove = True + + elif node.target == exir_ops.edge.aten.view_copy.default: + input_val = input_node.meta.get("val") + output_val = node.meta.get("val") + if input_val is not None and output_val is not None: + try: + if input_val.shape == output_val.shape: + remove = True + except Exception: + pass + + elif node.target == exir_ops.edge.aten.permute_copy.default: + perm = node.args[1] + if list(perm) == list(range(len(perm))): + remove = True + + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: + input_val = input_node.meta.get("val") + output_val = node.meta.get("val") + if input_val is not None and output_val is not None: + try: + if input_val.shape == output_val.shape: + remove = True + except Exception: + pass + + if remove: + node.replace_all_uses_with(input_node) + graph.erase_node(node) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) diff --git a/backends/mlx/test/test_passes.py b/backends/mlx/test/test_passes.py index a9fdb3b996b..97172c1411a 100644 --- a/backends/mlx/test/test_passes.py +++ b/backends/mlx/test/test_passes.py @@ -4,3 +4,662 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +""" +Tests for graph transformation passes in the MLX backend. +""" + +import unittest + +import executorch.exir as exir +import torch +import torch.nn as nn +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.backends.mlx.passes import ( + _is_pure_dtype_cast, + CanonicalizePermutePass, + CollapseDtypeConversionPass, + CollapsePermutePass, + CollapseViewCopyPass, + FuseRMSNormPass, + RemoveNoOpsPass, +) +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.partitioner import PartitionResult +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export + + +class _PreserveOpsPartitioner(MLXPartitioner): + """MLXPartitioner that preserves ops (via ops_to_not_decompose) but skips delegation. + + This gives tests a real edge-dialect graph with MLX-relevant ops like + ``item`` preserved, without delegating nodes to the MLX backend. + """ + + def partition(self, edge_program): + return PartitionResult( + tagged_exported_program=edge_program, + partition_tags={}, + ) + + +def _to_edge_gm(module, example_inputs, dynamic_shapes=None): + """Export module and lower to edge dialect, returning the GraphModule.""" + ep = export(module, example_inputs, dynamic_shapes=dynamic_shapes, strict=False) + edge = exir.to_edge_transform_and_lower( + ep, + partitioner=[_PreserveOpsPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + return edge.exported_program().graph_module + + +def _count_ops(gm, target): + return sum( + 1 for n in gm.graph.nodes if n.op == "call_function" and n.target == target + ) + + +def _find_nodes(gm, target): + return [n for n in gm.graph.nodes if n.op == "call_function" and n.target == target] + + +def _has_op(gm, target): + return _count_ops(gm, target) > 0 + + +class TestIsPureDtypeCast(unittest.TestCase): + + def test_pure_dtype_only(self): + self.assertTrue(_is_pure_dtype_cast({"dtype": torch.float16})) + + def test_dtype_with_none_kwargs(self): + self.assertTrue( + _is_pure_dtype_cast( + { + "dtype": torch.float16, + "device": None, + "layout": None, + } + ) + ) + + def test_dtype_with_non_none_memory_format(self): + self.assertFalse( + _is_pure_dtype_cast( + { + "dtype": torch.float16, + "memory_format": torch.contiguous_format, + } + ) + ) + + def test_dtype_with_non_none_device(self): + self.assertFalse( + _is_pure_dtype_cast( + { + "dtype": torch.float16, + "device": torch.device("cpu"), + } + ) + ) + + def test_no_dtype_key(self): + self.assertFalse(_is_pure_dtype_cast({"device": None})) + + def test_empty_kwargs(self): + self.assertFalse(_is_pure_dtype_cast({})) + + +class TestCanonicalizePermutePass(unittest.TestCase): + + def test_transpose_becomes_permute(self): + class M(nn.Module): + def forward(self, x): + return x.transpose(0, 1) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + transpose_target = exir_ops.edge.aten.transpose_copy.int + + if not _has_op(gm, transpose_target): + self.skipTest("Edge lowering did not produce transpose_copy") + + result = CanonicalizePermutePass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, transpose_target)) + self.assertTrue( + _has_op(result.graph_module, exir_ops.edge.aten.permute_copy.default) + ) + + nodes = _find_nodes( + result.graph_module, exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(nodes[0].args[1], [1, 0]) + + def test_negative_dims_normalized(self): + class M(nn.Module): + def forward(self, x): + return x.transpose(-2, -1) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + result = CanonicalizePermutePass()(gm) + + nodes = _find_nodes( + result.graph_module, exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(len(nodes), 1) + # transpose(-2, -1) on 3D → [0, 2, 1] + self.assertEqual(nodes[0].args[1], [0, 2, 1]) + + def test_noop_when_no_transpose(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + result = CanonicalizePermutePass()(gm) + self.assertFalse(result.modified) + + +class TestCollapseViewCopyPass(unittest.TestCase): + + def test_consecutive_view_copys_collapsed(self): + """view_copy(view_copy(x, s1), s2) → view_copy(x, s2).""" + + class M(nn.Module): + def forward(self, x): + return x.view(2, 6).view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(12),)) + + target = exir_ops.edge.aten.view_copy.default + before = _count_ops(gm, target) + self.assertGreaterEqual(before, 2) + + result = CollapseViewCopyPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + def test_identity_view_copy_chain_removed(self): + """view_copy(view_copy(x, s1), original_shape) → removes both.""" + + class M(nn.Module): + def forward(self, x): + return x.view(12).view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + + result = CollapseViewCopyPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual( + _count_ops(result.graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + def test_single_view_copy_unchanged(self): + class M(nn.Module): + def forward(self, x): + return x.view(12) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + + result = CollapseViewCopyPass()(gm) + self.assertFalse(result.modified) + + def test_collapse_with_dynamic_batch(self): + """Consecutive view_copys with a dynamic leading dim should collapse.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 3, 4).view(-1, 2, 6) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 12),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + before = _count_ops(gm, target) + self.assertGreaterEqual(before, 2) + + result = CollapseViewCopyPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + def test_identity_chain_with_dynamic_batch(self): + """view_copy(view_copy(x, s1), original_shape) with dynamic dim → both removed.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 3, 4).view(-1, 12) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 12),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + before = _count_ops(gm, target) + self.assertGreaterEqual(before, 2) + + result = CollapseViewCopyPass()(gm) + self.assertTrue(result.modified) + # Meta-shape comparison resolves SymInt identity → both view_copys removed + self.assertEqual(_count_ops(result.graph_module, target), 0) + + +class TestCollapsePermutePass(unittest.TestCase): + + def test_inverse_permutations_removed(self): + """permute(permute(x, p), inverse(p)) → identity → removed.""" + + class M(nn.Module): + def forward(self, x): + return x.permute(2, 0, 1).permute(1, 2, 0) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + self.assertEqual(_count_ops(gm, target), 2) + + result = CollapsePermutePass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 0) + + def test_non_identity_composed(self): + """Non-identity composition yields a single permute.""" + + class M(nn.Module): + def forward(self, x): + return x.permute(1, 0, 2).permute(0, 2, 1) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + self.assertEqual(_count_ops(gm, target), 2) + + result = CollapsePermutePass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + # composed[i] = p1[p2[i]] where p1=[1,0,2], p2=[0,2,1] + # → [1, 2, 0] + nodes = _find_nodes(result.graph_module, target) + self.assertEqual(nodes[0].args[1], [1, 2, 0]) + + def test_single_permute_unchanged(self): + class M(nn.Module): + def forward(self, x): + return x.permute(1, 0, 2) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + result = CollapsePermutePass()(gm) + self.assertFalse(result.modified) + + def test_multi_user_parent_not_collapsed(self): + """Don't collapse when the parent permute has multiple users.""" + + class M(nn.Module): + def forward(self, x): + y = x.permute(1, 0, 2) + a = y.permute(1, 0, 2) + b = y.sum() + return a + b + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + result = CollapsePermutePass()(gm) + # Parent permute has 2 users → should not be collapsed + self.assertFalse(result.modified) + + +class TestCollapseDtypeConversionPass(unittest.TestCase): + + def test_consecutive_casts_collapsed(self): + """_to_copy(f32→bf16→f16) → _to_copy(f32→f16).""" + + class M(nn.Module): + def forward(self, x): + return x.to(torch.bfloat16).to(torch.float16) + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten._to_copy.default + before = _count_ops(gm, target) + + if before < 2: + self.skipTest("Export optimized away double cast") + + result = CollapseDtypeConversionPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + # Remaining cast should be to float16 + nodes = _find_nodes(result.graph_module, target) + self.assertEqual(nodes[0].kwargs.get("dtype"), torch.float16) + + def test_single_cast_unchanged(self): + class M(nn.Module): + def forward(self, x): + return x.to(torch.float16) + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + result = CollapseDtypeConversionPass()(gm) + self.assertFalse(result.modified) + + +class TestRemoveNoOpsPass(unittest.TestCase): + + def test_remove_clone(self): + class M(nn.Module): + def forward(self, x): + return x.clone() + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten.clone.default + + if not _has_op(gm, target): + self.skipTest("Export did not produce a clone op") + + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_remove_identity_view_copy(self): + """view_copy(x, same_shape) → removed.""" + + class M(nn.Module): + def forward(self, x): + return x.view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + target = exir_ops.edge.aten.view_copy.default + + if not _has_op(gm, target): + self.skipTest("Export optimized away identity view_copy") + + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_remove_identity_permute(self): + """permute_copy(x, [0, 1, ..., n-1]) → removed.""" + + class M(nn.Module): + def forward(self, x): + return x.permute(0, 1, 2) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + + if not _has_op(gm, target): + self.skipTest("Export optimized away identity permute") + + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_identity_dtype_cast_removed_after_collapse(self): + """Chain: f32→f16→f32 collapses to f32→f32, then RemoveNoOps removes it.""" + + class M(nn.Module): + def forward(self, x): + return x.to(torch.float16).to(torch.float32) + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten._to_copy.default + + if _count_ops(gm, target) < 2: + self.skipTest("Export optimized away double cast") + + CollapseDtypeConversionPass()(gm) + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 0) + + def test_to_copy_with_memory_format_not_removed(self): + """_is_pure_dtype_cast rejects kwargs with non-None memory_format.""" + # Can't easily produce this through export, so test the guard directly + self.assertFalse( + _is_pure_dtype_cast( + { + "dtype": torch.float32, + "memory_format": torch.contiguous_format, + } + ) + ) + + def test_non_identity_view_copy_kept(self): + """view_copy to a different shape should NOT be removed.""" + + class M(nn.Module): + def forward(self, x): + return x.view(6, 2) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + + result = RemoveNoOpsPass()(gm) + self.assertFalse(result.modified) + + def test_noop_when_nothing_to_remove(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + result = RemoveNoOpsPass()(gm) + self.assertFalse(result.modified) + + def test_identity_view_copy_with_dynamic_batch(self): + """view_copy(x, same_shape) with a dynamic dim → removed via meta-shape comparison.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 4) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 4),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + if not _has_op(gm, target): + self.skipTest("Export optimized away identity view_copy") + + result = RemoveNoOpsPass()(gm) + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_non_identity_view_copy_with_dynamic_batch(self): + """view_copy(x, different_shape) with dynamic dim should be kept.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 2, 2) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 4),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + if not _has_op(gm, target): + self.skipTest("Export did not produce view_copy") + + result = RemoveNoOpsPass()(gm) + # Shape changes, so view_copy should be kept + self.assertFalse(result.modified) + + def test_full_slice_with_dynamic_batch(self): + """slice_copy shape comparison with dynamic dim should not crash.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + a = x[:, :4] + b = x[:, 4:] + return torch.cat([b, a], dim=1) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 8),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.slice_copy.Tensor + self.assertTrue(_has_op(gm, target), "Expected slice_copy in the graph") + + # Must not crash with symbolic shapes (input_val.shape has SymInt) + RemoveNoOpsPass()(gm) + + +class TestFuseRMSNormPass(unittest.TestCase): + + def test_rms_norm_fused(self): + """Decomposed RMSNorm should be fused into a single aten.rms_norm op.""" + + class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return self.weight * x + + model = RMSNorm(16) + model.eval() + gm = _to_edge_gm(model, (torch.randn(1, 4, 16),)) + + result = FuseRMSNormPass()(gm) + + self.assertTrue( + result.modified, "FuseRMSNormPass should fuse the RMSNorm pattern" + ) + + has_rms_norm = any( + n.op == "call_function" and "rms_norm" in str(n.target) + for n in result.graph_module.graph.nodes + ) + self.assertTrue(has_rms_norm) + + # Intermediate ops (pow, rsqrt, mean) should be removed + has_rsqrt = any( + n.op == "call_function" and "rsqrt" in str(n.target) + for n in result.graph_module.graph.nodes + ) + self.assertFalse(has_rsqrt) + + def test_noop_on_non_rms_norm(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + ep = export(M(), (torch.randn(4, 4),), strict=False) + result = FuseRMSNormPass()(ep.graph_module) + self.assertFalse(result.modified) + + +class TestPassComposition(unittest.TestCase): + + def test_collapse_view_copy(self): + class M(nn.Module): + def forward(self, x): + return x.view(2, 6).view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(12),)) + target = exir_ops.edge.aten.view_copy.default + + self.assertGreaterEqual(_count_ops(gm, target), 2) + + CollapseViewCopyPass()(gm) + self.assertEqual(_count_ops(gm, target), 1) + + def test_canonicalize_then_collapse_permute_identity(self): + """Double transpose = identity → both removed.""" + + class M(nn.Module): + def forward(self, x): + return x.transpose(0, 1).transpose(0, 1) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + + CanonicalizePermutePass()(gm) + self.assertEqual(_count_ops(gm, target), 2) + + CollapsePermutePass()(gm) + self.assertEqual(_count_ops(gm, target), 0) + + def test_full_pipeline_does_not_crash(self): + """Running the full default pass list should not crash.""" + from executorch.backends.mlx.passes import get_default_passes + + class M(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(16, 16) + + def forward(self, x): + return self.linear(x).to(torch.float16) + + gm = _to_edge_gm(M(), (torch.randn(1, 16),)) + + for p in get_default_passes(): + p(gm) + + gm.graph.lint() + + def test_correctness_after_all_passes(self): + """Output values should be preserved after running all passes.""" + from executorch.backends.mlx.passes import get_default_passes + + class M(nn.Module): + def forward(self, x): + y = x.reshape(12).reshape(3, 4) + return y.transpose(0, 1) + + module = M() + module.eval() + x = torch.randn(3, 4) + expected = module(x) + + gm = _to_edge_gm(module, (x,)) + + for p in get_default_passes(): + p(gm) + + actual = gm(x) + # Edge graph modules may return a tuple + if isinstance(actual, tuple): + actual = actual[0] + torch.testing.assert_close(actual, expected) + + +if __name__ == "__main__": + unittest.main() From 1ae519b990a64c0d282ad4f0bbc0798e8f4b3e2d Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:24:29 -0800 Subject: [PATCH 18/34] up --- .github/workflows/mlx.yml | 25 +- backends/mlx/examples/llm/__init__.py | 6 + backends/mlx/examples/llm/export_llm_hf.py | 76 ++--- backends/mlx/examples/voxtral/README.md | 69 ++++ .../mlx/examples/voxtral/export_voxtral_hf.py | 113 ++----- backends/mlx/examples/whisper/README.md | 68 ++-- .../mlx/examples/whisper/export_whisper.py | 124 +++---- backends/mlx/llm/quantization.py | 126 +------ examples/models/parakeet/README.md | 36 +- .../models/parakeet/export_parakeet_tdt.py | 24 +- examples/models/voxtral_realtime/README.md | 75 ++++- .../voxtral_realtime/export_voxtral_rt.py | 72 +--- extension/audio/mel_spectrogram.py | 47 ++- extension/llm/export/quantize.py | 308 ++++++++++++------ 14 files changed, 596 insertions(+), 573 deletions(-) create mode 100644 backends/mlx/examples/llm/__init__.py create mode 100644 backends/mlx/examples/voxtral/README.md diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index e35242b9191..fb68d1739af 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -10,8 +10,8 @@ on: - .github/workflows/mlx.yml - backends/mlx/** - extension/llm/export/** + - extension/audio/** - examples/models/parakeet/** - - examples/models/voxtral/** - examples/models/voxtral_realtime/** workflow_dispatch: @@ -193,7 +193,7 @@ jobs: ${CONDA_RUN} python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ --output-dir /tmp/voxtral_mlx \ --dtype bf16 \ - --quantize-linear int4 + --qlinear 4w echo "::endgroup::" echo "::group::Build Voxtral MLX runner" @@ -251,6 +251,14 @@ jobs: echo "Model path: ${MODEL_PATH}" echo "::endgroup::" + echo "::group::Export preprocessor" + ${CONDA_RUN} python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --streaming \ + --backend mlx \ + --output_file /tmp/voxtral_rt_mlx/preprocessor.pte + echo "::endgroup::" + echo "::group::Export Voxtral Realtime (streaming)" ${CONDA_RUN} python -m executorch.examples.models.voxtral_realtime.export_voxtral_rt \ --model-path "${MODEL_PATH}" \ @@ -260,8 +268,7 @@ jobs: --qlinear-encoder 4w \ --qlinear 4w \ --qembedding 8w \ - --qembedding-group-size 128 \ - --export-preprocessor + --qembedding-group-size 128 echo "::endgroup::" echo "::group::Build Voxtral Realtime MLX runner" @@ -317,7 +324,7 @@ jobs: --model-id "openai/whisper-tiny" \ --output-dir /tmp/whisper_mlx \ --dtype bf16 \ - --quantize-linear int4 + --qlinear 4w echo "::endgroup::" echo "::group::Run Whisper inference" @@ -409,10 +416,11 @@ jobs: - id: "unsloth/gemma-3-1b-it" name: "gemma3-1b" use-custom: [false, true] + qconfig: ["4w", "nvfp4"] uses: pytorch/test-infra/.github/workflows/macos_job.yml@main secrets: inherit with: - job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }} + job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }}-${{ matrix.qconfig }} runner: macos-14-xlarge python-version: "3.12" submodules: recursive @@ -425,6 +433,7 @@ jobs: MODEL_ID="${{ matrix.model.id }}" MODEL_NAME="${{ matrix.model.name }}" USE_CUSTOM="${{ matrix.use-custom }}" + QCONFIG="${{ matrix.qconfig }}" CUSTOM_ARGS="" if [ "${USE_CUSTOM}" = "true" ]; then @@ -449,8 +458,8 @@ jobs: ${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "${MODEL_ID}" \ --output /tmp/${MODEL_NAME}.pte \ - --quantize-linear int4 \ - --quantize-embeddings int4 \ + --qlinear ${QCONFIG} \ + --qembedding ${QCONFIG} \ ${CUSTOM_ARGS} echo "::endgroup::" diff --git a/backends/mlx/examples/llm/__init__.py b/backends/mlx/examples/llm/__init__.py new file mode 100644 index 00000000000..f557ef26c5b --- /dev/null +++ b/backends/mlx/examples/llm/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index f00880ac9cb..39f13e434be 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -53,11 +53,11 @@ def _export_with_optimum( output_path: str, max_seq_len: int, dtype: str, - quantize_linear: Optional[str], - quantize_embeddings: Optional[str], + qlinear: Optional[str], + qembedding: Optional[str], no_tie_word_embeddings: bool = False, - linear_group_size: Optional[int] = None, - embeddings_group_size: Optional[int] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, ) -> None: import executorch.exir as exir from executorch.backends.mlx import MLXPartitioner @@ -77,18 +77,18 @@ def _export_with_optimum( max_seq_len=max_seq_len, ) - from executorch.backends.mlx.llm.quantization import apply_quantization + from executorch.backends.mlx.llm.quantization import quantize_model_ - apply_quantization( + quantize_model_( exportable.model, - quantize_linear, - quantize_embeddings, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, tie_word_embeddings=getattr( exportable.model.config, "tie_word_embeddings", False ) and not no_tie_word_embeddings, - linear_group_size=linear_group_size, - embeddings_group_size=embeddings_group_size, ) logger.info("Exporting model with torch.export...") @@ -127,13 +127,13 @@ def _export_with_custom_components( output_path: str, max_seq_len: int, dtype: str, - quantize_linear: Optional[str], - quantize_embeddings: Optional[str], + qlinear: Optional[str], + qembedding: Optional[str], use_custom_sdpa: bool, use_custom_kv_cache: bool, no_tie_word_embeddings: bool = False, - linear_group_size: Optional[int] = None, - embeddings_group_size: Optional[int] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, ) -> None: """ Export using direct HF model with custom MLX components. @@ -266,16 +266,16 @@ def _export_with_custom_components( ) logger.info(" HFStaticCache installed successfully") - from executorch.backends.mlx.llm.quantization import apply_quantization + from executorch.backends.mlx.llm.quantization import quantize_model_ - apply_quantization( + quantize_model_( exportable.model, - quantize_linear, - quantize_embeddings, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False) and not no_tie_word_embeddings, - linear_group_size=linear_group_size, - embeddings_group_size=embeddings_group_size, ) logger.info("Exporting model with torch.export...") @@ -344,13 +344,13 @@ def export_llama_hf( output_path: str, max_seq_len: int = 1024, dtype: str = "bf16", - quantize_linear: Optional[str] = None, - quantize_embeddings: Optional[str] = None, + qlinear: Optional[str] = None, + qembedding: Optional[str] = None, use_custom_sdpa: bool = False, use_custom_kv_cache: bool = False, no_tie_word_embeddings: bool = False, - linear_group_size: Optional[int] = None, - embeddings_group_size: Optional[int] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, ) -> None: """ Export a HuggingFace Llama model to ExecuTorch with MLX backend. @@ -360,8 +360,8 @@ def export_llama_hf( output_path: Path to save the .pte file max_seq_len: Maximum sequence length for KV cache dtype: Model dtype ("fp32", "fp16", "bf16") - quantize_linear: Quantization for linear layers ("int4", "int8", or None) - quantize_embeddings: Quantization for embeddings ("int4", "int8", or None) + qlinear: Quantization for linear layers ("4w", "8w", "nvfp4", or None) + qembedding: Quantization for embeddings ("4w", "8w", "nvfp4", or None) use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa) use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update) """ @@ -375,13 +375,13 @@ def export_llama_hf( output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, - quantize_linear=quantize_linear, - quantize_embeddings=quantize_embeddings, + qlinear=qlinear, + qembedding=qembedding, use_custom_sdpa=use_custom_sdpa, use_custom_kv_cache=use_custom_kv_cache, no_tie_word_embeddings=no_tie_word_embeddings, - linear_group_size=linear_group_size, - embeddings_group_size=embeddings_group_size, + qlinear_group_size=qlinear_group_size, + qembedding_group_size=qembedding_group_size, ) else: logger.info("Using optimum-executorch pipeline (no custom components)") @@ -390,11 +390,11 @@ def export_llama_hf( output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, - quantize_linear=quantize_linear, - quantize_embeddings=quantize_embeddings, + qlinear=qlinear, + qembedding=qembedding, no_tie_word_embeddings=no_tie_word_embeddings, - linear_group_size=linear_group_size, - embeddings_group_size=embeddings_group_size, + qlinear_group_size=qlinear_group_size, + qembedding_group_size=qembedding_group_size, ) @@ -450,13 +450,13 @@ def main(): output_path=args.output, max_seq_len=args.max_seq_len, dtype=args.dtype, - quantize_linear=args.quantize_linear, - quantize_embeddings=args.quantize_embeddings, + qlinear=args.qlinear, + qembedding=args.qembedding, use_custom_sdpa=args.use_custom_sdpa, use_custom_kv_cache=args.use_custom_kv_cache, no_tie_word_embeddings=args.no_tie_word_embeddings, - linear_group_size=args.linear_group_size, - embeddings_group_size=args.embeddings_group_size, + qlinear_group_size=args.qlinear_group_size, + qembedding_group_size=args.qembedding_group_size, ) diff --git a/backends/mlx/examples/voxtral/README.md b/backends/mlx/examples/voxtral/README.md new file mode 100644 index 00000000000..16d2384ed42 --- /dev/null +++ b/backends/mlx/examples/voxtral/README.md @@ -0,0 +1,69 @@ +# Voxtral MLX Export + +Export [mistralai/Voxtral-Mini-3B-2507](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507) +multimodal audio-language model to ExecuTorch with the MLX backend. + +Uses [optimum-executorch](https://github.com/huggingface/optimum-executorch) for +the export pipeline. + +## Prerequisites + +```bash +pip install transformers torch optimum-executorch mistral-common librosa +``` + +## Export + +Export with int4 weight quantization (recommended): + +```bash +python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --output-dir voxtral_mlx \ + --dtype bf16 \ + --qlinear 4w +``` + +This produces: +- `model.pte` — the main model (audio_encoder, token_embedding, text_decoder) +- `preprocessor.pte` — mel spectrogram preprocessor for raw audio + +### Export Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--model-id` | `mistralai/Voxtral-Mini-3B-2507` | HuggingFace model ID | +| `--output-dir` | `voxtral_mlx` | Output directory | +| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--max-seq-len` | `1024` | Maximum sequence length for KV cache | +| `--max-audio-len` | `300` | Maximum audio length in seconds | +| `--qlinear` | `4w` | Linear layer quantization (`4w`, `8w`, `nvfp4`, or None) | +| `--qlinear-group-size` | auto | Group size for linear quantization | + +### Quantization + +The `4w` config uses int4 weight-only quantization with the HQQ algorithm for +optimal scale selection. This typically reduces model size by ~4x with minimal +quality loss. + +## Run + +Requires the C++ voxtral runner. Build with: + +```bash +make voxtral-mlx +``` + +Run inference: + +```bash +./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path voxtral_mlx/model.pte \ + --processor_path voxtral_mlx/preprocessor.pte \ + --tokenizer_path /path/to/tekken.json \ + --audio_path /path/to/audio.wav \ + --prompt "What is happening in this audio?" \ + --temperature 0 +``` + +The `tekken.json` tokenizer is included in the model weights directory +downloaded from HuggingFace. diff --git a/backends/mlx/examples/voxtral/export_voxtral_hf.py b/backends/mlx/examples/voxtral/export_voxtral_hf.py index d2ae68f0d30..b9ed2bccf1c 100644 --- a/backends/mlx/examples/voxtral/export_voxtral_hf.py +++ b/backends/mlx/examples/voxtral/export_voxtral_hf.py @@ -29,88 +29,20 @@ import os from typing import Optional -import torch - FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -def export_preprocessor( - output_path: str, - feature_size: int = 128, - max_audio_len: int = 300, -) -> None: - """ - Export the Voxtral audio preprocessor (mel spectrogram) to MLX. - - Args: - output_path: Path to save the preprocessor .pte file - feature_size: Mel spectrogram feature dimension (128 for Voxtral) - max_audio_len: Maximum audio length in seconds - """ - import executorch.exir as exir - from executorch.backends.mlx import MLXPartitioner - from executorch.backends.mlx.passes import get_default_passes - from executorch.exir import EdgeCompileConfig - from executorch.exir.capture._config import ExecutorchBackendConfig - from executorch.exir.passes import MemoryPlanningPass - from executorch.extension.audio.mel_spectrogram import WhisperAudioProcessor - from torch.export import Dim - - logger.info("Exporting audio preprocessor with MLX backend...") - - model = WhisperAudioProcessor( - feature_size=feature_size, - max_audio_len=max_audio_len, - stack_output=True, - ) - - audio_tensor = torch.randn(93680) - shapes_collection = torch.export.ShapesCollection() - max_n_chunks = int(model.max_audio_len * model.n_samples) - shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)} - - with torch.no_grad(), torch.fx.experimental._config.patch( - backed_size_oblivious=True - ): - ep = torch.export.export( - model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True - ) - - edge_program = exir.to_edge_transform_and_lower( - ep, - transform_passes=get_default_passes(), - partitioner=[MLXPartitioner()], - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - - executorch_program = edge_program.to_executorch( - config=ExecutorchBackendConfig( - extract_delegate_segments=True, - memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), - ) - ) - - os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) - with open(output_path, "wb") as f: - f.write(executorch_program.buffer) - - logger.info(f"Saved preprocessor to: {output_path}") - logger.info( - f"Preprocessor size: {len(executorch_program.buffer) / 1024 / 1024:.2f} MB" - ) - - def export_voxtral_hf( model_id: str, output_dir: str, max_seq_len: int = 1024, dtype: str = "bf16", - quantize_linear: Optional[str] = None, - quantize_embeddings: Optional[str] = None, - linear_group_size: Optional[int] = None, - embeddings_group_size: Optional[int] = None, + qlinear: Optional[str] = None, + qembedding: Optional[str] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, max_audio_len: int = 300, ) -> None: """ @@ -124,10 +56,10 @@ def export_voxtral_hf( output_dir: Directory to save the .pte files max_seq_len: Maximum sequence length for KV cache dtype: Model dtype ("fp32", "fp16", "bf16") - quantize_linear: Quantization for linear layers ("int4", "int8", or None) - quantize_embeddings: Quantization for embedding layers ("int4", "int8", or None) - linear_group_size: Group size for linear quantization (default: 32 for int4, 128 for int8) - embeddings_group_size: Group size for embedding quantization (default: 32 for int4, 128 for int8) + qlinear: Quantization for linear layers ("4w", "8w", "nvfp4", or None) + qembedding: Quantization for embeddings ("4w", "8w", "nvfp4", or None) + qlinear_group_size: Group size for linear quantization (default: auto) + qembedding_group_size: Group size for embedding quantization (default: auto) max_audio_len: Maximum audio length in seconds for preprocessor """ from optimum.exporters.executorch.tasks.multimodal_text_to_text import ( @@ -137,9 +69,14 @@ def export_voxtral_hf( os.makedirs(output_dir, exist_ok=True) # --- Export preprocessor --- - export_preprocessor( - output_path=os.path.join(output_dir, "preprocessor.pte"), + from executorch.extension.audio.mel_spectrogram import export_processor + + export_processor( + output_file=os.path.join(output_dir, "preprocessor.pte"), + backend="mlx", + feature_size=128, max_audio_len=max_audio_len, + stack_output=True, ) # --- Export model --- @@ -155,14 +92,14 @@ def export_voxtral_hf( ) # Apply quantization if requested - from executorch.backends.mlx.llm.quantization import apply_quantization + from executorch.backends.mlx.llm.quantization import quantize_model_ - apply_quantization( + quantize_model_( exportable.model, - quantize_linear, - quantize_embeddings, - linear_group_size=linear_group_size, - embeddings_group_size=embeddings_group_size, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, ) logger.info("Exporting model with torch.export...") @@ -252,10 +189,10 @@ def main(): output_dir=args.output_dir, max_seq_len=args.max_seq_len, dtype=args.dtype, - quantize_linear=args.quantize_linear, - quantize_embeddings=args.quantize_embeddings, - linear_group_size=args.linear_group_size, - embeddings_group_size=args.embeddings_group_size, + qlinear=args.qlinear, + qembedding=args.qembedding, + qlinear_group_size=args.qlinear_group_size, + qembedding_group_size=args.qembedding_group_size, max_audio_len=args.max_audio_len, ) diff --git a/backends/mlx/examples/whisper/README.md b/backends/mlx/examples/whisper/README.md index ed7333d881d..3e7749a3957 100644 --- a/backends/mlx/examples/whisper/README.md +++ b/backends/mlx/examples/whisper/README.md @@ -1,59 +1,42 @@ -# Whisper MLX Examples +# Whisper MLX Export Export and run [OpenAI Whisper](https://huggingface.co/openai/whisper-tiny) speech-to-text models on the MLX backend. -## Scripts - -| Script | Description | -|---|---| -| `export_whisper.py` | Export with custom KV cache wrapper (3 separate `.pte` files) | -| `run_whisper.py` | Run models exported with `export_whisper` | - -## Quick start +## Prerequisites ```bash -# Export -python -m executorch.backends.mlx.examples.whisper.export_whisper \ - --model-id openai/whisper-tiny \ - --output-dir /tmp/whisper_mlx - -# Run -python -m executorch.backends.mlx.examples.whisper.run_whisper \ - --model-dir /tmp/whisper_mlx \ - --use-sample-audio +pip install transformers torchao soundfile datasets ``` +## Export -## export_whisper.py - -Custom export that splits the model into three programs: +The export script splits the model into three programs: - **encoder.pte** — audio features → encoder hidden states - **cross_kv.pte** — encoder hidden states → per-layer cross-attention K/V - **decoder.pte** — token-by-token generation with self-attention KV cache +Export with int4 weight quantization: + ```bash python -m executorch.backends.mlx.examples.whisper.export_whisper \ --model-id openai/whisper-tiny \ --output-dir /tmp/whisper_mlx \ - --quantize-linear int4 + --dtype bf16 \ + --qlinear 4w ``` +### Export Options + | Option | Default | Description | -|---|---|---| +|--------|---------|-------------| | `--model-id` | `openai/whisper-tiny` | HuggingFace model ID | | `--output-dir` | `whisper_mlx` | Output directory for `.pte` files | | `--max-decoder-seq-len` | `256` | Maximum decoder sequence length | | `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | -| `--quantize-linear` | `None` | Quantize linear layers (`int4`, `int8`) | -| `--quantize-embeddings` | `None` | Quantize embedding layers (`int4`, `int8`) | -| `--linear-group-size` | `None` | Group size for linear quantization (32, 64, 128; default: 32 for int4, 128 for int8) | -| `--embeddings-group-size` | `None` | Group size for embedding quantization (32, 64, 128; default: 32 for int4, 128 for int8) | -## run_whisper.py -Run models exported with `export_whisper.py`. Loads encoder, cross_kv, and -decoder programs from a directory. +## Run ```bash python -m executorch.backends.mlx.examples.whisper.run_whisper \ @@ -61,22 +44,23 @@ python -m executorch.backends.mlx.examples.whisper.run_whisper \ --use-sample-audio ``` +Or with a custom audio file: + +```bash +python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --audio-file /path/to/audio.wav +``` + +### Run Options + | Option | Default | Description | -|---|---|---| +|--------|---------|-------------| | `--model-dir` | `/tmp/whisper_mlx` | Directory containing exported `.pte` files | | `--model-id` | `openai/whisper-tiny` | HuggingFace model ID (used to load processor) | -| `--audio-file` | `None` | Path to audio file (WAV, MP3, etc.) | -| `--use-sample-audio` | `False` | Use sample audio from HuggingFace datasets | +| `--audio-file` | None | Path to audio file (WAV, MP3, etc.) | +| `--use-sample-audio` | off | Use sample audio from HuggingFace datasets | | `--max-new-tokens` | `256` | Maximum tokens to generate | | `--language` | `en` | Language code | | `--task` | `transcribe` | `transcribe` or `translate` | | `--dtype` | `bf16` | Input dtype (must match export dtype) | - -## Requirements - -After installing ExecuTorch, install optimum-executorch: - -```bash -pip install optimum-executorch -pip install transformers torchao soundfile datasets -``` diff --git a/backends/mlx/examples/whisper/export_whisper.py b/backends/mlx/examples/whisper/export_whisper.py index 03123c08935..97d3a22bc79 100644 --- a/backends/mlx/examples/whisper/export_whisper.py +++ b/backends/mlx/examples/whisper/export_whisper.py @@ -368,10 +368,10 @@ def export_whisper_to_mlx( output_dir: str, max_decoder_seq_len: int = 256, dtype: str = "bf16", - quantize_linear: Optional[str] = None, - quantize_embeddings: Optional[str] = None, - linear_group_size: Optional[int] = None, - embeddings_group_size: Optional[int] = None, + qlinear: Optional[str] = None, + qembedding: Optional[str] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, ) -> None: """ Export Whisper model components to MLX delegate. @@ -386,10 +386,10 @@ def export_whisper_to_mlx( output_dir: Directory to save .pte files max_decoder_seq_len: Maximum decoder sequence length dtype: Model dtype ("fp32", "fp16", "bf16") - quantize_linear: Quantization method for linear layers ("int4", "int8", or None) - quantize_embeddings: Quantization method for embedding layers ("int4", "int8", or None) - linear_group_size: Group size for linear quantization. Defaults to 32 for int4, 128 for int8. - embeddings_group_size: Group size for embedding quantization. Defaults to 32 for int4, 128 for int8. + qlinear: Quantization config for linear layers ("4w", "8w", "nvfp4", or None) + qembedding: Quantization config for embedding layers ("4w", "8w", "nvfp4", or None) + qlinear_group_size: Group size for linear quantization (default: auto) + qembedding_group_size: Group size for embedding quantization (default: auto) """ from transformers import AutoProcessor, WhisperForConditionalGeneration @@ -435,63 +435,44 @@ def export_whisper_to_mlx( # Apply quantization if requested # Whisper has 3 separate wrappers to quantize, and embed_positions must be # excluded from embedding quantization (accessed via indexing). - if quantize_linear or quantize_embeddings: - from executorch.backends.mlx.llm.quantization import _default_group_size - - try: - from torchao.quantization.granularity import PerGroup - from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ - from torchao.quantization.quantize_.workflows import ( - IntxChooseQParamsAlgorithm, - ) + if qlinear or qembedding: + from executorch.extension.llm.export.quantize import quantize_model_ + + if qlinear: + logger.info(f"Quantizing linear layers with {qlinear}...") + for module in [encoder_wrapper, cross_kv_wrapper, decoder_wrapper]: + quantize_model_( + module, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + skip_incompatible_shapes=True, + ) - qparams_algorithm = IntxChooseQParamsAlgorithm.HQQ_SCALE_ONLY + if qembedding: + # Custom filter: only embed_tokens, not embed_positions + from executorch.extension.llm.export.quantize import ( + _default_group_size, + _make_embedding_config, + ) + from torchao.quantization.quant_api import quantize_ - if quantize_embeddings: - embed_dtype = ( - torch.int4 if quantize_embeddings == "int4" else torch.int8 - ) - embed_gs = embeddings_group_size or _default_group_size( - quantize_embeddings - ) - logger.info( - f"Quantizing embedding layers with {quantize_embeddings} " - f"(group size {embed_gs})..." - ) - quantize_( - decoder_wrapper, - IntxWeightOnlyConfig( - weight_dtype=embed_dtype, - granularity=PerGroup(embed_gs), - intx_choose_qparams_algorithm=qparams_algorithm, - ), - lambda m, fqn: isinstance(m, nn.Embedding) - and "embed_tokens" in fqn, - ) + gs = ( + qembedding_group_size + if qembedding_group_size is not None + else _default_group_size(qembedding) + ) + embed_config = _make_embedding_config(qembedding, gs) + logger.info( + f"Quantizing embedding layers with {qembedding} " + f"(group size {gs})..." + ) + quantize_( + decoder_wrapper, + embed_config, + lambda m, fqn: isinstance(m, nn.Embedding) and "embed_tokens" in fqn, + ) - if quantize_linear: - linear_dtype = torch.int4 if quantize_linear == "int4" else torch.int8 - linear_gs = linear_group_size or _default_group_size(quantize_linear) - config = IntxWeightOnlyConfig( - weight_dtype=linear_dtype, - granularity=PerGroup(linear_gs), - intx_choose_qparams_algorithm=qparams_algorithm, - ) - logger.info( - f"Quantizing linear layers with {quantize_linear} " - f"(group size {linear_gs})..." - ) - for module in [encoder_wrapper, cross_kv_wrapper, decoder_wrapper]: - quantize_( - module, - config, - filter_fn=lambda m, fqn: isinstance(m, nn.Linear), - ) - - logger.info("Applied quantization successfully") - except ImportError: - logger.error("TorchAO not installed. Run: pip install torchao") - raise + logger.info("Applied quantization successfully") logger.info("Exporting encoder...") @@ -561,8 +542,8 @@ def export_whisper_to_mlx( metadata = { "model_id": model_id, "dtype": dtype, - "quantize_linear": quantize_linear, - "quantize_embeddings": quantize_embeddings, + "quantize_linear": qlinear, + "quantize_embeddings": qembedding, "max_decoder_seq_len": max_decoder_seq_len, "encoder_seq_len": encoder_seq_len, "num_decoder_layers": decoder_wrapper.num_layers, @@ -581,12 +562,9 @@ def _save_to_pte(ep, output_path: str, name: str) -> None: from executorch.exir import EdgeCompileConfig from executorch.exir.capture._config import ExecutorchBackendConfig - # Allow repeat_interleave and sdpa ops edge_config = EdgeCompileConfig( - _core_aten_ops_exception_list=[ - torch.ops.aten.repeat_interleave.self_int, - torch.ops.aten.scaled_dot_product_attention.default, - ] + _check_ir_validity=False, + _skip_dim_order=True, ) edge_program = exir.to_edge_transform_and_lower( @@ -628,10 +606,10 @@ def main(): output_dir=args.output_dir, max_decoder_seq_len=args.max_decoder_seq_len, dtype=args.dtype, - quantize_linear=args.quantize_linear, - quantize_embeddings=args.quantize_embeddings, - linear_group_size=args.linear_group_size, - embeddings_group_size=args.embeddings_group_size, + qlinear=args.qlinear, + qembedding=args.qembedding, + qlinear_group_size=args.qlinear_group_size, + qembedding_group_size=args.qembedding_group_size, ) diff --git a/backends/mlx/llm/quantization.py b/backends/mlx/llm/quantization.py index 0fdb988f9fa..196e4a9ac1f 100644 --- a/backends/mlx/llm/quantization.py +++ b/backends/mlx/llm/quantization.py @@ -6,46 +6,47 @@ # LICENSE file in the root directory of this source tree. """ -Shared quantization utilities for MLX LLM export scripts. +Quantization argument helpers for MLX LLM export scripts. + +Re-exports quantize_model_ from the shared ExecuTorch LLM export library +and provides add_quantization_args for MLX export CLI scripts. """ import argparse -import logging -from typing import Optional -import torch +from executorch.extension.llm.export.quantize import quantize_model_ -logger = logging.getLogger(__name__) +__all__ = ["add_quantization_args", "quantize_model_"] def add_quantization_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( - "--quantize-linear", + "--qlinear", type=str, - choices=["int4", "int8"], + choices=["4w", "8w", "nvfp4"], default=None, - help="Quantization method for linear layers", + help="Quantization config for linear layers", ) parser.add_argument( - "--quantize-embeddings", + "--qembedding", type=str, - choices=["int4", "int8"], + choices=["4w", "8w", "nvfp4"], default=None, - help="Quantization method for embedding layers", + help="Quantization config for embedding layers", ) parser.add_argument( - "--linear-group-size", + "--qlinear-group-size", type=int, choices=[32, 64, 128], default=None, - help="Group size for linear layer quantization (default: 32 for int4, 128 for int8)", + help="Group size for linear layer quantization (default: 32)", ) parser.add_argument( - "--embeddings-group-size", + "--qembedding-group-size", type=int, choices=[32, 64, 128], default=None, - help="Group size for embedding layer quantization (default: 32 for int4, 128 for int8)", + help="Group size for embedding layer quantization (default: 128)", ) parser.add_argument( "--no-tie-word-embeddings", @@ -54,98 +55,3 @@ def add_quantization_args(parser: argparse.ArgumentParser) -> None: help="Disable tying lm_head weights to embedding after quantization, " "even if the model config has tie_word_embeddings=True", ) - - -def _default_group_size(dtype_str: str) -> int: - return 32 if dtype_str == "int4" else 128 - - -def apply_quantization( - model: torch.nn.Module, - quantize_linear: Optional[str], - quantize_embeddings: Optional[str], - tie_word_embeddings: bool = False, - linear_group_size: Optional[int] = None, - embeddings_group_size: Optional[int] = None, -) -> None: - """Apply TorchAO quantization to the model. - - Uses the HQQ (Half-Quadratic Quantization) scale-only algorithm for - choosing quantization parameters. - - Args: - model: The model to quantize. Expected to have model.model.embed_tokens - and model.lm_head attributes for weight tying. - quantize_linear: Quantization method for linear layers ("int4", "int8", or None) - quantize_embeddings: Quantization method for embedding layers ("int4", "int8", or None) - tie_word_embeddings: If True, re-tie lm_head.weight to embed_tokens.weight - after quantization. Should be set from the HF model config's - tie_word_embeddings field, and can be overridden with --no-tie-word-embeddings. - linear_group_size: Group size for linear quantization. Defaults to 32 for int4, 128 for int8. - embeddings_group_size: Group size for embedding quantization. Defaults to 32 for int4, 128 for int8. - """ - if not quantize_linear and not quantize_embeddings: - return - - logger.info("Applying quantization with TorchAO...") - try: - from torchao.quantization.granularity import PerGroup - from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ - from torchao.quantization.quantize_.workflows import IntxChooseQParamsAlgorithm - - qparams_algorithm = IntxChooseQParamsAlgorithm.HQQ_SCALE_ONLY - - if quantize_embeddings: - embed_dtype = torch.int4 if quantize_embeddings == "int4" else torch.int8 - embed_group_size = embeddings_group_size or _default_group_size( - quantize_embeddings - ) - logger.info( - f"Quantizing embedding layers with {quantize_embeddings} " - f"(group size {embed_group_size})..." - ) - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=embed_dtype, - granularity=PerGroup(embed_group_size), - intx_choose_qparams_algorithm=qparams_algorithm, - ), - filter_fn=lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - - if quantize_linear: - linear_dtype = torch.int4 if quantize_linear == "int4" else torch.int8 - linear_group_size = linear_group_size or _default_group_size( - quantize_linear - ) - logger.info( - f"Quantizing linear layers with {quantize_linear} " - f"(group size {linear_group_size})..." - ) - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=linear_dtype, - granularity=PerGroup(linear_group_size), - intx_choose_qparams_algorithm=qparams_algorithm, - ), - filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear), - ) - - if ( - tie_word_embeddings - and hasattr(model, "lm_head") - and hasattr(model, "model") - ): - embed = getattr(model.model, "embed_tokens", None) - if embed is not None: - model.lm_head.weight = embed.weight - logger.info( - "Re-tied lm_head weights to embedding (tie_word_embeddings=True)" - ) - - logger.info("Applied quantization successfully") - except ImportError: - logger.error("TorchAO not installed. Run: pip install torchao") - raise diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b8f4bc0cd78..2a07fd06e73 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -39,24 +39,25 @@ The export script supports quantizing encoder and decoder linear layers using [t | Argument | Description | |----------|-------------| -| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | -| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) | +| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4` | +| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: auto) | | `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` | -| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | -| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) | +| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4` | +| `--qlinear_group_size` | Group size for decoder linear quantization (default: auto) | | `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` | -| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` | -| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) | +| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w`, `nvfp4` | +| `--qembedding_group_size` | Group size for embedding quantization (default: auto) | #### Quantization Configs | Config | Description | Backends | |--------|-------------|----------| -| `4w` | 4-bit weight only quantization | CUDA | -| `8w` | 8-bit weight only quantization | CUDA | -| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA | -| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA | +| `4w` | 4-bit weight only quantization | CUDA, MLX, XNNPACK (embedding only) | +| `8w` | 8-bit weight only quantization | CUDA, MLX, XNNPACK (embedding only) | +| `8da4w` | 8-bit dynamic activation, 4-bit weight | XNNPACK | +| `8da8w` | 8-bit dynamic activation, 8-bit weight | XNNPACK | | `fpa4w` | Floating point activation, 4-bit weight | Metal | +| `nvfp4` | 4-bit weight only quantization using NVIDIA's FP4 dtype | MLX | #### Example: Dynamic Quantization for XNNPACK @@ -172,7 +173,7 @@ This generates: - `model.pte` - The compiled Parakeet TDT model - `aoti_cuda_blob.ptd` - CUDA kernel blob required at runtime -### MLX Export (macOS) +### MLX Export Export with MLX backend (bf16, int4 quantized, group size 128): ```bash @@ -186,6 +187,19 @@ python export_parakeet_tdt.py \ --output-dir ./parakeet_mlx_4w ``` +Export with MLX backend (bf16, NVFP4 quantized): +```bash +python export_parakeet_tdt.py \ + --backend mlx \ + --dtype bf16 \ + --qlinear_encoder nvfp4 \ + --qlinear nvfp4 \ + --qembedding 4w \ + --output-dir ./parakeet_mlx_nvfp4 +``` + +> **Note:** Although MLX supports NVFP4 embedding quantization, Parakeet's embedding layer has dimensions not divisible by 16, which is incompatible with NVFP4. Use `4w` for embeddings instead. + This generates: - `model.pte` - The compiled model with MLX delegate (~470 MB) - `tokenizer.model` - SentencePiece tokenizer diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 340ac02d833..55821a98e16 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -304,15 +304,15 @@ def export_all( backend: Optional[str] = None, # Encoder quantization args qlinear_encoder: Optional[str] = None, - qlinear_encoder_group_size: int = 32, + qlinear_encoder_group_size: Optional[int] = None, qlinear_encoder_packing_format: Optional[str] = None, # Decoder quantization args qlinear: Optional[str] = None, - qlinear_group_size: int = 32, + qlinear_group_size: Optional[int] = None, qlinear_packing_format: Optional[str] = None, # Embedding quantization args (decoder has the embedding layer) qembedding: Optional[str] = None, - qembedding_group_size: int = 0, + qembedding_group_size: Optional[int] = None, ): """Export all model components. @@ -636,14 +636,14 @@ def main(): parser.add_argument( "--qlinear", type=str, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantization config for decoder linear layers", ) parser.add_argument( "--qlinear_group_size", type=int, - default=32, - help="Group size for decoder linear quantization (default: 32)", + default=None, + help="Group size for decoder linear quantization", ) parser.add_argument( "--qlinear_packing_format", @@ -656,14 +656,14 @@ def main(): parser.add_argument( "--qlinear_encoder", type=str, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantization config for encoder linear layers", ) parser.add_argument( "--qlinear_encoder_group_size", type=int, - default=32, - help="Group size for encoder linear quantization (default: 32)", + default=None, + help="Group size for encoder linear quantization", ) parser.add_argument( "--qlinear_encoder_packing_format", @@ -676,14 +676,14 @@ def main(): parser.add_argument( "--qembedding", type=str, - choices=["4w", "8w"], + choices=["4w", "8w", "nvfp4"], help="Quantization config for decoder embedding layer", ) parser.add_argument( "--qembedding_group_size", type=int, - default=0, - help="Group size for embedding quantization (default: 0 = per-axis)", + default=None, + help="Group size for embedding quantization", ) args = parser.parse_args() diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index 6915fba3580..7e1ea6c36b1 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -43,6 +43,16 @@ python -m executorch.extension.audio.mel_spectrogram \ --output_file ./voxtral_rt_exports/preprocessor.pte ``` +For MLX backend, use `--backend mlx`: + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --max_audio_len 300 \ + --backend mlx \ + --output_file ./voxtral_rt_exports/preprocessor.pte +``` + For streaming, use a separate preprocessor with `--streaming` (no audio length limit): @@ -53,6 +63,16 @@ python -m executorch.extension.audio.mel_spectrogram \ --output_file ./voxtral_streaming_exports/preprocessor.pte ``` +For streaming with MLX backend: + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --streaming \ + --backend mlx \ + --output_file ./voxtral_streaming_exports/preprocessor.pte +``` + ## Export Export produces a single `.pte` containing the audio encoder, text decoder, @@ -167,8 +187,34 @@ EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_ex #### MLX export examples MLX backend uses the MLX delegate for Apple Silicon GPU acceleration. +NVFP4 quantizes weights using NVIDIA's FP4 data type. -Offline: +Offline (NVFP4): + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder nvfp4 \ + --qlinear nvfp4 \ + --qembedding nvfp4 +``` + +Streaming (NVFP4): + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ + --streaming \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder nvfp4 \ + --qlinear nvfp4 \ + --qembedding nvfp4 +``` + +Offline (int4 linear + int8 embedding): ```bash python export_voxtral_rt.py \ @@ -178,11 +224,10 @@ python export_voxtral_rt.py \ --qlinear-encoder 4w \ --qlinear 4w \ --qembedding 8w \ - --qembedding-group-size 128 \ - --export-preprocessor + --qembedding-group-size 128 ``` -Streaming: +Streaming (int4 linear + int8 embedding): ```bash python export_voxtral_rt.py \ @@ -193,13 +238,9 @@ python export_voxtral_rt.py \ --qlinear-encoder 4w \ --qlinear 4w \ --qembedding 8w \ - --qembedding-group-size 128 \ - --export-preprocessor + --qembedding-group-size 128 ``` -`--export-preprocessor` bundles the mel preprocessor into the output directory -using the MLX partitioner, so no separate preprocessor export step is needed. - ### Options | Flag | Default | Description | @@ -210,15 +251,13 @@ using the MLX partitioner, so no separate preprocessor export step is needed. | `--output-dir` | `./voxtral_rt_exports` | Output directory | | `--max-seq-len` | `4096` | KV cache length | | `--delay-tokens` | `6` | Transcription delay in tokens (6 = 480ms) | -| `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) | -| `--qlinear-group-size` | `32` | Group size for decoder linear quantization | -| `--qlinear-packing-format` | (none) | Packing format for decoder 4w quantization (`tile_packed_to_4d` for CUDA) | -| `--qlinear-encoder` | (none) | Encoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) | -| `--qlinear-encoder-group-size` | `32` | Group size for encoder linear quantization | -| `--qlinear-encoder-packing-format` | (none) | Packing format for encoder 4w quantization (`tile_packed_to_4d` for CUDA) | -| `--qembedding` | (none) | Embedding layer quantization (`8w`) | -| `--qembedding-group-size` | `0` | Group size for embedding quantization (0 = per-channel) | -| `--export-preprocessor` | off | Export `preprocessor.pte` alongside the model | + +| `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4`) | +| `--qlinear-group-size` | auto | Group size for decoder linear quantization | +| `--qlinear-encoder` | (none) | Encoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4`) | +| `--qlinear-encoder-group-size` | auto | Group size for encoder linear quantization | +| `--qembedding` | (none) | Embedding layer quantization (`4w`, `8w`, `nvfp4`) | +| `--qembedding-group-size` | auto | Group size for embedding quantization | | `--streaming` | off | Export streaming encoder with KV cache | | `--max-enc-len` | `750` | Encoder sliding window size (streaming only) | diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index 8255aa9861c..e93d1a6fbfa 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -377,60 +377,6 @@ def _linear_bias_decomposition(input, weight, bias=None): return out -def export_preprocessor(output_dir, backend="xnnpack", streaming=False): - """Export mel spectrogram preprocessor. - - Uses XNNPACK for all backends except MLX, which uses MLX partitioner. - """ - from executorch.extension.audio.mel_spectrogram import WhisperAudioProcessor - - # Use MLX partitioner for mlx backend, XNNPACK for everything else - pp_backend = "mlx" if backend == "mlx" else "xnnpack" - print(f" Using {pp_backend.upper()} partitioner for preprocessor...") - - model = WhisperAudioProcessor( - feature_size=128, - max_audio_len=300, - streaming=streaming, - ) - - audio_tensor = torch.randn(93680) - shapes_collection = torch.export.ShapesCollection() - max_n_chunks = int(model.max_audio_len * model.n_samples) - shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)} - - with torch.no_grad(), torch.fx.experimental._config.patch( - backed_size_oblivious=True - ): - ep = export( - model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True - ) - - if pp_backend == "mlx": - from executorch.backends.mlx.partitioner import MLXPartitioner - - partitioner = [MLXPartitioner()] - else: - from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, - ) - - partitioner = [XnnpackPartitioner()] - - edge = to_edge_transform_and_lower( - ep, - partitioner=partitioner, - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - exec_prog = edge.to_executorch() - - pp_path = os.path.join(output_dir, "preprocessor.pte") - with open(pp_path, "wb") as f: - exec_prog.write_to_file(f) - size_mb = os.path.getsize(pp_path) / (1024 * 1024) - print(f" Saved preprocessor to {pp_path} ({size_mb:.1f} MB)") - - def lower_to_executorch(programs, metadata, backend="xnnpack"): """Lower exported programs to ExecuTorch.""" if backend == "xnnpack": @@ -554,7 +500,7 @@ def main(): parser.add_argument( "--qlinear", default=None, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantize decoder linear layers.", ) parser.add_argument( @@ -572,7 +518,7 @@ def main(): parser.add_argument( "--qlinear-encoder", default=None, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantize encoder linear layers (separate from decoder).", ) parser.add_argument( @@ -590,8 +536,8 @@ def main(): parser.add_argument( "--qembedding", default=None, - choices=["8w"], - help="Quantize embedding layers (8-bit weight-only).", + choices=["4w", "8w", "nvfp4"], + help="Quantize embedding layers.", ) parser.add_argument( "--qembedding-group-size", @@ -616,11 +562,6 @@ def main(): choices=["fp32", "bf16"], help="Model dtype (default: fp32).", ) - parser.add_argument( - "--export-preprocessor", - action="store_true", - help="Also export preprocessor.pte (uses XNNPACK, or MLX for --backend mlx).", - ) args = parser.parse_args() backend_for_export = "cuda" if args.backend == "cuda-windows" else args.backend @@ -691,11 +632,6 @@ def main(): et.write_tensor_data_to_file(args.output_dir) print(f"Saved tensor data to {args.output_dir}/") - # Export preprocessor if requested - if args.export_preprocessor: - print("\nExporting preprocessor...") - export_preprocessor(args.output_dir, args.backend, args.streaming) - print("\nDone!") diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index 50b9ded01af..8d44cb00b48 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -6,11 +6,11 @@ import argparse import logging +import os import torch import torch.nn as nn import torch.nn.functional as F - from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import ( EdgeCompileConfig, @@ -188,9 +188,9 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: return log_spec.unsqueeze(0) -def export_processor(model=None, output_file="whisper_preprocess.pte"): - if model is None: - model = WhisperAudioProcessor() +def _export_processor_model( + model, output_file="whisper_preprocess.pte", backend="xnnpack" +): audio_tensor = torch.randn(93680) shapes_collection = torch.export.ShapesCollection() @@ -204,10 +204,17 @@ def export_processor(model=None, output_file="whisper_preprocess.pte"): ) logging.debug(ep) + if backend == "mlx": + from executorch.backends.mlx.partitioner import MLXPartitioner + + partitioner = [MLXPartitioner()] + else: + partitioner = [XnnpackPartitioner()] + # to edge edge: EdgeProgramManager = to_edge_transform_and_lower( ep, - partitioner=[XnnpackPartitioner()], + partitioner=partitioner, compile_config=EdgeCompileConfig( _check_ir_validity=False, ), @@ -216,12 +223,28 @@ def export_processor(model=None, output_file="whisper_preprocess.pte"): # to executorch exec_prog = edge.to_executorch() + os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True) with open(output_file, "wb") as file: exec_prog.write_to_file(file) logging.debug("Done") +def export_processor( + output_file="whisper_preprocess.pte", backend="xnnpack", **model_kwargs +): + """Export a WhisperAudioProcessor to a .pte file. + + Args: + output_file: Output .pte file path. + backend: Backend for partitioning ("xnnpack" or "mlx"). + **model_kwargs: Passed to WhisperAudioProcessor constructor + (e.g. feature_size, max_audio_len, stack_output, streaming). + """ + model = WhisperAudioProcessor(**model_kwargs) + _export_processor_model(model, output_file, backend) + + def main(): parser = argparse.ArgumentParser( description="Export WhisperAudioProcessor to ExecuTorch" @@ -276,9 +299,19 @@ def main(): help="Streaming mode: skip 30-second chunk padding, produce mel frames proportional to input length. For use with real-time audio input.", ) + parser.add_argument( + "--backend", + type=str, + default="xnnpack", + choices=["xnnpack", "mlx"], + help="Backend for partitioning (default: xnnpack)", + ) + args = parser.parse_args() - model = WhisperAudioProcessor( + export_processor( + output_file=args.output_file, + backend=args.backend, feature_size=args.feature_size, sampling_rate=args.sampling_rate, hop_length=args.hop_length, @@ -289,8 +322,6 @@ def main(): streaming=args.streaming, ) - export_processor(model, args.output_file) - if __name__ == "__main__": main() diff --git a/extension/llm/export/quantize.py b/extension/llm/export/quantize.py index b372bbb9db8..fb2678ff60f 100644 --- a/extension/llm/export/quantize.py +++ b/extension/llm/export/quantize.py @@ -4,8 +4,8 @@ torch.export(). This is the source-transform counterpart to quantizer_lib.py (which handles PT2E graph-mode quantization). -Supported linear configs: "4w", "8w", "8da4w", "8da8w", "fpa4w" (Metal). -Supported embedding configs: "4w", "8w". +Supported linear configs: "4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4". +Supported embedding configs: "4w", "8w", "nvfp4". Usage: from executorch.extension.llm.export.quantize import quantize_model_ @@ -18,14 +18,168 @@ from executorch.exir._warnings import experimental +def _make_granularity(group_size: int): + """Create PerAxis(0) or PerGroup granularity.""" + from torchao.quantization.granularity import PerAxis, PerGroup + + return PerAxis(0) if group_size == 0 else PerGroup(group_size) + + +def _make_linear_config(config_name: str, group_size: int, packing_format=None): + """Build a TorchAO config for linear layer quantization.""" + from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ) + + granularity = _make_granularity(group_size) + + if config_name == "nvfp4": + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + + assert group_size == 16, "NVFP4 requires group_size=16" + return ExportableNVFP4Config(use_per_tensor_scale=False) + elif config_name == "4w": + if packing_format: + return Int4WeightOnlyConfig( + group_size=group_size, + int4_packing_format=packing_format, + int4_choose_qparams_algorithm="hqq", + ) + return IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8da4w": + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8da8w": + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int8, + weight_granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + else: + raise ValueError(f"Unsupported qlinear_config: {config_name}") + + +def _make_embedding_config(config_name: str, group_size: int): + """Build a TorchAO config for embedding layer quantization.""" + from torchao.quantization.quant_api import IntxWeightOnlyConfig + + if group_size != 0: + assert group_size % 2 == 0, "Embedding group size must be a multiple of 2." + + granularity = _make_granularity(group_size) + + if config_name == "nvfp4": + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + + return ExportableNVFP4Config(use_per_tensor_scale=False) + elif config_name == "4w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + else: + raise ValueError(f"Unsupported qembedding_config: {config_name}") + + +def _check_shape_compatible(m, fqn, config_name, group_size, skip_incompatible_shapes): + """Check shape compatibility. Returns True if compatible, False if skipped. + + Raises RuntimeError if incompatible and skip_incompatible_shapes is False. + """ + shape = m.weight.shape + if config_name == "nvfp4": + compatible = shape[-2] % group_size == 0 and shape[-1] % group_size == 0 + elif group_size != 0: + compatible = shape[-1] % group_size == 0 + else: + compatible = True + + if compatible: + return True + if not skip_incompatible_shapes: + raise RuntimeError( + f"Layer {fqn} has weight shape {shape} " + f"incompatible with {config_name} (group_size={group_size}). " + f"Use skip_incompatible_shapes=True to skip instead of failing." + ) + print( + f" Skipping {fqn}: weight shape {shape} " + f"incompatible with {config_name} (group_size={group_size})" + ) + return False + + +def _make_linear_filter( + config_name: str, group_size: int, skip_incompatible_shapes: bool = False +): + """Create a filter_fn for linear layers, skipping incompatible shapes.""" + + def linear_filter(m, fqn): + if not isinstance(m, torch.nn.Linear): + return False + return _check_shape_compatible( + m, fqn, config_name, group_size, skip_incompatible_shapes + ) + + return linear_filter + + +def _make_embedding_filter( + config_name: str, group_size: int, skip_incompatible_shapes: bool = False +): + """Create a filter_fn for embedding layers, skipping incompatible shapes.""" + + def embedding_filter(m, fqn): + if not isinstance(m, torch.nn.Embedding): + return False + return _check_shape_compatible( + m, fqn, config_name, group_size, skip_incompatible_shapes + ) + + return embedding_filter + + +def _default_group_size(config_name: Optional[str]) -> int: + """Return the default group size for a quantization config.""" + if config_name == "nvfp4": + return 16 + if config_name in ("8w", "8da8w"): + return 0 + return 32 + + @experimental("quantize_model_ is experimental and may change without notice.") -def quantize_model_( # noqa: C901 +def quantize_model_( module: torch.nn.Module, qlinear_config: Optional[str] = None, - qlinear_group_size: int = 32, + qlinear_group_size: Optional[int] = None, qlinear_packing_format: Optional[str] = None, qembedding_config: Optional[str] = None, - qembedding_group_size: int = 0, + qembedding_group_size: Optional[int] = None, + tie_word_embeddings: bool = False, + skip_incompatible_shapes: bool = False, ) -> None: """Quantize linear and embedding layers in a module in-place. @@ -36,20 +190,32 @@ def quantize_model_( # noqa: C901 Args: module: The PyTorch module to quantize. qlinear_config: Quantization config for linear layers - ("4w", "8w", "8da4w", "8da8w", "fpa4w"). - qlinear_group_size: Group size for linear quantization (default: 32). + ("4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"). + qlinear_group_size: Group size for linear quantization. + Defaults to 16 for nvfp4, 32 for 4w/8da4w, 0 (per-axis) for 8w/8da8w. qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). - qembedding_config: Quantization config for embedding layers ("4w", "8w"). - qembedding_group_size: Group size for embedding quantization - (default: 0 = per-axis). + qembedding_config: Quantization config for embedding layers + ("4w", "8w", "nvfp4"). + qembedding_group_size: Group size for embedding quantization. + Defaults to 16 for nvfp4, 32 for 4w, 0 (per-axis) for 8w. + tie_word_embeddings: If True and both linear and embedding use the + same quantization, re-tie lm_head.weight to embed_tokens.weight + after quantization. + skip_incompatible_shapes: If True, silently skip layers with + incompatible weight shapes. If False (default), raise RuntimeError. """ if not qlinear_config and not qembedding_config: return + if qlinear_group_size is None: + qlinear_group_size = _default_group_size(qlinear_config) + if qembedding_group_size is None: + qembedding_group_size = _default_group_size(qembedding_config) + from torchao.quantization.quant_api import quantize_ - # Metal (MPS) quantization uses different API + # Metal (MPS) quantization uses a separate API if qlinear_config == "fpa4w": import torchao.experimental.ops.mps # noqa: F401 from torchao.experimental.quant_api import UIntxWeightOnlyConfig @@ -59,111 +225,59 @@ def quantize_model_( # noqa: C901 bitwidth=4, uintx_choose_qparams_algorithm="hqq", ) - - def linear_filter(m, fqn): - if isinstance(m, torch.nn.Linear): - if m.weight.shape[1] % qlinear_group_size != 0: - raise ValueError( - f"Metal int4 quantization requires weight dimension (K) " - f"to be multiple of group_size. Layer {fqn} has weight " - f"shape {m.weight.shape} (K={m.weight.shape[1]}, " - f"group_size={qlinear_group_size})" - ) - return True - return False - print( f" Applying {qlinear_config} linear quantization " f"(group_size={qlinear_group_size})..." ) - quantize_(module, config, filter_fn=linear_filter) + quantize_( + module, + config, + filter_fn=_make_linear_filter( + "fpa4w", qlinear_group_size, skip_incompatible_shapes + ), + ) return - from torchao.quantization.granularity import PerAxis, PerGroup - from torchao.quantization.quant_api import ( - Int4WeightOnlyConfig, - Int8DynamicActivationIntxWeightConfig, - IntxWeightOnlyConfig, - ) - # Quantize embedding layers first if qembedding_config: - if qembedding_group_size == 0: - embedding_granularity = PerAxis(0) - else: - assert ( - qembedding_group_size % 2 == 0 - ), "Embedding group size must be a multiple of 2." - embedding_granularity = PerGroup(qembedding_group_size) - - embedding_config = IntxWeightOnlyConfig( - weight_dtype=torch.int4 if qembedding_config == "4w" else torch.int8, - granularity=embedding_granularity, - ) - + config = _make_embedding_config(qembedding_config, qembedding_group_size) print( f" Applying {qembedding_config} embedding quantization " f"(group_size={qembedding_group_size})..." ) quantize_( module, - embedding_config, - lambda m, fqn: isinstance(m, torch.nn.Embedding), + config, + filter_fn=_make_embedding_filter( + qembedding_config, qembedding_group_size, skip_incompatible_shapes + ), ) # Quantize linear layers if qlinear_config: - if qlinear_group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(qlinear_group_size) - - if qlinear_config == "4w": - if qlinear_packing_format: - config = Int4WeightOnlyConfig( - group_size=qlinear_group_size, - int4_packing_format=qlinear_packing_format, - int4_choose_qparams_algorithm="hqq", - ) - else: - config = IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=granularity, - ) - elif qlinear_config == "8w": - config = IntxWeightOnlyConfig( - weight_dtype=torch.int8, - granularity=granularity, - ) - elif qlinear_config == "8da4w": - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, - weight_granularity=granularity, - intx_choose_qparams_algorithm="hqq_scale_only", - ) - elif qlinear_config == "8da8w": - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int8, - weight_granularity=PerAxis(0), - ) - else: - raise ValueError(f"Unsupported qlinear_config: {qlinear_config}") - - def linear_filter(m, fqn): - if isinstance(m, torch.nn.Linear): - if qlinear_group_size == 0: - return True - if m.weight.shape[1] % qlinear_group_size != 0: - print( - f" Skipping {fqn}: weight shape {m.weight.shape} " - f"incompatible with group_size={qlinear_group_size}" - ) - return False - return True - return False - + config = _make_linear_config( + qlinear_config, qlinear_group_size, qlinear_packing_format + ) print( f" Applying {qlinear_config} linear quantization " f"(group_size={qlinear_group_size}, packing={qlinear_packing_format})..." ) - quantize_(module, config, filter_fn=linear_filter) + quantize_( + module, + config, + filter_fn=_make_linear_filter( + qlinear_config, qlinear_group_size, skip_incompatible_shapes + ), + ) + + # Re-tie word embeddings after quantization if both use the same config + if ( + tie_word_embeddings + and qlinear_config == qembedding_config + and hasattr(module, "lm_head") + and hasattr(module, "model") + ): + embed = getattr(module.model, "embed_tokens", None) + if embed is not None: + module.lm_head.weight = embed.weight + print(" Re-tied lm_head weights to embedding (tie_word_embeddings=True)") From b61faab12dd3d6a0aceee6281c628c5c57708196 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 10:24:39 -0800 Subject: [PATCH 19/34] up --- extension/llm/export/config/llm_config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 4da706bb889..2027d76c461 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -599,11 +599,8 @@ class BackendConfig: torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) tosa: TosaConfig = field(default_factory=TosaConfig) ethosu: EthosUConfig = field(default_factory=EthosUConfig) -<<<<<<< HEAD vgf: VgfConfig = field(default_factory=VgfConfig) -======= mlx: MLXConfig = field(default_factory=MLXConfig) ->>>>>>> d44535c8a5 (up) ################################################################################ From eb22885868e1c2c6d2ab556a441c471ce866f696 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 12:20:10 -0800 Subject: [PATCH 20/34] up --- examples/models/voxtral_realtime/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 9640e09983d..8ae17fc7f82 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -694,7 +694,7 @@ def update( v_val = v_val.transpose(1, 2) return self.ring_cache.update(input_pos, k_val, v_val) - def create_causal_mask(self, start_pos: int, seq_len: int) -> torch.Tensor: + def create_causal_mask(self, start_pos, seq_len, bool_mask=False) -> torch.Tensor: return self.ring_cache.create_sliding_window_mask(start_pos, seq_len) From b8f0fa6f14e2721b4f8d526f4d9a4e61a185ceb7 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 13:01:16 -0800 Subject: [PATCH 21/34] up --- examples/models/voxtral_realtime/export_voxtral_rt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index e93d1a6fbfa..df906d114a2 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -379,6 +379,8 @@ def _linear_bias_decomposition(input, weight, bias=None): def lower_to_executorch(programs, metadata, backend="xnnpack"): """Lower exported programs to ExecuTorch.""" + transform_passes = None + if backend == "xnnpack": from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, @@ -434,15 +436,18 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"): partitioner[key] = [CudaPartitioner(compile_specs)] elif backend == "mlx": from executorch.backends.mlx.partitioner import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes print("\nLowering to ExecuTorch with MLX...") partitioner = {key: [MLXPartitioner()] for key in programs} + transform_passes = get_default_passes() else: print("\nLowering to ExecuTorch (portable)...") partitioner = [] et_prog = to_edge_transform_and_lower( programs, + transform_passes=transform_passes, partitioner=partitioner, compile_config=EdgeCompileConfig( _check_ir_validity=False, From 1ed89ee1f761b147aa29794c49eed1f533771624 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 17:46:59 -0800 Subject: [PATCH 22/34] up --- .github/workflows/mlx.yml | 99 ++ .gitignore | 2 + .gitmodules | 4 + CMakeLists.txt | 15 + CMakePresets.json | 85 +- backends/mlx/CMakeLists.txt | 330 ++++ backends/mlx/README.md | 499 ++++++ backends/mlx/__init__.py | 17 + backends/mlx/_logging.py | 40 + backends/mlx/builder/__init__.py | 16 + backends/mlx/builder/op_helpers.py | 275 ++++ backends/mlx/builder/op_registry.py | 151 ++ backends/mlx/builder/pattern_matcher.py | 64 + backends/mlx/builder/program_builder.py | 1018 ++++++++++++ backends/mlx/builder/slot_manager.py | 187 +++ backends/mlx/custom_ops.py | 15 + backends/mlx/ops.py | 294 ++++ backends/mlx/partitioner.py | 298 ++++ backends/mlx/passes.py | 20 + backends/mlx/patches/mlx_json.patch | 29 + backends/mlx/pattern_utils.py | 360 +++++ backends/mlx/patterns.py | 14 + backends/mlx/preprocess.py | 168 ++ backends/mlx/pte_inspector.py | 897 ++++++++++ backends/mlx/runtime/MLXBackend.cpp | 419 +++++ backends/mlx/runtime/MLXExecutor.h | 878 ++++++++++ backends/mlx/runtime/MLXInterpreter.h | 169 ++ backends/mlx/serialization/MLXLoader.cpp.tmpl | 324 ++++ backends/mlx/serialization/MLXLoader.h.tmpl | 343 ++++ backends/mlx/serialization/README.md | 130 ++ backends/mlx/serialization/__init__.py | 32 + backends/mlx/serialization/generate.py | 1437 +++++++++++++++++ .../mlx/serialization/mlx_graph_serialize.py | 416 +++++ backends/mlx/serialization/schema.fbs | 192 +++ backends/mlx/test/CMakeLists.txt | 51 + backends/mlx/test/README.md | 164 ++ backends/mlx/test/__init__.py | 5 + backends/mlx/test/op_test_runner.cpp | 395 +++++ backends/mlx/test/run_all_tests.py | 496 ++++++ backends/mlx/test/strict_compile_test.cpp | 45 + backends/mlx/test/test_ops.py | 176 ++ backends/mlx/test/test_partitioner.py | 45 + backends/mlx/test/test_passes.py | 6 + backends/mlx/test/test_pattern_utils.py | 592 +++++++ backends/mlx/test/test_utils.py | 1122 +++++++++++++ backends/mlx/test/tester.py | 78 + backends/mlx/third-party/mlx | 1 + backends/test/suite/flow.py | 14 + backends/test/suite/flows/mlx.py | 14 + exir/_serialize/_program.py | 67 + setup.py | 33 + tools/cmake/Utils.cmake | 33 + tools/cmake/executorch-config.cmake | 45 + tools/cmake/preset/default.cmake | 1 + tools/cmake/preset/pybind.cmake | 18 + 55 files changed, 12637 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/mlx.yml create mode 100644 backends/mlx/CMakeLists.txt create mode 100644 backends/mlx/README.md create mode 100644 backends/mlx/__init__.py create mode 100644 backends/mlx/_logging.py create mode 100644 backends/mlx/builder/__init__.py create mode 100644 backends/mlx/builder/op_helpers.py create mode 100644 backends/mlx/builder/op_registry.py create mode 100644 backends/mlx/builder/pattern_matcher.py create mode 100644 backends/mlx/builder/program_builder.py create mode 100644 backends/mlx/builder/slot_manager.py create mode 100644 backends/mlx/custom_ops.py create mode 100644 backends/mlx/ops.py create mode 100644 backends/mlx/partitioner.py create mode 100644 backends/mlx/passes.py create mode 100644 backends/mlx/patches/mlx_json.patch create mode 100644 backends/mlx/pattern_utils.py create mode 100644 backends/mlx/patterns.py create mode 100644 backends/mlx/preprocess.py create mode 100644 backends/mlx/pte_inspector.py create mode 100644 backends/mlx/runtime/MLXBackend.cpp create mode 100644 backends/mlx/runtime/MLXExecutor.h create mode 100644 backends/mlx/runtime/MLXInterpreter.h create mode 100644 backends/mlx/serialization/MLXLoader.cpp.tmpl create mode 100644 backends/mlx/serialization/MLXLoader.h.tmpl create mode 100644 backends/mlx/serialization/README.md create mode 100644 backends/mlx/serialization/__init__.py create mode 100755 backends/mlx/serialization/generate.py create mode 100644 backends/mlx/serialization/mlx_graph_serialize.py create mode 100644 backends/mlx/serialization/schema.fbs create mode 100644 backends/mlx/test/CMakeLists.txt create mode 100644 backends/mlx/test/README.md create mode 100644 backends/mlx/test/__init__.py create mode 100644 backends/mlx/test/op_test_runner.cpp create mode 100644 backends/mlx/test/run_all_tests.py create mode 100644 backends/mlx/test/strict_compile_test.cpp create mode 100644 backends/mlx/test/test_ops.py create mode 100644 backends/mlx/test/test_partitioner.py create mode 100644 backends/mlx/test/test_passes.py create mode 100644 backends/mlx/test/test_pattern_utils.py create mode 100644 backends/mlx/test/test_utils.py create mode 100644 backends/mlx/test/tester.py create mode 160000 backends/mlx/third-party/mlx create mode 100644 backends/test/suite/flows/mlx.py diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml new file mode 100644 index 00000000000..2e8ca7aa3b7 --- /dev/null +++ b/.github/workflows/mlx.yml @@ -0,0 +1,99 @@ +name: MLX + +on: + push: + branches: + - main + - release/* + pull_request: + paths: + - .github/workflows/mlx.yml + - backends/mlx/** + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-mlx: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch and configure build" + ${CONDA_RUN} python install_executorch.py > /dev/null + # The sanitizers fail on github VM runner, but pass on real device + # TODO: figure out why + ${CONDA_RUN} cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON -DEXECUTORCH_MLX_ENABLE_SANITIZERS=OFF + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Build test runners" + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + echo "::endgroup::" + + echo "::group::Run op unit tests" + ${CONDA_RUN} python -m executorch.backends.mlx.test.run_all_tests -j4 --max-tasks-per-worker 10 --clean-after + echo "::endgroup::" + + echo "::group::Run Python unit tests" + ${CONDA_RUN} python -m pytest \ + backends/mlx/test/test_passes.py \ + backends/mlx/test/test_pattern_utils.py \ + backends/mlx/test/test_partitioner.py \ + -v + echo "::endgroup::" + + backend-tester: + strategy: + fail-fast: false + matrix: + suite: [models, operators] + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-backend-${{ matrix.suite }} + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Run backend test suite (${{ matrix.suite }})" + ${CONDA_RUN} pytest -c /dev/null backends/test/suite/${{ matrix.suite }}/ -m flow_mlx -n auto 2>&1 | tee pytest_output.txt || true + echo "::endgroup::" + + # Parse pytest summary and check failure threshold + if grep -E "^=+ .* =+$" pytest_output.txt | tail -1 | grep -q "failed"; then + FAILED=$(grep -E "^=+ .* =+$" pytest_output.txt | tail -1 | grep -oE "[0-9]+ failed" | grep -oE "[0-9]+") + else + FAILED=0 + fi + + if [ "${{ matrix.suite }}" = "operators" ]; then + MAX_FAILURES=0 + else + MAX_FAILURES=3 + fi + + echo "Failed tests: $FAILED (max allowed: $MAX_FAILURES)" + if [ "$FAILED" -gt "$MAX_FAILURES" ]; then + echo "::error::Too many test failures: $FAILED > $MAX_FAILURES" + exit 1 + fi diff --git a/.gitignore b/.gitignore index 4ddbb7c49ad..3453b7e9676 100644 --- a/.gitignore +++ b/.gitignore @@ -74,5 +74,7 @@ xcuserdata/ *.dll *.pyd + # Agents .claude/*.local.* +extension/pybindings/mlx.metallib diff --git a/.gitmodules b/.gitmodules index 1f202d4fdec..917e755da27 100644 --- a/.gitmodules +++ b/.gitmodules @@ -67,3 +67,7 @@ [submodule "third-party/json"] path = third-party/json url = https://github.com/nlohmann/json.git +[submodule "backends/mlx/third-party/mlx"] + path = backends/mlx/third-party/mlx + url = https://github.com/ml-explore/mlx.git + shallow = true diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a7595d99f4..6126b421919 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -682,6 +682,11 @@ if(EXECUTORCH_BUILD_MPS) list(APPEND _executorch_backends mpsdelegate) endif() +if(EXECUTORCH_BUILD_MLX) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mlx) + list(APPEND _executorch_backends mlxdelegate) +endif() + if(EXECUTORCH_BUILD_NEURON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek) list(APPEND _executorch_backends neuron_backend) @@ -979,6 +984,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs mpsdelegate) endif() + if(EXECUTORCH_BUILD_MLX) + list(APPEND _dep_libs mlxdelegate) + endif() + if(EXECUTORCH_BUILD_OPENVINO) list(APPEND _dep_libs openvino_backend) endif() @@ -1079,6 +1088,12 @@ if(EXECUTORCH_BUILD_PYBIND) install(TARGETS data_loader LIBRARY DESTINATION executorch/extension/pybindings ) + + # Copy MLX metallib next to _portable_lib.so for editable installs. MLX uses + # dladdr() to find the directory containing the library with MLX code, then + # looks for mlx.metallib in that directory. When MLX is statically linked into + # _portable_lib.so, we need the metallib colocated with it. + executorch_target_copy_mlx_metallib(portable_lib) endif() if(EXECUTORCH_BUILD_WASM) diff --git a/CMakePresets.json b/CMakePresets.json index ffcab2b3d6b..f12fabbc232 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -110,7 +110,7 @@ "inherits": ["common"], "cacheVariables": { "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/pybind.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0" }, "condition": { "type": "inList", @@ -310,6 +310,43 @@ "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_ethosu_linux.cmake", "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/aarch64-linux-musl-toolchain.cmake" } + }, + { + "name": "mlx", + "displayName": "Build MLX delegate", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/mlx.cmake", + "EXECUTORCH_ENABLE_LOGGING": "ON", + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "mlx-release", + "displayName": "MLX delegate release build", + "inherits": ["mlx"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "ET_MLX_ENABLE_OP_LOGGING": "OFF", + "ET_MIN_LOG_LEVEL": "Error" + } + }, + { + "name": "mlx-debug", + "displayName": "MLX delegate debug build with op logging", + "inherits": ["mlx"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "ET_MLX_ENABLE_OP_LOGGING": "ON", + "ET_MIN_LOG_LEVEL": "Debug" + } } ], "buildPresets": [ @@ -387,6 +424,24 @@ "install" ], "jobs": 0 + }, + { + "name": "mlx-release-install", + "displayName": "Build and install MLX delegate release artifacts", + "configurePreset": "mlx-release", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "mlx-debug-install", + "displayName": "Build and install MLX delegate debug artifacts", + "configurePreset": "mlx-debug", + "targets": [ + "install" + ], + "jobs": 0 } ], "workflowPresets": [ @@ -501,6 +556,34 @@ "name": "llm-metal-stats-install" } ] + }, + { + "name": "mlx-release", + "displayName": "Configure, build and install ExecuTorch MLX delegate", + "steps": [ + { + "type": "configure", + "name": "mlx-release" + }, + { + "type": "build", + "name": "mlx-release-install" + } + ] + }, + { + "name": "mlx-debug", + "displayName": "Configure, build and install ExecuTorch MLX delegate with op logging (Debug)", + "steps": [ + { + "type": "configure", + "name": "mlx-debug" + }, + { + "type": "build", + "name": "mlx-debug-install" + } + ] } ] } diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt new file mode 100644 index 00000000000..00e7c497b1c --- /dev/null +++ b/backends/mlx/CMakeLists.txt @@ -0,0 +1,330 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_compile_options -Wall -Werror -Wno-deprecated-declarations) + +# Sanitizer flags (asan + ubsan) for security hardening — CI only. Enable via: +# cmake --preset mlx-release -DEXECUTORCH_MLX_ENABLE_SANITIZERS=ON +option(EXECUTORCH_MLX_ENABLE_SANITIZERS + "Enable ASan + UBSan for MLX delegate and tests" OFF +) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + list(APPEND _common_compile_options -fsanitize=address,undefined + -fno-omit-frame-pointer + ) + set(_mlx_sanitizer_link_options -fsanitize=address,undefined) +endif() + +# ----------------------------------------------------------------------------- +# Code generation from schema.fbs +# ----------------------------------------------------------------------------- +# +# The generate.py script generates all code from schema.fbs: Python: +# mlx_graph_schema.py, _generated_serializers.py, _generated/ C++: MLXLoader.h, +# MLXLoader.cpp, schema_generated.h +# +# We run generate.py at build time so these files don't need to be checked in. +# ----------------------------------------------------------------------------- + +set(_mlx_generate_script + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/generate.py" +) +set(_mlx_schema_fbs "${CMAKE_CURRENT_SOURCE_DIR}/serialization/schema.fbs") + +# Generated C++ files that we need for compilation +set(_mlx_generated_cpp_files + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/schema_generated.h" + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.h" + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp" +) + +# Generated Python files (tracked for dependency purposes) +set(_mlx_generated_python_files + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/mlx_graph_schema.py" + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/_generated_serializers.py" +) + +# Run generate.py to create all generated files from schema.fbs Find Python - +# prefer Python3_EXECUTABLE if set (from FindPython3), otherwise use +# PYTHON_EXECUTABLE +if(Python3_EXECUTABLE) + set(_python_executable ${Python3_EXECUTABLE}) +elseif(PYTHON_EXECUTABLE) + set(_python_executable ${PYTHON_EXECUTABLE}) +else() + find_package( + Python3 + COMPONENTS Interpreter + REQUIRED + ) + set(_python_executable ${Python3_EXECUTABLE}) +endif() + +add_custom_command( + OUTPUT ${_mlx_generated_cpp_files} ${_mlx_generated_python_files} + COMMAND ${_python_executable} ${_mlx_generate_script} --flatc + $ + WORKING_DIRECTORY ${EXECUTORCH_ROOT} + DEPENDS ${_mlx_schema_fbs} ${_mlx_generate_script} flatc + COMMENT "Generating MLX delegate code from schema.fbs" + VERBATIM +) + +# Custom target to trigger generation +add_custom_target( + mlx_generate_code DEPENDS ${_mlx_generated_cpp_files} + ${_mlx_generated_python_files} +) + +# Interface library for schema includes +add_library(mlx_schema INTERFACE) +add_dependencies(mlx_schema mlx_generate_code) +target_include_directories( + mlx_schema + INTERFACE + $ + $ +) + +# ----------------------------------------------------------------------------- +# MLX dependency (from submodule) +# ----------------------------------------------------------------------------- + +set(MLX_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third-party/mlx) + +# Check that submodule is initialized +if(NOT EXISTS "${MLX_SOURCE_DIR}/CMakeLists.txt") + message( + FATAL_ERROR "MLX submodule not initialized.\n" + "Run: git submodule update --init backends/mlx/third-party/mlx" + ) +endif() + +# Validate deployment target - MLX requires macOS 14.0+ / iOS 17.0+ +# +# The macOS preset uses ios.toolchain.cmake (with PLATFORM=MAC_ARM64), so +# DEPLOYMENT_TARGET is set for both macOS and iOS builds. We check PLATFORM to +# distinguish them rather than relying on which variable is set. +set(_mlx_deployment_target_ok TRUE) +if(PLATFORM AND PLATFORM MATCHES "^MAC") + # macOS build via ios.toolchain.cmake (e.g., MAC_ARM64, MAC_UNIVERSAL) + if(DEPLOYMENT_TARGET) + set(_mlx_dt_value ${DEPLOYMENT_TARGET}) + elseif(CMAKE_OSX_DEPLOYMENT_TARGET) + set(_mlx_dt_value ${CMAKE_OSX_DEPLOYMENT_TARGET}) + endif() + if(_mlx_dt_value AND _mlx_dt_value VERSION_LESS "14.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${_mlx_dt_value}) + set(_mlx_deployment_target_min "14.0") + endif() +elseif(DEPLOYMENT_TARGET) + # iOS/tvOS/watchOS/visionOS builds via ios.toolchain.cmake + if(DEPLOYMENT_TARGET VERSION_LESS "17.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${DEPLOYMENT_TARGET}) + set(_mlx_deployment_target_min "17.0") + endif() +elseif(CMAKE_OSX_DEPLOYMENT_TARGET) + # Plain macOS build (no ios.toolchain.cmake) + if(CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS "14.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${CMAKE_OSX_DEPLOYMENT_TARGET}) + set(_mlx_deployment_target_min "14.0") + endif() +endif() + +if(NOT _mlx_deployment_target_ok) + message( + FATAL_ERROR + "MLX requires deployment target >= ${_mlx_deployment_target_min}, got ${_mlx_deployment_target_value}.\n" + "Either increase the deployment target or disable MLX with -DEXECUTORCH_BUILD_MLX=OFF" + ) +endif() + +# MLX build options - we only need the C++ library with Metal +set(MLX_BUILD_PYTHON_BINDINGS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_TESTS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_EXAMPLES + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_BENCHMARKS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_PYTHON_STUBS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_CUDA + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_CPU + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_METAL + ON + CACHE BOOL "" FORCE +) +set(MLX_BUILD_SHARED_LIBS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_GGUF + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_SAFETENSORS + OFF + CACHE BOOL "" FORCE +) +set(MLX_METAL_JIT + ON + CACHE BOOL "Use JIT compiled Metal kernels" +) + +# Auto-apply patches to MLX submodule. Each patch is applied idempotently: `git +# apply --check` tests whether the patch is still applicable (i.e. not yet +# applied), and only then applies it. +set(_mlx_patches "${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch") +foreach(_patch IN LISTS _mlx_patches) + if(EXISTS "${_patch}" AND EXISTS "${MLX_SOURCE_DIR}") + get_filename_component(_patch_name "${_patch}" NAME) + execute_process( + COMMAND git apply --check "${_patch}" + WORKING_DIRECTORY ${MLX_SOURCE_DIR} + RESULT_VARIABLE _patch_check + OUTPUT_QUIET ERROR_QUIET + ) + if(_patch_check EQUAL 0) + execute_process( + COMMAND git apply "${_patch}" WORKING_DIRECTORY ${MLX_SOURCE_DIR} + ) + message(STATUS "Applied ${_patch_name} to MLX submodule") + else() + message(STATUS "${_patch_name} already applied or not applicable") + endif() + endif() +endforeach() + +# Add MLX subdirectory +message(STATUS "Adding MLX from submodule: ${MLX_SOURCE_DIR}") +add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx) + +# ----------------------------------------------------------------------------- +# MLX Backend library +# ----------------------------------------------------------------------------- + +# Op logging option (for debugging) - OFF by default for performance +option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF) + +set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp +) + +add_library(mlxdelegate ${_mlx_backend__srcs}) + +# Ensure schema is generated before compiling +add_dependencies(mlxdelegate mlx_schema) + +# Add logging flag if enabled +if(ET_MLX_ENABLE_OP_LOGGING) + target_compile_definitions(mlxdelegate PRIVATE ET_MLX_ENABLE_OP_LOGGING=1) + message(STATUS "MLX delegate op logging ENABLED") +endif() + +target_include_directories( + mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime +) + +# Link against MLX and executorch mlx is only available at BUILD_INTERFACE - +# consumers must link to mlx separately +target_link_libraries( + mlxdelegate PRIVATE mlx_schema executorch_core $ +) + +executorch_target_link_options_shared_lib(mlxdelegate) +target_compile_options(mlxdelegate PRIVATE ${_common_compile_options}) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options(mlxdelegate PRIVATE ${_mlx_sanitizer_link_options}) +endif() + +install( + TARGETS mlxdelegate mlx_schema + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# Install mlx library for downstream consumers +install(TARGETS mlx DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +# Install mlx headers for downstream consumers that may need mlx types +install( + DIRECTORY ${MLX_SOURCE_DIR}/mlx/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/mlx + FILES_MATCHING + PATTERN "*.h" +) + +# Install mlx.metallib (Metal GPU kernels) for runtime execution +# +# MLX searches for metallib in this order (see mlx/backend/metal/device.cpp): 1. +# {binary_dir}/mlx.metallib - colocated with the .so/.dylib 2. +# {binary_dir}/Resources/mlx/ - Resources subdirectory 3. SwiftPM bundle - +# not applicable for us 4. {binary_dir}/Resources/default/ - Resources +# subdirectory 5. METAL_PATH (compile-time) - hardcoded build path (won't +# exist) +# +# where {binary_dir} is determined at runtime via dladdr() on the library +# containing MLX code. When MLX is statically linked into _portable_lib.so, this +# is the directory containing _portable_lib.so. +# +# For the installed library, we put metallib in lib/ alongside libmlx.a +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# Cache the metallib path for pybindings to copy it next to _portable_lib.so +# This enables editable installs to work correctly +set(MLX_METALLIB_PATH + "${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib" + CACHE INTERNAL "Path to mlx.metallib for pybindings" +) + +# ----------------------------------------------------------------------------- +# Tests (off by default; CI passes -DEXECUTORCH_BUILD_TESTS=ON) +# ----------------------------------------------------------------------------- + +if(EXECUTORCH_BUILD_TESTS) + add_subdirectory(test) +endif() diff --git a/backends/mlx/README.md b/backends/mlx/README.md new file mode 100644 index 00000000000..ebab893385a --- /dev/null +++ b/backends/mlx/README.md @@ -0,0 +1,499 @@ +# MLX Delegate for ExecuTorch + +The MLX delegate compiles PyTorch models to run on Apple Silicon GPUs via the +[MLX](https://github.com/ml-explore/mlx) framework. It consists of: + +- A **Python compilation pipeline** that converts ExportedPrograms (Edge IR) into + a custom FlatBuffer bytecode format. +- A **C++ runtime** that loads the bytecode and executes it using MLX GPU + primitives. + +> **Adding a new op?** Jump to [How to Add a New Op](#how-to-add-a-new-op). + +## Getting Started + +The MLX delegate requires **Apple Silicon** (M1 or later) and the **Metal +compiler**, which ships with Xcode (not the standalone Command Line Tools). + +**Check if Metal is available:** + +```bash +xcrun -sdk macosx --find metal +``` + +If this prints a path (e.g. `/Applications/Xcode.app/.../metal`), you're set. +If it errors, you either need to install Xcode from the +[App Store](https://apps.apple.com/us/app/xcode/id497799835) or +, or — if Xcode is already installed but the +command line developer directory points at Command Line Tools — switch it: + +```bash +sudo xcode-select -s /Applications/Xcode.app/Contents/Developer +``` + +### Python (pybindings) + +The simplest way to get started is to install ExecuTorch with Python bindings. +From the repo root: + +```bash +python install_executorch.py +``` + +This builds and installs the `executorch` pip package with pybindings. On Apple +Silicon, when the Metal compiler is available, the MLX backend is automatically +included. You can then export models in Python using the MLX partitioner and run +them via the ExecuTorch Python API. + +### C++ (CMake preset) + +To build the C++ runtime with the MLX delegate, use the `mlx-release` CMake +workflow preset from the repo root: + +```bash +cmake --workflow --preset mlx-release +``` + +This configures and builds a Release build of the ExecuTorch runtime with the +MLX delegate and installs artifacts into `cmake-out/`. The preset enables the +MLX delegate along with commonly needed extensions (module, data loader, flat +tensor, LLM runner, etc.). + +Downstream C++ apps can then `find_package(executorch)` and link against +`mlxdelegate` and `mlx`. See +[`examples/models/llama/CMakeLists.txt`](../../examples/models/llama/CMakeLists.txt) +for a working example. + +There is also an `mlx-debug` preset that enables debug symbols and compiles in +per-op logging support, which is useful during development: + +```bash +cmake --workflow --preset mlx-debug +``` + +The debug build compiles in the logging code, but to actually see per-op output +you must also set the environment variable when running the binary: + +```bash +ET_MLX_ENABLE_OP_LOGGING=1 ./cmake-out/my_app +``` + +### Debugging + +Set `ET_MLX_DEBUG=1` during AOT (export/compilation) to see detailed debug +logging from the partitioner and preprocessor — including ops-to-not-decompose +lists, graph dumps, per-node support decisions, and serialization details: + +```bash +ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llm_hf ... +``` + +--- + +## Directory Layout + +``` +backends/mlx/ +├── serialization/ # Schema + code generation +│ ├── schema.fbs # ← Source of truth (FlatBuffer schema) +│ ├── generate.py # Code generator (schema.fbs → everything else) +│ ├── mlx_graph_schema.py # [GENERATED] Python dataclasses for IR nodes +│ ├── mlx_graph_serialize.py # Serialization to FlatBuffer binary +│ ├── _generated_serializers.py # [GENERATED] Per-op FlatBuffer builders +│ └── _generated/ # [GENERATED] FlatBuffer Python bindings (flatc) +├── runtime/ # C++ runtime (loaded at inference time) +│ ├── MLXBackend.cpp # BackendInterface (init / execute / destroy) +│ ├── MLXLoader.h/.cpp # [GENERATED] FlatBuffer → C++ structs +│ ├── MLXExecutor.h # ExecutionState, constant loading, helpers +│ ├── MLXInterpreter.h # Op dispatch loop + per-op exec_* functions +│ └── schema_generated.h # [GENERATED] FlatBuffer C++ bindings (flatc) +├── llm/ # LLM infrastructure (KV cache, attention, etc.) +│ ├── cache.py # KV cache implementations (ET + HF static cache) +│ ├── et_attention.py # ExecuTorch custom SDPA attention +│ ├── hf_attention.py # HuggingFace custom SDPA attention +│ ├── quantization.py # TorchAO quantization helpers +│ └── source_transformation.py # Source transforms for MLX export +├── _generated_inspector.py # [GENERATED] Inspector utilities for .pte debugging +├── _logging.py # Debug logging utilities (ET_MLX_DEBUG) +├── builder/ # Core build infrastructure +│ ├── op_registry.py # REGISTRY (op handler registration) +│ ├── op_helpers.py # Helper utilities for op handlers +│ ├── pattern_matcher.py # Pattern matching for multi-node fusions +│ ├── program_builder.py # MLXProgramBuilder +│ └── slot_manager.py # Tensor/value slot allocation +├── ops.py # Op handlers (ATen target → MLX IR node) +├── patterns.py # Pattern handlers (multi-node fusions) +├── passes.py # Graph passes (RMSNorm fusion, CSE, etc.) +├── pattern_utils.py # Pattern matching utilities for passes +├── partitioner.py # Decides which ops to delegate to MLX +├── preprocess.py # BackendDetails.preprocess() entry point +├── custom_ops.py # Custom torch ops (kv_cache_update, custom_sdpa, rope) +├── pte_inspector.py # .pte file inspection/debugging tool +├── test/ +│ ├── test_ops.py # Op test definitions (models + configs) +│ ├── test_utils.py # OpTestCase base class + helpers +│ ├── op_test_runner.cpp # C++ test runner (loads .pte, runs, compares) +│ └── run_all_tests.py # End-to-end: export → C++ run → compare +└── examples/ + ├── llm/ # LLM export + run via HuggingFace + └── whisper/ # Whisper export + run +``` + +Files marked **[GENERATED]** are NOT CHECKED IN CODE and are produced by running: + +```bash +python backends/mlx/serialization/generate.py +``` + +--- + +## Compilation Pipeline + +The compilation pipeline converts a PyTorch model into a `.pte` file containing +the MLX delegate payload. The high-level flow: + +``` +torch.export() → ExportedProgram (ATen IR) +to_edge_transform_and_lower() → Edge IR + partitioning + lowering +``` + +Within that flow, the MLX-specific steps are: + +1. **Partitioning** (`partitioner.py`) — `MLXPartitioner` walks the Edge IR + graph and tags nodes that MLX can handle. It uses `MLXProgramBuilder` in a + dry-run mode to determine support — so partitioning and compilation use the + exact same logic. Unsupported ops fall back to ExecuTorch's portable + runtime. + +2. **Preprocessing** (`preprocess.py`) — For each partitioned subgraph, + `MLXBackend.preprocess()` is called. It builds an `MLXGraph` via + `MLXProgramBuilder`, serializes it to FlatBuffer, and returns a + `PreprocessResult` with the binary payload and constant data. + +3. **Op handling** (`ops.py`, `patterns.py`) — During the build, + `MLXProgramBuilder` walks the FX graph node-by-node and dispatches to + registered handlers. Single-op handlers live in `ops.py`; multi-node fused + patterns (e.g., quantized linear, SDPA, KV cache update) live in + `patterns.py`. + +4. **Serialization** (`serialization/`) — The `MLXGraph` dataclass tree is + serialized to a FlatBuffer binary. See [Serialization](#serialization) below. + +The complete preprocessing flow: + +``` +ExportedProgram (subgraph) + → MLXProgramBuilder.build() # walks FX graph, calls op handlers + → MLXGraph # Python IR (dataclasses from mlx_graph_schema.py) + → MLXGraphSerializer.serialize() # FlatBuffer binary + → PreprocessResult # returned to ExecuTorch +``` + +--- + +## How to Add a New Op + +This section walks through adding a new op end-to-end, using **`aten.linear`** +as an example. + +### Step 1: Add the Node to `schema.fbs` + +Add a new table in the "Op nodes" section and add it to the `OpNode` union: + +```fbs +table LinearNode { + x: Tid (required); + weight: Tid (required); + out: Tid (required); + bias: Tid; // optional +} +``` + +Then add `LinearNode` to the `union OpNode { ... }` list. + +### Step 2: Run the Code Generator + +```bash +python backends/mlx/serialization/generate.py +``` + +This regenerates: + +- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass +- `_generated_serializers.py` — adds `_build_LinearNode` serializer +- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader +- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization +- `runtime/schema_generated.h` — FlatBuffer C++ bindings + +### Step 3: Add the Python Op Handler (`ops.py`) + +Register a handler that converts the ATen op to your new node. Make sure to +import `LinearNode` from `mlx_graph_schema`: + +```python +from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode + +@REGISTRY.register(target=[torch.ops.aten.linear.default]) +def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.linear") + require_kwargs(P.kwargs(n), set(), "aten.linear") + x, w = args[0], args[1] + b = args[2] if len(args) > 2 else None + out = P.make_or_get_slot(n) + P.emit( + LinearNode( + x=P.slot_to_tid(x), + weight=P.slot_to_tid(w), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(b) if b else None, + ) + ) + return out +``` + +Key APIs: +- **`P.args(n)`** — resolves FX node args to `Slot` objects (tensor/value references) +- **`P.make_or_get_slot(n)`** — allocates the output tensor slot +- **`P.slot_to_tid(slot)`** — converts a `Slot` to a `Tid` for the IR node +- **`P.emit(node)`** — appends the instruction to the graph + +### Step 4: Add the C++ Op Handler (`MLXInterpreter.h`) + +Add an `exec_*` function in the `ops` namespace: + +```cpp +inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& X = st.const_tensor_ref(n.x); + auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s); + array Y = n.bias + ? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s) + : matmul(X, W, s); + st.set_tensor(n.out, std::move(Y)); +} +``` + +Then add the dispatch case in `Interpreter::execute_instruction()`: + +```cpp +case OpCode::LINEAR: + ops::exec_linear(std::get(instr.node), st, s); + break; +``` + +### Step 5: Write a Test (`test/test_ops.py`) + +Each test follows a standard pattern: + +1. **Define a `nn.Module`** that uses the op. +2. **Define an `OpTestCase` subclass** that specifies test configurations. +3. **Decorate with `@register_test`** to register it with the test runner. + +```python +class LinearModel(nn.Module): + def __init__(self, in_features=64, out_features=128, bias=True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +@register_test +class LinearTest(OpTestCase): + name = "linear" + rtol = 1e-4 + atol = 1e-4 + + def __init__(self, in_features=64, out_features=128, bias=True): + self.in_features = in_features + self.out_features = out_features + self.bias = bias + + @classmethod + def get_test_configs(cls): + return [cls(), cls(bias=False)] + + def create_model(self): + return LinearModel(self.in_features, self.out_features, bias=self.bias) + + def create_inputs(self): + return (torch.randn(2, 16, self.in_features),) +``` + +### Step 6: Run Tests + +Tests are end-to-end: export `.pte` → run via C++ `op_test_runner` → compare +outputs against PyTorch reference. Since adding a new op always involves C++ +changes, use `--rebuild` to recompile the runtime: + +```bash +python -m executorch.backends.mlx.test.run_all_tests --rebuild linear +``` + +Run all tests in parallel: + +```bash +python -m executorch.backends.mlx.test.run_all_tests --rebuild -j4 --clean-after +``` + +Other useful flags: + +| Flag | Purpose | +|---|---| +| `--rebuild` | Rebuild the C++ `op_test_runner` before running | +| `-j N` / `--parallel N` | Run N tests in parallel | +| `--clean-after` | Remove generated test artifacts after running | +| `--list` | List all available test names and exit | +| `-v` / `--verbose` | Verbose output | + +Test artifacts are saved to `test/op_tests//` (`.pte`, input/output +`.bin` files). See [`test/README.md`](test/README.md) for full details on test +architecture, prerequisites, and the `OpTestCase` API. + +### Checklist + +- [ ] Add `*Node` table to `schema.fbs` + add to `OpNode` union +- [ ] Run `python backends/mlx/serialization/generate.py` +- [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class) +- [ ] Add `exec_*` function in `runtime/MLXInterpreter.h` +- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()` +- [ ] Add test model + `OpTestCase` in `test/test_ops.py` +- [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild ` + +--- + +## Serialization + +### Overview + +The serialization system converts a Python `MLXGraph` dataclass tree into a +FlatBuffer binary that the C++ runtime can load. The source of truth is +**`schema.fbs`** — a single FlatBuffer schema file from which all code on both +sides is generated. + +### Schema (`schema.fbs`) + +The schema defines: + +| Concept | FlatBuffer type | Purpose | +|---|---|---| +| **`Tid`** | struct | Tensor slot index (indexes into the runtime tensor array) | +| **`Vid`** | struct | Value slot index (for scalar `int32`/`float`/`bool` values) | +| **`IntOrVid`** | table | A field that is either a literal `int64` or a runtime `Vid` reference (for dynamic shapes) | +| **`FloatOrVid`** | table | Same idea for floats | +| **`TidOrVid`** | table | Either a tensor or a scalar value | +| **Op node tables** | table | One per op (e.g. `AddNode`, `SiluNode`, `ReshapeNode`). Each declares its inputs/outputs as `Tid`/`Vid` references and any scalar parameters. | +| **`OpNode`** | union | Union of all op node tables | +| **`Instruction`** | table | Wraps an `OpNode` union | +| **`MLXGraph`** | table (root) | The complete program: slot counts, instruction list, I/O maps, named slots, tensor metadata | + +Key design points: + +- **No embedded weights.** Constants are stored in ExecuTorch's `named_data_map` + and loaded by name at runtime. This enables zero-copy on unified memory. +- **Tensor IDs (`Tid`) are globally ordered:** Constants → Inputs → Outputs → + Mutable Buffers → Temps. The runtime uses this ordering for O(1) type lookup. +- **Dynamic shapes** are supported via `IntOrVid` — a shape dimension can be + either a literal integer or a reference to a runtime value produced by + `sym_size` / `item()` ops. + +### Code Generation (`generate.py`) + +`generate.py` parses `schema.fbs` and generates **all** boilerplate on both the +Python and C++ sides: + +| Generated file | What it contains | +|---|---| +| `mlx_graph_schema.py` | Python `@dataclass` for every op node, `Tid`, `Vid`, `IntOrVid`, etc. | +| `_generated_serializers.py` | `GeneratedOpBuilders` mixin class with `_build_*Node` methods for every op | +| `_generated_inspector.py` | Inspector utilities for debugging `.pte` files | +| `runtime/MLXLoader.h` | C++ structs for every op node, `OpCode` enum, `NodeVariant`, `Instruction`, `MLXProgram` | +| `runtime/MLXLoader.cpp` | `load_instruction()` and `load_program()` — FlatBuffer → C++ struct conversion | +| `runtime/schema_generated.h` | Standard FlatBuffer C++ bindings (via `flatc`) | +| `_generated/` directory | Standard FlatBuffer Python bindings (via `flatc`) | + +Running the generator: + +```bash +python backends/mlx/serialization/generate.py +``` + +Use `--skip-flatc` if you only changed op node definitions (not core types) and +want to skip the `flatc` invocation. + +### Serialization Format + +The binary payload embedded in the `.pte` file has this layout: + +``` +[Header: 24 bytes] + 4 bytes padding (zeros) + 4 bytes magic ("MLX0") + 8 bytes data_segment_offset (uint64 LE) + 8 bytes data_segment_size (uint64 LE) +[FlatBuffer payload] +[Padding to 16-byte alignment] +[Data segment (currently unused — constants go via named_data_map)] +``` + +The `MLXGraphSerializer` class (in `mlx_graph_serialize.py`) drives +serialization. It inherits `GeneratedOpBuilders` for the per-op builders and +adds the root-table construction, I/O maps, tensor metadata, and header. + +--- + +## Runtime + +### Initialization (`init`) + +When ExecuTorch loads a `.pte` with an MLX delegate blob, `MLXBackend::init()` +is called: + +1. **Parse FlatBuffer** — `loader::load_program()` deserializes the binary into + an `MLXProgram` struct (C++ mirrors of the schema). +2. **Load constants** — Iterates `named_slots`, calls + `named_data_map->get_data(name)` for each constant tensor, wraps the buffer + as an `mlx::core::array` (zero-copy when possible on unified memory). +3. **Initialize mutable buffers** — Creates zero-filled MLX arrays for + persistent state (e.g., KV cache). These live across `execute()` calls. +4. **Bind execution state** — `ExecutionState::bind()` pre-computes tensor ID + ranges for O(1) routing. + +### Execution (`execute`) + +Each `execute()` call: + +1. **Reset** per-execution state (inputs/outputs/temps cleared; mutable buffers + and constants are retained). +2. **Bind inputs** — Walk `input_map`, convert each ExecuTorch tensor to an + `mlx::core::array` (zero-copy pointer wrap). +3. **Run instructions** — `Interpreter::run()` dispatches each `Instruction` + through a `switch` on `OpCode`, calling the corresponding `exec_*` function. +4. **Evaluate** — Call `mlx::core::eval()` on output tensors to trigger + lazy GPU computation. +5. **Copy outputs** — Convert MLX arrays back to ExecuTorch tensors via + `memcpy`. + +### Tensor ID Layout + +Tensor slot IDs are assigned in a fixed order during compilation: + +``` + ┌──────────┬──────────┬──────────┬────────────────┬──────────┐ + │ Constants│ Inputs │ Outputs │ Mutable Buffers│ Temps │ + │ 0..C-1 │ C..I-1 │ I..O-1 │ O..M-1 │ M..T-1 │ + └──────────┴──────────┴──────────┴────────────────┴──────────┘ +``` + +The runtime stores constants and mutable buffers in separate containers +(`ConstantData`, `MutableBufferData`). Inputs, outputs, and temps share a flat +`vector>` in `ExecutionState`. + +### Key Runtime Files + +| File | Role | +|---|---| +| `MLXBackend.cpp` | `init()` / `execute()` / `destroy()` — the ExecuTorch `BackendInterface` | +| `MLXLoader.h/.cpp` | [GENERATED] Deserializes FlatBuffer into `MLXProgram` (C++ structs) | +| `MLXExecutor.h` | `ExecutionState`, `ConstantData`, `MutableBufferData`, constant loading, dtype conversion | +| `MLXInterpreter.h` | The op dispatch switch + all `exec_*` implementations | diff --git a/backends/mlx/__init__.py b/backends/mlx/__init__.py new file mode 100644 index 00000000000..48f4c28f5ca --- /dev/null +++ b/backends/mlx/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""MLX backend for ExecuTorch - executes models on Apple Silicon using MLX.""" + +# Import custom_ops module to register custom ATen ops (rope, etc.) +from executorch.backends.mlx import custom_ops as _custom_ops # noqa: F401 +from executorch.backends.mlx.partitioner import MLXPartitioner + +from executorch.backends.mlx.preprocess import MLXBackend + +__all__ = ["MLXBackend", "MLXPartitioner"] diff --git a/backends/mlx/_logging.py b/backends/mlx/_logging.py new file mode 100644 index 00000000000..eff472550f9 --- /dev/null +++ b/backends/mlx/_logging.py @@ -0,0 +1,40 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Centralized logging for the MLX backend. + +Usage: + from executorch.backends.mlx._logging import logger + + logger.info("Always visible (e.g., unsupported ops summary)") + logger.debug("Only visible when ET_MLX_DEBUG=1") + logger.warning("Always visible") + +The logger is set to INFO by default, so logger.info() always prints. +Set ET_MLX_DEBUG=1 to lower the threshold to DEBUG for verbose output +(graph dumps, per-node traces, ops_to_not_decompose lists, etc.). +""" + +import logging +import os + +_MLX_DEBUG = os.environ.get("ET_MLX_DEBUG", "") not in ("", "0") + +logger = logging.getLogger("executorch.backends.mlx") +logger.setLevel(logging.DEBUG if _MLX_DEBUG else logging.INFO) +logger.propagate = False + +if not logger.handlers: + _handler = logging.StreamHandler() + _handler.setFormatter( + logging.Formatter( + "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" + ) + ) + logger.addHandler(_handler) diff --git a/backends/mlx/builder/__init__.py b/backends/mlx/builder/__init__.py new file mode 100644 index 00000000000..ce793ed9a15 --- /dev/null +++ b/backends/mlx/builder/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +# Trigger op/pattern handler registration. +# ops.py and patterns.py use @REGISTRY.register() decorators at import time. +# This must happen after REGISTRY is defined (in op_registry.py). +from executorch.backends.mlx import ops, patterns # noqa: F401 +from executorch.backends.mlx.builder.op_registry import REGISTRY # noqa: F401 +from executorch.backends.mlx.builder.program_builder import ( # noqa: F401 + MLXProgramBuilder, +) diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py new file mode 100644 index 00000000000..5e082cdf386 --- /dev/null +++ b/backends/mlx/builder/op_helpers.py @@ -0,0 +1,275 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.exir.scalar_type import ScalarType +from torch.fx.node import Node + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + +def get_aten_target(target): + """ + Unwrap EdgeOpOverload to get the underlying ATen op. + + In Edge IR, ops are wrapped in EdgeOpOverload. This extracts the + underlying ATen op for consistent comparison. + """ + if hasattr(target, "_op") and "EdgeOpOverload" in type(target).__name__: + return target._op + return target + + +# Mapping from _copy variants to their non-copy equivalents. +# Edge IR uses _copy variants for certain ops, but for pattern matching +# we want to compare against the semantic operation. +_COPY_TO_NON_COPY = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + torch.ops.aten.transpose_copy.int: torch.ops.aten.transpose.int, + torch.ops.aten.view_copy.default: torch.ops.aten.view.default, + torch.ops.aten.permute_copy.default: torch.ops.aten.permute.default, + torch.ops.aten.unsqueeze_copy.default: torch.ops.aten.unsqueeze.default, + torch.ops.aten.squeeze_copy.dim: torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dims: torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze_copy.default: torch.ops.aten.squeeze.default, + torch.ops.aten.expand_copy.default: torch.ops.aten.expand.default, + torch.ops.aten.alias_copy.default: torch.ops.aten.alias.default, +} + + +def get_aten_target_normalized(target): + """ + Get ATen target, mapping _copy variants to their non-copy equivalents. + + Use this for pattern matching where Edge IR uses _copy variants but + we want to match the semantic operation. + + E.g., aten.transpose_copy.int -> aten.transpose.int + """ + target = get_aten_target(target) + return _COPY_TO_NON_COPY.get(target, target) + + +def emit_stop_position( + P: "MLXProgramBuilder", + start: "Union[int, Slot]", + length_tensor: "Slot", + length_dim: int, + length_meta: "Optional[torch.Tensor]" = None, +) -> "Union[int, Slot]": + """ + Emit nodes to compute stop = start + length for slice operations. + + May emit SymSizeNode and/or AddIntNode depending on whether + start and length are static or dynamic. + + Args: + P: The program builder + start: Start position (int or Slot) + length_tensor: The tensor slot whose dimension gives the length + length_dim: Which dimension of length_tensor contains the length + length_meta: Optional tensor metadata for static length extraction + + Returns: + stop position as int (if fully static) or Slot (if any dynamic) + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + IntOrVid, + SymSizeNode, + ) + + # Check if seq_len is symbolic (dynamic) + seq_len_is_symbolic = False + seq_len_concrete = None + + if length_meta is not None: + seq_len_dim = length_meta.shape[length_dim] + if hasattr(seq_len_dim, "node"): + seq_len_is_symbolic = True + else: + seq_len_concrete = int(seq_len_dim) + + if seq_len_is_symbolic or length_meta is None: + # Dynamic seq_len: emit SymSizeNode to get length at runtime + _, seq_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(length_tensor), + dim=length_dim, + out=P.slot_to_vid(seq_len_slot), + ) + ) + _, stop_slot = P.slot_manager.make_tmp_value_slot() + if isinstance(start, Slot): + start_iov = P.to_int_or_vid(start) + else: + start_iov = IntOrVid.from_literal(int(start)) + P.emit( + AddIntNode( + a=start_iov, + b=IntOrVid.from_vid(P.slot_to_vid(seq_len_slot)), + out=P.slot_to_vid(stop_slot), + ) + ) + return stop_slot + else: + # Static seq_len + if isinstance(start, Slot): + # Dynamic start + static length + _, stop_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(start), + b=IntOrVid.from_literal(seq_len_concrete), + out=P.slot_to_vid(stop_slot), + ) + ) + return stop_slot + else: + # Both static - just return the sum + return start + seq_len_concrete + + +def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> Slot: + """Lift a scalar to a 0-D tensor. + + Concrete scalars (int/float/bool) become deduplicated constants. + Dynamic values (SymInt Slots) emit a FullNode at runtime. + """ + + if isinstance(value, (int, float, bool)): + return P.make_or_get_constant( + f"_scalar_{value}", torch.tensor(value, dtype=dtype) # 0-D + ) + + from executorch.backends.mlx.serialization.mlx_graph_schema import FullNode + + _, slot = P.make_tmp_slot() + P.emit( + FullNode( + shape=[], + v=P.to_float_or_vid(value), + scalar_type=torch_dtype_to_scalar_type(dtype), + out=P.slot_to_tid(slot), + ) + ) + return slot + + +def to_mlx_qparams( + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + bits: int, + compute_biases: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Convert TorchAO quantization params to MLX format. + + TorchAO uses: s * (q - z), with q signed + MLX uses: S * Q + B, with Q unsigned + + s * (q - z) + = s ((q + offset) - (z + offset)) + = s Q + B, + where Q = q + offset, B = -s * (z + offset) + + Args: + compute_biases: If False, skip bias computation (for scale_only mode). + Returns (Q, None) in this case. This is valid when + zero_point is all zeros, as the C++ runtime will compute + biases = -scales * 2^(bits-1). + """ + assert qdata.dtype == torch.int8 + offset = 2 ** (bits - 1) + Q = qdata.to(torch.int32) + offset + + # Pack data tightly into uint32 + assert 32 % bits == 0 + vals_per_uint32 = 32 // bits + assert qdata.shape[1] % vals_per_uint32 == 0 + + Q = Q.reshape(-1, vals_per_uint32) + shifts = torch.arange(0, 32, bits, dtype=torch.int64) + + # Convert to int64 for shift/packing + Q = Q.to(torch.int64) + Q = (Q << shifts).sum(dim=-1) + Q = Q.to(torch.uint32) + Q = Q.reshape(qdata.shape[0], -1) + + if compute_biases: + B = -scale * (zero_point.to(scale.dtype) + offset) + return Q, B + else: + return Q, None + + +def parse_dequant_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: + """Parse a torchao.dequantize_affine node. + + Accepts N-dimensional block_size with a single non-1 element identifying + the quantized dimension and group_size. For example: + - Linear weights (2D): block_size=[1, 32] → quantized_dim=1 + - Conv2d weights (4D): block_size=[1, 32, 1, 1] → quantized_dim=1 + + Returns (qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim) + or None if unsupported. + """ + qdata, block_size, scale, zero_point, dtype, qmin, qmax = node.args[0:7] + out_dtype = ( + node.args[7] if len(node.args) > 7 else node.kwargs.get("output_dtype", None) + ) + if dtype != torch.int8: + return None + if len(block_size) < 2: + return None + non_one = [(i, d) for i, d in enumerate(block_size) if d != 1] + if len(non_one) != 1: + return None + quantized_dim, group_size = non_one[0] + if group_size not in [32, 64, 128]: + return None + if qmin == -8 and qmax == 7: + bits = 4 + elif qmin == -128 and qmax == 127: + bits = 8 + else: + return None + return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim + + +# Mapping from torch dtype to ET ScalarType int value +# See executorch/exir/scalar_type.py for ScalarType enum +_TORCH_DTYPE_TO_SCALAR_TYPE: Dict[torch.dtype, int] = { + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.bfloat16: ScalarType.BFLOAT16, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.uint32: ScalarType.UINT32, + torch.uint8: ScalarType.BYTE, + torch.bool: ScalarType.BOOL, + torch.int8: ScalarType.CHAR, +} + + +def torch_dtype_to_scalar_type(dtype: torch.dtype) -> int: + """Convert torch dtype to ET ScalarType int value.""" + if dtype not in _TORCH_DTYPE_TO_SCALAR_TYPE: + raise ValueError(f"Unsupported dtype: {dtype}") + return int(_TORCH_DTYPE_TO_SCALAR_TYPE[dtype]) diff --git a/backends/mlx/builder/op_registry.py b/backends/mlx/builder/op_registry.py new file mode 100644 index 00000000000..19668ca2c1b --- /dev/null +++ b/backends/mlx/builder/op_registry.py @@ -0,0 +1,151 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union + +from executorch.backends.mlx._logging import logger +from torch.fx.node import Node + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.backends.mlx.builder.slot_manager import Slot + from torch.export import ExportedProgram + +# Handler type: takes (builder, node) and returns optional slot(s) +Handler = Callable[ + ["MLXProgramBuilder", Node], Optional[Union["Slot", Tuple["Slot", ...]]] +] + + +class PatternHandler: + def __init__(self, head: Node, body: List[Node]) -> None: + self.head: Node = head + self.body: List[Node] = body + + @classmethod + def deferred_handler(cls, P: MLXProgramBuilder, n: Node) -> None: + pass + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node) -> Optional[PatternHandler]: + raise NotImplementedError + + def __call__(self, P: MLXProgramBuilder, n: Node) -> None: + raise NotImplementedError + + def set_handlers(self, P: MLXProgramBuilder): + if P.node_info[self.head].handler is not None: + raise AssertionError( + f"Head node {self.head.name} already has handler {P.node_info[self.head].handler}, " + f"cannot set pattern {self.__class__.__name__}" + ) + for n in self.body: + if P.node_info[n].handler is not None: + raise AssertionError( + f"Body node {n.name} already has handler {P.node_info[n].handler}, " + f"cannot set pattern {self.__class__.__name__}" + ) + + logger.debug( + f"Pattern {self.__class__.__name__}: " + f"HEAD={self.head.name}, BODY={[n.name for n in self.body]}" + ) + P.node_info[self.head].handler = self + for n in self.body: + P.node_info[n].handler = PatternHandler.deferred_handler + + +class MLXOpRegistry: + """Registry for op handlers and pattern handlers.""" + + def __init__(self): + self._handlers: Dict[Union[str, Callable], Handler] = {} + self._patterns: Dict[str, Type[PatternHandler]] = {} + + def reset(self) -> None: + """Reset the registry to empty state. Useful for testing.""" + self._handlers.clear() + self._patterns.clear() + + def register(self, target: Union[str, Callable, list, tuple]): + """Decorator to register a handler for one or more op targets.""" + + def deco(fn: Handler): + targets = target if isinstance(target, (list, tuple)) else [target] + for t in targets: + if t in self._handlers: + raise ValueError(f"Target {t} already registered") + self._handlers[t] = fn + return fn + + return deco + + def get_handler(self, node: Node) -> Optional[Handler]: + """Get the handler for a node, or None if not registered.""" + t = node.target + if t in self._handlers: + return self._handlers[t] + # Handle EdgeOpOverload by extracting the underlying ATen op + if hasattr(t, "_op") and t._op in self._handlers: + return self._handlers[t._op] + # Check for string-based targets (e.g., higher_order ops) + target_str = str(t) + if target_str in self._handlers: + return self._handlers[target_str] + return None + + def registered_ops(self) -> set: + """Return all registered op targets.""" + return set(self._handlers.keys()) + + def unregister(self, target: Union[str, Callable, list, tuple]) -> None: + """Remove a handler for one or more op targets. + + This is useful for debugging - allows temporarily disabling specific + handlers to test if they are causing issues. + + Args: + target: Single target or list of targets to unregister + """ + targets = target if isinstance(target, (list, tuple)) else [target] + for t in targets: + if t in self._handlers: + del self._handlers[t] + + def register_pattern(self, name: str): + """Decorator to register a pattern handler class.""" + + def deco(cls: Type[PatternHandler]): + if not issubclass(cls, PatternHandler): + raise TypeError( + "register_pattern must decorate a PatternHandler subclass" + ) + if name in self._patterns: + raise ValueError(f"Pattern '{name}' already registered") + self._patterns[name] = cls + return cls + + return deco + + def get_pattern_cls(self, name: str) -> Optional[Type[PatternHandler]]: + """Get a pattern handler class by name.""" + return self._patterns.get(name) + + def get_noop_handler(self) -> Optional[Handler]: + """Get the NOOP handler, if registered.""" + return self._handlers.get("NOOP") + + def patterns(self): + """Return all registered pattern names.""" + return self._patterns.keys() + + +# Global registry +REGISTRY = MLXOpRegistry() diff --git a/backends/mlx/builder/pattern_matcher.py b/backends/mlx/builder/pattern_matcher.py new file mode 100644 index 00000000000..2db422e3f68 --- /dev/null +++ b/backends/mlx/builder/pattern_matcher.py @@ -0,0 +1,64 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import List, TYPE_CHECKING + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.op_registry import PatternHandler + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.op_registry import MLXOpRegistry + from torch.export import ExportedProgram + + +class PatternMatcher: + """ + Discovers and applies pattern handlers to an FX graph. + + Pattern handlers match multi-node subgraphs and lower them to optimized + MLX operations. This class orchestrates the pattern discovery process: + + 1. Iterates through all registered pattern types + 2. For each pattern, tries to match it against every node in the graph + 3. When a match is found, assigns handlers to the head and body nodes + + The ordering matters: patterns are matched before dead code elimination + because some pattern body nodes (e.g., update_cache) have no users + since they mutate in-place, but they're not dead. + """ + + def __init__(self, ep: ExportedProgram, registry: "MLXOpRegistry"): + self.ep = ep + self.registry = registry + self._matches: List[PatternHandler] = [] + + def find_patterns(self) -> List[PatternHandler]: + """ + Find all pattern matches in the graph. + + Returns a list of PatternHandler instances, one for each match found. + Patterns are tried in registration order. + """ + self._matches = [] + for name in self.registry.patterns(): + self._find_pattern(name) + return self._matches + + def _find_pattern(self, name: str) -> None: + """Try to match a single pattern type against all nodes.""" + pattern_cls = self.registry.get_pattern_cls(name) + if pattern_cls is None: + return + + for n in self.ep.graph.nodes: + handler = pattern_cls.maybe_create(self.ep, n) + if handler is not None: + logger.debug(f"Pattern {name} matched at node {n.name}") + self._matches.append(handler) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py new file mode 100644 index 00000000000..60d5ebbdbfe --- /dev/null +++ b/backends/mlx/builder/program_builder.py @@ -0,0 +1,1018 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Program Builder - converts an ExportedProgram to an MLXGraph. + +This module is responsible for: +1. Walking the FX graph from an ExportedProgram +2. Converting each node to the corresponding MLX op +3. Managing tensor and value slots +4. Building the final MLXGraph dataclass for serialization + +Op handlers are registered in ops.py. +Pattern handlers are registered in patterns.py. +""" + +from __future__ import annotations + +import traceback +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union + +import torch + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import ( + Handler, + PatternHandler, + REGISTRY, +) +from executorch.backends.mlx.builder.pattern_matcher import PatternMatcher +from executorch.backends.mlx.builder.slot_manager import ( + IdSpace, + IdType, + Slot, + SlotManager, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + FloatOrVid, + IdCopyNode, + Instruction, + InstructionChain, + IntOrVid, + IntOrVidOrTid, + MLXGraph, + NamedSlot, + OpNodeUnion, + ShapeDim, + SlotType, + SlotVariant, + TensorMeta, + Tid, + Vid, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node +from torch.utils import _pytree as pytree + + +def _check_dtype(node: Node) -> Optional[str]: + """ + Check if a node has a supported dtype. + + Args: + node: The FX node to check + + Returns: + None if the node's dtype is supported, otherwise an error message string + """ + fake_val = node.meta.get("val", None) + if fake_val is not None and hasattr(fake_val, "dtype"): + try: + torch_dtype_to_scalar_type(fake_val.dtype) + except ValueError: + return f"has unsupported dtype: {fake_val.dtype}" + return None + + +def _check_input_dtypes(node: Node) -> Optional[str]: + """ + Check if all input tensors to a node have supported dtypes. + + Args: + node: The FX node to check + + Returns: + None if all input dtypes are supported, otherwise an error message string + describing which input (arg position or kwarg name) has an unsupported dtype + """ + # Check positional args + for i, arg in enumerate(node.args): + if isinstance(arg, Node): + dtype_error = _check_dtype(arg) + if dtype_error is not None: + return f"arg[{i}] ({arg.name}) {dtype_error}" + + # Check kwargs + for kwarg_name, kwarg_val in node.kwargs.items(): + if isinstance(kwarg_val, Node): + dtype_error = _check_dtype(kwarg_val) + if dtype_error is not None: + return f"kwarg '{kwarg_name}' ({kwarg_val.name}) {dtype_error}" + + return None + + +@dataclass +class NodeInfo: + handled: bool = False + handler: Optional[Union[Handler, PatternHandler]] = None + supported: bool = False + unsupported_reason: Optional[str] = None + name: Optional[str] = None + remaining_reads: int = 0 + + +class MLXProgramBuilder: + """ + Builds an MLXGraph from an ExportedProgram. + + Args: + ep: The ExportedProgram to build from + """ + + def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): + self.ep: ExportedProgram = ep + self._instrs: List[Instruction] = [] + self.extra_constants: Dict[str, torch.Tensor] = {} + self.slot_manager = SlotManager() + self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) + self._mlx_graph: Optional[MLXGraph] = None + # Map from SymInt symbol names (e.g., "s77") to the FX Node that produces them. + # This is used to resolve symbolic tensor dimensions to Vid references. + self._symint_to_node: Dict[str, Node] = {} + # Maps for remapping local slot indices to global Tid/Vid indices during build + self._tid_slot_map: List[Tuple[Tid, Slot]] = [] + self._vid_slot_map: List[Tuple[Vid, Slot]] = [] + # Prefix for named_data_store keys and named_slots to avoid collisions + # in multi-method programs where different methods may have lifted tensor + # constants with the same auto-generated name. + self._named_data_key_prefix: str = named_data_key_prefix + # Unprefixed canonical-name → Slot for constants, populated by _build_io_maps(). + # Used by get_named_data_store() to look up tensors without prefix interference. + self._constant_name_to_slot: Dict[str, Slot] = {} + + def _prefix_key(self, name: str) -> str: + """Apply the named-data key prefix for the .pte namespace. + + This is the single point where canonical (unprefixed) names are + transformed into the external keys used in the .pte's ``named_data`` + section and the FlatBuffer ``named_slots`` field. + """ + if self._named_data_key_prefix: + return f"{self._named_data_key_prefix}/{name}" + return name + + def emit(self, op: OpNodeUnion) -> None: + self._instrs.append(Instruction(op=op)) + + def args(self, node: Node) -> Tuple[Any, ...]: + return self.slot_map(node.args) + + def kwargs(self, node: Node) -> Dict[str, Any]: + return self.slot_map(node.kwargs) + + def slot_map(self, tree): + leaves, spec = pytree.tree_flatten(tree) + new_leaves = [] + for a in leaves: + if isinstance(a, Node): + # Use make_or_get_slots which handles both single and multi-output nodes. + # For single-output nodes, returns a 1-tuple; for multi-output, returns n-tuple. + # We unwrap single-element tuples for convenience. + slots = self.make_or_get_slots(a) + if len(slots) == 1: + new_leaves.append(slots[0]) + else: + new_leaves.append(slots) + else: + new_leaves.append(a) + + for a in new_leaves: + if isinstance(a, Slot): + assert self.slot_manager.is_alive( + a + ), f"Slot {a} is not alive; it was either already freed or never created" + + return pytree.tree_unflatten(new_leaves, spec) + + def make_or_get_slots( + self, node: Node, id_space: IdSpace = IdSpace.Temp + ) -> Tuple[Slot, ...]: + """Get or create slots for a multi-output node. Always returns a tuple.""" + return self.slot_manager.make_or_get_slots(node, id_space) + + def make_or_get_slot(self, node: Node, id_space: IdSpace = IdSpace.Temp) -> Slot: + """Get or create a slot for a single-output node. Returns a single Slot.""" + return self.slot_manager.make_or_get_slot(node, id_space) + + def set_slot(self, node: Node, slot: Slot): + self.slot_manager.set_slot(node, slot) + + def make_tmp_slot(self) -> Tuple[str, Slot]: + """Create a temporary tensor slot.""" + return self.slot_manager.make_tmp_slot() + + def make_tmp_value_slot(self) -> Tuple[str, Slot]: + """Create a temporary value (SymInt) slot.""" + return self.slot_manager.make_tmp_value_slot() + + def make_or_get_constant(self, name: str, tensor: torch.Tensor) -> Slot: + """ + Creates an extra constant outside of the ExportedProgram state_dict. + Ops can use this to add constants during build that do not exist in the + ExportedProgram state_dict, e.g., doing naive packing of quantized ops. + """ + assert name not in self.ep.state_dict + assert name not in self.ep.constants + + if name in self.extra_constants: + # During fake tensor tracing, we can't use torch.equal + # Just assume tensors with same name are the same + slot = self.slot_manager.get_slot(name) + assert slot is not None + return slot + + slot = self.slot_manager.make_constant_slot(name) + self.extra_constants[name] = tensor + return slot + + def get_placeholder_target_and_tensor(self, node: Node) -> Tuple[str, torch.Tensor]: + assert node.op == "placeholder" + placeholder_name = node.name + + sig = self.ep.graph_signature + sd = self.ep.state_dict + consts = self.ep.constants + + for ispec in sig.input_specs: + if ispec.arg.name != placeholder_name: + continue + target = ispec.target + if target is None: + continue + if target in sd: + return (target, sd[target]) + if target in consts: + return (target, consts[target]) + + raise KeyError(f"Unable to resolve placeholder {placeholder_name}") + + def slot_to_tid(self, slot: Slot) -> Tid: + """Convert a tensor Slot to a Tid, recording it for later remapping.""" + assert slot.id_type == IdType.Tensor + # Use local slot.idx as placeholder - will be remapped to global idx in build() + tid = Tid(idx=slot.idx) + self._tid_slot_map.append((tid, slot)) + return tid + + def slot_to_vid(self, slot: Slot) -> Vid: + """Convert a value Slot to a Vid, recording it for later remapping.""" + assert slot.id_type != IdType.Tensor + vid = Vid(idx=slot.idx) + self._vid_slot_map.append((vid, slot)) + return vid + + def to_int_or_vid(self, v: Union[int, Slot]) -> IntOrVid: + if isinstance(v, Slot): + return IntOrVid.from_vid(self.slot_to_vid(v)) + return IntOrVid.from_literal(int(v)) + + def to_float_or_vid(self, v: Union[float, int, Slot]) -> FloatOrVid: + if isinstance(v, Slot): + return FloatOrVid.from_vid(self.slot_to_vid(v)) + return FloatOrVid.from_literal(float(v)) + + def to_int_or_vid_or_tid(self, v: Union[int, Slot]) -> IntOrVidOrTid: + if isinstance(v, Slot): + if v.id_type == IdType.Tensor: + return IntOrVidOrTid.from_tid(self.slot_to_tid(v)) + return IntOrVidOrTid.from_vid(self.slot_to_vid(v)) + return IntOrVidOrTid.from_literal(int(v)) + + def _mark_read(self, node: Node): + assert self.node_info[node].handled, f"Node {node} is not handled" + assert ( + self.node_info[node].remaining_reads > 0 + ), f"Reading node {node}, but it has no remaining reads" + self.node_info[node].remaining_reads -= 1 + + if self.node_info[node].remaining_reads == 0: + slot = self.slot_manager.get_slot(node) + if slot is None: + return + if not isinstance(slot, tuple): + slot = (slot,) + for s in slot: + if s.id_space != IdSpace.Temp: + continue + if s.id_type == IdType.Tensor: + self.slot_manager.tid_managers[IdSpace.Temp].return_id(s.idx) + else: + self.slot_manager.vid_managers[IdSpace.Temp].return_id(s.idx) + + def _mark_node_handled(self, node: Node, *, handler: Optional[Handler] = None): + if self.node_info[node].handled: + return + self.node_info[node].handled = True + self.node_info[node].remaining_reads = len(node.users) + self.node_info[node].handler = handler + + if handler == PatternHandler.deferred_handler: + return + + def mark_read(n: Node): + flat_args, spec = pytree.tree_flatten((n.args, n.kwargs)) + seen = set() + for a in flat_args: + if isinstance(a, Node): + if a not in seen: + self._mark_read(a) + seen.add(a) + + if isinstance(handler, PatternHandler): + for n in handler.body: + mark_read(n) + mark_read(node) + + def _mark_node_supported(self, node: Node, *, handler: Optional[Handler] = None): + self.node_info[node].supported = True + self._mark_node_handled(node, handler=handler) + + def _mark_node_unsupported(self, node: Node, reason: str): + self.node_info[node].supported = False + self.node_info[node].unsupported_reason = reason + self._mark_node_handled(node) + + def _is_handled(self, node: Node) -> bool: + return self.node_info[node].handled + + def _mark_supported( + self, nodes: Union[List[Node], Node], *, handler: Optional[Handler] = None + ) -> None: + if isinstance(nodes, Node): + nodes = [nodes] + for node in nodes: + self._mark_node_supported(node, handler=handler) + + def _mark_unsupported(self, nodes: Union[List[Node], Node], reason: str) -> None: + if isinstance(nodes, Node): + nodes = [nodes] + for node in nodes: + self._mark_node_unsupported(node, reason) + + def _make_io_slots(self): # noqa: C901 + from torch.export.graph_signature import ( + InputKind, + OutputKind, + SymIntArgument, + TensorArgument, + ) + + output_kind_targets = defaultdict(set) + constant_tensors = [] + user_inputs = [] + user_outputs = [] + mutable_buffers = [] + + for ospec in self.ep.graph_signature.output_specs: + kind = ospec.kind + arg = ospec.arg + name = arg.name + target = ospec.target + if target is not None: + output_kind_targets[kind].add(target) + if kind in (OutputKind.USER_OUTPUT, OutputKind.USER_INPUT_MUTATION): + user_outputs.append(name) + + for ispec in self.ep.graph_signature.input_specs: + kind = ispec.kind + arg = ispec.arg + name = arg.name + target = ispec.target + + if isinstance(arg, TensorArgument): + if kind == InputKind.PARAMETER: + # Parameters are treated as constants (not mutated) + constant_tensors.append(name) + elif kind == InputKind.BUFFER: + if target in output_kind_targets[OutputKind.BUFFER_MUTATION]: + mutable_buffers.append(name) + else: + # Non-mutated buffers (like lifted tensor constants) are constants + constant_tensors.append(name) + elif kind == InputKind.USER_INPUT: + user_inputs.append(name) + elif kind == InputKind.CONSTANT_TENSOR: + constant_tensors.append(name) + else: + raise NotImplementedError( + f"Support for input {arg} is not implemented" + ) + elif isinstance(arg, SymIntArgument): + if kind == InputKind.USER_INPUT: + user_inputs.append(name) + else: + raise NotImplementedError( + f"Support for input {arg} is not implemented" + ) + else: + raise NotImplementedError(f"Support for input {arg} is not implemented") + + for node in self.ep.graph.nodes: + if node.op == "placeholder": + if node.users == {}: + continue + if node.name in constant_tensors: + self.make_or_get_slot(node, id_space=IdSpace.Constant) + elif node.name in user_inputs: + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and not val.is_contiguous(): + raise ValueError( + f"MLX backend requires contiguous input tensors, " + f"but input '{node.name}' has non-contiguous strides. " + f"shape={list(val.shape)}, stride={list(val.stride())}. " + f"Ensure example inputs passed to torch.export.export() " + f"are contiguous (call .contiguous() on them)." + ) + self.make_or_get_slot(node, id_space=IdSpace.Input) + elif node.name in mutable_buffers: + self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) + else: + raise NotImplementedError( + f"Support for placeholder {node.name} is not implemented" + ) + elif node.op == "output": + outs, _ = pytree.tree_flatten(node.args) + for o in outs: + if isinstance(o, Node) and o.name in user_outputs: + self.make_or_get_slot(o, id_space=IdSpace.Output) + + def _mark_noop(self): + """Mark noops and dead nodes.""" + dead = set() + noop_handler = REGISTRY.get_noop_handler() + if noop_handler is None: + return + + for n in reversed(self.ep.graph.nodes): + handler = REGISTRY.get_handler(n) + if handler == noop_handler: + dead.add(n) + + if n.op != "output" and all(user in dead for user in n.users): + self.node_info[n].handler = noop_handler + dead.add(n) + + def _apply_patterns(self) -> None: + """ + Find and apply pattern handlers to the graph. + + Uses PatternMatcher to discover multi-node patterns and assigns + handlers to matched nodes. This must run BEFORE _mark_noop so + pattern body nodes don't get incorrectly marked as dead. + """ + matcher = PatternMatcher(self.ep, REGISTRY) + for handler in matcher.find_patterns(): + handler.set_handlers(self) + + def _process_nodes(self) -> None: # noqa C901 + """ + Common logic for processing all nodes: create slots, match patterns, run handlers. + + This method: + 1. Creates I/O slots for placeholders and outputs + 2. Matches patterns FIRST (so body nodes get handlers and aren't marked dead) + 3. Marks dead/noop nodes + 4. Runs handlers for remaining nodes, marking them supported/unsupported + + The ordering is important: patterns must be matched before noops because + some pattern body nodes (e.g., update_cache) have no users since they + mutate in-place, but they're not dead - they're handled by the pattern. + """ + self._make_io_slots() + + # Apply patterns BEFORE _mark_noop so pattern body nodes don't get + # incorrectly marked as dead (e.g., update_cache nodes have no users + # since they mutate in-place, but they're not dead) + self._apply_patterns() + self._mark_noop() + + for n in self.ep.graph.nodes: + if self._is_handled(n): + continue + + if self.node_info[n].handler is not None: + handler = self.node_info[n].handler + handler(self, n) + self._mark_supported(n, handler=handler) + continue + + # Check input dtypes before processing node + unsupported_dtype_msg = _check_input_dtypes(n) + if unsupported_dtype_msg is not None: + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, unsupported_dtype_msg) + continue + + if n.op in ("placeholder", "output"): + dtype_error = _check_dtype(n) + if dtype_error is not None: + self._mark_unsupported(n, f"{n.op} {dtype_error}") + continue + self._mark_supported(n) + continue + + handler = REGISTRY.get_handler(n) + if handler is None: + msg = f"no handler for target={n.target}" + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, msg) + continue + + try: + handler(self, n) + self._mark_supported(n, handler=handler) + except Exception as e: + trace_str = traceback.format_exc() + msg = f"{handler} failed for {n.target}: {e}.\n{trace_str}" + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, msg) + + def check_support_only(self) -> None: + """ + Check which nodes are supported without building the full MLXGraph. + + This method populates node_info with supported/unsupported status for each + node, but avoids calling _build_mlx_graph() which can corrupt the shape_env + by evaluating symbolic shapes. + + Use this method for ops_to_not_decompose() and similar queries where you + only need to know support status, not the full compiled graph. + """ + self._process_nodes() + # NOTE: We intentionally skip _verify_build() and _build_mlx_graph() here + # because _build_mlx_graph() calls int() on tensor shapes which evaluates + # SymInts and corrupts the shape_env. This method is used for + # ops_to_not_decompose() where we only need support status. + + def _emit_buffer_mutation_writebacks(self): + """Emit copy-back instructions for BUFFER_MUTATION outputs. + + When a model mutates a buffer (e.g., via .copy_() or .mul_()), + torch.export functionalizes it: the new value is a computation result, + and the output spec marks it as BUFFER_MUTATION with a target buffer. + + This method emits an IdCopyNode for each BUFFER_MUTATION output, + copying the computation result back to the mutable buffer slot so + the updated value persists across execution calls. + """ + from torch.export.graph_signature import InputKind, OutputKind + + # Map buffer target name -> input placeholder name + target_to_placeholder = {} + for ispec in self.ep.graph_signature.input_specs: + if ispec.kind == InputKind.BUFFER and ispec.target is not None: + target_to_placeholder[ispec.target] = ispec.arg.name + + for ospec in self.ep.graph_signature.output_specs: + if ospec.kind != OutputKind.BUFFER_MUTATION: + continue + + result_slot = self.slot_manager.get_slot(ospec.arg.name) + placeholder_name = target_to_placeholder.get(ospec.target) + if result_slot is None or placeholder_name is None: + continue + + buffer_slot = self.slot_manager.get_slot(placeholder_name) + if buffer_slot is None or buffer_slot.id_space != IdSpace.MutableBuffer: + continue + + self.emit( + IdCopyNode( + x=self.slot_to_tid(result_slot), + out=self.slot_to_tid(buffer_slot), + ) + ) + + def build(self) -> MLXGraph: + if self._mlx_graph is not None: + return self._mlx_graph + + self._process_nodes() + self._emit_buffer_mutation_writebacks() + self._verify_build() + self._mlx_graph = self._build_mlx_graph() + return self._mlx_graph + + def _verify_build(self): + noop_handler = REGISTRY.get_noop_handler() + + for n, info in self.node_info.items(): + assert info.handled + assert ( + info.remaining_reads == 0 + ), f"Expected {n} to have no remaining reads, but it has {info.remaining_reads}" + if n.op == "output": + assert self.slot_manager.get_slot(n) is None + continue + if ( + info.handler in (noop_handler, PatternHandler.deferred_handler) + or n.users == {} + ): + assert ( + self.slot_manager.get_slot(n) is None + ), f"Did not expect node {n} handled by {info.handler} to have a slot" + else: + assert ( + self.slot_manager.get_slot(n) is not None + ), f"Expected slot for node {n}" + + def _collect_used_slots( + self, + ) -> Tuple[Set[Slot], Dict[IdSpace, int], Dict[IdSpace, int]]: + """ + Collect all used slots and count tensors/values per IdSpace. + + For constants and temps, only includes those actually referenced by + instructions. This ensures unused slots are not serialized or counted. + + Returns: + (used_slots, num_tensors, num_values) + """ + # Get slots actually referenced by instructions + instruction_referenced: Set[Slot] = {slot for _, slot in self._tid_slot_map} + instruction_referenced.update({slot for _, slot in self._vid_slot_map}) + + used_slots: Set[Slot] = set() + for _n, slot in self.slot_manager.name_to_slot.items(): + if not isinstance(slot, tuple): + slot = (slot,) + for s in slot: + # For constants and temps, only include if referenced by instructions + if s.id_space in (IdSpace.Constant, IdSpace.Temp): + if s in instruction_referenced: + used_slots.add(s) + else: + # Inputs, outputs, mutable buffers - always include + used_slots.add(s) + + num_tensors: Dict[IdSpace, int] = defaultdict(int) + num_values: Dict[IdSpace, int] = defaultdict(int) + seen: Set[Slot] = set() + for s in used_slots: + if s in seen: + continue + seen.add(s) + if s.id_type == IdType.Tensor: + num_tensors[s.id_space] += 1 + else: + num_values[s.id_space] += 1 + + return used_slots, num_tensors, num_values + + def _create_slot_mappings( + self, used_slots: Set[Slot] + ) -> Tuple[Dict[Slot, int], Dict[Slot, int]]: + """ + Create slot→Tid and slot→Vid mappings, and remap existing references. + + Returns: + (slot_to_tid, slot_to_vid) + """ + id_space_order = { + IdSpace.Constant: 0, + IdSpace.Input: 1, + IdSpace.Output: 2, + IdSpace.MutableBuffer: 3, + IdSpace.Temp: 4, + } + + # Create Tid mapping + slot_to_tid = sorted( + [s for s in used_slots if s.id_type == IdType.Tensor], + key=lambda s: (id_space_order[s.id_space], s.idx), + ) + slot_to_tid = {s: idx for idx, s in enumerate(slot_to_tid)} + + # Create Vid mapping + slot_to_vid = sorted( + [s for s in used_slots if s.id_type != IdType.Tensor], + key=lambda s: (id_space_order[s.id_space], s.idx), + ) + slot_to_vid = {s: idx for idx, s in enumerate(slot_to_vid)} + + # Remap all Tid/Vid values in instructions to use global indices + if hasattr(self, "_tid_slot_map"): + for tid, slot in self._tid_slot_map: + if slot in slot_to_tid: + tid.idx = slot_to_tid[slot] + else: + logger.warning(f"Slot {slot} not found in slot_to_tid mapping") + + if hasattr(self, "_vid_slot_map"): + for vid, slot in self._vid_slot_map: + if slot in slot_to_vid: + vid.idx = slot_to_vid[slot] + else: + logger.warning(f"Slot {slot} not found in slot_to_vid mapping") + + return slot_to_tid, slot_to_vid + + def _to_slot_variant( + self, + slot: Slot, + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + ) -> SlotVariant: + """Convert a Slot to a SlotVariant using the provided mappings.""" + if slot.id_type == IdType.Tensor: + idx = slot_to_tid[slot] + slot_type = SlotType.TensorSlot + elif slot.id_type == IdType.SymInt: + idx = slot_to_vid[slot] + slot_type = SlotType.IntValueSlot + elif slot.id_type == IdType.SymBool: + idx = slot_to_vid[slot] + slot_type = SlotType.BoolValueSlot + else: + raise NotImplementedError(f"Unsupported slot type {slot.id_type}") + return SlotVariant(idx=idx, slot_type=slot_type) + + def _build_io_maps( + self, + used_slots: Set[Slot], + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + ) -> Tuple[ + List[SlotVariant], List[SlotVariant], List[SlotVariant], List[NamedSlot] + ]: + """ + Build input/output/mutable_buffer maps and named slots. + + Returns: + (input_map, output_map, mutable_buffer_map, named_slots) + """ + input_map: List[SlotVariant] = [] + output_map: List[SlotVariant] = [] + mutable_buffer_map: List[SlotVariant] = [] + # Canonical (unprefixed) name → Slot. The prefix is applied only at + # the exit boundaries: NamedSlot construction and NamedDataStore keys. + name_to_slot: Dict[str, Slot] = {} + + for ispec in self.ep.graph_signature.input_specs: + slot = self.slot_manager.get_slot(ispec.arg.name) + if slot is None: + continue + assert isinstance(slot, Slot) + name = ispec.target if ispec.target is not None else ispec.arg.name + if slot.id_space == IdSpace.Input: + input_map.append(self._to_slot_variant(slot, slot_to_tid, slot_to_vid)) + name_to_slot[name] = slot + elif slot.id_space == IdSpace.MutableBuffer: + mutable_buffer_map.append( + self._to_slot_variant(slot, slot_to_tid, slot_to_vid) + ) + name_to_slot[name] = slot + else: + if slot in used_slots: + name_to_slot[name] = slot + + for ospec in self.ep.graph_signature.output_specs: + name = ospec.arg.name + slot = self.slot_manager.get_slot(name) + if slot is None: + continue + assert isinstance(slot, Slot) + if slot.id_space == IdSpace.Output: + output_map.append(self._to_slot_variant(slot, slot_to_tid, slot_to_vid)) + name = ospec.target if ospec.target is not None else ospec.arg.name + name_to_slot[name] = slot + elif slot.id_space == IdSpace.MutableBuffer: + name = ospec.target if ospec.target is not None else ospec.arg.name + name_to_slot[name] = slot + + for name in self.extra_constants: + slot = self.slot_manager.get_slot(name) + assert slot is not None and isinstance(slot, Slot) + if slot in used_slots: + name_to_slot[name] = slot + + # Store unprefixed constant mapping for get_named_data_store() + self._constant_name_to_slot = { + n: s for n, s in name_to_slot.items() if s.id_space == IdSpace.Constant + } + + # Apply prefix at the exit boundary — the FlatBuffer named_slots + named_slots = [ + NamedSlot( + name=self._prefix_key(n), + slot=self._to_slot_variant(s, slot_to_tid, slot_to_vid), + ) + for n, s in name_to_slot.items() + ] + + return input_map, output_map, mutable_buffer_map, named_slots + + def _build_tensor_meta( # noqa: C901 + self, + used_slots: Set[Slot], + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + num_tensors: Dict[IdSpace, int], + ) -> List[TensorMeta]: + """ + Build tensor metadata list with shape/dtype information. + + Static dimensions are stored as ShapeDim(value=N). + Dynamic dimensions (SymInt) are stored as ShapeDim(value=-1) + with min/max bounds from the shape_env. + + Note: tensor_meta shapes are only consumed by the runtime for + constant and mutable buffer allocation (which are always static). + Dynamic dim metadata is informational — the runtime resolves + dynamic shapes via SymSizeNode at execution time. + """ + + def _get_dim_bounds(dim: torch.SymInt) -> tuple: + """Get (min, max) bounds for a symbolic dimension.""" + try: + node = dim.node + shape_env = node.shape_env + if shape_env is not None: + expr = node.expr + lower = int(shape_env.bound_sympy(expr).lower) + upper = int(shape_env.bound_sympy(expr).upper) + if upper > 2**30: + return (lower, -1) # treat as unbounded + return (lower, upper) + except Exception: + pass + return (0, -1) # unbounded fallback + + def to_tensor_meta(t: torch.Tensor) -> TensorMeta: + shape: List[ShapeDim] = [] + for dim in t.shape: + if isinstance(dim, torch.SymInt): + lo, hi = _get_dim_bounds(dim) + shape.append(ShapeDim(value=-1, min_value=lo, max_value=hi)) + else: + shape.append(ShapeDim(value=int(dim))) + + dim_order = list(range(len(t.shape))) if len(t.shape) > 0 else None + + return TensorMeta( + shape=shape, + scalar_type=torch_dtype_to_scalar_type(t.dtype), + dim_order=dim_order, + ) + + tensor_meta: Dict[int, TensorMeta] = {} + for n in self.node_info: + slot = self.slot_manager.get_slot(n) + if not isinstance(slot, tuple): + slot = (slot,) + fake_val = n.meta.get("val", None) + if not isinstance(fake_val, tuple): + fake_val = (fake_val,) + for s, fv in zip(slot, fake_val): + if s not in used_slots: + continue + if s.id_type != IdType.Tensor: + continue + if s.id_space == IdSpace.Temp: + continue + idx = slot_to_tid[s] + if fv is not None and hasattr(fv, "shape"): + tensor_meta[idx] = to_tensor_meta(fv) + + for name, t in self.extra_constants.items(): + slot = self.slot_manager.get_slot(name) + assert slot is not None and isinstance(slot, Slot) + if slot in used_slots: + idx = slot_to_tid[slot] + tensor_meta[idx] = to_tensor_meta(t) + + num_non_temp_tensors = sum(num_tensors.values()) - num_tensors[IdSpace.Temp] + return [tensor_meta.get(i) for i in range(num_non_temp_tensors)] + + def _build_mlx_graph(self) -> MLXGraph: + # Check support + for node, info in self.node_info.items(): + if not info.supported: + raise ValueError( + f"Found unsupported node: {node}\nReason: {info.unsupported_reason}" + ) + + # Collect slots and create mappings + used_slots, num_tensors, num_values = self._collect_used_slots() + slot_to_tid, slot_to_vid = self._create_slot_mappings(used_slots) + + # Store for use in get_constant_data() - needed to serialize in Tid order + self._slot_to_final_tid = slot_to_tid + + # Build I/O maps and metadata + input_map, output_map, mutable_buffer_map, named_slots = self._build_io_maps( + used_slots, slot_to_tid, slot_to_vid + ) + tensor_meta_list = self._build_tensor_meta( + used_slots, slot_to_tid, slot_to_vid, num_tensors + ) + + # Compute final counts + num_constant_tensors = num_tensors[IdSpace.Constant] + num_temp_tensors = num_tensors[IdSpace.Temp] + num_values_count = sum(num_values.values()) + + return MLXGraph( + version="1", + num_constant_tensors=num_constant_tensors, + num_input_tensors=num_tensors[IdSpace.Input], + num_output_tensors=num_tensors[IdSpace.Output], + num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], + num_temp_tensors=num_temp_tensors, + num_values=num_values_count, + instruction_chains=[InstructionChain(instructions=self._instrs)], + main_chain_idx=0, + init_chain_idx=-1, + input_map=input_map, + output_map=output_map, + mutable_buffer_map=mutable_buffer_map, + named_slots=named_slots, + tensor_meta=tensor_meta_list, + ) + + def get_named_data_store(self) -> NamedDataStore: + """ + Get a NamedDataStore containing all constant tensors. + + Uses the unprefixed canonical-name → Slot mapping built by + ``_build_io_maps()`` so that tensor lookups hit ``ep.state_dict`` / + ``ep.constants`` / ``extra_constants`` (which all use unprefixed + keys). The prefix is applied at the exit boundary — the + ``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``. + """ + named_data_store = NamedDataStore() + + # Sort by final TID for deterministic ordering + entries = sorted( + self._constant_name_to_slot.items(), + key=lambda x: self._slot_to_final_tid.get(x[1], 0), + ) + + logger.debug(f"Adding {len(entries)} constants to NamedDataStore...") + for canonical_name, _slot in entries: + tensor = self._find_constant_tensor(canonical_name) + if tensor is None: + continue + + t = tensor.detach().cpu().contiguous() + named_data_store.add_named_data( + key=self._prefix_key(canonical_name), + data=t, + alignment=16, + ) + logger.debug("Done adding constants to NamedDataStore") + + return named_data_store + + def get_mutable_buffer_names(self) -> List[str]: + """ + Get the names of all mutable buffers in Tid order. + + Returns: + List of mutable buffer names in the order they appear in mutable_buffer_map. + """ + assert self._mlx_graph is not None, "Must call build() first" + + names = [] + for name, slot in self.slot_manager.name_to_slot.items(): + if isinstance(slot, tuple): + continue + if slot.id_space != IdSpace.MutableBuffer: + continue + if slot in self._slot_to_final_tid: + names.append((name, self._slot_to_final_tid[slot])) + + # Sort by Tid and return just the names + names.sort(key=lambda x: x[1]) + return [n for n, _ in names] + + def _find_constant_tensor(self, name: str) -> Optional[torch.Tensor]: + """Find a constant tensor by name from various sources.""" + if name in self.ep.state_dict: + return self.ep.state_dict[name] + if name in self.ep.constants: + return self.ep.constants[name] + if name in self.extra_constants: + return self.extra_constants[name] + # Look up by target + for ispec in self.ep.graph_signature.input_specs: + if ispec.arg.name == name and ispec.target is not None: + if ispec.target in self.ep.state_dict: + return self.ep.state_dict[ispec.target] + if ispec.target in self.ep.constants: + return self.ep.constants[ispec.target] + return None diff --git a/backends/mlx/builder/slot_manager.py b/backends/mlx/builder/slot_manager.py new file mode 100644 index 00000000000..b1884a76a68 --- /dev/null +++ b/backends/mlx/builder/slot_manager.py @@ -0,0 +1,187 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import uuid +from collections import defaultdict +from dataclasses import dataclass +from enum import auto, Enum +from typing import Dict, Optional, Tuple, Union + +import torch +from torch.fx.node import Node + + +class IdType(Enum): + Tensor = auto() + SymInt = auto() + SymBool = auto() + + +class IdSpace(Enum): + Constant = auto() + Input = auto() + Output = auto() + MutableBuffer = auto() + Temp = auto() + + +@dataclass(frozen=True) +class Slot: + id_type: IdType + id_space: IdSpace + idx: Optional[int] = None + + +class IdManager: + def __init__(self): + self.free: set[int] = set() + self.next_new_id = 0 + + def get_id(self): + return self.free.pop() if self.free else self._bump() + + def _bump(self): + idx = self.next_new_id + self.next_new_id += 1 + return idx + + def return_id(self, idx): + self.free.add(idx) + + +class SlotManager: + def __init__(self): + self.tid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager) + self.vid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager) + self.name_to_slot: Dict[str, Slot] = {} + + def set_slot(self, node_or_name: Union[Node, str], slot: Slot): + if isinstance(node_or_name, Node): + node_or_name = node_or_name.name + # Allow setting a slot to the same value (e.g., for in-place ops like SLICE_UPDATE) + existing = self.name_to_slot.get(node_or_name) + if existing is not None: + # If already set to the same slot, it's fine + if existing == slot: + return + raise AssertionError( + f"Slot for {node_or_name} already set to {existing}, trying to set to {slot}" + ) + self.name_to_slot[node_or_name] = slot + + def get_slot( + self, node_or_name: Union[Node, str] + ) -> Optional[Union[Tuple[Slot], Slot]]: + if isinstance(node_or_name, Node): + node_or_name = node_or_name.name + return self.name_to_slot.get(node_or_name, None) + + def _val_to_idtype(self, v) -> IdType: + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(v, FakeTensor): + return IdType.Tensor + elif isinstance(v, torch.SymInt): + return IdType.SymInt + elif isinstance(v, torch.SymBool): + return IdType.SymBool + else: + raise NotImplementedError(f"val_to_idtype: {v}") + + def is_alive(self, slot: Slot) -> bool: + if slot.id_type == IdType.Tensor: + manager = self.tid_managers[slot.id_space] + else: + manager = self.vid_managers[slot.id_space] + idx = slot.idx + if idx >= manager.next_new_id: + return False + if idx in manager.free: + return False + return True + + def make_constant_slot(self, name: str) -> Slot: + assert name not in self.name_to_slot + id_space = IdSpace.Constant + manager = self.tid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return slot + + def make_tmp_slot(self) -> Tuple[str, Slot]: + name = f"tmp_{uuid.uuid4().hex}" + id_space = IdSpace.Temp + manager = self.tid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return name, slot + + def make_tmp_value_slot(self) -> Tuple[str, Slot]: + """Create a temporary SymInt slot and register it.""" + name = f"tmp_val_{uuid.uuid4().hex}" + id_space = IdSpace.Temp + manager = self.vid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.SymInt, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return name, slot + + def make_or_get_slots( + self, node: Node, id_space: IdSpace = IdSpace.Temp + ) -> Tuple[Slot, ...]: + """ + Get or create slots for a node. Always returns a tuple of slots. + + Use this for multi-output ops (e.g., topk returns (values, indices)). + For single-output ops, prefer make_or_get_slot() which returns a single Slot. + """ + if node.name in self.name_to_slot: + slot = self.name_to_slot[node.name] + # Normalize to tuple for consistent return type + if not isinstance(slot, tuple): + return (slot,) + return slot + + val = node.meta.get("val", None) + assert val is not None, f"Node {node} has no val" + if not isinstance(val, (list, tuple)): + val = (val,) + + slots = [] + for v in val: + id_type = self._val_to_idtype(v) + if id_type == IdType.Tensor: + manager = self.tid_managers[id_space] + else: + manager = self.vid_managers[id_space] + idx = manager.get_id() + slots.append(Slot(id_type=id_type, id_space=id_space, idx=idx)) + slots = tuple(slots) + + # Store in the format that matches the node's output structure + if len(slots) == 1: + self.set_slot(node, slots[0]) + else: + self.set_slot(node, slots) + return slots + + def make_or_get_slot(self, node: Node, id_space: IdSpace = IdSpace.Temp) -> Slot: + """ + Get or create a slot for a single-output node. Returns a single Slot. + + Use this for single-output ops (the common case). + For multi-output ops, use make_or_get_slots() instead. + """ + slots = self.make_or_get_slots(node, id_space) + assert len(slots) == 1, ( + f"Expected single output for node {node.name}, got {len(slots)}. " + f"Use make_or_get_slots() for multi-output ops." + ) + return slots[0] diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py new file mode 100644 index 00000000000..81853adbd6d --- /dev/null +++ b/backends/mlx/custom_ops.py @@ -0,0 +1,15 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Custom MLX operator definitions. + +This module defines custom operators that are supported by the MLX backend. +These ops are used during model export to represent operations that MLX +can execute efficiently but may not have direct PyTorch equivalents. +""" diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py new file mode 100644 index 00000000000..6e8516e86b1 --- /dev/null +++ b/backends/mlx/ops.py @@ -0,0 +1,294 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Op Handlers - registered handlers for converting ATen/custom ops to MLX. + +This module contains all the op handler functions registered with the MLXOpRegistry. +Each handler converts a specific PyTorch operation to the corresponding MLX graph node. +""" + +from __future__ import annotations + +import operator +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode +from torch.fx.node import Node + + +def require_static_int(value: Any, param_name: str, op_name: str) -> None: + """ + Validate that a parameter is a static integer (not a Slot/SymInt). + + Raises NotImplementedError if the value is dynamic. + + Args: + value: The parameter value to check + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if isinstance(value, Slot) or not isinstance(value, int): + raise NotImplementedError( + f"{op_name} with dynamic {param_name} is not supported. " + f"{param_name} requires a static int32 value, but got {value} (type={type(value).__name__})." + ) + + +def require_static_float(value: Any, param_name: str, op_name: str) -> None: + """ + Validate that a parameter is a static float (not a Slot/SymFloat). + + Raises NotImplementedError if the value is dynamic. + + Args: + value: The parameter value to check + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if isinstance(value, Slot) or not isinstance(value, (int, float)): + raise NotImplementedError( + f"{op_name} with dynamic {param_name} is not supported. " + f"{param_name} requires a static float value, but got {value} (type={type(value).__name__})." + ) + + +def require_static_ints( + values: Union[List[Any], Any], param_name: str, op_name: str +) -> None: + """ + Validate that all values in a list are static integers (not Slots/SymInts). + + Raises NotImplementedError if any value is dynamic. + + Args: + values: List of values to check, or a single value + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if not isinstance(values, list): + values = [values] + + for v in values: + require_static_int(v, param_name, op_name) + + +def require_args( + args: List[Any], + min_count: int, + max_count: int, + op_name: str, +) -> None: + """ + Validate that args count is within expected range. + + Raises ValueError if the count is outside the expected range. + + Args: + args: The handler args list + min_count: Minimum number of args expected + max_count: Maximum number of args expected + op_name: Name of the operation (for error message) + """ + if not (min_count <= len(args) <= max_count): + if min_count == max_count: + raise ValueError(f"{op_name}: expected {min_count} args, got {len(args)}") + raise ValueError( + f"{op_name}: expected {min_count}-{max_count} args, got {len(args)}" + ) + + +def require_kwargs( + kwargs: Dict[str, Any], + allowed: Set[str], + op_name: str, +) -> None: + """ + Validate that only allowed kwargs are present. + + Raises ValueError if unexpected kwargs are found. + + Args: + kwargs: The handler kwargs dict + allowed: Set of allowed kwarg names + op_name: Name of the operation (for error message) + """ + unexpected = set(kwargs.keys()) - allowed + if unexpected: + raise ValueError(f"{op_name}: unexpected kwargs: {unexpected}") + + +def require_contiguous_format( + *, + layout=None, + memory_format=None, + dim_order=None, + op_name: str, +) -> None: + """ + Validate that layout/memory_format/dim_order specify contiguous format. + + MLX only supports contiguous (strided) tensors. Raises ValueError if + sparse layouts or non-contiguous memory formats are requested. + + Args: + layout: The torch layout (e.g., torch.strided, torch.sparse_coo) + memory_format: The torch memory format (e.g., torch.contiguous_format, + torch.channels_last) + dim_order: The dimension order (list of ints, identity = contiguous) + op_name: Name of the operation (for error message) + """ + if layout is not None and layout != torch.strided: + raise ValueError(f"{op_name}: only strided layout supported, got {layout}") + + if memory_format is not None and memory_format not in ( + torch.contiguous_format, + torch.preserve_format, + ): + raise ValueError( + f"{op_name}: only contiguous memory format supported, got {memory_format}" + ) + + if dim_order is not None: + if list(dim_order) != list(range(len(dim_order))): + raise ValueError( + f"{op_name}: only contiguous dim_order supported, got {dim_order}" + ) + + +def is_static_value(value: Any) -> bool: + """ + Check if a value is static (not a Slot/SymInt). + + Returns: + True if the value is a static scalar (int, float, bool), False otherwise + """ + return not isinstance(value, Slot) + + +def used_getitem_indices(n: Node) -> Set[int]: + """Return the set of getitem indices actually consumed downstream. + + Only includes indices where the getitem node has at least one user. + """ + return { + user.args[1] + for user in n.users + if user.target == operator.getitem and len(user.users) > 0 + } + + +def normalize_reduction_dim( + args: List[Any], start_idx: int = 1 +) -> Tuple[Optional[List[int]], bool]: + """ + Normalize dim argument for reduction operations. + + Extracts and normalizes the dim argument from handler args, returning a list of axes + and the keepdim flag. Handles both list-based dims (e.g., sum.dim_IntList) and + single int dims (e.g., prod.dim_int). + + Args: + args: The handler args list + start_idx: Index where the dim argument starts (default 1, after self) + + Returns: + Tuple of (axes, keepdim) where: + - axes: List of dimension indices, or empty list for reduce-all + - keepdim: Boolean keepdim flag (default False) + """ + if len(args) > start_idx and isinstance(args[start_idx], (list, tuple)): + dim = list(args[start_idx]) + keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False + elif len(args) > start_idx and isinstance(args[start_idx], int): + dim = [args[start_idx]] + keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False + else: + dim = [] + keepdim = False + + return dim, keepdim + + +@REGISTRY.register(target=[torch.ops.aten.addmm.default]) +def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle addmm: self + (mat1 @ mat2). + + addmm(self, mat1, mat2, *, beta=1, alpha=1) computes: + beta * self + alpha * (mat1 @ mat2) + + This is typically the result of decomposing linear(x, w, b) in Edge IR: + permute(w) -> addmm(b, x, permuted_w) + + For the common case where beta=1 and alpha=1, this is equivalent to: + mat1 @ mat2 + self + + We use AddmmNode which calls matmul directly (no transposition needed). + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 3, 3, "aten.addmm") + require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm") + bias, mat1, mat2 = args[0], args[1], args[2] + + beta = kwargs.get("beta", 1) + alpha = kwargs.get("alpha", 1) + + out = P.make_or_get_slot(n) + + # Emit AddmmNode with alpha and beta parameters + P.emit( + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(bias), + alpha=float(alpha), + beta=float(beta), + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.mm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, + ] +) +def _mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle mm/bmm/matmul: matrix multiplication without bias. + + All three ops compute matrix products with different dimension expectations: + - mm: 2D x 2D + - bmm: 3D x 3D (batched) + - matmul: arbitrary dimensions (NumPy semantics) + + MLX's matmul handles all cases, so we emit AddmmNode with bias=None. + """ + args = P.args(n) + require_args(args, 2, 2, "aten.mm/bmm/matmul") + require_kwargs(P.kwargs(n), set(), "aten.mm/bmm/matmul") + mat1, mat2 = args[0], args[1] + + out = P.make_or_get_slot(n) + + P.emit( + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), + out=P.slot_to_tid(out), + bias=None, + ) + ) + return out diff --git a/backends/mlx/partitioner.py b/backends/mlx/partitioner.py new file mode 100644 index 00000000000..0896cafc301 --- /dev/null +++ b/backends/mlx/partitioner.py @@ -0,0 +1,298 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Partitioner - decides which ops should run on the MLX delegate. + +This module provides a Partitioner implementation that analyzes an EdgeIR +graph and marks supported operations for delegation to MLX. +""" + +from __future__ import annotations + +import inspect +from typing import Any, Callable, Dict, List, Tuple, Union + +import torch +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.preprocess import MLXBackend +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_partitions_from_list_of_nodes, +) +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import Partition +from torch.fx.passes.operator_support import OperatorSupportBase + + +class MLXOperatorSupport(OperatorSupportBase): + """ + Determines which operators are supported by the MLX delegate. + + Uses MLXProgramBuilder to determine support - this ensures the partitioner + uses the exact same logic as the actual compilation. A node is supported + if the builder can handle it (either via direct handler or pattern match). + """ + + def __init__( + self, + edge_program: torch.export.ExportedProgram, + compile_specs: List[CompileSpec], + ): + self.edge_program = edge_program + self.compile_specs = compile_specs + + # Run the builder to determine which nodes are supported + # The builder populates node_info with supported/unsupported status + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + self._builder = MLXProgramBuilder(edge_program) + self._builder.check_support_only() + + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + # Check if builder determined this node is supported + info = self._builder.node_info.get(node) + if info is not None and info.supported: + logger.debug(f"[SUPPORTED] Node {node.target}") + return True + + logger.debug(f"[UNSUPPORTED] Node {node.target}") + return False + + +class MLXPartitioner(Partitioner): + """ + Partitioner for the MLX delegate. + + Analyzes an EdgeIR graph and partitions supported operations + for delegation to MLX. + """ + + def __init__(self, compile_specs: List[CompileSpec] | None = None) -> None: + self.compile_specs = compile_specs or [] + self.delegation_spec = DelegationSpec(MLXBackend.__name__, self.compile_specs) + self.partition_tags: Dict[str, DelegationSpec] = {} + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> tuple[list[torch._ops.OpOverload], Callable[[torch.fx.Node], bool] | None]: + """ + Return ops that should NOT be decomposed during edge lowering. + + This runs the MLXProgramBuilder to trace through the graph and determine + which nodes are supported (either via direct handlers or patterns). + Only ops for nodes that are actually supported should be preserved. + + This is called by to_edge_transform_and_lower to determine which + ops to preserve before partitioning. + + NOTE: We use check_support_only() instead of build() to avoid corrupting + the shape_env. build() calls _build_mlx_graph() which evaluates SymInts + to concrete values when converting tensor shapes, which corrupts the + shape_env and causes dynamic shapes to be lost during decomposition. + """ + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + # Check if the graph already contains lowered modules (post-partitioning pass) + # In this case, we should return empty since partitioning is already done + for node in ep.graph.nodes: + if node.op == "get_attr" and "lowered_module" in node.name: + logger.debug( + "MLX ops_to_not_decompose: Graph already partitioned, returning empty" + ) + return ([], None) + + # Run the builder to determine which nodes are supported + builder = MLXProgramBuilder(ep) + builder.check_support_only() + + # Collect ops for nodes that are actually supported + do_not_decompose: list[torch._ops.OpOverload] = [] + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + info = builder.node_info.get(node) + if info is not None and info.supported: + if node.target not in do_not_decompose: + do_not_decompose.append(node.target) + + logger.debug( + f"MLX ops_to_not_decompose: {[str(op) for op in do_not_decompose]}" + ) + return (do_not_decompose, None) + + def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: + """Generate partitions of supported nodes.""" + self.supported_ops = MLXOperatorSupport( + edge_program=edge_program, + compile_specs=self.delegation_spec.compile_specs, + ) + + # Collect unsupported ops, aggregated by target + unsupported_by_target: Dict[str, Tuple[int, str]] = ( + {} + ) # target -> (count, reason) + for node in edge_program.graph.nodes: + is_supported = self.supported_ops.is_node_supported({}, node) + if not is_supported and node.op == "call_function": + target_str = str(node.target) + info = self.supported_ops._builder.node_info.get(node) + reason = info.unsupported_reason if info else "No handler registered" + if target_str in unsupported_by_target: + count, _ = unsupported_by_target[target_str] + unsupported_by_target[target_str] = (count + 1, reason) + else: + unsupported_by_target[target_str] = (1, reason) + + logger.info("=" * 80) + logger.info("MLX Partitioner: UNSUPPORTED OPS SUMMARY") + logger.info("=" * 80) + if unsupported_by_target: + for target, (count, reason) in unsupported_by_target.items(): + logger.info(f" [UNSUPPORTED x{count}] {target}") + logger.info(f" Reason: {reason}") + else: + logger.info(" (All call_function nodes are supported!)") + logger.info("=" * 80) + + partitions = generate_partitions_from_list_of_nodes( + edge_program.graph_module, + op_support=self.supported_ops, + ) + + # WORKAROUND: Include sym_size nodes in partitions when any of their + # users are in the partition. Without this, sym_size nodes stay outside + # the partition and their results cross the partition boundary as concrete + # inputs, losing dynamic shape information during delegate lowering. + # By pulling them inside, the MLX runtime can execute SYM_SIZE at runtime, + # keeping shapes dynamic. + partitions = self._include_sym_size_nodes_in_partitions( + edge_program.graph_module, partitions + ) + + return partitions + + def _include_sym_size_nodes_in_partitions( + self, gm: torch.fx.GraphModule, partitions: List[Partition] + ) -> List[Partition]: + """ + Include sym_size nodes in partitions when any of their users are in the partition. + + This is a workaround for the dynamic shapes bug where symbolic shapes are lost + during delegate lowering if the sym_size node is not included in the partition. + """ + from executorch.exir.dialects.edge._ops import EdgeOpOverload + + for partition in partitions: + partition_nodes = set(partition.nodes) + nodes_to_add = [] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + + # Check if this is a sym_size node + target = node.target + if isinstance(target, EdgeOpOverload): + target = target._op + + if target != torch.ops.aten.sym_size.int: + continue + + # Check if any user of this sym_size node is in the partition + for user in node.users: + if user in partition_nodes: + # Add sym_size to partition if not already there + if node not in partition_nodes: + nodes_to_add.append(node) + logger.debug( + f"Adding sym_size node {node.name} to partition " + f"(used by {user.name})" + ) + break + + # Add the sym_size nodes to the partition + for node in nodes_to_add: + partition.add_node(node) + + return partitions + + def tag_nodes(self, partitions: List[Partition]) -> None: + """Tag nodes in each partition for delegation.""" + for partition in partitions: + delegation_tag = f"mlx_{partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = delegation_tag + self.partition_tags[delegation_tag] = self.delegation_spec + + @staticmethod + def check_partitions(partitions: Union[dict, list]) -> bool: + """Check if any partitions were found.""" + pl = len(partitions) + if pl == 0: + logger.warning("MLX: Nothing can be partitioned!") + else: + logger.info(f"MLX: Found {pl} subgraphs to be partitioned.") + return pl != 0 + + @staticmethod + def _is_to_edge_transform_and_lower() -> bool: + """Check whether we are being called from to_edge_transform_and_lower.""" + for frame_info in inspect.stack(): + if frame_info.function == "to_edge_transform_and_lower": + return True + return False + + def partition(self, edge_program: ExportedProgram) -> PartitionResult: + """ + Partition the edge program for MLX delegation. + + Args: + edge_program: The ExportedProgram to partition. + + Returns: + PartitionResult with tagged nodes and partition specs. + + Raises: + RuntimeError: If called from the deprecated ``to_edge`` workflow. + """ + if not self._is_to_edge_transform_and_lower(): + raise RuntimeError( + "MLXPartitioner must be used with to_edge_transform_and_lower(). " + "The to_edge() + to_backend() workflow is not supported because " + "it decomposes ops that MLX has optimized implementations for. " + "Please use:\n" + " exir.to_edge_transform_and_lower(\n" + ' {"forward": exported_program},\n' + " partitioner=[MLXPartitioner()],\n" + " )" + ) + partitions = self.generate_partitions(edge_program=edge_program) + if self.check_partitions(partitions): + self.tag_nodes(partitions) + # Tag constant data that are used by the supported ops + tag_constant_data(edge_program) + # Tag mutated buffers so they are included in the partition + # This ensures the partitioned subgraph has proper mutation tracking + tag_mutated_buffer(edge_program) + + return PartitionResult( + tagged_exported_program=edge_program, + partition_tags=self.partition_tags, + ) diff --git a/backends/mlx/passes.py b/backends/mlx/passes.py new file mode 100644 index 00000000000..c7efdf561de --- /dev/null +++ b/backends/mlx/passes.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph transformation passes for the MLX backend. +""" + +from typing import List + +from executorch.exir.pass_base import ExportPass + + +def get_default_passes() -> List[ExportPass]: + """ + Returns a list of passes that are enabled by default for the MLX backend. + """ + return [] diff --git a/backends/mlx/patches/mlx_json.patch b/backends/mlx/patches/mlx_json.patch new file mode 100644 index 00000000000..4760403c8e6 --- /dev/null +++ b/backends/mlx/patches/mlx_json.patch @@ -0,0 +1,29 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -304,12 +304,18 @@ else() + set(MLX_BUILD_ACCELERATE OFF) + endif() + +-message(STATUS "Downloading json") +-FetchContent_Declare( +- json +- URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +-FetchContent_MakeAvailable(json) +-target_include_directories( +- mlx PRIVATE $) ++# Only fetch json if nlohmann_json target doesn't already exist ++# (ExecuTorch provides its own copy) ++if(NOT TARGET nlohmann_json) ++ message(STATUS "Downloading json") ++ FetchContent_Declare( ++ json ++ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) ++ FetchContent_MakeAvailable(json) ++ target_include_directories( ++ mlx PRIVATE $) ++else() ++ message(STATUS "Using existing nlohmann_json target") ++endif() + + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) diff --git a/backends/mlx/pattern_utils.py b/backends/mlx/pattern_utils.py new file mode 100644 index 00000000000..0d3d86430eb --- /dev/null +++ b/backends/mlx/pattern_utils.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared pattern matching utilities for MLX backend. + +This module provides common utilities used by both: +- passes.py: Graph transformation passes (ExportPass) +- patterns.py: MLX lowering pattern handlers (PatternHandler) + +The core abstraction is the `PatternMatch` base class which provides: +- `maybe_create(head)` - Class method to match a pattern from a head node +- Captured values as typed fields +- `body` list of intermediate nodes to remove + +Usage in passes.py: + class FuseRMSNormPass(ExportPass): + def call(self, graph_module): + for node in graph.nodes: + if match := RMSNormMatch.maybe_create(node): + replacement = self._emit_fused_op(graph, match) + node.replace_all_uses_with(replacement) + match.remove_body_nodes(graph) + +Usage in patterns.py: + class RMSNormHandler(PatternHandler): + @classmethod + def maybe_create(cls, ep, head): + if match := RMSNormMatch.maybe_create(head): + return cls(head, match.body, match.input, match.weight, match.eps) + return None +""" + +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Set, Tuple, Union + +from executorch.backends.mlx.builder.op_helpers import get_aten_target_normalized +from torch.fx import Graph +from torch.fx.node import Node + + +# Type alias for walk_back result entries +# Each entry corresponds to an OpStep: +# - Node: matched node (for regular steps) +# - None: optional step that didn't match +# - List[Node]: repeat step (0 or more matches) +WalkBackEntry = Union[Node, None, List[Node]] + + +def match_target(node: Node, op: Any) -> bool: + """ + Check if a node's normalized aten target matches the given op. + + Uses get_aten_target_normalized to handle edge dialect ops. + This means slice_copy matches slice, etc. + + Args: + node: The node to check + op: The op to match (e.g., torch.ops.aten.mul.Tensor) + """ + return node.op == "call_function" and get_aten_target_normalized(node.target) == op + + +def has_single_user(node: Node) -> bool: + return len(node.users) == 1 + + +def has_no_users(node: Node) -> bool: + return len(node.users) == 0 + + +def extract_lifted_tensor_constant(node: Node) -> Optional[float]: + """ + Extract scalar value from a lifted tensor constant node. + + Lifted constants are created during torch.export and contain small + constant tensors (like epsilon values). The actual value is stored + in node.meta["val"]. + + Args: + node: A node that may be a lifted tensor constant + + Returns: + The scalar float value, or None if not a lifted constant or not scalar + """ + if not isinstance(node, Node): + return None + if "lifted_tensor_constant" not in node.name: + return None + val = node.meta.get("val") + if val is None: + return None + if not hasattr(val, "item"): + return None + try: + return float(val.item()) + except (RuntimeError, ValueError): + return None + + +@dataclass +class OpStep: + """ + One step in a backward walk through the graph. + + Used with walk_back() to define pattern chains. Supports both exact op + matching and predicate-based matching. + + Attributes: + op: Specific op to match (e.g., torch.ops.aten.rsqrt.default) + predicate: Alternative to op - a function that returns True for matching nodes + optional: If True, skip this step if it doesn't match + repeat: If True, match this step 0 or more times (like regex *) + require_single_user: If True (default), only match nodes with exactly one user + nargs: Number of args required. Can be: + - int: minimum number of args (default 1, since we advance via args[0]) + - tuple (min, max): range of args required (inclusive) + kwargs: Set of kwargs we handle (node's kwargs must be subset of this) + arg_index: Which arg to follow when advancing (default 0) + + Examples: + # Match specific op + OpStep(op=torch.ops.aten.rsqrt.default) + + # Match with predicate (for matching families of ops) + OpStep(predicate=lambda n: match_target(n, torch.ops.aten.select.int)) + + # Match chain of same op type (0 or more) + OpStep(op=torch.ops.aten.select.int, repeat=True) + + # Optional dtype conversion + OpStep(op=torch.ops.aten._to_copy.default, optional=True) + + # Require between 2 and 4 args + OpStep(op=torch.ops.aten.some_op.default, nargs=(2, 4)) + + # Declare that we handle 'dtype' kwarg + OpStep(op=torch.ops.aten._to_copy.default, kwargs={"dtype"}) + + # Follow second arg (e.g., mul(x, rsqrt(y)) -> follow rsqrt in args[1]) + OpStep(op=torch.ops.aten.mul.Tensor, arg_index=1) + """ + + op: Any = None + predicate: Optional[Callable[[Node], bool]] = None + optional: bool = False + repeat: bool = False + require_single_user: bool = True + nargs: Union[int, Tuple[int, int]] = 1 + kwargs: Set[str] = field(default_factory=set) # Empty = no kwargs allowed + arg_index: int = 0 + + def matches(self, node: Node) -> bool: + """Check if this step fully matches the given node.""" + # Check op or predicate + if self.op is not None: + if not match_target(node, self.op): + return False + elif self.predicate is not None: + if not self.predicate(node): + return False + else: + return False + + # Check single user requirement + if self.require_single_user and not has_single_user(node): + return False + + # Check nargs and kwargs + if not self._check_nargs(node): + return False + if not self._check_kwargs(node): + return False + + return True + + def _check_nargs(self, node: Node) -> bool: + """Check if node has the required number of args.""" + n = len(node.args) + if isinstance(self.nargs, tuple): + min_args, max_args = self.nargs + # Must be in range AND enough to access arg_index + return min_args <= n <= max_args and n > self.arg_index + else: + # Must have at least nargs, AND enough to access arg_index + return n >= self.nargs and n > self.arg_index + + def _check_kwargs(self, node: Node) -> bool: + """Check that node's kwargs are all declared in self.kwargs (no unhandled kwargs).""" + return set(node.kwargs.keys()).issubset(self.kwargs) + + +def walk_back( # noqa: C901 + node: Node, + steps: List[OpStep], + debug: bool = False, +) -> Optional[Tuple[Node, List[WalkBackEntry]]]: + """ + Walk backwards through a chain of ops, matching against a pattern. + + Starting from *node*, try to match each step against the current node. + At every matched step the walk advances to ``cur.args[step.arg_index]``. + Optional steps are silently skipped when they don't match. Repeat steps + match 0 or more times. + + Args: + node: Starting node + steps: List of OpStep to match in order + + Returns: + ``(base_node, entries)`` if the full chain matches, else ``None``. + *base_node* is the input to the first (deepest) op in the chain. + *entries* is a list with one entry per OpStep: + - Node: matched node (for regular steps) + - None: optional step that didn't match + - List[Node]: repeat step (0 or more matches) + + Examples: + # Match: rsqrt(add(mean(pow(x, 2)), eps)) + result = walk_back(rsqrt_node, [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.add.Tensor), + OpStep(op=torch.ops.aten.mean.dim), + OpStep(op=torch.ops.aten.pow.Tensor_Scalar), + ]) + if result: + base, entries = result + rsqrt, add, mean, pow = entries # Each is a Node + + # Match chain of select ops (like tensor[0][0]) + result = walk_back(node, [ + OpStep(op=torch.ops.aten.select.int, repeat=True), + ]) + if result: + base, entries = result + select_nodes = entries[0] # List[Node], may be empty + + # Skip optional _to_copy, then match rsqrt + result = walk_back(node, [ + OpStep(op=torch.ops.aten._to_copy.default, optional=True), + OpStep(op=torch.ops.aten.rsqrt.default), + ]) + if result: + base, entries = result + to_copy, rsqrt = entries # to_copy may be None + """ + entries: List[WalkBackEntry] = [] + cur = node + + for i, step in enumerate(steps): + if not isinstance(cur, Node): + if debug: + print( + f" [walk_back] step {i}: cur is not a Node ({type(cur).__name__})" + ) + return None + + if step.repeat: + # Match 0 or more times, return as list + matched_nodes: List[Node] = [] + while isinstance(cur, Node) and step.matches(cur): + matched_nodes.append(cur) + cur = cur.args[step.arg_index] + entries.append(matched_nodes) + if debug: + print( + f" [walk_back] step {i} (repeat): matched {len(matched_nodes)} nodes" + ) + # repeat always succeeds (matches 0 or more) + continue + + if step.matches(cur): + entries.append(cur) + if debug: + print(f" [walk_back] step {i}: matched {cur.name}") + cur = cur.args[step.arg_index] + elif step.optional: + entries.append(None) + if debug: + print(f" [walk_back] step {i} (optional): skipped, cur={cur.name}") + continue + else: + if debug: + print( + f" [walk_back] step {i}: FAILED at cur={cur.name}, target={cur.target}, step.op={step.op}" + ) + return None + + if not isinstance(cur, Node): + return None + + return cur, entries + + +@dataclass +class PatternMatch: + """ + Base class for pattern match results. + + Subclasses should: + 1. Add fields for captured values (input nodes, constants, etc.) + 2. Implement maybe_create() classmethod for pattern matching + 3. Optionally implement emit_* methods for specific backends + + Example: + @dataclass + class RMSNormMatch(PatternMatch): + input_node: Node + weight_node: Node + eps: float + + @classmethod + def maybe_create(cls, head: Node) -> Optional["RMSNormMatch"]: + # Pattern matching logic... + if not matched: + return None + return cls( + head=head, + body=body_nodes, + input_node=input_node, + weight_node=weight_node, + eps=eps_value, + ) + """ + + head: Node # The output node of the matched pattern + body: List[Node] = field(default_factory=list) # Intermediate nodes + + @classmethod + def maybe_create(cls, head: Node, **context) -> Optional["PatternMatch"]: + """ + Try to match the pattern starting from head node. + + Override in subclasses to implement pattern-specific matching. + + Args: + head: Candidate head node to match from + **context: Additional context (e.g., ExportedProgram for patterns.py) + + Returns: + PatternMatch instance with captured values, or None if no match + """ + return None + + def remove_body_nodes(self, graph: Graph) -> None: + """ + Remove body nodes from the graph (in reverse order for safety). + + Call after replacing head with fused op. + """ + for node in reversed(self.body): + if has_no_users(node): + graph.erase_node(node) + + def all_nodes(self) -> List[Node]: + """Return all nodes in the pattern (head + body).""" + return [self.head] + self.body diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py new file mode 100644 index 00000000000..c8bef1f91ca --- /dev/null +++ b/backends/mlx/patterns.py @@ -0,0 +1,14 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Pattern Handlers - pattern-based op lowering for fused operations. + +This module contains pattern handlers that match multi-node subgraphs and lower +them to optimized MLX operations. +""" diff --git a/backends/mlx/preprocess.py b/backends/mlx/preprocess.py new file mode 100644 index 00000000000..315835f1689 --- /dev/null +++ b/backends/mlx/preprocess.py @@ -0,0 +1,168 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Backend preprocessing - converts EdgeIR to MLX delegate payload. + +This module implements the BackendDetails.preprocess() method which: +1. Takes an ExportedProgram (edge dialect) +2. Builds an MLXGraph using MLXProgramBuilder +3. Serializes to FlatBuffer (no embedded constants - those come via named_data_map) +4. Returns PreprocessResult with the binary and data_store_output for constants +""" + +from __future__ import annotations + +import hashlib +from typing import ClassVar, final, List + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( + HEADER_LENGTH, + MAGIC, + serialize_mlx_graph, +) +from executorch.exir.backend.backend_details import ( + BackendDetails, + CompileSpec, + PreprocessResult, +) +from torch.export.exported_program import ExportedProgram + + +@final +class MLXBackend(BackendDetails): + """ + ExecuTorch backend for MLX (Apple Silicon GPU compute framework). + + This backend compiles EdgeIR programs to a custom bytecode format + that can be executed by the MLX C++ runtime. + + Constants (weights) are stored in ExecuTorch's named_data_map rather than + embedded in the delegate payload. This allows ExecuTorch to own the constant + data and provide it to the backend at runtime. + """ + + MAGIC_IX: ClassVar[slice] = slice(4, 8) + DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16) + DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24) + + EXPECTED_MAGIC: ClassVar[bytes] = MAGIC + EXPECTED_LENGTH: ClassVar[int] = HEADER_LENGTH + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Convert an ExportedProgram to MLX delegate payload. + + Args: + edge_program: The ExportedProgram in edge dialect to compile. + compile_specs: List of compilation options. + + Returns: + PreprocessResult containing the serialized MLX program and + data_store_output with constant tensor data. + """ + logger.debug("MLXBackend.preprocess() called") + logger.debug(f"Edge program:\n{edge_program}") + + # Build MLXGraph from ExportedProgram + # Use a deterministic 4-hex prefix derived from the edge program to + # namespace named_data keys, avoiding collisions in multi-method + # programs where different methods may have lifted tensor constants + # with the same auto-generated name. + prefix = hashlib.sha256(str(edge_program).encode()).hexdigest()[:4] + builder = MLXProgramBuilder(edge_program, named_data_key_prefix=prefix) + mlx_graph = builder.build() + + # Get constant data as NamedDataStore (ET will own this data) + named_data_store = builder.get_named_data_store() + + logger.debug(f" named_data_store entries: {len(named_data_store.pte_data)}") + _log_mlx_graph(mlx_graph) + + # Serialize to bytes (no constant data embedded) + serialized = serialize_mlx_graph(mlx_graph) + + logger.debug(f"MLXBackend.preprocess() complete: {len(serialized)} bytes") + + return PreprocessResult( + processed_bytes=serialized, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + +def _format_tensor_meta(meta) -> str: + """Format a TensorMeta for display.""" + shape_parts = [] + for dim in meta.shape: + if dim.value == -1: + # Dynamic dim + if dim.max_value == -1: + shape_parts.append(f"dyn(min={dim.min_value})") + else: + shape_parts.append(f"dyn({dim.min_value}..{dim.max_value})") + else: + shape_parts.append(str(dim.value)) + shape_str = f"[{', '.join(shape_parts)}]" + dtype_str = f"dtype={meta.scalar_type}" if meta.scalar_type is not None else "" + dim_order_str = f"dim_order={meta.dim_order}" if meta.dim_order is not None else "" + parts = [shape_str] + if dtype_str: + parts.append(dtype_str) + if dim_order_str: + parts.append(dim_order_str) + return ", ".join(parts) + + +def _log_mlx_graph(mlx_graph) -> None: # noqa: C901 + """Log MLXGraph contents at DEBUG level for debugging.""" + logger.debug("MLXGraph:") + logger.debug(f" version: {mlx_graph.version}") + logger.debug(f" num_constant_tensors: {mlx_graph.num_constant_tensors}") + logger.debug(f" num_input_tensors: {mlx_graph.num_input_tensors}") + logger.debug(f" num_output_tensors: {mlx_graph.num_output_tensors}") + logger.debug( + f" num_mutable_buffer_tensors: {mlx_graph.num_mutable_buffer_tensors}" + ) + logger.debug(f" num_temp_tensors: {mlx_graph.num_temp_tensors}") + logger.debug(f" num_values: {mlx_graph.num_values}") + logger.debug(f" instruction_chains ({len(mlx_graph.instruction_chains)}):") + for c, chain in enumerate(mlx_graph.instruction_chains): + label = "" + if c == mlx_graph.main_chain_idx: + label = " (main)" + elif c == mlx_graph.init_chain_idx: + label = " (init)" + logger.debug(f" chain {c}{label} ({len(chain.instructions)} instructions):") + for i, instr in enumerate(chain.instructions): + logger.debug(f" [{i}]: {type(instr.op).__name__}") + if mlx_graph.input_map: + logger.debug(f" input_map ({len(mlx_graph.input_map)}):") + for i, slot in enumerate(mlx_graph.input_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.output_map: + logger.debug(f" output_map ({len(mlx_graph.output_map)}):") + for i, slot in enumerate(mlx_graph.output_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.mutable_buffer_map: + logger.debug(f" mutable_buffer_map ({len(mlx_graph.mutable_buffer_map)}):") + for i, slot in enumerate(mlx_graph.mutable_buffer_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.named_slots: + logger.debug(f" named_slots ({len(mlx_graph.named_slots)}):") + for ns in mlx_graph.named_slots: + logger.debug(f" {ns.name}: {ns.slot}") + if mlx_graph.tensor_meta: + logger.debug(f" tensor_meta ({len(mlx_graph.tensor_meta)}):") + for i, meta in enumerate(mlx_graph.tensor_meta): + logger.debug(f" t{i}: {_format_tensor_meta(meta)}") diff --git a/backends/mlx/pte_inspector.py b/backends/mlx/pte_inspector.py new file mode 100644 index 00000000000..d9e533b0b1e --- /dev/null +++ b/backends/mlx/pte_inspector.py @@ -0,0 +1,897 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PTE Inspector - Extract and dump data from ExecuTorch .pte files. + +This utility can: +1. Parse the PTE file structure (header, flatbuffer, segments) +2. Extract delegate payloads (e.g., MLX backend data) +3. Convert FlatBuffer data to JSON for inspection + +Usage: + python pte_inspector.py mlx_mlp.pte + python pte_inspector.py mlx_mlp.pte --output output.json + python pte_inspector.py mlx_mlp.pte --extract-delegate mlx --output mlx_payload.bin +""" + +from __future__ import annotations + +import argparse +import json +import sys +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from executorch.backends.mlx._generated_inspector import OP_NODE_FIELDS +from executorch.backends.mlx.serialization._generated_serializers import ( + MLX_OP_TYPE_NAMES, +) +from executorch.exir._serialize._program import ( + _ExtendedHeader, + _extract_delegate_payload as extract_delegate_payload, +) + +MLX_MAGIC = b"MLX0" +MLX_HEADER_LENGTH = 24 + +_SLOT_TYPE_NAMES = {0: "Tensor", 1: "Int", 2: "Float", 3: "Bool"} + + +@dataclass +class MLXHeader: + + magic: bytes + data_segment_offset: int + data_segment_size: int + + @classmethod + def from_bytes(cls, data: bytes) -> "MLXHeader": + if len(data) < MLX_HEADER_LENGTH: + raise ValueError( + f"Not enough data for MLX header: {len(data)} < {MLX_HEADER_LENGTH}" + ) + + # Layout: [4 bytes padding][4 bytes magic][8 bytes offset][8 bytes size] + magic = data[4:8] + data_segment_offset = int.from_bytes(data[8:16], byteorder="little") + data_segment_size = int.from_bytes(data[16:24], byteorder="little") + + return cls( + magic=magic, + data_segment_offset=data_segment_offset, + data_segment_size=data_segment_size, + ) + + def is_valid(self) -> bool: + return self.magic == MLX_MAGIC + + def to_dict(self) -> Dict[str, Any]: + return { + "magic": self.magic.decode("utf-8", errors="replace"), + "data_segment_offset": self.data_segment_offset, + "data_segment_size": self.data_segment_size, + } + + +@dataclass +class MLXPayload: + """Parsed MLX delegate payload: header + flatbuffer bytes.""" + + header: MLXHeader + fb_data: bytes + raw: bytes + + +def _load_mlx_payload(pte_data: bytes, delegate_index: int = 0) -> MLXPayload: + """Extract MLX delegate payload from PTE data and parse its header. + + Raises ``ValueError`` if the delegate cannot be found or the MLX header is + invalid. + """ + payload = extract_delegate_payload(pte_data, "mlx", delegate_index=delegate_index) + if payload is None: + raise ValueError(f"Could not extract MLX delegate {delegate_index}") + + header = MLXHeader.from_bytes(payload) + if not header.is_valid(): + raise ValueError(f"Invalid MLX magic: {header.magic!r}") + + fb_data = payload[MLX_HEADER_LENGTH : header.data_segment_offset] + return MLXPayload(header=header, fb_data=fb_data, raw=payload) + + +def _find_mlx_delegates(pte_data: bytes) -> List[Tuple[int, Dict]]: + """Return list of ``(plan_index, delegate_dict)`` for every MLX delegate.""" + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + program_data = json.loads(_program_flatbuffer_to_json(pte_data)) + delegates: List[Tuple[int, Dict]] = [] + for plan in program_data.get("execution_plan", []): + for i, delegate in enumerate(plan.get("delegates", [])): + if "mlx" in delegate.get("id", "").lower(): + delegates.append((i, delegate)) + return delegates + + +def _get_fb_graph(fb_data: bytes): + """Return the FlatBuffer MLXGraph root object.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + MLXGraph as FBMLXGraph, + ) + + return FBMLXGraph.MLXGraph.GetRootAs(fb_data, 0) + + +def _parse_graph_info(graph) -> Dict[str, Any]: + """Extract top-level graph scalars (tensor counts, chain counts, etc.).""" + return { + "version": graph.Version().decode("utf-8") if graph.Version() else None, + "num_constant_tensors": graph.NumConstantTensors(), + "num_input_tensors": graph.NumInputTensors(), + "num_output_tensors": graph.NumOutputTensors(), + "num_mutable_buffer_tensors": graph.NumMutableBufferTensors(), + "num_temp_tensors": graph.NumTempTensors(), + "num_values": graph.NumValues(), + "num_instruction_chains": graph.InstructionChainsLength(), + "main_chain_idx": graph.MainChainIdx(), + "init_chain_idx": graph.InitChainIdx(), + "input_map_length": graph.InputMapLength(), + "output_map_length": graph.OutputMapLength(), + "mutable_buffer_map_length": graph.MutableBufferMapLength(), + "named_slots_length": graph.NamedSlotsLength(), + "tensor_meta_length": graph.TensorMetaLength(), + } + + +def _parse_instructions(graph) -> List[Dict[str, Any]]: + """Parse all instruction chains and their op nodes.""" + chains: List[Dict[str, Any]] = [] + for c in range(graph.InstructionChainsLength()): + chain = graph.InstructionChains(c) + chain_info: Dict[str, Any] = {"chain_index": c, "instructions": []} + if chain: + for i in range(chain.InstructionsLength()): + try: + instr = chain.Instructions(i) + if instr: + op_type = instr.OpType() + op_name = MLX_OP_TYPE_NAMES.get(op_type, f"Unknown({op_type})") + instr_info: Dict[str, Any] = { + "instr_idx": i, + "op_type": op_type, + "op_name": op_name, + } + op_data = _parse_op_node(instr, op_name) + if op_data: + instr_info.update(op_data) + chain_info["instructions"].append(instr_info) + except Exception as e: + chain_info["instructions"].append( + {"instr_idx": i, "error": f"parse_failed: {e}"} + ) + chains.append(chain_info) + return chains + + +def _parse_named_slots(graph) -> List[Dict[str, Any]]: + slots: List[Dict[str, Any]] = [] + for i in range(graph.NamedSlotsLength()): + try: + ns = graph.NamedSlots(i) + if ns: + info: Dict[str, Any] = { + "name": ns.Name().decode("utf-8") if ns.Name() else None, + } + slot = ns.Slot() + if slot: + info["slot_idx"] = slot.Idx() + info["slot_type"] = slot.SlotType() + slots.append(info) + except Exception as e: + slots.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return slots + + +def _parse_tensor_meta(graph) -> List[Dict[str, Any]]: + metas: List[Dict[str, Any]] = [] + for i in range(graph.TensorMetaLength()): + try: + tm = graph.TensorMeta(i) + if tm: + shape: List[Any] = [] + for j in range(tm.ShapeLength()): + sd = tm.Shape(j) + if sd.Value() == -1: + lo = sd.MinValue() + hi = sd.MaxValue() + if hi == -1: + shape.append(f"dyn(min={lo})") + else: + shape.append(f"dyn({lo}..{hi})") + else: + shape.append(sd.Value()) + meta: Dict[str, Any] = { + "index": i, + "dtype": tm.Dtype(), + "shape": shape, + } + if tm.StridesLength() > 0: + meta["strides"] = [tm.Strides(j) for j in range(tm.StridesLength())] + metas.append(meta) + except Exception as e: + metas.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return metas + + +def _parse_io_maps( + graph, +) -> Tuple[List[Dict], List[Dict], List[Dict]]: + """Return (input_map, output_map, mutable_buffer_map) as slot-variant dicts.""" + + def _extract( + length_fn: Callable[[], int], getter_fn: Callable[[int], Any] + ) -> List[Dict]: + result = [] + for i in range(length_fn()): + try: + sv = getter_fn(i) + if sv: + result.append({"idx": sv.Idx(), "slot_type": sv.SlotType()}) + except Exception as e: + result.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return result + + return ( + _extract(graph.InputMapLength, graph.InputMap), + _extract(graph.OutputMapLength, graph.OutputMap), + _extract(graph.MutableBufferMapLength, graph.MutableBufferMap), + ) + + +def parse_mlx_flatbuffer(fb_data: bytes) -> Dict[str, Any]: + """Parse MLX FlatBuffer data into a dict using the generated FlatBuffer bindings.""" + result: Dict[str, Any] = {} + try: + graph = _get_fb_graph(fb_data) + + result = _parse_graph_info(graph) + result["instruction_chains"] = _parse_instructions(graph) + result["named_slots"] = _parse_named_slots(graph) + result["tensor_meta"] = _parse_tensor_meta(graph) + + input_map, output_map, mutable_buffer_map = _parse_io_maps(graph) + result["input_map"] = input_map + result["output_map"] = output_map + result["mutable_buffer_map"] = mutable_buffer_map + + try: + cs = graph.ConstantSegment() + if cs: + result["constant_segment"] = { + "offset": cs.Offset(), + "size": cs.Size(), + } + except Exception as e: + result["constant_segment_error"] = f"parse_failed: {e}" + + except ImportError as e: + result["error"] = f"FlatBuffer bindings not available: {e}" + result["_fallback"] = "Using basic header parsing only" + except Exception as e: + result["error"] = f"FlatBuffer parse error: {e}" + result["traceback"] = traceback.format_exc() + + return result + + +def _parse_op_node(instr, op_name: str) -> Optional[Dict[str, Any]]: + """Parse the specific op node fields from an instruction. + + Uses the generated field mappings in ``OP_NODE_FIELDS`` to extract + op-specific fields without manually maintaining per-op logic. + """ + try: + op = instr.Op() + if op is None: + return None + + if op_name not in OP_NODE_FIELDS: + return {"error": f"Unknown op type: {op_name}"} + + module = __import__( + f"executorch.backends.mlx.serialization._generated.mlx_delegate.{op_name}", + fromlist=[op_name], + ) + node_class = getattr(module, op_name) + node = node_class() + node.Init(op.Bytes, op.Pos) + + result: Dict[str, Any] = {} + for field_name, accessor_name, kind in OP_NODE_FIELDS[op_name]: + try: + result[field_name] = _extract_field(node, accessor_name, kind) + except Exception as e: + result[field_name] = {"error": str(e)} + + result = {k: v for k, v in result.items() if v is not None} + return result if result else None + + except Exception as e: + return {"parse_error": str(e), "traceback": traceback.format_exc()} + + +def _extract_vid_or_tid(obj) -> Optional[Dict[str, Any]]: + """Extract a VidOrTid FlatBuffer object into a dict. + + VidOrTid has: .IsVid() -> bool, .Vid() -> Vid|None, .Tid() -> Tid|None. + Same pattern as IntOrVid but references value/tensor slots instead of + holding a literal. + """ + if obj is None: + return None + if obj.IsVid(): + v = obj.Vid() + return {"vid": v.Idx()} if v else None + t = obj.Tid() + return {"tid": t.Idx()} if t else None + + +def _extract_field(node, accessor_name: str, kind: str) -> Any: # noqa: C901 + """Extract a single field from a FlatBuffer op node based on its *kind*.""" + if kind == "tid": + t = getattr(node, accessor_name)() + return {"tid": t.Idx()} if t else None + + if kind == "vid": + v = getattr(node, accessor_name)() + return {"vid": v.Idx()} if v else None + + if kind == "vid_or_tid": + return _extract_vid_or_tid(getattr(node, accessor_name)()) + + if kind == "int_or_vid_or_tid": + ivt = getattr(node, accessor_name)() + if ivt is None: + return None + k = ivt.Kind() + if k == 0: # literal int + return {"literal": ivt.Literal()} + elif k == 1: # Vid + v = ivt.Vid() + return {"vid": v.Idx()} if v else None + elif k == 2: # Tid + t = ivt.Tid() + return {"tid": t.Idx()} if t else None + return {"kind": k} + + if kind == "int_or_vid": + iov = getattr(node, accessor_name)() + if iov is None: + return None + if iov.IsVid(): + v = iov.Vid() + return {"vid": v.Idx()} if v else None + return {"literal": iov.Literal()} + + if kind == "float_or_vid": + fov = getattr(node, accessor_name)() + if fov is None: + return None + if fov.IsVid(): + v = fov.Vid() + return {"vid": v.Idx()} if v else None + return {"literal": fov.Literal()} + + if kind == "int_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + return [getter(i) for i in range(length)] + + if kind == "tid_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + items = [] + for i in range(length): + s = getter(i) + items.append(f"tid {s.Idx()}" if s else None) + return items + + if kind == "int_or_vid_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + items = [] + for i in range(length): + iov = getter(i) + if iov is None: + items.append(None) + elif iov.IsVid(): + v = iov.Vid() + items.append({"vid": v.Idx()} if v else None) + else: + items.append({"literal": iov.Literal()}) + return items + + if kind == "string": + val = getattr(node, accessor_name)() + return val.decode("utf-8") if val else None + + # scalar (default) + return getattr(node, accessor_name)() + + +def parse_mlx_payload(payload: bytes) -> Dict[str, Any]: + """Parse raw MLX delegate payload bytes into a dict. + + This is the public entry point for callers that already have the raw + delegate payload (e.g. from ``extract_delegate_payload``). + """ + header = MLXHeader.from_bytes(payload) + + if not header.is_valid(): + return { + "error": f"Invalid MLX magic: {header.magic!r}", + "header": header.to_dict(), + } + + fb_data = payload[MLX_HEADER_LENGTH : header.data_segment_offset] + result: Dict[str, Any] = { + "header": header.to_dict(), + "flatbuffer_size": len(fb_data), + "graph": parse_mlx_flatbuffer(fb_data), + } + + if header.data_segment_size > 0: + result["constant_data_size"] = header.data_segment_size + + return result + + +def parse_executorch_program(pte_data: bytes) -> Dict[str, Any]: # noqa: C901 + result: Dict[str, Any] = {} + + if len(pte_data) < 8: + raise ValueError("File too small to be a valid PTE file") + + fb_magic = pte_data[4:8] + result["flatbuffer_magic"] = fb_magic.decode("utf-8", errors="replace") + + extended_header_offset = 8 + if len(pte_data) > extended_header_offset + 32: + try: + header = _ExtendedHeader.from_bytes(pte_data[extended_header_offset:]) + if header.is_valid(): + result["extended_header"] = { + "magic": header.magic.decode("utf-8", errors="replace"), + "length": header.length, + "program_size": header.program_size, + "segment_base_offset": header.segment_base_offset, + "segment_data_size": header.segment_data_size, + } + fb_start = extended_header_offset + header.length + result["flatbuffer_offset"] = fb_start + result["flatbuffer_size"] = header.program_size + result["segment_offset"] = header.segment_base_offset + result["segment_size"] = header.segment_data_size + except Exception as e: + result["header_parse_error"] = str(e) + + try: + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + program_data = json.loads(_program_flatbuffer_to_json(pte_data)) + result["program"] = program_data + + if "execution_plan" in program_data: + delegates = [] + for plan in program_data["execution_plan"]: + if "delegates" in plan: + for delegate in plan["delegates"]: + delegate_info: Dict[str, Any] = { + "id": delegate.get("id"), + "processed_type": delegate.get("processed", {}).get( + "location" + ), + } + processed = delegate.get("processed", {}) + if "data" in processed: + delegate_info["inline_data_size"] = len(processed["data"]) + if "location" in processed: + delegate_info["location"] = processed["location"] + delegates.append(delegate_info) + result["delegates"] = delegates + + except ImportError: + result["program_parse_error"] = "ExecuTorch FlatBuffer parsing not available" + except Exception as e: + result["program_parse_error"] = str(e) + + return result + + +def _slot_type_display(slot_type: int, style: str = "full") -> str: + """Return display string for a slot type. + + *style* controls the format: + - ``"full"``: "Tensor", "Int", etc. (for summary tables) + - ``"short"``: "tid", "vid" (for instruction I/O lists) + """ + if style == "short": + return "tid" if slot_type == 0 else "vid" + return _SLOT_TYPE_NAMES.get(slot_type, "Unknown") + + +def _print_slot_map(label: str, slots: List[Dict]) -> None: + """Print a list of slot-variant dicts with their type names.""" + if not slots: + return + print(f"\n {label}:") + for i, slot in enumerate(slots): + type_name = _slot_type_display(slot.get("slot_type", 0)) + print(f" [{i}]: idx={slot.get('idx')}, type={type_name}") + + +def show_mlx_summary(pte_data: bytes) -> None: # noqa: C901 + try: + mlx_delegates = _find_mlx_delegates(pte_data) + if not mlx_delegates: + print("No MLX delegates found in this PTE file.") + return + + print(f"\n{'='*70}") + print("MLX DELEGATE SUMMARY") + print(f"{'='*70}") + print(f"File contains {len(mlx_delegates)} MLX delegate(s)\n") + + for idx, (delegate_idx, delegate) in enumerate(mlx_delegates): + print(f"\n--- Delegate {idx} (plan index {delegate_idx}) ---") + print(f"ID: {delegate.get('id', 'unknown')}") + + try: + mlx = _load_mlx_payload(pte_data, delegate_index=idx) + except ValueError as e: + print(f" {e}") + continue + + graph_info = parse_mlx_flatbuffer(mlx.fb_data) + + print("\nMLX Graph Info:") + for key in ( + "num_constant_tensors", + "num_input_tensors", + "num_output_tensors", + "num_mutable_buffer_tensors", + "num_temp_tensors", + "num_values", + "num_instruction_chains", + ): + label = f" {key + ':':<29}" + print(f"{label}{graph_info.get(key, '?')}") + + main_idx = graph_info.get("main_chain_idx", 0) + chains = graph_info.get("instruction_chains", []) + main_num = "?" + if main_idx < len(chains): + main_num = len(chains[main_idx].get("instructions", [])) + print(f" {'main_chain_idx:':<29}{main_idx} ({main_num} instructions)") + print(f" {'init_chain_idx:':<29}{graph_info.get('init_chain_idx', '?')}") + + print("\nI/O Maps:") + print( + f" {'input_map length:':<29}{graph_info.get('input_map_length', '?')}" + ) + print( + f" {'output_map length:':<29}{graph_info.get('output_map_length', '?')}" + ) + print( + f" {'mutable_buffer_map length:':<29}{graph_info.get('mutable_buffer_map_length', '?')}" + ) + + input_len = graph_info.get("input_map_length", 0) + mutable_len = graph_info.get("mutable_buffer_map_length", 0) + if input_len and mutable_len is not None: + print( + f" => regular inputs expected: {input_len - mutable_len} (input_map - mutable_buffer_map)" + ) + + _print_slot_map("Input Map Details", graph_info.get("input_map", [])) + if graph_info.get("mutable_buffer_map"): + _print_slot_map( + "Mutable Buffer Map Details", + graph_info["mutable_buffer_map"], + ) + _print_slot_map("Output Map Details", graph_info.get("output_map", [])) + + if mlx.header.data_segment_size > 0: + print(f"\n Constant data size: {mlx.header.data_segment_size:,} bytes") + + print(f"\n{'='*70}\n") + + except Exception as e: + print(f"Error showing MLX summary: {e}", file=sys.stderr) + traceback.print_exc() + + +def show_mlx_instructions(pte_data: bytes) -> None: # noqa: C901 + try: + mlx_delegates = _find_mlx_delegates(pte_data) + if not mlx_delegates: + print("No MLX delegates found in this PTE file.", file=sys.stderr) + sys.exit(1) + + if len(mlx_delegates) > 1: + print( + f"Found {len(mlx_delegates)} MLX delegate(s) in PTE file\n", + file=sys.stderr, + ) + + for idx, (delegate_idx, _delegate) in enumerate(mlx_delegates): + try: + mlx = _load_mlx_payload(pte_data, delegate_index=idx) + except ValueError as e: + print(f"\nError: {e}", file=sys.stderr) + continue + + graph = parse_mlx_flatbuffer(mlx.fb_data) + if "error" in graph: + print( + f"\nError parsing delegate {idx}: {graph['error']}", + file=sys.stderr, + ) + continue + + # Print delegate header + if len(mlx_delegates) > 1: + print("\n" + "=" * 70) + print(f"MLX DELEGATE {idx} (plan index {delegate_idx})") + print("=" * 70) + else: + print("\n" + "=" * 70) + print("MLX Graph Summary") + print("=" * 70) + + # Basic info + print(f"Version: {graph.get('version', 'unknown')}") + print(f"Constant tensors: {graph.get('num_constant_tensors', 0)}") + print(f"Input tensors: {graph.get('num_input_tensors', 0)}") + print(f"Output tensors: {graph.get('num_output_tensors', 0)}") + print( + f"Mutable buffer tensors: {graph.get('num_mutable_buffer_tensors', 0)}" + ) + print(f"Temp tensors: {graph.get('num_temp_tensors', 0)}") + print(f"Values: {graph.get('num_values', 0)}") + num_chains = graph.get("num_instruction_chains", 0) + main_idx = graph.get("main_chain_idx", 0) + init_idx = graph.get("init_chain_idx", -1) + print(f"Instruction chains: {num_chains}") + print(f"Main chain idx: {main_idx}") + if init_idx >= 0: + print(f"Init chain idx: {init_idx}") + + constant_seg = graph.get("constant_segment", {}) + if constant_seg: + print(f"Constant data: {constant_seg.get('size', 0):,} bytes") + + # Instruction chains + for chain_info in graph.get("instruction_chains", []): + chain_idx = chain_info.get("chain_index", "?") + label = "" + if chain_idx == main_idx: + label = " (main)" + elif chain_idx == init_idx: + label = " (init)" + instructions = chain_info.get("instructions", []) + print(f"\nChain {chain_idx}{label} ({len(instructions)} instructions):") + for instr in instructions: + op_name = instr.get("op_name", f"op_{instr.get('op_type', '?')}") + print(f" [{instr.get('instr_idx', '?')}] {op_name}") + + for key, value in instr.items(): + if key in ("instr_idx", "op_type", "op_name"): + continue + if isinstance(value, dict): + if "tid" in value: + print(f" {key}: tid {value['tid']}") + elif "vid" in value: + print(f" {key}: vid {value['vid']}") + else: + print(f" {key}: {value}") + elif value is not None: + print(f" {key}: {value}") + + # Named slots + named_slots = graph.get("named_slots", []) + if named_slots: + print("\nNamed Slots:") + for slot in named_slots: + slot_type = _slot_type_display( + slot.get("slot_type", 0), style="short" + ) + print( + f" [{slot.get('slot_idx', '?')}] {slot.get('name', '?')} ({slot_type})" + ) + + # Input/Output maps + input_map = graph.get("input_map", []) + output_map = graph.get("output_map", []) + + if input_map: + print("\nInputs:") + for inp in input_map: + slot_type = _slot_type_display( + inp.get("slot_type", 0), style="short" + ) + print(f" {slot_type} {inp.get('idx', '?')}") + + if output_map: + print("\nOutputs:") + for out in output_map: + slot_type = _slot_type_display( + out.get("slot_type", 0), style="short" + ) + print(f" {slot_type} {out.get('idx', '?')}") + + print("=" * 70 + "\n") + + except Exception as e: + print(f"Error showing MLX instructions: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + +def main(): # noqa: C901 + parser = argparse.ArgumentParser( + description="Inspect ExecuTorch .pte files and extract data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +MLX-Specific Options: + --mlx-summary Show high-level summary (tensor counts, I/O maps) + --mlx-instructions Show detailed instruction list with operation parameters + (use this to verify quantization, inspect ops, etc.) + +Examples: + # Basic PTE file inspection + python -m executorch.backends.mlx.pte_inspector model.pte + + # Show high-level MLX delegate summary + python -m executorch.backends.mlx.pte_inspector model.pte --mlx-summary + + # Show detailed MLX instructions (verify quantization, inspect operations) + python -m executorch.backends.mlx.pte_inspector model.pte --mlx-instructions + + # Extract raw delegate payload to binary file + python -m executorch.backends.mlx.pte_inspector model.pte \\ + --extract-delegate MLXBackend -o delegate.bin + """, + ) + parser.add_argument("pte_file", type=Path, help="Path to the .pte file") + parser.add_argument( + "--output", "-o", type=Path, help="Output file (default: stdout)" + ) + parser.add_argument( + "--extract-delegate", + type=str, + metavar="ID", + help="Extract delegate payload by ID (e.g., 'mlx')", + ) + parser.add_argument( + "--delegate-index", + type=int, + default=None, + metavar="N", + help="Index of delegate to extract (0-based). If not specified, extracts first matching delegate.", + ) + parser.add_argument( + "--parse-mlx", + action="store_true", + help="Parse extracted MLX payload (use with --extract-delegate mlx)", + ) + parser.add_argument( + "--mlx-summary", + action="store_true", + help="Show summary of all MLX delegates (input/output/mutable buffer counts)", + ) + parser.add_argument( + "--mlx-instructions", + action="store_true", + help="Show detailed MLX instruction list with operands and quantization details", + ) + parser.add_argument( + "--format", + choices=["json", "summary"], + default="json", + help="Output format (default: json)", + ) + parser.add_argument( + "--indent", + type=int, + default=2, + help="JSON indentation (default: 2)", + ) + + args = parser.parse_args() + + if not args.pte_file.exists(): + print(f"Error: File not found: {args.pte_file}", file=sys.stderr) + sys.exit(1) + + pte_data = args.pte_file.read_bytes() + print(f"Loaded {len(pte_data)} bytes from {args.pte_file}", file=sys.stderr) + + if args.mlx_instructions: + show_mlx_instructions(pte_data) + return + + if args.mlx_summary: + show_mlx_summary(pte_data) + return + + if args.extract_delegate: + payload = extract_delegate_payload( + pte_data, args.extract_delegate, delegate_index=args.delegate_index + ) + if payload is None: + print( + f"Error: Delegate '{args.extract_delegate}' not found", file=sys.stderr + ) + sys.exit(1) + + if args.parse_mlx and args.extract_delegate.lower() == "mlx": + result = parse_mlx_payload(payload) + + output = json.dumps(result, indent=args.indent, default=str) + + if args.output: + args.output.write_text(output) + print(f"Wrote parsed MLX data to {args.output}", file=sys.stderr) + else: + print(output) + else: + if args.output: + args.output.write_bytes(payload) + print(f"Wrote {len(payload)} bytes to {args.output}", file=sys.stderr) + else: + print(f"Delegate payload: {len(payload)} bytes", file=sys.stderr) + if len(payload) >= MLX_HEADER_LENGTH: + header = MLXHeader.from_bytes(payload) + print(f" Magic: {header.magic!r}", file=sys.stderr) + print( + f" Data offset: {header.data_segment_offset}", file=sys.stderr + ) + print(f" Data size: {header.data_segment_size}", file=sys.stderr) + return + + result = parse_executorch_program(pte_data) + result["file_size"] = len(pte_data) + result["file_path"] = str(args.pte_file) + + if args.format == "summary": + print(f"PTE File: {args.pte_file}") + print(f" Size: {len(pte_data):,} bytes") + if "extended_header" in result: + h = result["extended_header"] + print(f" Program size: {h['program_size']:,} bytes") + print(f" Segment offset: {h['segment_base_offset']:,}") + print(f" Segment size: {h['segment_data_size']:,} bytes") + if "delegates" in result: + print(f" Delegates: {len(result['delegates'])}") + for d in result["delegates"]: + print(f" - {d.get('id', 'unknown')}") + else: + output = json.dumps(result, indent=args.indent, default=str) + + if args.output: + args.output.write_text(output) + print(f"Wrote JSON to {args.output}", file=sys.stderr) + else: + print(output) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp new file mode 100644 index 00000000000..38dff189935 --- /dev/null +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -0,0 +1,419 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#include "MLXExecutor.h" +#include "MLXInterpreter.h" +#include "MLXLoader.h" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// Note: We use fully qualified executorch::aten::Tensor because MLXExecutor.h +// defines Tensor as mlx::core::array in the executorch::backends::mlx +// namespace. +using ETTensor = ::executorch::aten::Tensor; +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::Backend; +using ::executorch::runtime::BackendExecutionContext; +using ::executorch::runtime::BackendInitContext; +using ::executorch::runtime::CompileSpec; +using ::executorch::runtime::DelegateHandle; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::FreeableBuffer; +using ::executorch::runtime::Result; +using ::executorch::runtime::Span; + +using ::mlx::core::array; +using ::mlx::core::Dtype; +using ::mlx::core::eval; + +namespace { + +array tensor_to_mlx( + const ETTensor& t, + const std::optional& expected_meta = std::nullopt) { + if (!executorch::runtime::tensor_is_contiguous(t)) { + throw std::runtime_error("tensor_to_mlx: input tensor is not contiguous"); + } + + Dtype dtype = + resolve_dtype(static_cast(t.scalar_type())); + + if (expected_meta.has_value()) { + Dtype expected_dtype = resolve_dtype(expected_meta->scalar_type); + if (dtype != expected_dtype) { + throw std::runtime_error( + std::string("tensor_to_mlx: dtype mismatch - input tensor has ") + + ExecutionState::dtype_str(dtype) + " but model expects " + + ExecutionState::dtype_str(expected_dtype)); + } + } + + ::mlx::core::Shape shape; + for (int i = 0; i < t.dim(); ++i) { + auto dim_size = t.size(i); + if (dim_size > std::numeric_limits::max() || + dim_size < std::numeric_limits::min()) { + throw std::runtime_error( + "tensor_to_mlx: dimension " + std::to_string(i) + " size " + + std::to_string(dim_size) + " exceeds int range"); + } + shape.push_back(static_cast(dim_size)); + } + + // SAFETY: MLX reads this data during async_eval() Metal command encoding, + // which completes before the lock is released. The ET tensor must remain + // valid until async_eval returns. + const void* cptr = t.const_data_ptr(); + if (!cptr) { + throw std::runtime_error("tensor_to_mlx: tensor has null data pointer"); + } + void* data_ptr = const_cast(cptr); + auto deleter = [](void*) {}; + return array(data_ptr, shape, dtype, deleter); +} + +// Build the contiguous + dtype conversion pipeline for an output array. +// Returns a lazy array (not yet evaluated) ready for async_eval. +array prepare_output( + const array& arr, + Dtype expected_dtype, + const ::mlx::core::Stream& stream) { + array result = + ::mlx::core::contiguous(arr, /*allow_col_major=*/false, stream); + if (result.dtype() != expected_dtype) { + result = ::mlx::core::astype(result, expected_dtype, stream); + } + return result; +} + +// Wait for a prepared output array and copy its data to an ET tensor. +// The array must have been submitted via async_eval before calling this. +void write_output(array& arr, ETTensor& out) { + arr.wait(); + + // Resize output tensor if shape doesn't match (dynamic shapes) + const auto& mlx_shape = arr.shape(); + auto out_sizes = out.sizes(); + + bool shape_matches = (mlx_shape.size() == static_cast(out.dim())); + if (shape_matches) { + for (size_t i = 0; i < mlx_shape.size(); ++i) { + if (static_cast(mlx_shape[i]) != + static_cast(out_sizes[i])) { + shape_matches = false; + break; + } + } + } + + if (!shape_matches) { + std::vector new_sizes; + new_sizes.reserve(mlx_shape.size()); + for (auto d : mlx_shape) { + new_sizes.push_back(static_cast(d)); + } + auto err = resize_tensor( + out, + ArrayRef( + new_sizes.data(), new_sizes.size())); + if (err != Error::Ok) { + throw std::runtime_error("write_output: failed to resize output tensor"); + } + } + + size_t mlx_nbytes = arr.nbytes(); + size_t out_nbytes = out.nbytes(); + if (mlx_nbytes != out_nbytes) { + throw std::runtime_error( + "write_output: size mismatch - MLX has " + std::to_string(mlx_nbytes) + + " bytes, output has " + std::to_string(out_nbytes) + " bytes"); + } + + const void* src = arr.data(); + if (!src) { + throw std::runtime_error( + "write_output: arr.data() is null after wait()"); + } + std::memcpy(out.mutable_data_ptr(), src, out_nbytes); +} + +} // namespace + +struct MLXHandle { + MLXProgram program; + ConstantData constants; + MutableBufferData mutable_buffers; + ExecutionState state; // Reusable execution state + Interpreter interpreter; + ::mlx::core::Stream stream; // Dedicated GPU stream for this handle + + // Keep the constant buffers alive for zero-copy constants + // Each FreeableBuffer must outlive the MLX arrays that reference it + std::vector constant_buffers; + + MLXHandle() : stream(::mlx::core::new_stream(::mlx::core::Device::gpu)) {} + ~MLXHandle() = default; + + MLXHandle(const MLXHandle&) = delete; + MLXHandle& operator=(const MLXHandle&) = delete; +}; + +// MLX is not thread-safe: its computation graph is global shared state. +// A global mutex serializes graph construction and command submission +// across all handles. GPU execution and output copies can proceed +// without the lock (see execute() for the async pipeline design). +static std::mutex& mlx_global_mutex() { + static std::mutex m; + return m; +} + +class MLXBackend final : public ::executorch::runtime::BackendInterface { + public: + ~MLXBackend() override = default; + + bool is_available() const override { + return ::mlx::core::metal::is_available(); + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + std::lock_guard lock(mlx_global_mutex()); + auto* handle = + context.get_runtime_allocator()->allocateInstance(); + if (handle == nullptr) { + return Error::MemoryAllocationFailed; + } + + try { + new (handle) MLXHandle(); + + if (!processed || !processed->data() || processed->size() == 0) { + throw std::runtime_error("init: null or empty delegate payload"); + } + + handle->program = loader::load_program( + static_cast(processed->data()), processed->size()); + + // Validate schema version + if (handle->program.version != "1") { + throw std::runtime_error( + "Unsupported MLX schema version '" + handle->program.version + + "' (expected '1'). Rebuild the .pte with a matching SDK version."); + } + + // Load constants from named_data_map + // Constants are stored by name in the .pte file and provided by ET at + // runtime + const runtime::NamedDataMap* named_data_map = + context.get_named_data_map(); + load_constants( + handle->program, + named_data_map, + handle->constants, + handle->constant_buffers); + + // Delegate payload no longer needed after constants are loaded + processed->Free(); + processed = nullptr; + + // Load mutable buffers (e.g., KV cache) + load_mutable_buffers(handle->program, handle->mutable_buffers); + + // Bind execution state (reused across execute() calls) + handle->state.bind( + handle->program, handle->constants, handle->mutable_buffers); + + // Run init chain if present. + // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the + // static_cast cannot produce UINT32_MAX from a -1 sentinel. + if (handle->program.init_chain_idx >= 0) { + handle->interpreter.run_chain( + handle->program, + static_cast(handle->program.init_chain_idx), + handle->state, + handle->stream); + } + + } catch (const std::exception& e) { + ET_LOG(Error, "Failed to load MLX program: %s", e.what()); + handle->~MLXHandle(); + if (processed != nullptr) { + processed->Free(); + } + return Error::InvalidProgram; + } + + return handle; + } + + Error execute( + ET_UNUSED BackendExecutionContext& context, + DelegateHandle* handle, + Span args) const override { + try { + std::vector prepared_outputs; + struct OutputInfo { + size_t arg_idx; + size_t prepared_idx; + }; + + std::vector tensor_output_info; + size_t arg_idx = 0; + + auto* h = static_cast(handle); + const auto& program = h->program; + + // Graph construction + async GPU dispatch (locked) + { + std::lock_guard lock(mlx_global_mutex()); + + h->state.reset(); + + const size_t n_inputs = program.input_map.size(); + const size_t n_outputs = program.output_map.size(); + if (n_inputs > SIZE_MAX - n_outputs) { + throw std::runtime_error("execute: input + output count overflow"); + } + const size_t expected_args = n_inputs + n_outputs; + if (args.size() != expected_args) { + ET_LOG( + Error, "Expected %zu args, got %zu", expected_args, args.size()); + return Error::InvalidArgument; + } + + // Bind inputs + for (const auto& slot : program.input_map) { + if (arg_idx >= args.size()) { + throw std::runtime_error( + "execute: arg_idx " + std::to_string(arg_idx) + + " out of bounds (args.size()=" + std::to_string(args.size()) + + ")"); + } + if (slot.slot_type == SlotType::TensorSlot) { + const ETTensor& tensor = args[arg_idx++]->toTensor(); + Tid tid{slot.idx}; + std::optional expected_meta = std::nullopt; + if (tid.idx < program.tensor_meta.size()) { + expected_meta = program.tensor_meta[tid.idx]; + } + h->state.set_tensor(tid, tensor_to_mlx(tensor, expected_meta)); + } else if (slot.slot_type == SlotType::IntValueSlot) { + int64_t val = args[arg_idx]->toInt(); + arg_idx++; + if (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) { + ET_LOG( + Error, + "Int input value %lld exceeds int32 range", + static_cast(val)); + return Error::InvalidArgument; + } + h->state.set_value(Vid{slot.idx}, static_cast(val)); + } else { + throw std::runtime_error( + "Unhandled input slot type: " + + std::to_string(static_cast(slot.slot_type))); + } + } + + // Run the MLX program (builds lazy computation graph) + h->interpreter.run(program, h->state, h->stream); + + // Prepare output pipeline and collect int outputs + // Build contiguous + dtype conversion lazily for each tensor output, + // and extract int outputs (which don't need GPU) while still locked. + prepared_outputs.reserve(program.num_output_tensors); + + for (const auto& slot : program.output_map) { + if (slot.slot_type == SlotType::TensorSlot) { + ETTensor& out_tensor = args[arg_idx]->toTensor(); + Dtype expected_dtype = + resolve_dtype(static_cast( + out_tensor.scalar_type())); + array out_arr = prepare_output( + h->state.const_tensor_ref(Tid{slot.idx}), + expected_dtype, + h->stream); + tensor_output_info.push_back({arg_idx, prepared_outputs.size()}); + prepared_outputs.push_back(std::move(out_arr)); + arg_idx++; + } else if (slot.slot_type == SlotType::IntValueSlot) { + Vid vid{slot.idx}; + int64_t int_val = + static_cast(h->state.const_value_ref(vid)); + *args[arg_idx] = EValue(int_val); + arg_idx++; + } else { + throw std::runtime_error( + "Unhandled output slot type: " + + std::to_string(static_cast(slot.slot_type))); + } + } + + // Submit all output work to GPU asynchronously + // async_eval encodes Metal commands and returns immediately. + // The GPU will signal events on completion. + if (!prepared_outputs.empty()) { + ::mlx::core::async_eval(prepared_outputs); + } + + } // Lock released — GPU is still executing + + for (auto& info : tensor_output_info) { + ETTensor& out_tensor = args[info.arg_idx]->toTensor(); + + // write_output waits on arr to be ready + write_output(prepared_outputs[info.prepared_idx], out_tensor); + } + + h->state.reset(); // Release temp GPU buffers back to MLX cache + + return Error::Ok; + } catch (const std::exception& e) { + ET_LOG(Error, "MLX execute failed: %s", e.what()); + return Error::Internal; + } + } + + void destroy(DelegateHandle* handle) const override { + std::lock_guard lock(mlx_global_mutex()); + if (handle != nullptr) { + auto* mlx_handle = static_cast(handle); + mlx_handle->~MLXHandle(); + } + } +}; + +namespace { +auto cls = MLXBackend(); +Backend backend{"MLXBackend", &cls}; +static auto success_with_compiler = register_backend(backend); +} // namespace + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXExecutor.h b/backends/mlx/runtime/MLXExecutor.h new file mode 100644 index 00000000000..32d623790ab --- /dev/null +++ b/backends/mlx/runtime/MLXExecutor.h @@ -0,0 +1,878 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#pragma once + +#include "MLXLoader.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================= +// Op Logging - compile-time gate + runtime env var check +// +// Compile flag (CMake: -DET_MLX_ENABLE_OP_LOGGING=1) controls whether logging +// code is compiled in at all. When off, all logging is stripped (zero +// overhead). When on, the env var ET_MLX_ENABLE_OP_LOGGING=1 must also be set +// at runtime to actually produce output. +// ============================================================================= +#ifndef ET_MLX_ENABLE_OP_LOGGING +#define ET_MLX_ENABLE_OP_LOGGING 0 +#endif + +// ============================================================================= +// Constant Zero-Copy - Enable via CMake: -DET_MLX_ENABLE_CONSTANT_ZERO_COPY=1 +// When enabled, attempts to load model constants (weights) using zero-copy +// on Apple Silicon's unified memory. Falls back to copying if zero-copy fails. +// Disable if you want predictable memory usage (always copies). +// ============================================================================= +#ifndef ET_MLX_ENABLE_CONSTANT_ZERO_COPY +#define ET_MLX_ENABLE_CONSTANT_ZERO_COPY 1 // Enabled by default +#endif + +namespace executorch { +namespace backends { +namespace mlx { + +/// Multiply two unsigned values, throw on overflow. +template +inline T safe_mul(T a, T b, const char* context) { + static_assert(std::is_unsigned::value, "safe_mul requires unsigned type"); + T result; + if (__builtin_mul_overflow(a, b, &result)) { + throw std::runtime_error(std::string(context) + ": unsigned mul overflow"); + } + return result; +} + +// Runtime check for op logging (only callable when compiled in) +#if ET_MLX_ENABLE_OP_LOGGING +inline bool isOpLoggingEnabled() { + static const bool enabled = []() { + const char* val = std::getenv("ET_MLX_ENABLE_OP_LOGGING"); + return val != nullptr && std::string(val) == "1"; + }(); + return enabled; +} +#else +constexpr bool isOpLoggingEnabled() { + return false; +} +#endif + +// Compile-time constant zero-copy flag +constexpr bool kEnableConstantZeroCopy = ET_MLX_ENABLE_CONSTANT_ZERO_COPY; + +using Tensor = ::mlx::core::array; +using Value = std::variant; +using StreamOrDevice = ::mlx::core::StreamOrDevice; + +struct ConstantData { + std::vector tensors; + + inline const Tensor& get(Tid id) const { + if (id.idx >= tensors.size()) { + throw std::out_of_range("ConstantData::get: id out of range"); + } + return tensors[id.idx]; + } + + inline void add(Tensor t) { + tensors.push_back(std::move(t)); + } +}; + +struct MutableBufferData { + // Maps tensor slot idx to MLX array + // Using vector of optional since mlx::array has no default constructor + std::vector> tensors; + + inline void resize(size_t n) { + tensors.resize(n, std::nullopt); + } + + inline bool has(Tid id) const { + return id.idx < tensors.size() && tensors[id.idx].has_value(); + } + + inline Tensor& get(Tid id) { + if (id.idx >= tensors.size() || !tensors[id.idx].has_value()) { + throw std::out_of_range("MutableBufferData::get: id not found or unset"); + } + return *tensors[id.idx]; + } + + inline const Tensor& get(Tid id) const { + if (id.idx >= tensors.size() || !tensors[id.idx].has_value()) { + throw std::out_of_range("MutableBufferData::get: id not found or unset"); + } + return *tensors[id.idx]; + } + + inline void set(Tid id, Tensor t) { + if (id.idx >= tensors.size()) { + throw std::out_of_range("MutableBufferData::set: id out of range"); + } + tensors[id.idx] = std::move(t); + } + + inline void clear() { + tensors.clear(); + } +}; + +struct ExecutionState { + const MLXProgram* program{nullptr}; + const ConstantData* constants{nullptr}; // Shared, read-only + MutableBufferData* mutable_buffers{nullptr}; // Per-handle, persistent + + // Per-execution tensors: inputs, outputs, temps (NOT constants or mutable + // buffers) + std::vector> tensors; + + // Non-constant values (SymInt, etc.) + std::vector> values; + + // Logging context + size_t current_op_idx{0}; + const char* current_op_name{nullptr}; + + // Tensor ID range boundaries for O(1) type lookup (computed at bind time) + uint32_t num_constants{0}; + uint32_t input_end{0}; + uint32_t output_end{0}; + uint32_t mutable_buffer_end{0}; + + void bind( + const MLXProgram& prog, + const ConstantData& const_data, + MutableBufferData& mut_bufs) { + program = &prog; + constants = &const_data; + mutable_buffers = &mut_bufs; + + // Allocate space for inputs, outputs, and temps only (not constants or + // mutable buffers) + uint64_t num_per_execution_tensors = + static_cast(prog.num_input_tensors) + + prog.num_output_tensors + prog.num_temp_tensors; + if (num_per_execution_tensors > 1'000'000) { + throw std::runtime_error( + "bind: num_per_execution_tensors " + + std::to_string(num_per_execution_tensors) + " exceeds limit"); + } + tensors.assign( + static_cast(num_per_execution_tensors), std::nullopt); + if (prog.num_values > 1'000'000) { + throw std::runtime_error( + "bind: num_values " + std::to_string(prog.num_values) + + " exceeds limit"); + } + values.assign(prog.num_values, std::nullopt); + + // Compute tensor ID range boundaries for fast type lookup + // ID assignment order: Constant -> Input -> Output -> MutableBuffer -> Temp + num_constants = prog.num_constant_tensors; + uint64_t ie = static_cast(num_constants) + prog.num_input_tensors; + uint64_t oe = ie + prog.num_output_tensors; + uint64_t me = oe + prog.num_mutable_buffer_tensors; + if (me > std::numeric_limits::max()) { + throw std::runtime_error("bind: tensor ID range overflow"); + } + input_end = static_cast(ie); + output_end = static_cast(oe); + mutable_buffer_end = static_cast(me); + } + + // Check if a tensor ID is a mutable buffer + inline bool is_mutable_buffer(Tid id) const { + return id.idx >= output_end && id.idx < mutable_buffer_end; + } + + // Convert tensor ID to index in the tensors vector + // Accounts for constants and mutable buffers not being in the vector + inline uint32_t tensor_index(Tid id) const { + if (id.idx < num_constants) { + throw std::runtime_error( + "tensor_index: called with constant tensor id " + + std::to_string(id.idx)); + } + if (is_mutable_buffer(id)) { + throw std::runtime_error( + "tensor_index: called with mutable buffer tensor id " + + std::to_string(id.idx)); + } + uint32_t idx = id.idx - num_constants; + // If this ID is after mutable buffer range, subtract mutable buffer count + if (id.idx >= mutable_buffer_end) { + if (idx < program->num_mutable_buffer_tensors) { + throw std::runtime_error( + "tensor_index: underflow for tensor id " + std::to_string(id.idx)); + } + idx -= program->num_mutable_buffer_tensors; + } + if (idx >= tensors.size()) { + throw std::out_of_range( + "tensor_index: computed index " + std::to_string(idx) + + " out of range (size=" + std::to_string(tensors.size()) + + ") for tensor id " + std::to_string(id.idx)); + } + return idx; + } + + void reset() { + // Clear per-execution tensors (inputs, outputs, temps) + // Constants and mutable buffers are not in this vector + for (auto& t : tensors) { + t = std::nullopt; + } + for (auto& v : values) { + v = std::nullopt; + } + } + + static inline const char* dtype_str(::mlx::core::Dtype dtype) { + using namespace ::mlx::core; + switch (dtype.val()) { + case float32.val(): + return "f32"; + case float16.val(): + return "f16"; + case bfloat16.val(): + return "bf16"; + case int32.val(): + return "i32"; + case int64.val(): + return "i64"; + case int16.val(): + return "i16"; + case int8.val(): + return "i8"; + case uint32.val(): + return "u32"; + case uint8.val(): + return "u8"; + case bool_.val(): + return "bool"; + default: + return "?"; + } + } + + static inline std::string format_tensor_info(const Tensor& t) { + std::ostringstream ss; + ss << dtype_str(t.dtype()); + ss << "("; + const auto& shape = t.shape(); + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) + ss << ","; + ss << shape[i]; + } + ss << ")"; + return ss.str(); + } + + // Compute tensor stats: min, max, mean, nan_count + // Uses MLX ops for GPU-accelerated computation + static inline std::string format_tensor_stats(const Tensor& t) { + using namespace ::mlx::core; + + try { + std::ostringstream ss; + + size_t numel = t.size(); + if (numel == 0) { + ss << "[empty]"; + return ss.str(); + } + + // Cast to float32 for stats computation (handles bf16/fp16/int/bool) + Tensor t_float = astype(t, float32); + + // Use MLX ops for efficient GPU-based stats + Tensor nan_mask = isnan(t_float); + Tensor inf_mask = isinf(t_float); + Tensor nan_count_arr = sum(astype(nan_mask, int32)); + Tensor inf_count_arr = sum(astype(inf_mask, int32)); + + // For min/max/mean, we need to handle NaN/Inf - replace with 0 + Tensor valid_mask = logical_not(logical_or(nan_mask, inf_mask)); + Tensor t_valid = where(valid_mask, t_float, zeros_like(t_float)); + + Tensor min_arr = min(t_valid); + Tensor max_arr = max(t_valid); + Tensor mean_arr = mean(t_valid); + + // Evaluate all at once + eval({nan_count_arr, inf_count_arr, min_arr, max_arr, mean_arr}); + + int nan_count = nan_count_arr.item(); + int inf_count = inf_count_arr.item(); + float min_val = min_arr.item(); + float max_val = max_arr.item(); + float mean_val = mean_arr.item(); + + ss << std::fixed << std::setprecision(4); + ss << "[min=" << min_val << " max=" << max_val << " mean=" << mean_val; + if (nan_count > 0) { + ss << " NaN=" << nan_count; + } + if (inf_count > 0) { + ss << " Inf=" << inf_count; + } + ss << "]"; + return ss.str(); + } catch (const std::exception& e) { + return std::string("[stats error: ") + e.what() + "]"; + } catch (...) { + return "[stats error: unknown]"; + } + } + + // Get tensor type prefix for logging: "c", "i", "o", "b", "t" + inline const char* tensor_type_prefix(Tid id) const { + if (!program) + return "?"; + + uint32_t tid = id.idx; + + // Check each range in order (mutually exclusive ranges) + if (tid < program->num_constant_tensors) + return "c"; // Constant + if (tid < input_end) + return "i"; // User Input + if (tid < output_end) + return "o"; // User Output + if (tid < mutable_buffer_end) + return "b"; // Mutable Buffer + return "t"; // Temp + } + + inline void begin_op(size_t idx, const char* name) { + current_op_idx = idx; + current_op_name = name; + if (isOpLoggingEnabled()) { + std::cout << "[" << idx << "] " << name << std::endl; + } + } + + inline void end_op() { + if (isOpLoggingEnabled()) { + std::cout << "----\n"; + } + } + + inline Tensor& tensor_ref(Tid id) { + if (isOpLoggingEnabled()) { + std::cout << " ref " << tensor_type_prefix(id) << id.idx << std::flush; + } + if (!program) { + throw std::runtime_error("tensor_ref: Program not bound"); + } + if (id.idx >= program->num_tensors()) { + throw std::out_of_range("tensor_ref: id out of range"); + } + if (program->is_constant_tensor(id)) { + throw std::runtime_error("tensor_ref: cannot mutate constant tensor"); + } + // Route to mutable buffers or per-execution tensors + Tensor* t = nullptr; + if (is_mutable_buffer(id)) { + if (!mutable_buffers) { + throw std::runtime_error("tensor_ref: mutable_buffers not bound"); + } + t = &mutable_buffers->get(id); + } else { + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("tensor_ref: tensor idx out of range"); + } + auto& opt = tensors[idx]; + if (!opt) { + throw std::runtime_error( + "tensor_ref: uninitialized tensor idx=" + std::to_string(id.idx)); + } + t = &*opt; + } + if (isOpLoggingEnabled()) { + std::cout << " " << format_tensor_info(*t) << "\n"; + } + return *t; + } + + inline const Tensor& const_tensor_ref(Tid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in " << tensor_type_prefix(id) << id.idx << std::flush; + } + if (!program) { + throw std::runtime_error("const_tensor_ref: Program not bound"); + } + if (id.idx >= program->num_tensors()) { + throw std::out_of_range("const_tensor_ref: id out of range"); + } + + const Tensor* t = nullptr; + if (program->is_constant_tensor(id)) { + // Route to constants + if (!constants) { + throw std::runtime_error("const_tensor_ref: constants not bound"); + } + t = &constants->get(id); + } else if (is_mutable_buffer(id)) { + // Route to mutable buffers + if (!mutable_buffers) { + throw std::runtime_error("const_tensor_ref: mutable_buffers not bound"); + } + t = &mutable_buffers->get(id); + } else { + // Route to per-execution tensors + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("const_tensor_ref: tensor idx out of range"); + } + const auto& opt = tensors[idx]; + if (!opt) { + throw std::runtime_error( + "const_tensor_ref: uninitialized tensor idx=" + + std::to_string(id.idx)); + } + t = &*opt; + } + + if (isOpLoggingEnabled()) { + std::cout << " " << format_tensor_info(*t) << " " + << format_tensor_stats(*t) << "\n"; + } + return *t; + } + + // Set a tensor output + inline void set_tensor(Tid id, Tensor arr) { + if (isOpLoggingEnabled()) { + std::cout << " out " << tensor_type_prefix(id) << id.idx << " " + << format_tensor_info(arr) << " " << format_tensor_stats(arr) + << "\n"; + } + if (!program) { + throw std::runtime_error("set_tensor: Program not bound"); + } + if (id.idx < program->num_constant_tensors) { + throw std::runtime_error("set_tensor: cannot write to constant tensor"); + } + // Route to mutable buffers or per-execution tensors + if (is_mutable_buffer(id)) { + if (!mutable_buffers) { + throw std::runtime_error("set_tensor: mutable_buffers not bound"); + } + mutable_buffers->set(id, std::move(arr)); + } else { + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("set_tensor: tensor idx out of range"); + } + tensors[idx] = std::move(arr); + } + } + + template + inline T& value_ref(Vid id) { + if (isOpLoggingEnabled()) { + std::cout << " ref v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("value_ref: id out of range"); + } + auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "value_ref: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::cout << " " << std::get(*opt) << "\n"; + } + return std::get(*opt); + } + + template + inline const T& const_value_ref(Vid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("const_value_ref: id out of range"); + } + const auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "const_value_ref: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::cout << " " << std::get(*opt) << "\n"; + } + return std::get(*opt); + } + + inline const Value& const_value(Vid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("const_value: id out of range"); + } + const auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "const_value: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::visit([](auto&& arg) { std::cout << " " << arg << "\n"; }, *opt); + } + return *opt; + } + + template + inline void set_value(Vid id, T val) { + if (isOpLoggingEnabled()) { + std::cout << " out v" << id.idx << " " << val << "\n"; + } + if (id.idx >= values.size()) { + throw std::out_of_range("set_value: id out of range"); + } + values[id.idx] = val; + } +}; + +inline ::mlx::core::Dtype resolve_dtype(ScalarType d) { + using namespace ::mlx::core; + switch (d) { + case ScalarType::Half: + return float16; + case ScalarType::Float: + return float32; + case ScalarType::BFloat16: + return bfloat16; + case ScalarType::Int: + return int32; + case ScalarType::Short: + return int16; + case ScalarType::Long: + return int64; + case ScalarType::UInt32: + return uint32; + case ScalarType::Byte: + return uint8; + case ScalarType::Bool: + return bool_; + case ScalarType::Char: + return int8; + default: + throw std::runtime_error( + "Unsupported ScalarType: " + std::to_string(static_cast(d))); + } +} + +inline ::mlx::core::Dtype resolve_dtype(int8_t d) { + return resolve_dtype(static_cast(d)); +} + +// Maximum allocation size for any single tensor created from untrusted data. +// This bounds GPU memory allocation from malformed payloads. +constexpr size_t kMaxAllocationBytes = + static_cast(4) * 1024 * 1024 * 1024; // 4 GB + +/// Validate that a tensor with the given shape and dtype does not exceed +/// kMaxAllocationBytes. Throws std::runtime_error on invalid dimensions +/// or if the total size exceeds the limit. +inline void check_allocation_bounded( + const ::mlx::core::Shape& shape, + ::mlx::core::Dtype dtype, + const char* context) { + size_t elem_size = ::mlx::core::size_of(dtype); + size_t numel = 1; + for (auto d : shape) { + if (d <= 0) { + throw std::runtime_error( + std::string(context) + ": invalid dimension " + std::to_string(d)); + } + numel = safe_mul(numel, static_cast(d), context); + } + size_t total_bytes = safe_mul(numel, elem_size, context); + if (total_bytes > kMaxAllocationBytes) { + throw std::runtime_error( + std::string(context) + ": allocation exceeds 4GB limit"); + } +} + +inline int32_t clamp_to_int32(int64_t val64) { + // INT64_MAX is commonly used as a sentinel for "slice to end". + // Non-sentinel large values are silently clamped, which may change + // slice semantics — but this matches PyTorch behavior. + if (val64 >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } else if ( + val64 <= static_cast(std::numeric_limits::min())) { + return std::numeric_limits::min(); + } + return static_cast(val64); +} + +inline int32_t resolve_int( + const std::variant& v, + const ExecutionState& st) { + if (std::holds_alternative(v)) { + return clamp_to_int32(std::get(v)); + } + return st.const_value_ref(std::get(v)); +} + +inline std::vector resolve_ints( + const std::vector>& v, + const ExecutionState& st) { + std::vector out; + out.reserve(v.size()); + for (const auto& elem : v) { + out.push_back(resolve_int(elem, st)); + } + return out; +} + +inline float resolve_float( + const std::variant& v, + const ExecutionState& st) { + if (std::holds_alternative(v)) { + return static_cast(std::get(v)); + } + // The value may be stored as int32_t (from SymInt computations) or float. + const auto& val = st.const_value(std::get(v)); + return std::visit( + [](auto&& arg) -> float { return static_cast(arg); }, val); +} + +inline ::mlx::core::Shape to_shape( + const std::vector>& dims, + const ExecutionState& st) { + auto resolved = resolve_ints(dims, st); + return ::mlx::core::Shape(resolved.begin(), resolved.end()); +} + +inline ::mlx::core::Shape to_shape(const std::vector& dims) { + return ::mlx::core::Shape(dims.begin(), dims.end()); +} + +// Overload for static shapes (used when loading constants where all dims must +// be literals) +// Convert ShapeDim vector to MLX Shape (for constants and mutable buffers). +// Only static dimensions are allowed — dynamic dims (value == -1) are rejected. +inline ::mlx::core::Shape to_shape(const std::vector& dims) { + ::mlx::core::Shape out; + out.reserve(dims.size()); + for (const auto& d : dims) { + if (d.is_dynamic()) { + throw std::runtime_error( + "to_shape: expected static shape but found dynamic dimension"); + } + out.push_back(d.value); + } + return out; +} + +// Load constants from ExecuTorch's NamedDataMap. +// Constants are stored by name in the .pte file and loaded via the +// named_data_map interface. This allows ExecuTorch to own the constant data and +// enables zero-copy on Apple Silicon unified memory. +// +// Parameters: +// program: The loaded MLXProgram containing tensor metadata and named_slots +// named_data_map: ExecuTorch's interface for accessing named data +// store: Output storage for loaded constant tensors +// constant_buffers: Vector to store FreeableBuffers (must outlive store for +// zero-copy) +inline void load_constants( + const MLXProgram& program, + const runtime::NamedDataMap* named_data_map, + ConstantData& store, + std::vector& constant_buffers) { + using namespace ::mlx::core; + + store.tensors.clear(); + constant_buffers.clear(); + + if (program.num_constant_tensors == 0) { + return; + } + + store.tensors.reserve(program.num_constant_tensors); + constant_buffers.reserve(program.num_constant_tensors); + + // Build tid -> name map for O(1) lookup + std::unordered_map tid_to_name; + tid_to_name.reserve(program.named_slots.size()); + for (const auto& ns : program.named_slots) { + if (ns.slot.slot_type == SlotType::TensorSlot) { + tid_to_name[ns.slot.idx] = &ns.name; + } + } + + // Load each constant tensor by name + for (uint32_t tid = 0; tid < program.num_constant_tensors; ++tid) { + // Get tensor metadata + if (tid >= program.tensor_meta.size() || !program.tensor_meta[tid]) { + throw std::runtime_error( + "load_constants: missing metadata for constant " + + std::to_string(tid)); + } + + // Find the name for this tensor ID + auto it = tid_to_name.find(tid); + const std::string* name = (it != tid_to_name.end()) ? it->second : nullptr; + if (!name) { + throw std::runtime_error( + "load_constants: no name found for constant tensor " + + std::to_string(tid)); + } + + // Get data from named_data_map + if (named_data_map == nullptr) { + throw std::runtime_error( + "load_constants: named_data_map is null but program has constants"); + } + + auto data_result = named_data_map->get_data(name->c_str()); + if (!data_result.ok()) { + throw std::runtime_error( + "load_constants: failed to get data for constant '" + *name + + "': error " + std::to_string(static_cast(data_result.error()))); + } + + // Move the buffer into our storage (keeps it alive for zero-copy) + constant_buffers.push_back(std::move(data_result.get())); + runtime::FreeableBuffer& buffer = constant_buffers.back(); + + const auto& meta = *program.tensor_meta[tid]; + Shape shape = to_shape(meta.shape); + Dtype dtype = resolve_dtype(meta.scalar_type); + + // Create MLX array with zero-copy when enabled. + // SAFETY: Constants are read-only; the program builder ensures no in-place + // ops target constant tensors. The const_cast is required by MLX's array + // constructor but the data will not be mutated + void* data_ptr = const_cast(buffer.data()); + + if constexpr (kEnableConstantZeroCopy) { + // Zero-copy: wrap pointer directly with no-op deleter + // The FreeableBuffer in constant_buffers keeps the data alive + auto deleter = [](void*) { + // Data lifetime managed by FreeableBuffer in + // MLXHandle::constant_buffers + }; + array arr = array(data_ptr, shape, dtype, deleter); + store.add(std::move(arr)); + } else { + // No deleter = MLX copies the data into its own memory + store.add(array(static_cast(data_ptr), shape, dtype)); + } + } + + // Evaluate all constants immediately to prepare Metal buffers + // This trades init time for faster first inference + eval(store.tensors); +} + +inline void load_mutable_buffers( + const MLXProgram& program, + MutableBufferData& store) { + using namespace ::mlx::core; + + store.clear(); + + if (program.mutable_buffer_map.empty()) { + return; + } + + // Pre-size the storage to fit all tensor IDs + // Mutable buffer IDs are in the global tensor ID space + uint32_t max_tid = 0; + for (const auto& slot : program.mutable_buffer_map) { + if (slot.idx > max_tid) { + max_tid = slot.idx; + } + } + if (max_tid >= 1'000'000) { + throw std::runtime_error( + "load_mutable_buffers: max_tid " + std::to_string(max_tid) + + " exceeds limit"); + } + store.resize(max_tid + 1); + + for (const auto& slot : program.mutable_buffer_map) { + if (slot.slot_type != SlotType::TensorSlot) { + throw std::runtime_error( + "load_mutable_buffers: unexpected slot type " + + std::to_string(static_cast(slot.slot_type))); + } + + Tid tid{slot.idx}; + + // Get metadata for this tensor + if (tid.idx >= program.tensor_meta.size()) { + ET_LOG( + Error, + "load_mutable_buffers: tid %u >= tensor_meta.size() %zu", + tid.idx, + program.tensor_meta.size()); + throw std::runtime_error( + "load_mutable_buffers: tensor index out of range for tensor " + + std::to_string(tid.idx)); + } + + if (!program.tensor_meta[tid.idx]) { + ET_LOG( + Error, + "load_mutable_buffers: missing metadata for tensor %u", + tid.idx); + throw std::runtime_error( + "load_mutable_buffers: missing metadata for tensor " + + std::to_string(tid.idx)); + } + + const auto& meta = *program.tensor_meta[tid.idx]; + auto shape = to_shape(meta.shape); + auto dtype = resolve_dtype(meta.scalar_type); + + check_allocation_bounded(shape, dtype, "load_mutable_buffers"); + + // Initialize mutable buffer to zeros + // This matches the typical initialization of KV cache buffers + auto arr = zeros(shape, dtype); + + // Evaluate immediately to allocate in GPU memory + eval(arr); + + store.set(tid, std::move(arr)); + } +} + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h new file mode 100644 index 00000000000..f3b6e9b720f --- /dev/null +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -0,0 +1,169 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#pragma once + +#include "MLXExecutor.h" + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +namespace ops { + +using namespace ::mlx::core; + +/** + * Normalize axis to be in range [0, rank) and validate. + * @param axis The axis value (can be negative) + * @param rank The tensor rank + * @param op_name Name of the operation for error messages + * @return Normalized axis in range [0, rank) + * @throws std::out_of_range if axis is out of range + */ +inline int normalize_axis(int axis, int rank, const char* op_name) { + if (axis < -rank || axis >= rank) { + throw std::out_of_range(std::string(op_name) + ": axis out of range"); + } + if (axis < 0) + axis += rank; + return axis; +} + +/** + * Infers dimensions with -1 in a reshape-like operation. + * + * PyTorch allows -1 in shapes to mean "infer this dimension from total size". + * MLX requires concrete positive integers, so we must resolve -1 values. + * + * @param shape The shape to resolve (may contain -1) + * @param input_size Total number of elements in the input tensor + * @return Resolved shape with all positive integers + * @throws std::runtime_error if shape has multiple -1 or incompatible sizes + */ +inline std::vector infer_shape_with_minus_one( + const std::vector& shape, + size_t input_size) { + std::vector resolved_shape = shape; + int neg_one_idx = -1; + int64_t known_size = 1; // Use int64_t to avoid overflow + + // Find -1 dimension and compute product of known dimensions + for (size_t i = 0; i < resolved_shape.size(); i++) { + if (resolved_shape[i] == -1) { + if (neg_one_idx != -1) { + throw std::runtime_error("infer_shape: only one dimension can be -1"); + } + neg_one_idx = static_cast(i); + } else { + known_size *= static_cast(resolved_shape[i]); + } + } + + // Infer the -1 dimension if present + if (neg_one_idx != -1) { + if (known_size == 0) { + throw std::runtime_error( + "infer_shape: cannot infer -1 dimension when known product is 0"); + } + int64_t input_size_i64 = static_cast(input_size); + if (input_size_i64 % known_size != 0) { + throw std::runtime_error( + "infer_shape: cannot infer dimension - size mismatch"); + } + int64_t inferred_dim = input_size_i64 / known_size; + + // Check that inferred dimension fits in int + if (inferred_dim > std::numeric_limits::max()) { + throw std::runtime_error( + "infer_shape: inferred dimension exceeds int max"); + } + + resolved_shape[static_cast(neg_one_idx)] = + static_cast(inferred_dim); + } + + return resolved_shape; +} + +inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} + +inline void +exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + + st.set_tensor(n.out, std::move(Y)); +} + +} // namespace ops + +class Interpreter { + public: + void run( + const MLXProgram& prog, + ExecutionState& st, + StreamOrDevice stream = {}) const { + run_chain(prog, prog.main_chain_idx, st, stream); + } + + void run_chain( + const MLXProgram& prog, + uint32_t chain_idx, + ExecutionState& st, + StreamOrDevice stream = {}) const { + if (chain_idx >= prog.instruction_chains.size()) { + throw std::runtime_error( + "run_chain: chain_idx " + std::to_string(chain_idx) + + " out of range (num_chains=" + + std::to_string(prog.instruction_chains.size()) + ")"); + } + const auto& chain = prog.instruction_chains[chain_idx]; + size_t idx = 0; + for (const auto& instr : chain) { + st.begin_op(idx, op_name(instr.op)); + dispatch(instr, st, stream); + st.end_op(); + ++idx; + } + } + + private: + void dispatch(const Instruction& instr, ExecutionState& st, StreamOrDevice s) + const { + switch (instr.op) { + case OpCode::NOOP: + ops::exec_noop(std::get(instr.node), st, s); + break; + case OpCode::ADDMM: + ops::exec_addmm(std::get(instr.node), st, s); + break; + default: + throw std::runtime_error( + "Unknown opcode: " + std::to_string(static_cast(instr.op))); + } + } +}; + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/MLXLoader.cpp.tmpl b/backends/mlx/serialization/MLXLoader.cpp.tmpl new file mode 100644 index 00000000000..aa4716d7a4a --- /dev/null +++ b/backends/mlx/serialization/MLXLoader.cpp.tmpl @@ -0,0 +1,324 @@ +// -*- c++ -*- + +#include "MLXLoader.h" + +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { +namespace loader { + +namespace { + +// Header structure for MLX payload +constexpr size_t kHeaderSize = 24; +constexpr uint32_t kMagic = 0x30584C4D; // "MLX0" in little-endian + +struct MLXHeader { + uint32_t padding; + uint32_t magic; + uint64_t data_offset; + uint64_t data_size; +}; +static_assert(sizeof(MLXHeader) == kHeaderSize, "MLXHeader size mismatch"); + +bool parse_header(const void* data, size_t size, MLXHeader& header) { + if (size < kHeaderSize) { + return false; + } + std::memcpy(&header, data, sizeof(MLXHeader)); + if (header.magic != kMagic) { + return false; + } + // Validate data_offset: must be strictly greater than kHeaderSize (so the + // FlatBuffer region is non-empty) and must not exceed the total buffer size. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + return false; + } + return true; +} + +// Helper to convert FlatBuffer vectors to std::vector. +// Caps size to prevent unbounded allocations from malformed payloads. +template +std::vector to_vector(const flatbuffers::Vector* fb_vec) { + if (!fb_vec) { + return {}; + } + constexpr size_t kMaxVectorSize = 1'000'000; + if (fb_vec->size() > kMaxVectorSize) { + throw std::runtime_error( + "FlatBuffer vector size " + std::to_string(fb_vec->size()) + + " exceeds maximum of " + std::to_string(kMaxVectorSize)); + } + return std::vector(fb_vec->begin(), fb_vec->end()); +} + +} // namespace + +// ============================================================================= +// load_instruction - AUTO-GENERATED switch statement +// ============================================================================= + +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr) { + Instruction instr; + + if (!fb_instr || !fb_instr->op()) { + instr.op = OpCode::NOOP; + instr.node = NoopNode{}; + return instr; + } + + auto op_type = fb_instr->op_type(); + + switch (op_type) { +{{LOAD_INSTRUCTION_CASES}} + default: + throw std::runtime_error( + "Unknown op_type in load_instruction: " + + std::to_string(static_cast(op_type)) + + ". The .pte was built with a newer schema than this binary. " + "Rebuild with the latest runtime."); + } + + return instr; +} + +// ============================================================================= +// load_program +// ============================================================================= + +MLXProgram load_program(const void* data, size_t size) { + MLXHeader header; + if (!parse_header(data, size, header)) { + throw std::runtime_error("Invalid MLX header"); + } + + // Defense-in-depth: parse_header already validates this, but guard the + // unsigned subtraction against underflow in case the call site ever changes. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + throw std::runtime_error("data_offset out of range"); + } + const uint8_t* fb_data = static_cast(data) + kHeaderSize; + size_t fb_size = header.data_offset - kHeaderSize; + + flatbuffers::Verifier verifier(fb_data, fb_size); + if (!mlx_delegate::VerifyMLXGraphBuffer(verifier)) { + throw std::runtime_error("Invalid FlatBuffer data"); + } + + const auto* fb_graph = mlx_delegate::GetMLXGraph(fb_data); + if (!fb_graph) { + throw std::runtime_error("Failed to parse MLXGraph"); + } + + MLXProgram program; + + if (fb_graph->version()) { + program.version = fb_graph->version()->str(); + } + + program.num_constant_tensors = fb_graph->num_constant_tensors(); + program.num_input_tensors = fb_graph->num_input_tensors(); + program.num_output_tensors = fb_graph->num_output_tensors(); + program.num_mutable_buffer_tensors = fb_graph->num_mutable_buffer_tensors(); + program.num_temp_tensors = fb_graph->num_temp_tensors(); + program.num_values = fb_graph->num_values(); + + // Cap all counts/collection sizes to prevent unbounded allocations from + // malformed FlatBuffer payloads + constexpr size_t kMaxCollectionSize = 1'000'000; + auto check_collection_size = [](size_t sz, const char* name) { + if (sz > kMaxCollectionSize) { + throw std::runtime_error( + std::string("Malformed program: ") + name + " size " + + std::to_string(sz) + " exceeds maximum of " + + std::to_string(kMaxCollectionSize)); + } + }; + + check_collection_size(program.num_tensors(), "num_tensors()"); + check_collection_size(program.num_values, "num_values"); + + if (fb_graph->instruction_chains()) { + check_collection_size(fb_graph->instruction_chains()->size(), "instruction_chains"); + program.instruction_chains.reserve(fb_graph->instruction_chains()->size()); + for (size_t c = 0; c < fb_graph->instruction_chains()->size(); ++c) { + const auto* fb_chain = fb_graph->instruction_chains()->Get(static_cast(c)); + std::vector chain; + if (fb_chain && fb_chain->instructions()) { + check_collection_size(fb_chain->instructions()->size(), "instructions in chain"); + chain.reserve(fb_chain->instructions()->size()); + for (size_t i = 0; i < fb_chain->instructions()->size(); ++i) { + chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)))); + } + } + program.instruction_chains.push_back(std::move(chain)); + } + } + + program.main_chain_idx = fb_graph->main_chain_idx(); + program.init_chain_idx = fb_graph->init_chain_idx(); + + // Validate chain indices against actual instruction_chains size. + if (program.main_chain_idx >= program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid main_chain_idx " + + std::to_string(program.main_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + if (program.init_chain_idx >= 0 && + static_cast(program.init_chain_idx) >= + program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid init_chain_idx " + + std::to_string(program.init_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + + if (fb_graph->input_map()) { + check_collection_size(fb_graph->input_map()->size(), "input_map"); + for (size_t i = 0; i < fb_graph->input_map()->size(); ++i) { + const auto* slot = fb_graph->input_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "input_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.input_map.push_back(sv); + } + } + + if (fb_graph->output_map()) { + check_collection_size(fb_graph->output_map()->size(), "output_map"); + for (size_t i = 0; i < fb_graph->output_map()->size(); ++i) { + const auto* slot = fb_graph->output_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "output_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.output_map.push_back(sv); + } + } + + if (fb_graph->mutable_buffer_map()) { + check_collection_size(fb_graph->mutable_buffer_map()->size(), "mutable_buffer_map"); + for (size_t i = 0; i < fb_graph->mutable_buffer_map()->size(); ++i) { + const auto* slot = fb_graph->mutable_buffer_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "mutable_buffer_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.mutable_buffer_map.push_back(sv); + } + } + + if (fb_graph->named_slots()) { + check_collection_size(fb_graph->named_slots()->size(), "named_slots"); + for (size_t i = 0; i < fb_graph->named_slots()->size(); ++i) { + const auto* fb_slot = fb_graph->named_slots()->Get(static_cast(i)); + if (!fb_slot || !fb_slot->name()) { + throw std::runtime_error( + "Malformed program: named_slot at index " + std::to_string(i) + + " is null or has null name"); + } + NamedSlot slot; + slot.name = fb_slot->name()->str(); + slot.slot = convert_slot_variant(fb_slot->slot()); + program.named_slots.push_back(std::move(slot)); + } + } + + if (fb_graph->tensor_meta()) { + check_collection_size(fb_graph->tensor_meta()->size(), "tensor_meta"); + for (size_t i = 0; i < fb_graph->tensor_meta()->size(); ++i) { + const auto* fb_meta = fb_graph->tensor_meta()->Get(static_cast(i)); + if (fb_meta) { + TensorMeta meta; + if (fb_meta->shape()) { + // Validate tensor rank against kTensorDimensionLimit to prevent + // stack overflows from unchecked rank + constexpr size_t kTensorDimensionLimit = 16; + if (fb_meta->shape()->size() > kTensorDimensionLimit) { + throw std::runtime_error( + "Tensor at index " + std::to_string(i) + + " has rank " + std::to_string(fb_meta->shape()->size()) + + " exceeding kTensorDimensionLimit (" + + std::to_string(kTensorDimensionLimit) + ")"); + } + for (size_t j = 0; j < fb_meta->shape()->size(); ++j) { + const auto* fb_dim = fb_meta->shape()->Get(static_cast(j)); + if (!fb_dim) { + throw std::runtime_error( + "Null ShapeDim at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + ShapeDim dim; + dim.value = fb_dim->value(); + dim.min_value = fb_dim->min_value(); + dim.max_value = fb_dim->max_value(); + if (dim.value < -1) { + throw std::runtime_error( + "Invalid ShapeDim value " + std::to_string(dim.value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.is_dynamic()) { + if (dim.min_value < 0) { + throw std::runtime_error( + "Invalid ShapeDim min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.max_value != -1 && dim.max_value < dim.min_value) { + throw std::runtime_error( + "ShapeDim max_value " + std::to_string(dim.max_value) + + " < min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + } + meta.shape.push_back(dim); + } + } + auto raw_scalar_type = fb_meta->scalar_type(); + if (raw_scalar_type < 0 || + raw_scalar_type >= + static_cast(ScalarType::NumOptions)) { + throw std::runtime_error( + "Invalid scalar_type " + std::to_string(raw_scalar_type) + + " in tensor_meta at index " + std::to_string(i)); + } + meta.scalar_type = static_cast(raw_scalar_type); + if (fb_meta->dim_order()) { + meta.dim_order = to_vector(fb_meta->dim_order()); + } + program.tensor_meta.push_back(std::move(meta)); + } else { + program.tensor_meta.push_back(std::nullopt); + } + } + } + + return program; +} + +} // namespace loader +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/MLXLoader.h.tmpl b/backends/mlx/serialization/MLXLoader.h.tmpl new file mode 100644 index 00000000000..0930d5e00e1 --- /dev/null +++ b/backends/mlx/serialization/MLXLoader.h.tmpl @@ -0,0 +1,343 @@ +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "schema_generated.h" + +// ExecuTorch scalar type for dtype representation +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// ============================================================================= +// Core types matching the Python side +// ============================================================================= + +struct Tid { + uint32_t idx{}; +}; + +struct Vid { + uint32_t idx{}; +}; + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Import ScalarType from ExecuTorch +using ScalarType = ::executorch::runtime::etensor::ScalarType; + +struct ShapeDim { + int32_t value{-1}; // Static dim (>= 0), or -1 for dynamic + int32_t min_value{0}; // Lower bound (when value == -1) + int32_t max_value{-1}; // Upper bound (-1 = unbounded, when value == -1) + + bool is_dynamic() const { return value < 0; } +}; + +struct TensorMeta { + std::vector shape; + ScalarType scalar_type{ScalarType::Float}; // ET ScalarType + std::vector dim_order; +}; + +// VidOrTid: either a scalar value (Vid) or a tensor (Tid) +struct VidOrTid { + Vid vid{}; + Tid tid{}; + bool is_vid{false}; // false = use tid, true = use vid +}; + +// IntOrVidOrTid: a literal int, a runtime Vid, or a tensor (Tid) +struct IntOrVidOrTid { + int64_t literal{0}; + Vid vid{}; + Tid tid{}; + uint8_t kind{0}; // 0 = literal int, 1 = vid, 2 = tid +}; + +// ============================================================================= +// Op node types (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +{{OP_NODE_STRUCTS}} + +// ============================================================================= +// OpCode enum (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +enum class OpCode : uint8_t { +{{OPCODE_ENUM_VALUES}} +}; + +// OpCode to string conversion (for logging) +inline const char* op_name(OpCode op) { + switch (op) { +{{OP_NAME_CASES}} + } + return "UNKNOWN"; +} + +// ============================================================================= +// NodeVariant for type-erased op storage (AUTO-GENERATED) +// ============================================================================= + +using NodeVariant = std::variant< +{{NODE_VARIANT_TYPES}} +>; + +// ============================================================================= +// Instruction +// ============================================================================= + +struct Instruction { + OpCode op{OpCode::NOOP}; + NodeVariant node; + + template + T& get() { + return std::get(node); + } + + template + const T& get() const { + return std::get(node); + } +}; + +// ============================================================================= +// SlotVariant for I/O mapping +// ============================================================================= + +enum class SlotType : uint8_t { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3, +}; + +struct SlotVariant { + uint32_t idx; + SlotType slot_type; +}; + +// ============================================================================= +// Named slot (name -> slot mapping) +// ============================================================================= + +struct NamedSlot { + std::string name; + SlotVariant slot; +}; + +// ============================================================================= +// MLXProgram - the loaded program ready for execution +// ============================================================================= + +struct MLXProgram { + std::string version; + + // Tensor/value slot counts (in Tid assignment order) + uint32_t num_constant_tensors{0}; + uint32_t num_input_tensors{0}; + uint32_t num_output_tensors{0}; + uint32_t num_mutable_buffer_tensors{0}; + uint32_t num_temp_tensors{0}; + uint32_t num_values{0}; + + // Instruction chains + std::vector> instruction_chains; + uint32_t main_chain_idx{0}; + int32_t init_chain_idx{-1}; // -1 = no init chain + + // I/O mappings + std::vector input_map; + std::vector output_map; + std::vector mutable_buffer_map; + + // Name to slot lookup + std::vector named_slots; + + // Tensor metadata + std::vector> tensor_meta; + + // Helper methods + inline uint64_t num_tensors() const { + return static_cast(num_constant_tensors) + + num_input_tensors + num_output_tensors + + num_mutable_buffer_tensors + num_temp_tensors; + } + + inline bool is_constant_tensor(Tid id) const { + return id.idx < num_constant_tensors; + } + + inline size_t num_inputs() const { + return input_map.size(); + } + + inline size_t num_outputs() const { + return output_map.size(); + } +}; + +// ============================================================================= +// FlatBuffer loading functions +// ============================================================================= + +namespace loader { + +// Convert FlatBuffer SlotType to our SlotType +inline SlotType convert_slot_type(mlx_delegate::SlotType fb_type) { + switch (fb_type) { + case mlx_delegate::SlotType_TensorSlot: + return SlotType::TensorSlot; + case mlx_delegate::SlotType_IntValueSlot: + return SlotType::IntValueSlot; + case mlx_delegate::SlotType_FloatValueSlot: + return SlotType::FloatValueSlot; + case mlx_delegate::SlotType_BoolValueSlot: + return SlotType::BoolValueSlot; + default: + throw std::runtime_error("Unknown SlotType: " + + std::to_string(static_cast(fb_type))); + } +} + +// Convert FlatBuffer Tid +inline Tid convert_tid(const mlx_delegate::Tid* fb_tid) { + if (!fb_tid) { + throw std::runtime_error("Null Tid in FlatBuffer"); + } + return Tid{fb_tid->idx()}; +} + +// Convert FlatBuffer Vid +inline Vid convert_vid(const mlx_delegate::Vid* fb_vid) { + if (!fb_vid) { + throw std::runtime_error("Null Vid in FlatBuffer"); + } + return Vid{fb_vid->idx()}; +} + +// Convert FlatBuffer IntOrVid +inline std::variant convert_int_or_vid( + const mlx_delegate::IntOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("IntOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer FloatOrVid +inline std::variant convert_float_or_vid( + const mlx_delegate::FloatOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null FloatOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("FloatOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer VidOrTid (scalar value or tensor) +inline VidOrTid convert_vid_or_tid( + const mlx_delegate::VidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null VidOrTid in FlatBuffer"); + } + VidOrTid result; + result.is_vid = fb->is_vid(); + if (result.is_vid) { + if (!fb->vid()) { + throw std::runtime_error("VidOrTid has is_vid=true but vid pointer is null"); + } + result.vid = Vid{fb->vid()->idx()}; + } else { + if (!fb->tid()) { + throw std::runtime_error("VidOrTid has is_vid=false but tid pointer is null"); + } + result.tid = Tid{fb->tid()->idx()}; + } + return result; +} + +// Convert FlatBuffer IntOrVidOrTid (literal int, Vid, or Tid) +inline IntOrVidOrTid convert_int_or_vid_or_tid( + const mlx_delegate::IntOrVidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVidOrTid in FlatBuffer"); + } + IntOrVidOrTid result; + result.kind = fb->kind(); + switch (result.kind) { + case 0: // literal int + result.literal = fb->literal(); + break; + case 1: { // Vid + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=1 (Vid) but vid pointer is null"); + } + result.vid = Vid{vid_ptr->idx()}; + break; + } + case 2: { // Tid + const auto* tid_ptr = fb->tid(); + if (!tid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=2 (Tid) but tid pointer is null"); + } + result.tid = Tid{tid_ptr->idx()}; + break; + } + default: + throw std::runtime_error( + "IntOrVidOrTid has invalid kind: " + std::to_string(result.kind)); + } + return result; +} + +// Convert FlatBuffer SlotVariant +inline SlotVariant convert_slot_variant(const mlx_delegate::SlotVariant* fb) { + if (!fb) { + throw std::runtime_error("Null SlotVariant in FlatBuffer"); + } + return SlotVariant{fb->idx(), convert_slot_type(fb->slot_type())}; +} + +// Load an instruction from FlatBuffer +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr); + +// Load the full MLXProgram from FlatBuffer data +MLXProgram load_program(const void* data, size_t size); + +} // namespace loader + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/README.md b/backends/mlx/serialization/README.md new file mode 100644 index 00000000000..f2c022d0c80 --- /dev/null +++ b/backends/mlx/serialization/README.md @@ -0,0 +1,130 @@ +# MLX Delegate Serialization + +This directory contains the serialization code for the MLX delegate, which converts +Python graph representations to FlatBuffer format for execution on Apple Silicon. + +## Single Source of Truth: `schema.fbs` + +The FlatBuffer schema file `schema.fbs` is the **single source of truth** for all +serialization-related code. When you need to add a new op or modify existing types, +edit `schema.fbs` and regenerate all derived files. + +## Code Generator + +The `generate.py` script parses `schema.fbs` and generates: + +| Generated File | Description | +|----------------|-------------| +| `mlx_graph_schema.py` | Python dataclasses for all schema types | +| `_generated_serializers.py` | Python FlatBuffer serialization methods | +| `_generated/` | Python FlatBuffer reader classes (via `flatc`) | +| `../runtime/MLXLoader.h` | C++ structs, OpCode enum, NodeVariant | +| `../runtime/MLXLoader.cpp` | C++ `load_instruction()` switch statement | +| `../runtime/schema_generated.h` | C++ FlatBuffer reader classes (via `flatc`) | + +## Usage + +### Regenerate all files + +From the executorch root directory: + +```bash +python backends/mlx/serialization/generate.py +``` + +Or with explicit flatc path: + +```bash +python backends/mlx/serialization/generate.py --flatc /path/to/flatc +``` + +### Options + +``` +--flatc PATH Path to flatc compiler (default: "flatc") +--skip-flatc Skip running flatc (use existing FlatBuffer bindings) +--dry-run Print what would be generated without writing files +``` + +## File Structure + +``` +serialization/ +├── README.md # This file +├── schema.fbs # SOURCE OF TRUTH - FlatBuffer schema +├── generate.py # Code generator script +├── mlx_graph_schema.py # [GENERATED] Python dataclasses +├── mlx_graph_serialize.py # Main serializer (uses generated code) +├── _generated_serializers.py # [GENERATED] Op serialization methods +└── _generated/ # [GENERATED] FlatBuffer Python bindings + └── mlx_delegate/ + ├── *.py # One file per table/enum + +runtime/ +├── MLXLoader.h # [GENERATED] C++ types and loader decls +├── MLXLoader.cpp # [GENERATED] C++ loader implementation +├── schema_generated.h # [GENERATED] FlatBuffer C++ bindings +├── MLXInterpreter.h # C++ executor (manual) +├── MLXExecutor.h # C++ executor interface (manual) +└── MLXBackend.cpp # ExecuTorch backend integration (manual) +``` + +## Schema Design Notes + +### Field Types + +- `Tid` - Tensor slot identifier (indexes into tensor array) +- `Vid` - Value slot identifier (indexes into values array for scalars) +- `IntOrVid` - Either a literal int64 or a Vid (for dynamic shapes) +- `FloatOrVid` - Either a literal double or a Vid +- `DTypeId` - Data type enum (f16, f32, bf16, i32, etc.) + +### Optional Fields + +FlatBuffer fields without `(required)` are optional. In the generated Python +dataclasses, these become `Optional[T]` with default `None`. + +For optional scalar fields that need a sentinel (to distinguish None from 0), +use the `= null` default: + +```flatbuffers +table MyNode { + value: float = null; // None by default, distinguishes None from 0.0 +} +``` + +This requires FlatBuffers 2.0+ (ExecuTorch uses 24.3.25). The generated Python +dataclass will have `value: Optional[float] = None`. + +## Troubleshooting + +### flatc not found + +Install FlatBuffers or specify the path: + +```bash +# macOS +brew install flatbuffers + +# Or specify path +python generate.py --flatc /usr/local/bin/flatc +``` + +### Import errors after regeneration + +Make sure you're running from the correct environment: + +```bash +conda run -n et-mlx python backends/mlx/serialization/generate.py +``` + +### Generated code doesn't match schema + +Delete all generated files and regenerate: + +```bash +rm -rf backends/mlx/serialization/_generated +rm backends/mlx/serialization/mlx_graph_schema.py +rm backends/mlx/serialization/_generated_serializers.py +python backends/mlx/serialization/generate.py +``` diff --git a/backends/mlx/serialization/__init__.py b/backends/mlx/serialization/__init__.py new file mode 100644 index 00000000000..35a4f0cef8a --- /dev/null +++ b/backends/mlx/serialization/__init__.py @@ -0,0 +1,32 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Serialization utilities for MLX delegate.""" + +from pathlib import Path + + +_schema_py = Path(__file__).parent / "mlx_graph_schema.py" +if not _schema_py.exists(): + raise ImportError( + "MLX delegate generated files not found. " + "Run 'python install_executorch.py' first." + ) + +# Export serialization functions for convenience +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( # noqa: F401, E501 + deserialize_to_json, + parse_header, + serialize_mlx_graph, +) + +__all__ = [ + "deserialize_to_json", + "parse_header", + "serialize_mlx_graph", +] diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py new file mode 100755 index 00000000000..d12743906db --- /dev/null +++ b/backends/mlx/serialization/generate.py @@ -0,0 +1,1437 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +""" +Code generator for MLX delegate. + +This is the SINGLE SOURCE OF TRUTH generator. Edit schema.fbs, then run: + python generate.py + +Generates: +1. FlatBuffer bindings (via flatc): + - _generated/ (Python) + - ../runtime/schema_generated.h (C++) +2. mlx_graph_schema.py (Python dataclasses) +3. _generated_serializers.py (Python serialization code) +4. ../runtime/MLXLoader.h (C++ structs, enums) - PARTIAL +5. ../runtime/MLXLoader.cpp (C++ loader switch) - PARTIAL + +Usage: + python generate.py [--flatc PATH_TO_FLATC] [--skip-flatc] +""" + +from __future__ import annotations + +import argparse +import re +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple + + +SCRIPT_DIR = Path(__file__).parent +SCHEMA_FBS = SCRIPT_DIR / "schema.fbs" +GENERATED_DIR = SCRIPT_DIR / "_generated" +GENERATED_SERIALIZERS = SCRIPT_DIR / "_generated_serializers.py" +GENERATED_SCHEMA_PY = SCRIPT_DIR / "mlx_graph_schema.py" +GENERATED_INSPECTOR = SCRIPT_DIR.parent / "_generated_inspector.py" +RUNTIME_DIR = SCRIPT_DIR.parent / "runtime" +LOADER_H_TMPL = SCRIPT_DIR / "MLXLoader.h.tmpl" +LOADER_CPP_TMPL = SCRIPT_DIR / "MLXLoader.cpp.tmpl" +LOADER_H = RUNTIME_DIR / "MLXLoader.h" +LOADER_CPP = RUNTIME_DIR / "MLXLoader.cpp" + + +@dataclass +class FBSEnum: + name: str + base_type: str # e.g., "byte" + values: List[Tuple[str, Optional[int]]] # (name, explicit_value or None) + + +@dataclass +class FBSField: + name: str + type_str: str + required: bool + default: Optional[str] + + +# FBS integer types (signed and unsigned) +FBS_INTEGER_TYPES = frozenset( + { + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + } +) + +# FBS float types +FBS_FLOAT_TYPES = frozenset({"float", "double"}) + +# All FBS primitive scalar types (numbers + bool) +FBS_SCALAR_TYPES = FBS_INTEGER_TYPES | FBS_FLOAT_TYPES | frozenset({"bool"}) + +# Compound "or" types that wrap a literal + Vid +FBS_COMPOUND_TYPES = frozenset({"IntOrVid", "FloatOrVid", "VidOrTid", "IntOrVidOrTid"}) + +# Python type mapping for FBS primitives +FBS_TO_PYTHON = { + "int8": "int", + "int16": "int", + "int32": "int", + "int64": "int", + "uint8": "int", + "uint16": "int", + "uint32": "int", + "uint64": "int", + "float": "float", + "double": "float", + "bool": "bool", + "string": "str", + "byte": "int", +} + +# C++ type mapping for FBS primitives +FBS_TO_CPP = { + "int8": "int8_t", + "int16": "int16_t", + "int32": "int32_t", + "int64": "int64_t", + "uint8": "uint8_t", + "uint16": "uint16_t", + "uint32": "uint32_t", + "uint64": "uint64_t", + "float": "float", + "double": "double", + "bool": "bool", + "string": "std::string", + "byte": "uint8_t", + "Tid": "Tid", + "Vid": "Vid", + "IntOrVid": "std::variant", + "FloatOrVid": "std::variant", +} + + +def _section_header(comment: str, title: str) -> List[str]: + """Generate a section-header banner for generated output.""" + sep = f"{comment} {'=' * 76}" + return [sep, f"{comment} {title}", sep, ""] + + +def _file_header(comment: str, description: str = "") -> List[str]: + """Generate a standard auto-generated file header. + + Args: + comment: Comment prefix, e.g. '#' for Python or '//' for C++. + description: Optional description appended after the banner. + """ + sep = f"{comment} {'=' * 76}" + lines = [ + f"{comment}", + f"{comment} Copyright (c) Meta Platforms, Inc. and affiliates.", + f"{comment} All rights reserved.", + f"{comment}", + f"{comment} This source code is licensed under the BSD-style license found in the", + f"{comment} LICENSE file in the root directory of this source tree.", + f"{comment}", + sep, + f"{comment} AUTO-GENERATED FILE - DO NOT EDIT MANUALLY", + sep, + f"{comment}", + f"{comment} This file was generated from schema.fbs by the MLX delegate code generator.", + f"{comment}", + f"{comment} Source: backends/mlx/serialization/schema.fbs", + f"{comment} Generator: backends/mlx/serialization/generate.py", + f"{comment}", + f"{comment} To regenerate, run from the executorch root:", + f"{comment} python backends/mlx/serialization/generate.py", + f"{comment}", + sep, + ] + if description: + lines.append(f"{comment}") + lines.append(f"{comment} {description}") + return lines + + +@dataclass +class FBSStruct: + name: str + fields: List[FBSField] + + +@dataclass +class FBSTable: + name: str + fields: List[FBSField] + + +@dataclass +class FBSUnion: + name: str + types: List[str] + + +@dataclass +class FBSSchema: + namespace: str + enums: List[FBSEnum] + structs: List[FBSStruct] + tables: List[FBSTable] + unions: List[FBSUnion] + + def get_op_nodes(self) -> List[FBSTable]: + """Get all tables that are part of the OpNode union.""" + op_union = next((u for u in self.unions if u.name == "OpNode"), None) + if not op_union: + return [] + op_names = set(op_union.types) + return [t for t in self.tables if t.name in op_names] + + +def parse_fbs(fbs_path: Path) -> FBSSchema: + """Parse a FlatBuffer schema file.""" + with open(fbs_path) as f: + content = f.read() + + # Remove comments + content = re.sub(r"//.*$", "", content, flags=re.MULTILINE) + + namespace = "" + enums: List[FBSEnum] = [] + structs: List[FBSStruct] = [] + tables: List[FBSTable] = [] + unions: List[FBSUnion] = [] + + # Parse namespace + ns_match = re.search(r"namespace\s+(\w+)\s*;", content) + if ns_match: + namespace = ns_match.group(1) + + # Parse enums + for match in re.finditer(r"enum\s+(\w+)\s*:\s*(\w+)\s*\{([^}]+)\}", content): + enum_name = match.group(1) + base_type = match.group(2) + body = match.group(3) + values = [] + for val_match in re.finditer(r"(\w+)\s*(?:=\s*(\d+))?", body): + name = val_match.group(1) + explicit_val = int(val_match.group(2)) if val_match.group(2) else None + values.append((name, explicit_val)) + enums.append(FBSEnum(enum_name, base_type, values)) + + # Parse structs + for match in re.finditer(r"struct\s+(\w+)\s*\{([^}]+)\}", content): + struct_name = match.group(1) + body = match.group(2) + fields = _parse_fields(body) + structs.append(FBSStruct(struct_name, fields)) + + # Parse tables + for match in re.finditer(r"table\s+(\w+)\s*\{([^}]*)\}", content): + table_name = match.group(1) + body = match.group(2) + fields = _parse_fields(body) + tables.append(FBSTable(table_name, fields)) + + # Parse unions + for match in re.finditer(r"union\s+(\w+)\s*\{([^}]+)\}", content): + union_name = match.group(1) + body = match.group(2) + types = [t.strip() for t in body.split(",") if t.strip()] + unions.append(FBSUnion(union_name, types)) + + return FBSSchema(namespace, enums, structs, tables, unions) + + +def _parse_fields(body: str) -> List[FBSField]: + """Parse fields from a struct/table body.""" + fields = [] + for line in body.split(";"): + line = line.strip() + if not line: + continue + + # Parse: name: type (attributes) = default + match = re.match( + r"(\w+)\s*:\s*(\[?\w+\]?)\s*(?:\(([^)]*)\))?\s*(?:=\s*([^;]+))?", line + ) + if match: + name = match.group(1) + type_str = match.group(2) + attrs = match.group(3) or "" + default = match.group(4).strip() if match.group(4) else None + required = "required" in attrs + fields.append(FBSField(name, type_str, required, default)) + + return fields + + +# Config for compound type factory methods. +# Maps compound type name -> (primary_field_name, primary_python_type, description) +_COMPOUND_TYPE_CONFIG = { + "IntOrVid": ("literal", "int", "a literal integer"), + "FloatOrVid": ("literal", "float", "a literal float"), + "VidOrTid": ("tid", "Tid", "a tensor reference"), + "IntOrVidOrTid": ("literal", "int", "a literal integer"), +} + + +def _generate_compound_type(table: FBSTable) -> List[str]: # noqa: C901 + """Generate a Python dataclass for a compound type (IntOrVid, etc.) from schema.""" + name = table.name + config = _COMPOUND_TYPE_CONFIG.get(name) + if not config: + raise ValueError(f"No compound type config for '{name}'") + + primary_field, primary_py_type, primary_desc = config + + # Build the docstring from the schema structure + lines = [ + "@dataclass", + f"class {name}:", + ] + + # Docstring: describe the two alternatives + lines.append( + f' """Represents either {primary_desc} or a runtime Vid reference."""' + ) + + # Dataclass fields from the parsed schema + for fld in table.fields: + if fld.default == "false": + default = "False" + elif fld.default == "true": + default = "True" + elif fld.type_str in ("Tid", "Vid"): + default = "None" + elif fld.default is not None: + default = fld.default + elif fld.type_str in FBS_INTEGER_TYPES: + default = "0" + elif fld.type_str in FBS_FLOAT_TYPES: + default = "0.0" + else: + default = "None" + truly_required = default != "None" + py_type = _fbs_type_to_python(fld.type_str, truly_required) + lines.append(f" {fld.name}: {py_type} = {default}") + + # Check if this is a 3-way discriminator (IntOrVidOrTid uses 'kind') + has_kind = any(fld.name == "kind" for fld in table.fields) + has_tid = any(fld.name == "tid" for fld in table.fields) + + # Factory: from_primary (e.g. from_literal, from_tid) + lines.append("") + lines.append(" @classmethod") + lines.append( + f' def from_{primary_field}(cls, value: {primary_py_type}) -> "{name}":' + ) + lines.append(f' """Create a {name} from {primary_desc}."""') + if has_kind: + lines.append(f" return cls({primary_field}=value, kind=0)") + else: + lines.append(f" return cls({primary_field}=value, is_vid=False)") + + # Factory: from_vid + lines.append("") + lines.append(" @classmethod") + lines.append(f' def from_vid(cls, vid: Vid) -> "{name}":') + lines.append(f' """Create a {name} from a Vid reference."""') + if has_kind: + lines.append(" return cls(vid=vid, kind=1)") + else: + lines.append(" return cls(vid=vid, is_vid=True)") + + # Factory: from_tid (only for types with a tid field) + if has_tid: + lines.append("") + lines.append(" @classmethod") + lines.append(f' def from_tid(cls, tid: Tid) -> "{name}":') + lines.append(f' """Create a {name} from a Tid tensor reference."""') + if has_kind: + lines.append(" return cls(tid=tid, kind=2)") + else: + lines.append(" return cls(tid=tid, is_vid=False)") + + lines.append("") + return lines + + +def _generate_dataclass(table: FBSTable) -> List[str]: + """Generate a Python @dataclass from a parsed FBS table. + + Handles field ordering (required/defaulted before optional), skips + _is_set sentinel fields, and emits proper type annotations with defaults. + """ + lines = ["@dataclass", f"class {table.name}:"] + fields = [f for f in table.fields if not f.name.endswith("_is_set")] + if not fields: + lines.append(" pass") + else: + required_fields = [f for f in fields if f.required or f.default is not None] + optional_fields = [f for f in fields if not f.required and f.default is None] + + for fld in required_fields: + py_type = _fbs_type_to_python(fld.type_str, True) + default = _fbs_default_to_python(fld.default, fld.type_str) + if default is not None: + lines.append(f" {fld.name}: {py_type} = {default}") + else: + lines.append(f" {fld.name}: {py_type}") + + for fld in optional_fields: + py_type = _fbs_type_to_python(fld.type_str, fld.required) + lines.append(f" {fld.name}: {py_type} = None") + + lines.extend(["", ""]) + return lines + + +def generate_python_schema(schema: FBSSchema) -> str: # noqa: C901 + """Generate mlx_graph_schema.py from parsed FBS.""" + lines = _file_header("#") + lines.extend( + [ + "", + "from __future__ import annotations", + "", + "from dataclasses import dataclass, field", + "from enum import IntEnum", + "from typing import List, Optional, Union", + "", + "", + *_section_header("#", "Enums"), + ] + ) + + # Generate enums + for enum in schema.enums: + lines.append(f"class {enum.name}(IntEnum):") + val = 0 + for name, explicit_val in enum.values: + if explicit_val is not None: + val = explicit_val + lines.append(f" {name} = {val}") + val += 1 + lines.append("") + lines.append("") + + lines.extend(_section_header("#", "Core types")) + + # Generate structs (Tid, Vid) + for struct in schema.structs: + lines.append("@dataclass") + lines.append(f"class {struct.name}:") + for fld in struct.fields: + py_type = _fbs_type_to_python(fld.type_str, fld.required) + default = _fbs_default_to_python(fld.default, fld.type_str) + if default: + lines.append(f" {fld.name}: {py_type} = {default}") + else: + lines.append(f" {fld.name}: {py_type}") + lines.append("") + lines.append("") + + # Generate compound types (IntOrVid, FloatOrVid, TidOrVid) from schema + for type_name in sorted(FBS_COMPOUND_TYPES): + table = next((t for t in schema.tables if t.name == type_name), None) + if table: + lines.extend(_generate_compound_type(table)) + lines.append("") + + # Generate ShapeDim, SlotVariant, NamedSlot, TensorMeta (but not Instruction/MLXGraph yet - they reference OpNode) + other_tables = ["ShapeDim", "SlotVariant", "NamedSlot", "TensorMeta"] + for table_name in other_tables: + table = next((t for t in schema.tables if t.name == table_name), None) + if table: + lines.extend(_generate_dataclass(table)) + + lines.extend(_section_header("#", "Op nodes")) + + # Generate op node dataclasses + op_nodes = schema.get_op_nodes() + for table in op_nodes: + lines.extend(_generate_dataclass(table)) + + # Generate OpNodeUnion type alias + op_names = [t.name for t in op_nodes] + lines.append("# Union of all op types") + lines.append("OpNodeUnion = Union[") + for name in op_names: + lines.append(f" {name},") + lines.append("]") + lines.append("") + + # Generate Instruction and MLXGraph (these reference OpNode so must come after) + lines.extend( + [ + *_section_header("#", "Container types (reference OpNodeUnion)"), + "@dataclass", + "class Instruction:", + " op: OpNodeUnion", + "", + "", + "@dataclass", + "class InstructionChain:", + " instructions: List[Instruction]", + "", + "", + "@dataclass", + "class MLXGraph:", + " instruction_chains: List[InstructionChain]", + " version: Optional[str] = None", + " num_constant_tensors: int = 0", + " num_input_tensors: int = 0", + " num_output_tensors: int = 0", + " num_mutable_buffer_tensors: int = 0", + " num_temp_tensors: int = 0", + " num_values: int = 0", + " main_chain_idx: int = 0", + " init_chain_idx: int = -1", + " input_map: Optional[List[SlotVariant]] = None", + " output_map: Optional[List[SlotVariant]] = None", + " mutable_buffer_map: Optional[List[SlotVariant]] = None", + " named_slots: Optional[List[NamedSlot]] = None", + " tensor_meta: Optional[List[TensorMeta]] = None", + "", + ] + ) + + return "\n".join(lines) + + +def _fbs_type_to_python(fbs_type: str, required: bool) -> str: + """Convert FBS type to Python type annotation. + + When required=False, the result is wrapped in Optional[…] for all types + (scalars, lists, and reference types alike). + """ + # Handle arrays + if fbs_type.startswith("[") and fbs_type.endswith("]"): + inner = fbs_type[1:-1] + inner_py = _fbs_type_to_python(inner, True) + base = f"List[{inner_py}]" + return base if required else f"Optional[{base}]" + + py_type = FBS_TO_PYTHON.get(fbs_type, fbs_type) + + if not required: + return f"Optional[{py_type}]" + + return py_type + + +def _fbs_default_to_python(default: Optional[str], fbs_type: str) -> Optional[str]: + """Convert FBS default value to Python.""" + if default is None: + return None + + if default == "false": + return "False" + if default == "true": + return "True" + if default == "null": + return "None" + + # Handle enum defaults like 'TensorSlot' + if fbs_type == "SlotType": + return f"SlotType.{default}" + + # Numeric defaults + return default + + +def generate_python_serializers(schema: FBSSchema) -> str: + """Generate _generated_serializers.py from parsed FBS.""" + op_nodes = schema.get_op_nodes() + op_union = next((u for u in schema.unions if u.name == "OpNode"), None) + + header = _file_header( + "#", + "This file contains auto-generated serializer methods for all op types.", + ) + + # Imports and module-level code + op_imports = ",\n".join(f" {t.name}" for t in op_nodes) + lines = [ + *header, + "", + "from __future__ import annotations", + "", + "from typing import List, Tuple, Dict", + "", + "import flatbuffers", + "", + ] + + # Generate op type names dict from union order + lines.append( + "# FlatBuffer union indices: 0 = NONE, then 1-indexed from union order" + ) + lines.append("MLX_OP_TYPE_NAMES = {") + lines.append(' 0: "NONE",') + if op_union: + for i, type_name in enumerate(op_union.types, start=1): + lines.append(f' {i}: "{type_name}",') + lines.append("}") + lines.append("") + + lines.extend( + [ + "from executorch.backends.mlx.serialization.mlx_graph_schema import (", + f"{op_imports},", + " IntOrVid,", + " FloatOrVid,", + " VidOrTid,", + " IntOrVidOrTid,", + " Tid,", + " Vid,", + ")", + "", + "", + "def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int:", + ' """Build a vector of int32."""', + " builder.StartVector(4, len(vec), 4)", + " for v in reversed(vec):", + " builder.PrependInt32(v)", + " return builder.EndVector()", + "", + "", + "class GeneratedOpBuilders:", + ' """Mixin class with auto-generated op builder methods."""', + "", + " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> int:", + ' """Build an IntOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVid as FBIntOrVidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBIntOrVidModule.Start(builder)", + " FBIntOrVidModule.AddLiteral(builder, iov.literal)", + " FBIntOrVidModule.AddIsVid(builder, iov.is_vid)", + " if iov.vid is not None:", + " # Vid is an inline struct - must be added last for proper FlatBuffer layout", + " FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx))", + " return FBIntOrVidModule.End(builder)", + "", + " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> int:", + ' """Build a FloatOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import FloatOrVid as FBFloatOrVidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBFloatOrVidModule.Start(builder)", + " FBFloatOrVidModule.AddLiteral(builder, fov.literal)", + " FBFloatOrVidModule.AddIsVid(builder, fov.is_vid)", + " if fov.vid is not None:", + " FBFloatOrVidModule.AddVid(builder, CreateVid(builder, fov.vid.idx))", + " return FBFloatOrVidModule.End(builder)", + "", + " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> int:", + ' """Build a TidOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import VidOrTid as FBVidOrTidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBVidOrTidModule.Start(builder)", + " FBVidOrTidModule.AddIsVid(builder, vot.is_vid)", + " if vot.tid is not None:", + " FBVidOrTidModule.AddTid(builder, CreateTid(builder, vot.tid.idx))", + " if vot.vid is not None:", + " FBVidOrTidModule.AddVid(builder, CreateVid(builder, vot.vid.idx))", + " return FBVidOrTidModule.End(builder)", + "", + " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> int:", + ' """Build an IntOrVidOrTid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVidOrTid as FBIntOrVidOrTidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBIntOrVidOrTidModule.Start(builder)", + " FBIntOrVidOrTidModule.AddLiteral(builder, ivt.literal)", + " FBIntOrVidOrTidModule.AddKind(builder, ivt.kind)", + " if ivt.tid is not None:", + " FBIntOrVidOrTidModule.AddTid(builder, CreateTid(builder, ivt.tid.idx))", + " if ivt.vid is not None:", + " FBIntOrVidOrTidModule.AddVid(builder, CreateVid(builder, ivt.vid.idx))", + " return FBIntOrVidOrTidModule.End(builder)", + "", + " def _build_int_or_vid_vector(", + " self, builder: flatbuffers.Builder, vec: List[IntOrVid]", + " ) -> int:", + ' """Build a vector of IntOrVid tables."""', + " offsets = []", + " for iov in vec:", + " offsets.append(self._build_int_or_vid(builder, iov))", + " builder.StartVector(4, len(offsets), 4)", + " for off in reversed(offsets):", + " builder.PrependUOffsetTRelative(off)", + " return builder.EndVector()", + "", + " def _build_tid_vector(", + " self, builder: flatbuffers.Builder, vec: List[Tid]", + " ) -> int:", + ' """Build a vector of Tid structs."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + "", + " # For vectors of structs, we need to build the vector differently", + " # Each Tid struct is 4 bytes (uint32), so we manually write them", + " builder.StartVector(4, len(vec), 4)", + " for tid in reversed(vec):", + " builder.Prep(4, 0) # Align for struct", + " builder.PrependUint32(tid.idx)", + " return builder.EndVector()", + "", + ] + ) + + # Generate builder methods for each op + for table in op_nodes: + lines.append(_generate_op_builder_method(table)) + + return "\n".join(lines) + + +def _generate_op_builder_method(table: FBSTable) -> str: + """Generate a _build_XxxNode method for the serializer class.""" + class_name = table.name + fb_module_name = f"FB{class_name}Module" + + lines = [ + f" def _build_{class_name}(", + f" self, builder: flatbuffers.Builder, op: {class_name}", + " ) -> Tuple[int, int]:", + f' """Auto-generated builder for {class_name}."""', + " # Import the MODULE (not class) to access builder functions like Start(), Add*(), End()", + f" from executorch.backends.mlx.serialization._generated.mlx_delegate import {class_name} as {fb_module_name}", + " from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + ] + + # Pre-build any strings or vectors (must be done before Start) + prebuild_lines = [] + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + kind = _get_field_kind(fld, table) + pb = _emit_py_prebuild(kind, fld) + if pb: + prebuild_lines.extend(pb) + + if prebuild_lines: + lines.extend(prebuild_lines) + lines.append("") + + # Start the FlatBuffer table + lines.append(f" {fb_module_name}.Start(builder)") + + # Add each field + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + fb_field_name = _to_pascal_case(fld.name) + kind = _get_field_kind(fld, table) + add_lines = _emit_py_add(kind, fld, fb_module_name, fb_field_name) + if add_lines is None: + raise ValueError( + f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _emit_py_add()." + ) + lines.extend(add_lines) + + # End the FlatBuffer table and return offset + union type + lines.append(f" offset = {fb_module_name}.End(builder)") + lines.append(f" return offset, FBOpNodeModule.OpNode.{class_name}") + lines.append("") + + return "\n".join(lines) + + +# Prebuild emitters: return list of lines or None if no prebuild needed. +# These build offsets/vectors that must be created before FlatBuffer Start(). + +_PY_PREBUILD_VECTOR = { + "list_int": "_build_int_vector(builder, op.{name})", + "list_int_or_vid": "self._build_int_or_vid_vector(builder, op.{name})", + "list_tid": "self._build_tid_vector(builder, op.{name})", +} + +_PY_PREBUILD_OFFSET = { + "str": "builder.CreateString(op.{name})", + "int_or_vid": "self._build_int_or_vid(builder, op.{name})", + "float_or_vid": "self._build_float_or_vid(builder, op.{name})", + "vid_or_tid": "self._build_vid_or_tid(builder, op.{name})", + "int_or_vid_or_tid": "self._build_int_or_vid_or_tid(builder, op.{name})", + "optional_str": "builder.CreateString(op.{name}) if op.{name} is not None else None", +} + + +def _emit_py_prebuild(kind: str, fld: FBSField) -> List[str]: + """Emit prebuild lines for a field kind, or empty list if none needed.""" + n = fld.name + if kind in _PY_PREBUILD_VECTOR: + expr = _PY_PREBUILD_VECTOR[kind].format(name=n) + if fld.required: + return [f" {n}_vec = {expr}"] + else: + return [f" {n}_vec = {expr} if op.{n} is not None else None"] + if kind in _PY_PREBUILD_OFFSET: + suffix = "_off" + expr = _PY_PREBUILD_OFFSET[kind].format(name=n) + return [f" {n}{suffix} = {expr}"] + return [] + + +# Maps struct kinds to their Python Create function name +_PY_STRUCT_CREATOR = {"tid": "CreateTid", "vid": "CreateVid"} + + +def _emit_py_add( + kind: str, fld: FBSField, mod: str, fb_name: str +) -> "List[str] | None": + """Emit Add lines for a field kind, or None if kind is unrecognized.""" + n = fld.name + add = f"{mod}.Add{fb_name}" + + # Required struct via inline Create call + if kind in _PY_STRUCT_CREATOR: + creator = _PY_STRUCT_CREATOR[kind] + return [f" {add}(builder, {creator}(builder, op.{n}.idx))"] + # Scalars (direct value) + if kind in ("int", "float", "bool"): + return [f" {add}(builder, op.{n})"] + # Pre-built offsets (string, compound types) + if kind in ("str", "int_or_vid", "float_or_vid", "vid_or_tid", "int_or_vid_or_tid"): + return [f" {add}(builder, {n}_off)"] + # Pre-built vectors (required vs optional) + if kind in ("list_int", "list_int_or_vid", "list_tid"): + if fld.required: + return [f" {add}(builder, {n}_vec)"] + return [ + f" if {n}_vec is not None:", + f" {add}(builder, {n}_vec)", + ] + # Optional struct via inline Create call + if kind in ("optional_tid", "optional_vid"): + creator = _PY_STRUCT_CREATOR[kind.removeprefix("optional_")] + return [ + f" if op.{n} is not None:", + f" {add}(builder, {creator}(builder, op.{n}.idx))", + ] + # Optional scalars + if kind in ("optional_float", "optional_int"): + return [ + f" if op.{n} is not None:", + f" {add}(builder, op.{n})", + ] + # Optional string offset + if kind == "optional_str": + return [ + f" if {n}_off is not None:", + f" {add}(builder, {n}_off)", + ] + return None + + +def _get_field_kind(fld: FBSField, table: FBSTable) -> str: # noqa: C901 + """Classify a field into a canonical kind string. + + This is the single source of truth for field classification, used by all + generators (Python builder, C++ loader, and inspector via _INSPECTOR_KIND_MAP). + """ + t = fld.type_str + + # Handle arrays + if t.startswith("[") and t.endswith("]"): + inner = t[1:-1] + if inner in FBS_INTEGER_TYPES: + return "list_int" + if inner == "IntOrVid": + return "list_int_or_vid" + if inner == "Tid": + return "list_tid" + raise ValueError( + f"Unrecognized array element type '{inner}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _get_field_kind()." + ) + + # Handle basic types + if t == "Tid": + return "optional_tid" if not fld.required else "tid" + if t == "Vid": + return "optional_vid" if not fld.required else "vid" + if t == "IntOrVid": + return "int_or_vid" + if t == "FloatOrVid": + return "float_or_vid" + if t == "VidOrTid": + return "vid_or_tid" + if t == "IntOrVidOrTid": + return "int_or_vid_or_tid" + if t in FBS_INTEGER_TYPES: + if fld.default == "null": + return "optional_int" + return "int" + if t in FBS_FLOAT_TYPES: + # Check if this is optional (has = null default) + if fld.default == "null": + return "optional_float" + return "float" + if t == "bool": + return "bool" + if t == "string": + return "optional_str" if not fld.required else "str" + + raise ValueError( + f"Unrecognized field type '{t}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _get_field_kind()." + ) + + +def _to_pascal_case(name: str) -> str: + """Convert snake_case to PascalCase.""" + # Handle special cases + if name == "table_": + return "Table_" + parts = name.split("_") + return "".join(p.capitalize() for p in parts) + + +def generate_cpp_loader_h(schema: FBSSchema) -> str: + """Generate MLXLoader.h from parsed FBS using template.""" + op_nodes = schema.get_op_nodes() + + struct_lines = [] + for table in op_nodes: + struct_lines.append(f"struct {table.name} {{") + if not table.fields: + struct_lines.append("};") + else: + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + cpp_type = _fbs_type_to_cpp(fld.type_str, fld.required, table, fld) + struct_lines.append(f" {cpp_type} {fld.name};") + struct_lines.append("};") + struct_lines.append("") + + enum_lines = [] + for table in op_nodes: + enum_lines.append(f" {_table_name_to_opcode(table.name)},") + + name_lines = [] + for table in op_nodes: + op_code = _table_name_to_opcode(table.name) + name_lines.append(f" case OpCode::{op_code}:") + name_lines.append(f' return "{op_code}";') + + variant_lines = [] + for i, table in enumerate(op_nodes): + comma = "," if i < len(op_nodes) - 1 else "" + variant_lines.append(f" {table.name}{comma}") + + # Read template and fill placeholders + header = "\n".join(_file_header("//")) + "\n//\n" + tmpl = LOADER_H_TMPL.read_text() + result = tmpl.replace("{{OP_NODE_STRUCTS}}", "\n".join(struct_lines)) + result = result.replace("{{OPCODE_ENUM_VALUES}}", "\n".join(enum_lines)) + result = result.replace("{{OP_NAME_CASES}}", "\n".join(name_lines)) + result = result.replace("{{NODE_VARIANT_TYPES}}", "\n".join(variant_lines)) + return header + result + + +def _fbs_type_to_cpp( + fbs_type: str, + required: bool, + table: Optional["FBSTable"] = None, + fld: Optional["FBSField"] = None, +) -> str: + """Convert FBS type to C++ type. + + Args: + fbs_type: The FlatBuffer type string + required: Whether the field is required + table: Optional table context for type inference + fld: Optional field context for the current field + + Note: Most scalar types (float, int, etc.) are never optional in C++. + The Python serialization layer is responsible for ensuring scalar fields + have values (using defaults if user doesn't provide them). + Reference types (Tid, Vid) and DTypeId with '= null' default can be optional. + """ + # Handle arrays + if fbs_type.startswith("[") and fbs_type.endswith("]"): + inner = fbs_type[1:-1] + inner_cpp = _fbs_type_to_cpp(inner, True) + return f"std::vector<{inner_cpp}>" + + cpp_type = FBS_TO_CPP.get(fbs_type, fbs_type) + + # Handle optional types + if not required: + if fbs_type == "Tid": + return "std::optional" + if fbs_type == "Vid": + return "std::optional" + if fld is not None and fld.default == "null" and fbs_type in FBS_TO_CPP: + return f"std::optional<{cpp_type}>" + + return cpp_type + + +_OPCODE_OVERRIDES = { + "ARange": "ARANGE", + "AsType": "ASTYPE", + "Conv1D": "CONV1D", + "Conv2D": "CONV2D", + "Conv3D": "CONV3D", + "ConvTranspose1D": "CONV_TRANSPOSE1D", + "ConvTranspose2D": "CONV_TRANSPOSE2D", + "ConvTranspose3D": "CONV_TRANSPOSE3D", +} + + +def _table_name_to_opcode(name: str) -> str: + """Convert table name like 'LinearNode' to opcode like 'LINEAR'. + + Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small + override dict for names whose conventional opcode doesn't follow the + normal camelCase splitting rules (e.g. Conv1D → CONV1D, not CONV1_D). + """ + name = name.removesuffix("Node") + if name in _OPCODE_OVERRIDES: + return _OPCODE_OVERRIDES[name] + s = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name) + s = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", s) + return s.upper() + + +def generate_cpp_loader_cpp(schema: FBSSchema) -> str: + """Generate MLXLoader.cpp from parsed FBS using template.""" + op_nodes = schema.get_op_nodes() + + case_lines = [] + for table in op_nodes: + case_lines.extend(_generate_loader_case(table)) + + # Read template and fill placeholders + header = "\n".join(_file_header("//")) + "\n" + tmpl = LOADER_CPP_TMPL.read_text() + result = tmpl.replace("{{LOAD_INSTRUCTION_CASES}}", "\n".join(case_lines)) + return header + result + + +def _generate_loader_case(table: FBSTable) -> List[str]: + """Generate a switch case for loading an op node.""" + class_name = table.name + op_code = _table_name_to_opcode(class_name) + + lines = [ + f" case mlx_delegate::OpNode_{class_name}: {{", + ] + + if not table.fields: + # NoopNode case + lines.extend( + [ + f" instr.op = OpCode::{op_code};", + f" instr.node = {class_name}{{}};", + " break;", + " }", + "", + ] + ) + return lines + + lines.append(f" auto fb = fb_instr->op_as_{class_name}();") + lines.append(" if (!fb) {{") + lines.append( + ' throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}");' + ) + lines.append(" }}") + lines.append(f" {class_name} node;") + + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + + fb_field_name = fld.name + kind = _get_field_kind(fld, table) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name) + if load_lines is None: + raise ValueError( + f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _emit_cpp_load()." + ) + lines.extend(load_lines) + + lines.extend( + [ + f" instr.op = OpCode::{op_code};", + " instr.node = std::move(node);", + " break;", + " }", + "", + ] + ) + + return lines + + +# Maps kinds to their C++ converter function name +_CPP_CONVERTER = { + "tid": "convert_tid", + "vid": "convert_vid", + "int_or_vid": "convert_int_or_vid", + "float_or_vid": "convert_float_or_vid", + "vid_or_tid": "convert_vid_or_tid", + "int_or_vid_or_tid": "convert_int_or_vid_or_tid", +} + + +def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": + """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" + # Required struct / compound via converter + if kind in _CPP_CONVERTER: + conv = _CPP_CONVERTER[kind] + return [f" node.{name} = {conv}(fb->{fb_name}());"] + # Scalars (direct value) + if kind in ("int", "float", "bool"): + return [f" node.{name} = fb->{fb_name}();"] + # Required string + if kind == "str": + return [f' node.{name} = fb->{fb_name}() ? fb->{fb_name}()->str() : "";'] + # Optional struct / compound via guarded converter + base_kind = kind.removeprefix("optional_") + if kind.startswith("optional_") and base_kind in _CPP_CONVERTER: + conv = _CPP_CONVERTER[base_kind] + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = {conv}(fb->{fb_name}());", + " }", + ] + # Optional scalar (FlatBuffers returns flatbuffers::Optional) + if kind in ("optional_float", "optional_int"): + return [ + f" auto {fb_name}_opt = fb->{fb_name}();", + f" if ({fb_name}_opt.has_value()) {{", + f" node.{name} = {fb_name}_opt.value();", + " }", + ] + # Optional string + if kind == "optional_str": + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = fb->{fb_name}()->str();", + " }", + ] + # Integer/bool vector via to_vector + if kind == "list_int": + return [f" node.{name} = to_vector(fb->{fb_name}());"] + # Int-or-vid vector (indexed access) + if kind == "list_int_or_vid": + return [ + f" if (fb->{fb_name}()) {{", + f" for (size_t i = 0; i < fb->{fb_name}()->size(); ++i) {{", + f" node.{name}.push_back(convert_int_or_vid(fb->{fb_name}()->Get(static_cast(i))));", + " }", + " }", + ] + # Tid vector (range-based iteration) + if kind == "list_tid": + return [ + f" if (fb->{fb_name}()) {{", + f" for (auto fb_tid : *fb->{fb_name}()) {{", + f" node.{name}.push_back(convert_tid(fb_tid));", + " }", + " }", + ] + return None + + +def run_flatc(flatc_path: str = "flatc") -> bool: + """Run flatc to generate Python and C++ bindings.""" + print(f"Running flatc on {SCHEMA_FBS}...") + + # Create output directories + GENERATED_DIR.mkdir(parents=True, exist_ok=True) + + success = True + + # Generate Python bindings + cmd_py = [ + flatc_path, + "--python", + "-o", + str(GENERATED_DIR), + str(SCHEMA_FBS), + ] + try: + result = subprocess.run(cmd_py, capture_output=True, text=True) + if result.returncode != 0: + print(f"flatc (Python) failed: {result.stderr}") + success = False + else: + print(f"Generated FlatBuffer Python bindings in {GENERATED_DIR}") + except FileNotFoundError: + print(f"flatc not found at '{flatc_path}'. Skipping FlatBuffer generation.") + success = False + + # Generate C++ bindings + cmd_cpp = [ + flatc_path, + "--cpp", + "-o", + str(RUNTIME_DIR), + str(SCHEMA_FBS), + ] + try: + result = subprocess.run(cmd_cpp, capture_output=True, text=True) + if result.returncode != 0: + print(f"flatc (C++) failed: {result.stderr}") + success = False + else: + print(f"Generated FlatBuffer C++ bindings in {RUNTIME_DIR}") + except FileNotFoundError: + success = False + + return success + + +_FLATC_IMPORT_PREFIX = "executorch.backends.mlx.serialization._generated." + + +def _fixup_flatc_imports() -> None: + """Rewrite bare ``from mlx_delegate.X`` imports in generated FlatBuffer code. + + ``flatc --python`` emits lazy imports like ``from mlx_delegate.Tid import Tid`` + inside accessor methods. These only resolve if the ``_generated/`` directory is + on ``sys.path``. We rewrite them to fully-qualified imports so no ``sys.path`` + manipulation is needed at runtime. + """ + fb_dir = GENERATED_DIR / "mlx_delegate" + if not fb_dir.exists(): + return + + count = 0 + for py_file in fb_dir.glob("*.py"): + content = py_file.read_text() + if "from mlx_delegate." not in content: + continue + new_content = content.replace( + "from mlx_delegate.", f"from {_FLATC_IMPORT_PREFIX}mlx_delegate." + ) + if new_content != content: + py_file.write_text(new_content) + count += 1 + + if count: + print(f"Fixed bare imports in {count} generated FlatBuffer file(s)") + + +# Mapping from fine-grained field kinds (from _get_field_kind) to inspector +# display kinds. The inspector uses coarser categories: optional/required +# distinctions collapse, and int/float/bool all map to "scalar". +_INSPECTOR_KIND_MAP = { + "tid": "tid", + "optional_tid": "tid", + "vid": "vid", + "optional_vid": "vid", + "int_or_vid": "int_or_vid", + "float_or_vid": "float_or_vid", + "vid_or_tid": "vid_or_tid", + "int_or_vid_or_tid": "int_or_vid_or_tid", + "list_int": "int_list", + "list_int_or_vid": "int_or_vid_list", + "list_tid": "tid_list", + "int": "scalar", + "optional_int": "scalar", + "float": "scalar", + "optional_float": "scalar", + "bool": "scalar", + "str": "string", + "optional_str": "string", +} + + +def generate_inspector(schema: "Schema") -> str: # noqa: F821 + """Generate the inspector field mappings file.""" + lines = _file_header("#") + lines.extend( + [ + "", + '"""', + "Auto-generated inspector field mappings for MLX delegate.", + "", + "This module provides field metadata for each op node type, enabling", + "the pte_inspector to parse FlatBuffer op nodes without manually", + "maintaining field mappings.", + '"""', + "", + "from __future__ import annotations", + "", + "from typing import Dict, List, Tuple", + "", + "", + "# Field kinds and their extractors", + "# Each field is a tuple of (display_name, accessor_name, kind)", + "# where kind is one of: 'tid', 'vid', 'int_or_vid', 'float_or_vid',", + "# 'int_list', 'int_or_vid_list', 'tid_list', 'scalar', 'string'", + "", + "FieldSpec = Tuple[str, str, str] # (display_name, accessor_name, kind)", + "", + "", + "# Mapping from op node name to list of field specs", + "OP_NODE_FIELDS: Dict[str, List[FieldSpec]] = {", + ] + ) + + op_nodes = schema.get_op_nodes() + + for table in op_nodes: + lines.append(f' "{table.name}": [') + for fld in table.fields: + # Skip fields ending in _is_set (legacy pattern) + if fld.name.endswith("_is_set"): + continue + + kind = _get_field_kind(fld, table) + inspector_kind = _INSPECTOR_KIND_MAP.get(kind) + if inspector_kind is None: + raise ValueError( + f"No inspector mapping for field kind '{kind}' " + f"(field '{fld.name}' in table '{table.name}'). " + f"Add a mapping in _INSPECTOR_KIND_MAP." + ) + accessor = _to_pascal_case(fld.name) + lines.append(f' ("{fld.name}", "{accessor}", "{inspector_kind}"),') + lines.append(" ],") + + lines.append("}") + lines.append("") + lines.append("") + + # Add the list of op node names for import generation + lines.append("# List of all op node names (for dynamic imports)") + lines.append("OP_NODE_NAMES: List[str] = [") + for table in op_nodes: + lines.append(f' "{table.name}",') + lines.append("]") + lines.append("") + + return "\n".join(lines) + + +def main(): # noqa: C901 + parser = argparse.ArgumentParser( + description="Generate MLX delegate code from schema.fbs" + ) + parser.add_argument( + "--flatc", + default="flatc", + help="Path to flatc compiler", + ) + parser.add_argument( + "--skip-flatc", + action="store_true", + help="Skip running flatc (use existing generated files)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be generated without writing files", + ) + args = parser.parse_args() + + print(f"Parsing {SCHEMA_FBS}...") + schema = parse_fbs(SCHEMA_FBS) + print( + f" Found {len(schema.enums)} enums, {len(schema.structs)} structs, " + f"{len(schema.tables)} tables, {len(schema.unions)} unions" + ) + print(f" Op nodes: {len(schema.get_op_nodes())}") + + # Run flatc + if not args.skip_flatc: + run_flatc(args.flatc) + _fixup_flatc_imports() + + # Generate all code files + generators = [ + (generate_python_schema, GENERATED_SCHEMA_PY, "mlx_graph_schema.py"), + ( + generate_python_serializers, + GENERATED_SERIALIZERS, + "_generated_serializers.py", + ), + (generate_cpp_loader_h, LOADER_H, "MLXLoader.h"), + (generate_cpp_loader_cpp, LOADER_CPP, "MLXLoader.cpp"), + (generate_inspector, GENERATED_INSPECTOR, "_generated_inspector.py"), + ] + for gen_fn, output_path, label in generators: + print(f"Generating {output_path}...") + content = gen_fn(schema) + if args.dry_run: + print(f"--- {label} (first 50 lines) ---") + print("\n".join(content.split("\n")[:50])) + else: + with open(output_path, "w") as f: + f.write(content) + + # Create __init__.py for _generated package that re-exports from mlx_delegate + init_file = GENERATED_DIR / "__init__.py" + if not args.dry_run: + init_file.parent.mkdir(parents=True, exist_ok=True) + + # Get all the exports from mlx_delegate (tables, enums, structs, and unions) + exports = [] + for table in schema.tables: + exports.append(table.name) + for enum in schema.enums: + exports.append(enum.name) + for struct in schema.structs: + exports.append(struct.name) + for union in schema.unions: + exports.append(union.name) + + # Create __init__.py with re-exports + init_content = """# Auto-generated FlatBuffer bindings +# Re-exports from mlx_delegate namespace for convenient imports + +""" + # Add imports from mlx_delegate + for export in sorted(exports): + init_content += f"from executorch.backends.mlx.serialization._generated.mlx_delegate.{export} import {export}\n" + + init_content += f"\n__all__ = {sorted(exports)!r}\n" + init_file.write_text(init_content) + + print("Done!") + print("") + print("Generated files:") + print(f" - {GENERATED_SCHEMA_PY}") + print(f" - {GENERATED_SERIALIZERS}") + print(f" - {GENERATED_INSPECTOR}") + print(f" - {LOADER_H}") + print(f" - {LOADER_CPP}") + if not args.skip_flatc: + print(f" - {GENERATED_DIR}/ (FlatBuffer Python bindings)") + print(f" - {RUNTIME_DIR}/schema_generated.h (FlatBuffer C++ bindings)") + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/serialization/mlx_graph_serialize.py b/backends/mlx/serialization/mlx_graph_serialize.py new file mode 100644 index 00000000000..db5acc9048f --- /dev/null +++ b/backends/mlx/serialization/mlx_graph_serialize.py @@ -0,0 +1,416 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Serialization utilities for MLX delegate. + +Converts MLXGraph dataclasses to FlatBuffer binary format. + +Constants are NOT embedded in the delegate payload - they are provided by +ExecuTorch via named_data_map at runtime. + +Layout: + [Header: 24 bytes] + - Padding: 4 bytes (zeros) + - Magic: 4 bytes ("MLX0") + - Reserved: 16 bytes (zeros, for future use) + [FlatBuffer payload] +""" + +from __future__ import annotations + +import struct +from typing import Any, List, Tuple + +import flatbuffers + +# Import auto-generated serializers +from executorch.backends.mlx.serialization._generated_serializers import ( + GeneratedOpBuilders, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( # noqa: F401 + FloatOrVid, + Instruction, + IntOrVid, + MLXGraph, + NamedSlot, + OpNodeUnion, + SlotType, + SlotVariant, + TensorMeta, + Tid, + Vid, +) +from executorch.exir._serialize._program import Cord + +HEADER_LENGTH = 24 +MAGIC = b"MLX0" +ALIGNMENT = 16 + + +def _padding_required(offset: int, alignment: int) -> int: + remainder = offset % alignment + return (alignment - remainder) % alignment + + +def _build_tid(builder: flatbuffers.Builder, tid: Tid) -> int: + return tid.idx + + +def _build_vid(builder: flatbuffers.Builder, vid: Vid) -> int: + return vid.idx + + +def _build_int_or_vid(builder: flatbuffers.Builder, iov: IntOrVid) -> int: + # Import the MODULE (not class) to access builder functions + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + IntOrVid as FBIntOrVidModule, + ) + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import ( + CreateVid, + ) + + FBIntOrVidModule.Start(builder) + FBIntOrVidModule.AddLiteral(builder, iov.literal) + FBIntOrVidModule.AddIsVid(builder, iov.is_vid) + if iov.vid is not None: + # Vid is an inline struct - must be added last for proper FlatBuffer layout + FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx)) + return FBIntOrVidModule.End(builder) + + +def _build_string(builder: flatbuffers.Builder, s: str) -> int: + return builder.CreateString(s) + + +def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: + # FlatBuffers vectors must be created before the table that contains them + builder.StartVector(4, len(vec), 4) # elem_size=4, num_elems, alignment + for v in reversed(vec): + builder.PrependInt32(v) + return builder.EndVector() + + +class MLXGraphSerializer(GeneratedOpBuilders): + """ + Serializes MLXGraph to bytes with separate constant data segment. + + Inherits auto-generated op builders from GeneratedOpBuilders mixin. + """ + + def __init__(self, graph: MLXGraph, constant_data: bytes = b""): + self.graph = graph + self.constant_data = constant_data + + def serialize(self) -> bytes: + """ + Serialize the graph to bytes. + + Returns: + Complete serialized payload with header, flatbuffer, and data segment. + """ + # Build FlatBuffer + fb_bytes = self._build_flatbuffer() + + # Calculate offsets + data_segment_offset = HEADER_LENGTH + len(fb_bytes) + padding_len = _padding_required(data_segment_offset, ALIGNMENT) + data_segment_offset += padding_len + data_segment_size = len(self.constant_data) + + # Build header + header = ( + b"\x00\x00\x00\x00" # 4 bytes padding + + MAGIC # 4 bytes magic + + struct.pack(" 0: + result.append(b"\x00" * padding_len) + result.append(self.constant_data) + + return bytes(result) + + def _build_flatbuffer(self) -> bytes: + builder = flatbuffers.Builder(4096) + + # Build all components bottom-up (FlatBuffers requirement) + + # 1. Build instruction chains + chain_offsets = [] + for chain in self.graph.instruction_chains: + instr_offsets = [] + for instr in chain.instructions: + instr_offsets.append(self._build_instruction(builder, instr)) + instr_vec = self._build_offset_vector(builder, instr_offsets) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + InstructionChain as FBInstructionChainModule, + ) + + FBInstructionChainModule.Start(builder) + FBInstructionChainModule.AddInstructions(builder, instr_vec) + chain_offsets.append(FBInstructionChainModule.End(builder)) + + chains_vec = self._build_offset_vector(builder, chain_offsets) + + # 2. Build I/O maps + input_map_vec = self._build_slot_variant_vector(builder, self.graph.input_map) + output_map_vec = self._build_slot_variant_vector(builder, self.graph.output_map) + mutable_buffer_map_vec = self._build_slot_variant_vector( + builder, self.graph.mutable_buffer_map + ) + + # 3. Build named slots + named_slots_offsets = [] + for ns in self.graph.named_slots: + named_slots_offsets.append(self._build_named_slot(builder, ns)) + named_slots_vec = self._build_offset_vector(builder, named_slots_offsets) + + # 4. Build tensor metadata + tensor_meta_offsets = [] + for tm in self.graph.tensor_meta: + if tm is not None: + tensor_meta_offsets.append(self._build_tensor_meta(builder, tm)) + else: + tensor_meta_offsets.append(0) # null + tensor_meta_vec = self._build_offset_vector(builder, tensor_meta_offsets) + + # 5. Build version string (must be created before the table that uses it) + version_off = builder.CreateString(self.graph.version) + + # 6. Build the root MLXGraph table + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + MLXGraph as FBMLXGraphModule, + ) + + FBMLXGraphModule.Start(builder) + FBMLXGraphModule.AddVersion(builder, version_off) + FBMLXGraphModule.AddNumConstantTensors(builder, self.graph.num_constant_tensors) + FBMLXGraphModule.AddNumInputTensors(builder, self.graph.num_input_tensors) + FBMLXGraphModule.AddNumOutputTensors(builder, self.graph.num_output_tensors) + FBMLXGraphModule.AddNumMutableBufferTensors( + builder, self.graph.num_mutable_buffer_tensors + ) + FBMLXGraphModule.AddNumTempTensors(builder, self.graph.num_temp_tensors) + FBMLXGraphModule.AddNumValues(builder, self.graph.num_values) + FBMLXGraphModule.AddInstructionChains(builder, chains_vec) + FBMLXGraphModule.AddMainChainIdx(builder, self.graph.main_chain_idx) + FBMLXGraphModule.AddInitChainIdx(builder, self.graph.init_chain_idx) + FBMLXGraphModule.AddInputMap(builder, input_map_vec) + FBMLXGraphModule.AddOutputMap(builder, output_map_vec) + FBMLXGraphModule.AddMutableBufferMap(builder, mutable_buffer_map_vec) + FBMLXGraphModule.AddNamedSlots(builder, named_slots_vec) + FBMLXGraphModule.AddTensorMeta(builder, tensor_meta_vec) + root = FBMLXGraphModule.End(builder) + + builder.Finish(root) + return bytes(builder.Output()) + + def _build_instruction( + self, builder: flatbuffers.Builder, instr: Instruction + ) -> int: + op_offset, op_type = self._build_op_node(builder, instr.op) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + Instruction as FBInstructionModule, + ) + + FBInstructionModule.Start(builder) + FBInstructionModule.AddOpType(builder, op_type) + FBInstructionModule.AddOp(builder, op_offset) + return FBInstructionModule.End(builder) + + def _build_op_node( + self, builder: flatbuffers.Builder, op: OpNodeUnion + ) -> Tuple[int, int]: + """ + Build an op node and return (offset, union_type). + + This is the main dispatch for all op types. + """ + # Map Python class to FlatBuffer union type and builder + # This would ideally be auto-generated + + op_type = type(op).__name__ + builder_method = getattr(self, f"_build_{op_type}", None) + + if builder_method is None: + raise NotImplementedError(f"No builder for op type: {op_type}") + + return builder_method(builder, op) + + def _build_offset_vector( + self, builder: flatbuffers.Builder, offsets: List[int] + ) -> int: + builder.StartVector(4, len(offsets), 4) + for off in reversed(offsets): + builder.PrependUOffsetTRelative(off) + return builder.EndVector() + + def _build_slot_variant_vector( + self, builder: flatbuffers.Builder, slots: List[SlotVariant] + ) -> int: + offsets = [] + for slot in slots: + offsets.append(self._build_slot_variant(builder, slot)) + return self._build_offset_vector(builder, offsets) + + def _build_slot_variant( + self, builder: flatbuffers.Builder, slot: SlotVariant + ) -> int: + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + SlotVariant as FBSlotVariantModule, + ) + + FBSlotVariantModule.Start(builder) + FBSlotVariantModule.AddIdx(builder, slot.idx) + FBSlotVariantModule.AddSlotType(builder, slot.slot_type) + return FBSlotVariantModule.End(builder) + + def _build_named_slot(self, builder: flatbuffers.Builder, ns: NamedSlot) -> int: + name_off = builder.CreateString(ns.name) + slot_off = self._build_slot_variant(builder, ns.slot) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + NamedSlot as FBNamedSlotModule, + ) + + FBNamedSlotModule.Start(builder) + FBNamedSlotModule.AddName(builder, name_off) + FBNamedSlotModule.AddSlot(builder, slot_off) + return FBNamedSlotModule.End(builder) + + def _build_tensor_meta(self, builder: flatbuffers.Builder, tm: TensorMeta) -> int: + # Shape is a vector of ShapeDim tables + shape_offsets = [] + for dim in tm.shape: + shape_offsets.append(self._build_shape_dim(builder, dim)) + shape_vec = self._build_offset_vector(builder, shape_offsets) + + # Build dim_order vector (uint8) + dim_order_vec = 0 + if tm.dim_order: + builder.StartVector(1, len(tm.dim_order), 1) # elem_size=1 for uint8 + for d in reversed(tm.dim_order): + builder.PrependUint8(d) + dim_order_vec = builder.EndVector() + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + TensorMeta as FBTensorMetaModule, + ) + + FBTensorMetaModule.Start(builder) + FBTensorMetaModule.AddShape(builder, shape_vec) + if tm.scalar_type is not None: + FBTensorMetaModule.AddScalarType(builder, tm.scalar_type) + if dim_order_vec: + FBTensorMetaModule.AddDimOrder(builder, dim_order_vec) + return FBTensorMetaModule.End(builder) + + def _build_shape_dim(self, builder: flatbuffers.Builder, dim) -> int: + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + ShapeDim as FBShapeDimModule, + ) + + FBShapeDimModule.Start(builder) + FBShapeDimModule.AddValue(builder, dim.value) + FBShapeDimModule.AddMinValue(builder, dim.min_value) + FBShapeDimModule.AddMaxValue(builder, dim.max_value) + return FBShapeDimModule.End(builder) + + +def serialize_mlx_graph(graph: MLXGraph, constant_data: bytes = b"") -> bytes: + """ + Serialize an MLXGraph to bytes. + + Args: + graph: The MLXGraph to serialize. + constant_data: Raw bytes for constant tensors. + + Returns: + Serialized bytes with header, flatbuffer, and data segment. + """ + serializer = MLXGraphSerializer(graph, constant_data) + return serializer.serialize() + + +def parse_header(data: bytes) -> Tuple[int, int, int, int]: + """ + Parse the MLX delegate header. + + Returns: + (flatbuffer_offset, flatbuffer_size, data_segment_offset, data_segment_size) + """ + if len(data) < HEADER_LENGTH: + raise ValueError(f"Data too short: {len(data)} < {HEADER_LENGTH}") + + magic = data[4:8] + if magic != MAGIC: + raise ValueError(f"Invalid magic: {magic!r} (expected {MAGIC!r})") + + data_segment_offset = struct.unpack(" dict: + """ + Deserialize MLX delegate payload to a JSON-compatible dict. + + Useful for debugging - extracts the FlatBuffer and dumps it as JSON. + """ + fb_off, fb_size, ds_off, ds_size = parse_header(data) + + # Extract FlatBuffer portion + fb_data = data[fb_off : fb_off + fb_size] + + # Parse using generated FlatBuffer code + from executorch.backends.mlx.serialization._generated.mlx_delegate.MLXGraph import ( + MLXGraph as FBMLXGraphClass, + ) + + graph = FBMLXGraphClass.GetRootAs(fb_data, 0) + + # Convert to dict (recursive) + result = _fb_to_dict(graph) + result["_constant_segment_size"] = ds_size + + return result + + +def _fb_to_dict(obj: Any) -> Any: + if obj is None: + return None + if isinstance(obj, (int, float, str, bool, bytes)): + return obj + if isinstance(obj, (list, tuple)): + return [_fb_to_dict(item) for item in obj] + + # FlatBuffer object - extract fields + result = {} + for attr in dir(obj): + if attr.startswith("_") or attr[0].islower(): + continue + try: + value = getattr(obj, attr)() + result[attr] = _fb_to_dict(value) + except (TypeError, AttributeError): + pass + + return result diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs new file mode 100644 index 00000000000..945186ebef8 --- /dev/null +++ b/backends/mlx/serialization/schema.fbs @@ -0,0 +1,192 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// FlatBuffer schema for MLX delegate - THIS IS THE SOURCE OF TRUTH +// Defines the IR that gets serialized into the .pte file and executed by MLX runtime +// +// After editing this file, regenerate dependent files with: +// python backends/mlx/serialization/generate.py +// +// BACKWARD COMPATIBILITY RULES: +// - New fields in tables: APPEND ONLY (add at the end, with a default value) +// - New union members: APPEND ONLY (add at the end of the union) +// - New tables: Safe to add freely +// - New enum values: APPEND ONLY +// - NEVER remove, reorder, or change the type of existing fields/members + +namespace mlx_delegate; + +// ============================================================================= +// Core types +// ============================================================================= + +// We use ET's ScalarType (int8) directly. +// See runtime/core/portable_type/scalar_type.h for ScalarType values. + +// Tensor slot identifier - indexes into tensors array +struct Tid { + idx: uint32; +} + +// Value slot identifier - indexes into values array +// Values are stored as variant at runtime +struct Vid { + idx: uint32; +} + +// NOTE: These compound types use tables with manual discriminators rather than +// FlatBuffers unions because IntOrVid is used in vectors ([IntOrVid]), and +// FlatBuffers does not support vectors of unions. + +// For fields that can be either a literal int or a runtime Vid +table IntOrVid { + literal: int64; // widened to int64 for future-proofing + vid: Vid; + is_vid: bool = false; +} + +// For fields that can be either a literal float or a runtime Vid +table FloatOrVid { + literal: double; // widened to double for future-proofing + vid: Vid; + is_vid: bool = false; +} + +// For fields that can be either a tensor (Tid) or a scalar value (Vid) +table VidOrTid { + vid: Vid; + tid: Tid; + is_vid: bool = false; // false = use tid, true = use vid +} + +// For fields that can be a literal int, a runtime Vid, or a tensor (Tid) +table IntOrVidOrTid { + literal: int64; + vid: Vid; + tid: Tid; + kind: uint8 = 0; // 0 = literal int, 1 = vid, 2 = tid +} + +// ============================================================================= +// Op nodes - mirrors ops_schema.py dataclasses +// ============================================================================= + +table NoopNode {} + +table AddmmNode { + mat1: Tid (required); // First matrix + mat2: Tid (required); // Second matrix + out: Tid (required); + bias: Tid; // optional - added to result + alpha: float = 1.0; // Scalar multiplier for mat1 @ mat2 + beta: float = 1.0; // Scalar multiplier for bias +} + +// ============================================================================= +// Union of all op types +// ============================================================================= + +// BC: APPEND ONLY — new op nodes must be added at the end of this union. +// Reordering or removing members changes numeric type IDs and breaks existing .pte files. +union OpNode { + NoopNode, + AddmmNode + // BC: Add new op nodes here (append only) +} + +// ============================================================================= +// Instruction wrapper +// ============================================================================= + +table Instruction { + op: OpNode (required); +} + +// ============================================================================= +// Instruction chain (basic block of sequential instructions) +// ============================================================================= + +table InstructionChain { + instructions: [Instruction] (required); + // BC: New fields must be appended here with a default value +} + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Shape dimension: static value, or dynamic with optional bounds +table ShapeDim { + value: int32 = -1; // Static dim (>= 0), or -1 for dynamic + min_value: int32 = 0; // Lower bound (only when value == -1) + max_value: int32 = -1; // Upper bound (-1 = unbounded, only when value == -1) +} + +table TensorMeta { + shape: [ShapeDim] (required); // Dimension info with static/dynamic distinction + scalar_type: int8; // ET ScalarType value (see runtime/core/portable_type/scalar_type.h) + dim_order: [uint8]; // Memory layout order (matches TensorLayout.dim_order, DimOrderType = uint8_t) +} + +// ============================================================================= +// Slot variant for I/O mapping +// ============================================================================= + +enum SlotType : byte { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3 +} + +table SlotVariant { + idx: uint32; + slot_type: SlotType = TensorSlot; +} + +// ============================================================================= +// Name to slot mapping entry +// ============================================================================= + +table NamedSlot { + name: string (required); + slot: SlotVariant (required); +} + +// ============================================================================= +// Root type: MLX Graph +// ============================================================================= + +// BC: New fields must be appended at the end of this table with a default value. +table MLXGraph { + // Version for compatibility + version: string; + + // Tensor slot counts + + num_constant_tensors: uint32; + num_input_tensors: uint32; + num_output_tensors: uint32; + num_mutable_buffer_tensors: uint32; + num_temp_tensors: uint32; + num_values: uint32; + + // Instruction chains (basic blocks of sequential instructions) + instruction_chains: [InstructionChain] (required); + main_chain_idx: uint32 = 0; // Chain to run every execute() call + init_chain_idx: int32 = -1; // Chain to run once at init(), -1 = none + + // I/O mappings + input_map: [SlotVariant]; + output_map: [SlotVariant]; + mutable_buffer_map: [SlotVariant]; + + // Name to slot lookup (used for constant/mutable buffer keys in named_data_map) + named_slots: [NamedSlot]; + + // Tensor metadata (for non-temp tensors), indexed by Tid + tensor_meta: [TensorMeta]; + + // BC: New fields must be appended here with a default value +} + +root_type MLXGraph; diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt new file mode 100644 index 00000000000..2a709a63412 --- /dev/null +++ b/backends/mlx/test/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# MLX backend tests + +# Strict compiler flags for the test runner — mlxdelegate uses PRIVATE so these +# don't propagate to downstream consumers +set(_mlx_test_compile_options -Wall -Werror -Wconversion -Wsign-conversion + -Wshorten-64-to-32 +) + +# Sanitizers are inherited from parent via EXECUTORCH_MLX_ENABLE_SANITIZERS +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + list(APPEND _mlx_test_compile_options -fsanitize=address,undefined + -fno-omit-frame-pointer + ) +endif() + +# Op test runner - generic test binary for testing individual ops +add_executable(op_test_runner op_test_runner.cpp) + +target_compile_options(op_test_runner PRIVATE ${_mlx_test_compile_options}) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options(op_test_runner PRIVATE ${_mlx_sanitizer_link_options}) +endif() + +target_link_libraries( + op_test_runner PRIVATE extension_module extension_tensor executorch + mlxdelegate +) + +# -------------------------------------------------------------------------- +# Compile-only strict warnings test for delegate headers +# +# Verifies MLXExecutor.h, MLXInterpreter.h, MLXLoader.h compile cleanly under +# -Wconversion -Wsign-conversion -Wshorten-64-to-32 -Werror. ExecuTorch and MLX +# headers are suppressed via pragma in the source file. This target is never +# linked or run — a successful compile is the test. +# -------------------------------------------------------------------------- +add_library(strict_compile_test OBJECT strict_compile_test.cpp) +target_compile_options(strict_compile_test PRIVATE ${_mlx_test_compile_options}) +target_include_directories( + strict_compile_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../runtime +) +target_link_libraries( + strict_compile_test PRIVATE mlx_schema executorch_core mlx +) +add_dependencies(op_test_runner strict_compile_test) diff --git a/backends/mlx/test/README.md b/backends/mlx/test/README.md new file mode 100644 index 00000000000..6d90d513fec --- /dev/null +++ b/backends/mlx/test/README.md @@ -0,0 +1,164 @@ +# MLX Backend Tests + +This directory contains end-to-end tests for the MLX backend. Each test verifies that a specific op or pattern is correctly lowered to MLX and produces matching outputs between PyTorch and the MLX runtime. + +## Setup + +### 1. Install ExecuTorch Python package (if not already installed) + +```bash +python install_executorch.py --editable +``` + +### 2. Configure CMake with MLX preset + +From the ExecuTorch root directory: + +```bash +cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON +``` + +This configures the build with MLX delegate support and test targets. Build files are generated in `cmake-out/`. + +### 3. Build the test runner + +```bash +cmake --build cmake-out --target op_test_runner +``` + +This builds the `op_test_runner` binary that executes `.pte` models using the MLX runtime. + + + +## Prerequisites + +1. **Python environment**: Tests must be run in an environment where the `executorch` Python package is installed +2. **Built C++ runtime**: The `op_test_runner` binary must be built (see Setup above) + +## Running Tests + +### Run All Tests + +To run all registered tests: + +```bash +python -m executorch.backends.mlx.test.run_all_tests -j4 --clean-after +``` + +### Options + +| Flag | Description | +|------|-------------| +| `-j N` / `--parallel N` | Run tests in parallel with N workers | +| `--clean-after` | Clean up generated test files after running | +| `--clean` | Clean up generated test files and exit | +| `--rebuild` | Rebuild the C++ test runner before running | +| `--list` | List available tests and exit | +| `-v` / `--verbose` | Verbose output | +| `--timeout SECS` | Timeout per test in seconds (default: 300) | + +### Memory Management Options + +Running many tests can accumulate memory (torch/MLX/Metal allocations). These flags help manage memory: + +| Flag | Description | +|------|-------------| +| `--isolate` | Run each test in a separate subprocess (sequential mode only). Provides full memory isolation but is slower due to Python/torch import overhead per test. | +| `--max-tasks-per-worker N` | Recycle parallel workers after N tests (parallel mode only). Workers are terminated and replaced after completing N tests, releasing accumulated memory. | + +**Comparison:** + +| Mode | Memory Isolation | Speed | +|------|------------------|-------| +| `-j 4` | None (workers reused) | Fastest | +| `-j 4 --max-tasks-per-worker 10` | Bounded (recycled every 10 tests) | Fast | +| `-j 4 --max-tasks-per-worker 1` | Full (new process per test) | Slower | +| `--isolate` | Full (subprocess per test) | Slowest (sequential) | + +**Recommended for CI with memory constraints:** + +```bash +python -m executorch.backends.mlx.test.run_all_tests -j4 --max-tasks-per-worker 10 --clean-after +``` + +### Run a Specific Test + +To run a specific test by name (e.g., `linear`): + +```bash +python -m executorch.backends.mlx.test.run_all_tests linear +``` + +With verbose output: + +```bash +python -m executorch.backends.mlx.test.run_all_tests -v linear +``` + +### List Available Tests + +```bash +python -m executorch.backends.mlx.test.run_all_tests --list +``` + +## Test Architecture + +All tests are defined in `test_ops.py`. Each test follows a common pattern: + +1. **Define a model** - A simple `nn.Module` that uses the op being tested +2. **Create test inputs** - Generate random input tensors +3. **Export and lower** - Export the model and lower it to the MLX backend +4. **Run C++ binary** - Execute the lowered model using `op_test_runner` +5. **Compare outputs** - Verify PyTorch and MLX outputs match within tolerance + +### Test Class Structure + +Tests inherit from `OpTestCase` and implement: + +```python +@register_test +class MyTest(OpTestCase): + name = "my_test" # Test name (used for output directory) + rtol = 1e-5 # Relative tolerance for comparison + atol = 1e-5 # Absolute tolerance for comparison + + def create_model(self) -> nn.Module: + """Return the model to test.""" + ... + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Return input tensors for export.""" + ... + + def get_dynamic_shapes(self) -> Optional[Dict]: + """Return dynamic shape specs, or None for static shapes.""" + ... + + @classmethod + def get_test_configs(cls) -> List["MyTest"]: + """Return list of test configurations to run.""" + ... +``` + +## Test Output + +Test artifacts are saved to `op_tests//`: +- `model.pte` - Exported ExecuTorch model +- `input.bin` - Serialized input tensors +- `expected_output.bin` - PyTorch reference output +- `actual_output.bin` - MLX runtime output + +## Adding a New Test + +1. Add a new model class and `OpTestCase` subclass to `test_ops.py` +2. Use the `@register_test` decorator on the test class +3. Implement `create_model()`, `create_inputs()`, and `get_test_configs()` +4. Run the test to verify it works E2E + +## Test harness + +MLX also plugs into the ExecuTorch test harness for even more coverage. To run, use the following command from the ExecuTorch root directory: + +```bash +pytest -c /dev/null backends/test/suite/operators/ -m flow_mlx +``` diff --git a/backends/mlx/test/__init__.py b/backends/mlx/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/test/op_test_runner.cpp b/backends/mlx/test/op_test_runner.cpp new file mode 100644 index 00000000000..6bed13d7a56 --- /dev/null +++ b/backends/mlx/test/op_test_runner.cpp @@ -0,0 +1,395 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Generic op test runner for MLX delegate. + * + * Loads a .pte file, reads inputs from .bin files, runs the model, + * and writes outputs to .bin files. + * + * Build: + * cd cmake-out-mlx && cmake --build . --target op_test_runner + * + * Usage: + * ./cmake-out-mlx/backends/mlx/test/op_test_runner \ + * --pte \ + * --input \ + * --output + * + * Binary file format: + * - 4 bytes: number of tensors (uint32_t) + * For each tensor: + * - 4 bytes: dtype (0=float32, 1=float16, 2=int32, 3=int64, 4=bfloat16, + * 5=bool) + * - 4 bytes: number of dimensions (uint32_t) + * - 4 bytes * ndim: shape (int32_t each) + * - N bytes: data (size = product of shape * sizeof(dtype)) + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wsign-conversion" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#pragma clang diagnostic ignored "-Wimplicit-float-conversion" +#include +#include +#pragma clang diagnostic pop + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::executorch::extension; +using namespace ::executorch::runtime; + +enum class DType : uint32_t { + Float32 = 0, + Float16 = 1, + Int32 = 2, + Int64 = 3, + BFloat16 = 4, + Bool = 5, +}; + +size_t dtype_size(DType dtype) { + switch (dtype) { + case DType::Float32: + return 4; + case DType::Float16: + return 2; + case DType::Int32: + return 4; + case DType::Int64: + return 8; + case DType::BFloat16: + return 2; + case DType::Bool: + return 1; + default: + return 4; + } +} + +exec_aten::ScalarType dtype_to_scalar_type(DType dtype) { + switch (dtype) { + case DType::Float32: + return exec_aten::ScalarType::Float; + case DType::Float16: + return exec_aten::ScalarType::Half; + case DType::Int32: + return exec_aten::ScalarType::Int; + case DType::Int64: + return exec_aten::ScalarType::Long; + case DType::BFloat16: + return exec_aten::ScalarType::BFloat16; + case DType::Bool: + return exec_aten::ScalarType::Bool; + default: + return exec_aten::ScalarType::Float; + } +} + +DType scalar_type_to_dtype(exec_aten::ScalarType stype) { + switch (stype) { + case exec_aten::ScalarType::Float: + return DType::Float32; + case exec_aten::ScalarType::Half: + return DType::Float16; + case exec_aten::ScalarType::Int: + return DType::Int32; + case exec_aten::ScalarType::Long: + return DType::Int64; + case exec_aten::ScalarType::BFloat16: + return DType::BFloat16; + case exec_aten::ScalarType::Bool: + return DType::Bool; + default: + return DType::Float32; + } +} + +struct TensorData { + DType dtype; + std::vector shape; + std::vector data; +}; + +std::vector read_tensors_from_bin(const std::string& path) { + std::ifstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open input file: " + path); + } + + uint32_t num_tensors; + file.read(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + std::vector tensors; + tensors.reserve(num_tensors); + + for (uint32_t i = 0; i < num_tensors; ++i) { + TensorData t; + + uint32_t dtype_val; + file.read(reinterpret_cast(&dtype_val), sizeof(dtype_val)); + t.dtype = static_cast(dtype_val); + + uint32_t ndim; + file.read(reinterpret_cast(&ndim), sizeof(ndim)); + + t.shape.resize(ndim); + file.read(reinterpret_cast(t.shape.data()), ndim * sizeof(int32_t)); + + size_t numel = 1; + for (int32_t s : t.shape) { + numel *= static_cast(s); + } + size_t data_size = numel * dtype_size(t.dtype); + + t.data.resize(data_size); + file.read( + reinterpret_cast(t.data.data()), + static_cast(data_size)); + + tensors.push_back(std::move(t)); + } + + return tensors; +} + +void write_tensors_to_bin( + const std::string& path, + const std::vector& tensors) { + std::ofstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open output file: " + path); + } + + uint32_t num_tensors = static_cast(tensors.size()); + file.write(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + for (const auto& t : tensors) { + uint32_t dtype_val = static_cast(t.dtype); + file.write(reinterpret_cast(&dtype_val), sizeof(dtype_val)); + + uint32_t ndim = static_cast(t.shape.size()); + file.write(reinterpret_cast(&ndim), sizeof(ndim)); + + file.write( + reinterpret_cast(t.shape.data()), ndim * sizeof(int32_t)); + + file.write( + reinterpret_cast(t.data.data()), + static_cast(t.data.size())); + } +} + +void print_usage(const char* prog_name) { + std::cerr << "Usage: " << prog_name << " [options]\n" + << "Options:\n" + << " --pte Path to .pte model file (required)\n" + << " --input Path to input .bin file (required)\n" + << " --output Path to output .bin file (required)\n" + << " --verbose Print verbose output\n" + << std::endl; +} + +int main(int argc, char* argv[]) { + std::string pte_path; + std::string input_path; + std::string output_path; + bool verbose = false; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--pte" && i + 1 < argc) { + pte_path = argv[++i]; + } else if (arg == "--input" && i + 1 < argc) { + input_path = argv[++i]; + } else if (arg == "--output" && i + 1 < argc) { + output_path = argv[++i]; + } else if (arg == "--verbose") { + verbose = true; + } else if (arg == "--help" || arg == "-h") { + print_usage(argv[0]); + return 0; + } else { + std::cerr << "Unknown argument: " << arg << std::endl; + print_usage(argv[0]); + return 1; + } + } + + if (pte_path.empty() || input_path.empty() || output_path.empty()) { + std::cerr << "Error: --pte, --input, and --output are required\n"; + print_usage(argv[0]); + return 1; + } + + try { + if (verbose) { + std::cout << "Loading model from: " << pte_path << std::endl; + } + + Module module(pte_path); + auto load_error = module.load(); + if (load_error != Error::Ok) { + std::cerr << "Failed to load model: " << static_cast(load_error) + << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Model loaded successfully" << std::endl; + } + + auto load_method_error = module.load_method("forward"); + if (load_method_error != Error::Ok) { + std::cerr << "Failed to load forward method: " + << static_cast(load_method_error) << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Reading inputs from: " << input_path << std::endl; + } + + auto input_tensors = read_tensors_from_bin(input_path); + + if (verbose) { + std::cout << "Read " << input_tensors.size() << " input tensors" + << std::endl; + for (size_t i = 0; i < input_tensors.size(); ++i) { + std::cout << " Input " << i + << ": dtype=" << static_cast(input_tensors[i].dtype) + << ", shape=["; + for (size_t j = 0; j < input_tensors[i].shape.size(); ++j) { + std::cout << input_tensors[i].shape[j]; + if (j < input_tensors[i].shape.size() - 1) + std::cout << ", "; + } + std::cout << "]" << std::endl; + } + } + + std::vector tensor_ptrs; + std::vector inputs; + tensor_ptrs.reserve(input_tensors.size()); + inputs.reserve(input_tensors.size()); + + for (const auto& t : input_tensors) { + std::vector sizes(t.shape.begin(), t.shape.end()); + + TensorPtr tensor_ptr; + if (t.dtype == DType::Float32) { + std::vector data(t.data.size() / sizeof(float)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Float16) { + std::vector data( + t.data.size() / sizeof(exec_aten::Half)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::BFloat16) { + std::vector data( + t.data.size() / sizeof(exec_aten::BFloat16)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Int32) { + std::vector data(t.data.size() / sizeof(int32_t)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Int64) { + std::vector data(t.data.size() / sizeof(int64_t)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Bool) { + std::vector data(t.data.size()); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr( + sizes, std::move(data), {}, {}, exec_aten::ScalarType::Bool); + } else { + std::cerr << "Unsupported dtype: " << static_cast(t.dtype) + << std::endl; + return 1; + } + + tensor_ptrs.push_back(tensor_ptr); + inputs.push_back(tensor_ptr); + } + + if (verbose) { + std::cout << "Executing forward..." << std::endl; + } + + auto result = module.forward(inputs); + if (result.error() != Error::Ok) { + std::cerr << "Execution failed: " << static_cast(result.error()) + << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Execution succeeded, " << result->size() << " outputs" + << std::endl; + } + + std::vector output_tensors; + output_tensors.reserve(result->size()); + + for (size_t i = 0; i < result->size(); ++i) { + const auto& evalue = result->at(i); + if (!evalue.isTensor()) { + std::cerr << "Output " << i << " is not a tensor" << std::endl; + return 1; + } + + const auto& tensor = evalue.toTensor(); + TensorData t; + t.dtype = scalar_type_to_dtype(tensor.scalar_type()); + + t.shape.resize(static_cast(tensor.dim())); + for (size_t d = 0; d < static_cast(tensor.dim()); ++d) { + t.shape[d] = static_cast(tensor.size(static_cast(d))); + } + + size_t data_size = tensor.nbytes(); + t.data.resize(data_size); + std::memcpy(t.data.data(), tensor.const_data_ptr(), data_size); + + if (verbose) { + std::cout << " Output " << i << ": dtype=" << static_cast(t.dtype) + << ", shape=["; + for (size_t j = 0; j < t.shape.size(); ++j) { + std::cout << t.shape[j]; + if (j < t.shape.size() - 1) + std::cout << ", "; + } + std::cout << "]" << std::endl; + } + + output_tensors.push_back(std::move(t)); + } + + if (verbose) { + std::cout << "Writing outputs to: " << output_path << std::endl; + } + + write_tensors_to_bin(output_path, output_tensors); + + std::cout << "OK" << std::endl; + return 0; + + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } +} diff --git a/backends/mlx/test/run_all_tests.py b/backends/mlx/test/run_all_tests.py new file mode 100644 index 00000000000..3cda35da275 --- /dev/null +++ b/backends/mlx/test/run_all_tests.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run all MLX delegate op tests. + +Usage: + # Run all tests (all configurations): + python -m executorch.backends.mlx.test.run_all_tests + + # Run specific test (all its configurations): + python -m executorch.backends.mlx.test.run_all_tests add + + # Run specific test configuration: + python -m executorch.backends.mlx.test.run_all_tests add_scalar + + # List available tests: + python -m executorch.backends.mlx.test.run_all_tests --list + + # Rebuild C++ runner before running: + python -m executorch.backends.mlx.test.run_all_tests --rebuild + + # Run tests in parallel: + python -m executorch.backends.mlx.test.run_all_tests -j 4 + + # Run with custom timeout: + python -m executorch.backends.mlx.test.run_all_tests --timeout 60 +""" + +import argparse +import importlib +import multiprocessing +import subprocess +import sys +from multiprocessing import Pool +from typing import List, Optional, Tuple + +from .test_utils import ( + clean_test_outputs, + DEFAULT_TEST_TIMEOUT, + get_all_test_configs, + get_registered_tests, + get_test_output_size, + rebuild_op_test_runner, +) + + +def discover_and_import_tests(): + """ + Import test_ops.py module which contains all test definitions. + This triggers registration of all tests. + """ + importlib.import_module(".test_ops", package=__package__) + + +def _run_single_test( + test_class_name: str, + config_name: str, + config_kwargs: dict, + verbose: bool, + timeout: int, +) -> Tuple[str, bool, Optional[str]]: + """ + Run a single test configuration in a subprocess. + + Called via multiprocessing.Pool.starmap for parallel execution. + Recreates the test instance from the class name and kwargs. + + Args: + test_class_name: Name of the test class module.path + config_name: Name of this configuration + config_kwargs: Kwargs to recreate the test instance + verbose: Whether to print verbose output + timeout: Timeout in seconds + + Returns: + (config_name, passed, error_message) + """ + try: + # Re-discover and import tests in this subprocess + discover_and_import_tests() + + # Find the test config by name + all_configs = get_all_test_configs() + test_instance = None + for name, instance in all_configs: + if name == config_name: + test_instance = instance + break + + if test_instance is None: + return (config_name, False, f"Could not find test config: {config_name}") + + # Run the test + passed = test_instance.run_test(verbose=verbose, timeout=timeout) + return (config_name, passed, None) + + except Exception as e: + import traceback + + return (config_name, False, f"Exception: {e}\n{traceback.format_exc()}") + + +def run_tests_sequential( + configs_to_run: List[Tuple[str, object]], + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, + clean_after_each: bool = False, + isolate: bool = False, +) -> Tuple[int, int, List[str]]: + """ + Run tests sequentially. + + Args: + configs_to_run: List of (config_name, test_instance) tuples. + verbose: Whether to print verbose output. + timeout: Timeout in seconds per test. + clean_after_each: Whether to clean up test outputs after each test. + isolate: Whether to run each test in a subprocess to prevent memory + accumulation across tests (torch/MLX/Metal allocations). + + Returns: + (passed_count, failed_count, failed_test_names) + """ + passed = 0 + failed = 0 + failed_tests = [] + + for config_name, test in configs_to_run: + if isolate: + test_passed = _run_test_in_subprocess( + config_name, verbose=verbose, timeout=timeout + ) + else: + try: + test_passed = test.run_test(verbose=verbose, timeout=timeout) + except Exception as e: + print(f"✗ FAILED: {config_name} - Exception: {e}") + import traceback + + traceback.print_exc() + test_passed = False + + if test_passed: + passed += 1 + else: + failed += 1 + failed_tests.append(config_name) + + if clean_after_each: + clean_test_outputs([config_name], verbose=False) + + return passed, failed, failed_tests + + +def _run_test_in_subprocess( + config_name: str, + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, +) -> bool: + """ + Run a single test in an isolated subprocess. + + Each test gets its own Python interpreter so torch/MLX/Metal memory is + fully released between tests, preventing OOM on CI runners. + + Args: + config_name: Name of the test configuration to run. + verbose: Whether to print verbose output. + timeout: Timeout in seconds. + + Returns: + True if test passed, False otherwise. + """ + cmd = [ + sys.executable, + "-m", + "executorch.backends.mlx.test.test_utils", + config_name, + "run", + ] + if verbose: + cmd.append("--verbose") + + try: + result = subprocess.run( + cmd, + timeout=timeout, + capture_output=False, + ) + return result.returncode == 0 + except subprocess.TimeoutExpired: + print(f"✗ FAILED: {config_name} - Timeout after {timeout}s") + return False + except Exception as e: + print(f"✗ FAILED: {config_name} - Subprocess error: {e}") + return False + + +def run_tests_parallel( + configs_to_run: List[Tuple[str, object]], + num_workers: int, + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, + max_tasks_per_worker: Optional[int] = None, +) -> Tuple[int, int, List[str]]: + """ + Run tests in parallel using multiprocessing.Pool. + + Args: + configs_to_run: List of (config_name, test_instance) tuples. + num_workers: Number of parallel workers. + verbose: Whether to print verbose output. + timeout: Timeout in seconds per test. + max_tasks_per_worker: Maximum tasks per worker before recycling. + When set, worker processes are terminated and replaced after + completing this many tests, which releases accumulated memory + (torch/MLX/Metal allocations). None means workers are never recycled. + + Returns: + (passed_count, failed_count, failed_test_names) + """ + passed = 0 + failed = 0 + failed_tests = [] + + # Prepare test args for parallel execution + # We pass config names and let subprocesses recreate the test instances + test_args = [("", name, {}, verbose, timeout) for name, _ in configs_to_run] + + recycle_msg = "" + if max_tasks_per_worker is not None: + recycle_msg = f", recycling workers every {max_tasks_per_worker} tests" + print( + f"\nRunning {len(test_args)} tests with {num_workers} workers{recycle_msg}...\n" + ) + + with Pool(processes=num_workers, maxtasksperchild=max_tasks_per_worker) as pool: + results = pool.starmap(_run_single_test, test_args) + + for result_name, result_passed, error_msg in results: + if result_passed: + print(f"✓ PASSED: {result_name}") + passed += 1 + else: + if error_msg: + print(f"✗ FAILED: {result_name} - {error_msg}") + else: + print(f"✗ FAILED: {result_name}") + failed += 1 + failed_tests.append(result_name) + + return passed, failed, failed_tests + + +def run_tests( + test_filter: List[str], + verbose: bool = False, + parallel: int = 1, + timeout: int = DEFAULT_TEST_TIMEOUT, + clean_after_each: bool = False, + isolate: bool = False, + max_tasks_per_worker: Optional[int] = None, +) -> Tuple[int, int, List[str]]: + """ + Run tests matching the filter. + + Args: + test_filter: List of test names/patterns to run. If empty, runs all tests. + Can match either base test name (e.g., "add") or config name (e.g., "add_scalar"). + verbose: Whether to print verbose output. + parallel: Number of parallel workers (1 = sequential). + timeout: Timeout in seconds per test. + clean_after_each: Whether to clean up test outputs after each test (sequential only). + isolate: Whether to run each test in a subprocess (sequential only). + max_tasks_per_worker: Maximum tasks per worker before recycling (parallel only). + + Returns: + (passed_count, failed_count, failed_test_names) + """ + all_configs = get_all_test_configs() + registry = get_registered_tests() + + # Determine which configs to run + configs_to_run = [] + if not test_filter: + # Run all + configs_to_run = all_configs + else: + for pattern in test_filter: + matched = False + + # Check if pattern matches a base test name + if pattern in registry: + configs_to_run.extend(registry[pattern]) + matched = True + else: + # Check if pattern matches a config name + for config_name, config in all_configs: + if config_name == pattern: + configs_to_run.append((config_name, config)) + matched = True + + if not matched: + print(f"Warning: No test matching '{pattern}', skipping") + + if not configs_to_run: + print("No tests to run.") + return 0, 0, [] + + # Run tests + if parallel > 1: + return run_tests_parallel( + configs_to_run, parallel, verbose, timeout, max_tasks_per_worker + ) + else: + return run_tests_sequential( + configs_to_run, verbose, timeout, clean_after_each, isolate + ) + + +def main(): # noqa: C901 + # Get CPU count for default parallel workers + cpu_count = multiprocessing.cpu_count() + + parser = argparse.ArgumentParser(description="Run all MLX delegate op tests") + parser.add_argument( + "tests", + nargs="*", + help="Specific tests to run (default: all). Can be base name (e.g., 'add') or config name (e.g., 'add_scalar')", + ) + parser.add_argument( + "--list", + action="store_true", + help="List available tests and exit", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Verbose output", + ) + parser.add_argument( + "--rebuild", + action="store_true", + help="Rebuild the C++ test runner before running", + ) + parser.add_argument( + "--clean", + action="store_true", + help="Clean up generated test files and exit", + ) + parser.add_argument( + "--clean-after", + action="store_true", + help="Clean up generated test files after running tests", + ) + parser.add_argument( + "--isolate", + action="store_true", + help="Run each test in a separate subprocess to prevent memory accumulation", + ) + parser.add_argument( + "-j", + "--parallel", + type=int, + default=1, + metavar="N", + help=f"Run tests in parallel with N workers (default: 1, max: {cpu_count})", + ) + parser.add_argument( + "--timeout", + type=int, + default=DEFAULT_TEST_TIMEOUT, + metavar="SECS", + help=f"Timeout per test in seconds (default: {DEFAULT_TEST_TIMEOUT})", + ) + parser.add_argument( + "--max-tasks-per-worker", + type=int, + default=None, + metavar="N", + help="Recycle parallel workers after N tests to release memory (default: no recycling)", + ) + args = parser.parse_args() + + # Validate parallel workers + if args.parallel < 1: + args.parallel = 1 + elif args.parallel > cpu_count: + print( + f"Warning: --parallel {args.parallel} exceeds CPU count ({cpu_count}), using {cpu_count}" + ) + args.parallel = cpu_count + + # Auto-discover and import all test modules + discover_and_import_tests() + + # Handle --clean flag + if args.clean: + # Determine which tests to clean + test_names = None + if args.tests: + # Get config names for the specified tests + registry = get_registered_tests() + test_names = [] + for pattern in args.tests: + if pattern in registry: + test_names.extend(cfg_name for cfg_name, _ in registry[pattern]) + else: + test_names.append(pattern) + + # Show current size + current_size = get_test_output_size(test_names) + if current_size > 0: + print(f"Current test output size: {current_size / 1024 / 1024:.2f} MB") + + # Clean + files_removed = clean_test_outputs(test_names, verbose=args.verbose) + if files_removed > 0: + print(f"Removed {files_removed} files") + else: + print("No files to clean") + sys.exit(0) + + # List tests + if args.list: + registry = get_registered_tests() + print("Available tests:") + for base_name in sorted(registry.keys()): + configs = registry[base_name] + if len(configs) == 1 and configs[0][0] == base_name: + # Single config with same name as base + print(f" {base_name}") + else: + # Multiple configs or different name + print(f" {base_name}:") + for config_name, _ in configs: + print(f" - {config_name}") + sys.exit(0) + + # Rebuild if requested + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + # Run tests + passed, failed, failed_tests = run_tests( + args.tests, + verbose=args.verbose, + parallel=args.parallel, + timeout=args.timeout, + clean_after_each=args.clean_after, + isolate=args.isolate, + max_tasks_per_worker=args.max_tasks_per_worker, + ) + + # Print summary + print("\n" + "=" * 60) + print("TEST SUMMARY") + print("=" * 60) + print(f"Passed: {passed}") + print(f"Failed: {failed}") + if failed_tests: + print(f"Failed tests: {', '.join(failed_tests)}") + print("=" * 60) + + # Clean up after tests if requested + if args.clean_after: + # Determine which tests to clean (same logic as --clean) + test_names = None + if args.tests: + registry = get_registered_tests() + test_names = [] + for pattern in args.tests: + if pattern in registry: + test_names.extend(cfg_name for cfg_name, _ in registry[pattern]) + else: + test_names.append(pattern) + + current_size = get_test_output_size(test_names) + files_removed = clean_test_outputs(test_names, verbose=args.verbose) + if files_removed > 0: + print( + f"\nCleaned up {files_removed} files ({current_size / 1024 / 1024:.2f} MB)" + ) + + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/test/strict_compile_test.cpp b/backends/mlx/test/strict_compile_test.cpp new file mode 100644 index 00000000000..28df78a7d5a --- /dev/null +++ b/backends/mlx/test/strict_compile_test.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Compile-only test to verify MLX delegate headers are clean under strict + * warnings (-Wconversion, -Wsign-conversion, -Wshorten-64-to-32, -Werror). + * + * This file includes the delegate headers and instantiates key types to ensure + * template code is also checked. It is never linked or executed — a successful + * compilation is the test. + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wsign-conversion" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#include +#include +#include +#include +#pragma clang diagnostic pop + +// These are the headers we want to verify under strict warnings +#include "MLXExecutor.h" +#include "MLXInterpreter.h" +#include "MLXLoader.h" + +// Instantiate key types to ensure template code is checked +namespace { +[[maybe_unused]] void force_instantiation() { + using namespace executorch::backends::mlx; + + // Force safe_mul template instantiation + (void)safe_mul(0, 0, "test"); + + // Force check_allocation_bounded instantiation + ::mlx::core::Shape shape = {1, 2, 3}; + check_allocation_bounded(shape, ::mlx::core::float32, "test"); +} +} // namespace diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py new file mode 100644 index 00000000000..01286f75f16 --- /dev/null +++ b/backends/mlx/test/test_ops.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Consolidated op tests for the MLX delegate. + +This file contains all op tests organized by category. Each test class inherits +from OpTestCase and can be run via the run_all_tests.py script. + +Usage: + # Run all tests (with 4 parallel workers, cleanup after) + python -m executorch.backends.mlx.test.run_all_tests -j4 --clean-after + + # Run specific test + python -m executorch.backends.mlx.test.run_all_tests add + + # List available tests + python -m executorch.backends.mlx.test.run_all_tests --list + +See README.md in this directory for full documentation. +""" + +from typing import List, Tuple + +import torch +import torch.nn as nn + +# Import custom ops for RoPE and KV cache tests +from executorch.backends.mlx import ( # noqa: F401 - registers mlx ops # noqa: F401 - registers mlx.rope + custom_ops, + ops, +) + +from .test_utils import OpTestCase, register_test + + +class BmmModel(nn.Module): + """Model that performs batch matrix multiplication.""" + + def __init__(self, batch_size: int, n: int, m: int, p: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bmm(x, self.weight) + + +@register_test +class BmmTest(OpTestCase): + """Test case for bmm (batch matrix multiplication).""" + + name = "bmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 4, + n: int = 8, + m: int = 16, + p: int = 32, + ): + self.batch_size = batch_size + self.n = n + self.m = m + self.p = p + self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + @classmethod + def get_test_configs(cls) -> List["BmmTest"]: + return [ + cls(batch_size=4, n=8, m=16, p=32), + cls(batch_size=2, n=64, m=64, p=32), + ] + + def create_model(self) -> nn.Module: + return BmmModel(self.batch_size, self.n, self.m, self.p) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.n, self.m) + return (x,) + + +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) + + +@register_test +class AddmmTest(OpTestCase): + """Test case for addmm.""" + + name = "addmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 2, + in_features: int = 64, + out_features: int = 32, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + self.batch_size = batch_size + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.alpha = alpha + self.beta = beta + + # Build unique test name + if not bias: + name = f"addmm_{in_features}x{out_features}_no_bias" + elif alpha != 1.0 or beta != 1.0: + name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" + else: + name = f"addmm_{in_features}x{out_features}" + self.name = name + + @classmethod + def get_test_configs(cls) -> List["AddmmTest"]: + return [ + cls( + batch_size=2, in_features=64, out_features=32 + ), # with bias, default alpha/beta + cls( + batch_size=2, in_features=64, out_features=32, bias=False + ), # without bias + cls(batch_size=4, in_features=128, out_features=64), # larger size + cls( + batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 + ), # custom alpha/beta + ] + + def create_model(self) -> nn.Module: + return AddmmModel( + self.in_features, + self.out_features, + bias=self.bias, + alpha=self.alpha, + beta=self.beta, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features) + return (x,) diff --git a/backends/mlx/test/test_partitioner.py b/backends/mlx/test/test_partitioner.py new file mode 100644 index 00000000000..4a5833aa656 --- /dev/null +++ b/backends/mlx/test/test_partitioner.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the MLX partitioner. +""" + +import unittest + +import torch +import torch.nn as nn +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.exir import EdgeCompileConfig, to_edge +from torch.export import export + + +class TestMLXPartitionerRejectsToEdge(unittest.TestCase): + """MLXPartitioner must only be used via to_edge_transform_and_lower.""" + + def test_to_edge_then_to_backend_raises(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + ep = export(M(), (torch.randn(4),), strict=False) + edge = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + + with self.assertRaises(RuntimeError) as ctx: + edge.to_backend(MLXPartitioner()) + + self.assertIn("to_edge_transform_and_lower", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_passes.py b/backends/mlx/test/test_passes.py new file mode 100644 index 00000000000..a9fdb3b996b --- /dev/null +++ b/backends/mlx/test/test_passes.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/test/test_pattern_utils.py b/backends/mlx/test/test_pattern_utils.py new file mode 100644 index 00000000000..48495a469d7 --- /dev/null +++ b/backends/mlx/test/test_pattern_utils.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for pattern_utils.py - shared pattern matching utilities. +""" + +import unittest + +import torch +from torch.export import export + + +def get_exported_graph(module, example_inputs): + """Export a module and return the graph with ATen ops.""" + ep = export(module, example_inputs) + return ep.graph_module.graph + + +def find_node_by_target(graph, target_name): + """Find first call_function node whose target contains target_name.""" + for node in graph.nodes: + if node.op == "call_function" and target_name in str(node.target): + return node + return None + + +def find_all_nodes_by_target(graph, target_name): + """Find all call_function nodes whose target contains target_name.""" + return [ + node + for node in graph.nodes + if node.op == "call_function" and target_name in str(node.target) + ] + + +class TestMatchTarget(unittest.TestCase): + """Tests for match_target function.""" + + def test_match_target_basic(self): + """Test basic op matching.""" + from executorch.backends.mlx.pattern_utils import match_target + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + self.assertIsNotNone(rsqrt_node) + self.assertTrue(match_target(rsqrt_node, torch.ops.aten.rsqrt.default)) + self.assertFalse(match_target(rsqrt_node, torch.ops.aten.add.Tensor)) + + def test_match_target_non_call_function(self): + """Test that non-call_function nodes don't match.""" + from executorch.backends.mlx.pattern_utils import match_target + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + + # Find a placeholder node + placeholder_node = None + for node in graph.nodes: + if node.op == "placeholder": + placeholder_node = node + break + + self.assertIsNotNone(placeholder_node) + self.assertFalse(match_target(placeholder_node, torch.ops.aten.rsqrt.default)) + + +class TestHasSingleUser(unittest.TestCase): + """Tests for has_single_user function.""" + + def test_single_user(self): + """Test node with single user.""" + from executorch.backends.mlx.pattern_utils import has_single_user + + class SingleUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Single use + return y + 1 + + graph = get_exported_graph(SingleUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertTrue(has_single_user(neg_node)) + + def test_multiple_users(self): + """Test node with multiple users.""" + from executorch.backends.mlx.pattern_utils import has_single_user + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertFalse(has_single_user(neg_node)) + + +class TestHasNoUsers(unittest.TestCase): + """Tests for has_no_users function.""" + + def test_has_users(self): + """Test node that has users.""" + from executorch.backends.mlx.pattern_utils import has_no_users + + class SimpleModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + return y + 1 + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertFalse(has_no_users(neg_node)) + + def test_no_users_after_removal(self): + """Test has_no_users returns True for orphaned nodes.""" + from executorch.backends.mlx.pattern_utils import has_no_users + + class SimpleModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Initially neg has a user (rsqrt) + self.assertFalse(has_no_users(neg_node)) + + # Replace rsqrt's input with placeholder to orphan neg + placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholder = node + break + rsqrt_node.replace_input_with(neg_node, placeholder) + + # Now neg has no users + self.assertTrue(has_no_users(neg_node)) + + +class TestOpStep(unittest.TestCase): + """Tests for OpStep dataclass.""" + + def test_matches_with_op(self): + """Test OpStep.matches with op field.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + step = OpStep(op=torch.ops.aten.rsqrt.default) + self.assertTrue(step.matches(rsqrt_node)) + + step_wrong = OpStep(op=torch.ops.aten.neg.default) + self.assertFalse(step_wrong.matches(rsqrt_node)) + + def test_matches_with_predicate(self): + """Test OpStep.matches with predicate field.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Predicate that always returns True + step_true = OpStep(predicate=lambda n: True) + self.assertTrue(step_true.matches(rsqrt_node)) + + # Predicate that always returns False + step_false = OpStep(predicate=lambda n: False) + self.assertFalse(step_false.matches(rsqrt_node)) + + def test_matches_no_op_no_predicate(self): + """Test OpStep.matches returns False when neither op nor predicate set.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + step_empty = OpStep() + self.assertFalse(step_empty.matches(rsqrt_node)) + + def test_matches_require_single_user_true(self): + """Test OpStep.matches with require_single_user=True (default).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + # Default require_single_user=True, neg has multiple users + step = OpStep(op=torch.ops.aten.neg.default) + self.assertFalse(step.matches(neg_node)) + + def test_matches_require_single_user_false(self): + """Test OpStep.matches with require_single_user=False.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + # With require_single_user=False, should match despite multiple users + step = OpStep(op=torch.ops.aten.neg.default, require_single_user=False) + self.assertTrue(step.matches(neg_node)) + + def test_matches_nargs_int(self): + """Test OpStep.matches with nargs as int (minimum).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # nargs=1 should match (rsqrt has 1 arg) + step = OpStep(op=torch.ops.aten.rsqrt.default, nargs=1) + self.assertTrue(step.matches(rsqrt_node)) + + # nargs=2 should fail (rsqrt only has 1 arg) + step_too_many = OpStep(op=torch.ops.aten.rsqrt.default, nargs=2) + self.assertFalse(step_too_many.matches(rsqrt_node)) + + def test_matches_nargs_tuple(self): + """Test OpStep.matches with nargs as tuple (range).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # nargs=(1, 3) should match (rsqrt has 1 arg, in range) + step = OpStep(op=torch.ops.aten.rsqrt.default, nargs=(1, 3)) + self.assertTrue(step.matches(rsqrt_node)) + + # nargs=(2, 4) should fail (rsqrt has 1 arg, not in range) + step_out_of_range = OpStep(op=torch.ops.aten.rsqrt.default, nargs=(2, 4)) + self.assertFalse(step_out_of_range.matches(rsqrt_node)) + + def test_matches_kwargs_empty(self): + """Test OpStep.matches with empty kwargs (node must have no kwargs).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # No kwargs + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Empty kwargs set() means node must have no kwargs (default) + step = OpStep(op=torch.ops.aten.rsqrt.default, kwargs=set()) + self.assertTrue(step.matches(rsqrt_node)) + + # Default is also empty set (strict checking) + step_default = OpStep(op=torch.ops.aten.rsqrt.default) + self.assertTrue(step_default.matches(rsqrt_node)) + + def test_matches_kwargs_declared(self): + """Test OpStep.matches with declared kwargs.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class CastModule(torch.nn.Module): + def forward(self, x): + return x.to(torch.float16) + + graph = get_exported_graph(CastModule(), (torch.randn(4, 4),)) + to_copy_node = find_node_by_target(graph, "_to_copy") + + if to_copy_node is not None: + # Check what kwargs exist + node_kwargs = set(to_copy_node.kwargs.keys()) + + # If we declare all kwargs, should match + step_all = OpStep( + op=torch.ops.aten._to_copy.default, + kwargs=node_kwargs, + ) + self.assertTrue(step_all.matches(to_copy_node)) + + # If we don't declare some kwargs, should fail + if node_kwargs: + step_missing = OpStep( + op=torch.ops.aten._to_copy.default, + kwargs=set(), # Empty, but node has kwargs + ) + self.assertFalse(step_missing.matches(to_copy_node)) + + def test_matches_arg_index(self): + """Test OpStep.matches validates arg_index is accessible.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # arg_index=0 should work (rsqrt has 1 arg) + step = OpStep(op=torch.ops.aten.rsqrt.default, arg_index=0) + self.assertTrue(step.matches(rsqrt_node)) + + # arg_index=1 should fail (rsqrt only has 1 arg, can't access args[1]) + step_bad_index = OpStep(op=torch.ops.aten.rsqrt.default, arg_index=1) + self.assertFalse(step_bad_index.matches(rsqrt_node)) + + +class TestWalkBack(unittest.TestCase): + """Tests for walk_back function.""" + + def test_walk_back_single_step(self): + """Test walk_back with a single step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + result = walk_back(rsqrt_node, [OpStep(op=torch.ops.aten.rsqrt.default)]) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0], rsqrt_node) + # base_node should be the input to rsqrt + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_chain(self): + """Test walk_back with multiple steps in a chain.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Match rsqrt -> neg chain + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.neg.default), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_no_match(self): + """Test walk_back returns None when pattern doesn't match.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Try to match neg which isn't there + result = walk_back(rsqrt_node, [OpStep(op=torch.ops.aten.neg.default)]) + + self.assertIsNone(result) + + def test_walk_back_optional_step(self): + """Test walk_back with optional step that doesn't match.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Match rsqrt, skip optional neg (not present) + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.neg.default, optional=True), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # One for each step + self.assertIsNotNone(entries[0]) # rsqrt matched + self.assertIsNone(entries[1]) # neg is None (optional, not matched) + + def test_walk_back_repeat_step(self): + """Test walk_back with repeat step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class RepeatModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.neg(y) + w = torch.neg(z) + return w + + graph = get_exported_graph(RepeatModule(), (torch.randn(4, 4),)) + + # Find the last neg node (output of the chain) + neg_nodes = find_all_nodes_by_target(graph, "neg") + self.assertEqual(len(neg_nodes), 3) + last_neg = neg_nodes[-1] + + # Match chain of neg ops + result = walk_back( + last_neg, + [OpStep(op=torch.ops.aten.neg.default, repeat=True)], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 1) # One entry for the repeat step + self.assertIsInstance(entries[0], list) # Repeat returns list + self.assertEqual(len(entries[0]), 3) # Three neg nodes matched + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_repeat_zero_matches(self): + """Test walk_back with repeat step matching zero times then another step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Try to match neg (repeat, 0 matches) then rsqrt + # neg doesn't exist at rsqrt, so 0 matches, then we match rsqrt + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.neg.default, repeat=True), + OpStep(op=torch.ops.aten.rsqrt.default), + ], + ) + + # This should match: neg repeat matches 0 times, rsqrt matches + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # One for each step + self.assertIsInstance(entries[0], list) # Repeat returns list + self.assertEqual(len(entries[0]), 0) # Zero neg nodes matched + self.assertIsNotNone(entries[1]) # rsqrt matched + + def test_walk_back_arg_index(self): + """Test walk_back with arg_index to follow non-first argument.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class BinaryModule(torch.nn.Module): + def forward(self, x): + y = torch.rsqrt(x) + return x * y # mul(x, rsqrt(x)) + + graph = get_exported_graph(BinaryModule(), (torch.randn(4, 4),)) + mul_node = find_node_by_target(graph, "mul") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + self.assertIsNotNone(mul_node) + self.assertIsNotNone(rsqrt_node) + + # Follow args[1] (rsqrt) instead of args[0] (placeholder) + result = walk_back( + mul_node, + [ + OpStep(op=torch.ops.aten.mul.Tensor, nargs=2, arg_index=1), + OpStep(op=torch.ops.aten.rsqrt.default), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # mul and rsqrt + self.assertEqual(entries[0], mul_node) + self.assertEqual(entries[1], rsqrt_node) + # base_node should be the input to rsqrt (placeholder) + self.assertEqual(base_node.op, "placeholder") + + +class TestPatternMatch(unittest.TestCase): + """Tests for PatternMatch base class.""" + + def test_all_nodes(self): + """Test all_nodes returns head + body.""" + from executorch.backends.mlx.pattern_utils import PatternMatch + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + match = PatternMatch(head=rsqrt_node, body=[neg_node]) + self.assertEqual(match.all_nodes(), [rsqrt_node, neg_node]) + + def test_remove_body_nodes(self): + """Test remove_body_nodes removes unused nodes.""" + from executorch.backends.mlx.pattern_utils import PatternMatch + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # To test remove_body_nodes, we'd need to first replace rsqrt's uses + # and then call remove_body_nodes. For this test, just verify the + # method exists and doesn't crash when nodes have users. + match = PatternMatch(head=rsqrt_node, body=[neg_node]) + + # This won't remove neg because it still has a user (rsqrt) + match.remove_body_nodes(graph) + + # neg should still exist because it has a user + self.assertIn(neg_node, graph.nodes) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py new file mode 100644 index 00000000000..090bceabf08 --- /dev/null +++ b/backends/mlx/test/test_utils.py @@ -0,0 +1,1122 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for MLX delegate op testing. + +This module provides functions to: +1. Save/load tensors to/from binary files (compatible with C++ op_test_runner) +2. Export simple models to .pte files +3. Compare expected vs actual outputs +4. Run the C++ op_test_runner binary +""" + +import json +import os +import struct +import subprocess +import tempfile +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + + +DEFAULT_TEST_TIMEOUT = 300 # 5 minutes default timeout + + +class TestTimeoutError(Exception): + """Raised when a test exceeds its timeout.""" + + pass + + +# DType enum values matching C++ op_test_runner +DTYPE_FLOAT32 = 0 +DTYPE_FLOAT16 = 1 +DTYPE_INT32 = 2 +DTYPE_INT64 = 3 +DTYPE_BFLOAT16 = 4 +DTYPE_BOOL = 5 + + +# Default tolerance presets for different data types. +# These are based on the precision characteristics of each dtype: +# - FP32: ~7 decimal digits of precision +# - FP16: ~3-4 decimal digits of precision +# - BF16: ~2-3 decimal digits of precision (same exponent range as FP32) +TOLERANCE_PRESETS = { + torch.float32: {"rtol": 1e-5, "atol": 1e-5}, + torch.float16: {"rtol": 1e-3, "atol": 1e-3}, + torch.bfloat16: {"rtol": 1e-2, "atol": 1e-2}, + # Integer types should match exactly + torch.int32: {"rtol": 0, "atol": 0}, + torch.int64: {"rtol": 0, "atol": 0}, +} + + +def get_tolerance_for_dtype(dtype: torch.dtype) -> Tuple[float, float]: + """ + Get appropriate (rtol, atol) tolerances for a given dtype. + + Args: + dtype: The torch dtype to get tolerances for. + + Returns: + (rtol, atol) tuple with appropriate tolerances for the dtype. + """ + if dtype in TOLERANCE_PRESETS: + preset = TOLERANCE_PRESETS[dtype] + return preset["rtol"], preset["atol"] + # Default to FP32 tolerances for unknown types + return 1e-5, 1e-5 + + +def get_tolerance_for_dtypes(dtypes: List[torch.dtype]) -> Tuple[float, float]: + """ + Get tolerances that work for a list of dtypes (uses the loosest tolerances). + + Args: + dtypes: List of torch dtypes. + + Returns: + (rtol, atol) tuple with tolerances that accommodate all dtypes. + """ + if not dtypes: + return 1e-5, 1e-5 + + max_rtol = 0.0 + max_atol = 0.0 + for dtype in dtypes: + rtol, atol = get_tolerance_for_dtype(dtype) + max_rtol = max(max_rtol, rtol) + max_atol = max(max_atol, atol) + + return max_rtol, max_atol + + +def torch_dtype_to_bin_dtype(dtype: torch.dtype) -> int: + """Convert torch dtype to binary file dtype enum value.""" + mapping = { + torch.float32: DTYPE_FLOAT32, + torch.float16: DTYPE_FLOAT16, + torch.int32: DTYPE_INT32, + torch.int64: DTYPE_INT64, + torch.bfloat16: DTYPE_BFLOAT16, + torch.bool: DTYPE_BOOL, + } + if dtype not in mapping: + raise ValueError(f"Unsupported dtype: {dtype}") + return mapping[dtype] + + +def bin_dtype_to_torch_dtype(dtype_val: int) -> torch.dtype: + """Convert binary file dtype enum value to torch dtype.""" + mapping = { + DTYPE_FLOAT32: torch.float32, + DTYPE_FLOAT16: torch.float16, + DTYPE_INT32: torch.int32, + DTYPE_INT64: torch.int64, + DTYPE_BFLOAT16: torch.bfloat16, + DTYPE_BOOL: torch.bool, + } + if dtype_val not in mapping: + raise ValueError(f"Unknown dtype value: {dtype_val}") + return mapping[dtype_val] + + +def _atomic_write_binary(path: Path, data: bytes) -> None: + """ + Atomically write binary data to a file. + + Writes to a temporary file in the same directory, then atomically replaces + the target path. This prevents race conditions when multiple parallel + workers write to the same ``op_tests/`` tree. + """ + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp") + closed = False + try: + os.write(fd, data) + os.close(fd) + closed = True + os.replace(tmp, path) + except BaseException: + if not closed: + os.close(fd) + if os.path.exists(tmp): + os.unlink(tmp) + raise + + +def save_tensors_to_bin(tensors: List[torch.Tensor], path: Union[str, Path]) -> None: + """ + Save a list of tensors to a binary file. + + Binary format: + - 4 bytes: number of tensors (uint32) + For each tensor: + - 4 bytes: dtype enum (uint32) + - 4 bytes: number of dimensions (uint32) + - 4 bytes * ndim: shape (int32 each) + - N bytes: tensor data + """ + path = Path(path) + + buf = bytearray() + # Write number of tensors + buf += struct.pack("I", len(tensors)) + + for tensor in tensors: + # Ensure contiguous + tensor = tensor.contiguous() + + # Write dtype + dtype_val = torch_dtype_to_bin_dtype(tensor.dtype) + buf += struct.pack("I", dtype_val) + + # Write ndim + buf += struct.pack("I", tensor.dim()) + + # Write shape + for s in tensor.shape: + buf += struct.pack("i", s) + + # Write data - bf16 needs special handling since numpy doesn't support it + if tensor.dtype == torch.bfloat16: + # View bf16 as uint16 to preserve raw bytes + buf += tensor.view(torch.uint16).numpy().tobytes() + else: + buf += tensor.numpy().tobytes() + + _atomic_write_binary(path, bytes(buf)) + + +def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: + path = Path(path) + + # Mapping from torch dtype to numpy dtype + np_dtype_map = { + torch.float32: np.float32, + torch.float16: np.float16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.bool: np.bool_, + # bfloat16 needs special handling - read as uint16 + } + + # Element size for each dtype + elem_size_map = { + torch.float32: 4, + torch.float16: 2, + torch.int32: 4, + torch.int64: 8, + torch.bfloat16: 2, + torch.bool: 1, + } + + tensors = [] + with open(path, "rb") as f: + # Read number of tensors + num_tensors = struct.unpack("I", f.read(4))[0] + + for _ in range(num_tensors): + # Read dtype + dtype_val = struct.unpack("I", f.read(4))[0] + dtype = bin_dtype_to_torch_dtype(dtype_val) + + # Read ndim + ndim = struct.unpack("I", f.read(4))[0] + + # Read shape + shape = [] + for _ in range(ndim): + shape.append(struct.unpack("i", f.read(4))[0]) + + # Read data + numel = 1 + for s in shape: + numel *= s + + elem_size = elem_size_map[dtype] + data_bytes = f.read(numel * elem_size) + + # Convert to tensor + if dtype == torch.bfloat16: + # Read as uint16 and view as bfloat16 + arr = np.frombuffer(data_bytes, dtype=np.uint16).reshape(shape) + tensor = torch.tensor(arr).view(torch.bfloat16) + else: + arr = np.frombuffer(data_bytes, dtype=np_dtype_map[dtype]).reshape( + shape + ) + tensor = torch.from_numpy(arr.copy()) + + tensors.append(tensor) + + return tensors + + +def export_model_to_pte( + model: torch.nn.Module, + example_inputs: Tuple[torch.Tensor, ...], + output_path: Union[str, Path], + dynamic_shapes: Optional[Dict] = None, + verbose: bool = False, +) -> None: + """ + Export a PyTorch model to a .pte file using the MLX delegate. + + Args: + model: The PyTorch model to export. + example_inputs: Example inputs for tracing. + output_path: Path to save the .pte file. + dynamic_shapes: + dynamic_shapes: Optional dynamic shapes specification for torch.export. + Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. + verbose: Whether to print the exported program for debugging. + """ + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.exir.capture._config import ExecutorchBackendConfig + from torch.export import export + + model = model.eval() + + # Export with torch.export + exported_program = export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + # Print exported program if verbose + if verbose: + print("\n" + "=" * 60) + print("EXPORTED PROGRAM (torch.export)") + print("=" * 60) + print(exported_program) + + # Lower to edge and delegate to MLX + edge_program = exir.to_edge_transform_and_lower( + exported_program, + partitioner=[MLXPartitioner()], + ) + + # Print edge program if verbose + if verbose: + print("\n" + "=" * 60) + print("EDGE PROGRAM (after decomposition)") + print("=" * 60) + print(edge_program.exported_program()) + + # Export to ExecuTorch + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + + # Save to file + output_path = Path(output_path) + _atomic_write_binary(output_path, executorch_program.buffer) + + +def inspect_pte_file(pte_path: Union[str, Path]) -> Dict: + """ + Inspect a PTE file and return the MLX graph information. + + Returns: + Dictionary with MLX graph details + """ + from executorch.backends.mlx.pte_inspector import ( + extract_delegate_payload, + parse_mlx_payload, + ) + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + + # Extract MLX delegate payload + payload = extract_delegate_payload(pte_data, "MLXBackend") + if payload is None: + return {"error": "Could not extract MLX delegate payload"} + + # Parse the MLX payload + mlx_data = parse_mlx_payload(payload) + return mlx_data + + +def print_mlx_graph_summary(pte_path: Union[str, Path]) -> None: + """ + Print a human-readable summary of the MLX graph in a PTE file. + + This function uses the pte_inspector module to display the MLX graph. + """ + from executorch.backends.mlx.pte_inspector import show_mlx_instructions + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + show_mlx_instructions(pte_data) + + +def count_mlx_delegate_segments(pte_path: Union[str, Path]) -> int: + """ + Count the number of MLX delegate segments in a PTE file. + + Args: + pte_path: Path to the .pte file + + Returns: + Number of MLX delegate segments found + """ + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + + try: + program_json = _program_flatbuffer_to_json(pte_data) + program_data = json.loads(program_json) + + # Count all MLX delegates across all execution plans + count = 0 + for plan in program_data.get("execution_plan", []): + for delegate in plan.get("delegates", []): + delegate_name = delegate.get("id", "") + # Match MLXBackend (case-insensitive) + if "mlx" in delegate_name.lower(): + count += 1 + + return count + except Exception as e: + print(f"Error counting MLX segments: {e}") + return 0 + + +def get_mlx_node_counts(pte_path: Union[str, Path]) -> Dict[str, int]: + """ + Get a count of each MLX op node type in a serialized .pte file. + + Args: + pte_path: Path to the .pte file + + Returns: + Dictionary mapping op name (e.g. "SdpaNode", "SliceUpdateNode") to count. + """ + data = inspect_pte_file(pte_path) + graph = data.get("graph", {}) + counts: Dict[str, int] = {} + for chain_info in graph.get("instruction_chains", []): + for instr in chain_info.get("instructions", []): + op_name = instr.get("op_name") + if op_name: + counts[op_name] = counts.get(op_name, 0) + 1 + return counts + + +def compare_outputs( + expected: List[torch.Tensor], + actual: List[torch.Tensor], + rtol: float = 1e-5, + atol: float = 1e-5, +) -> Tuple[bool, str]: + """ + Compare expected and actual outputs using torch.allclose. + + Returns: + (passed, message) tuple + """ + if len(expected) != len(actual): + return ( + False, + f"Output count mismatch: expected {len(expected)}, got {len(actual)}", + ) + + for i, (exp, act) in enumerate(zip(expected, actual)): + if exp.shape != act.shape: + return ( + False, + f"Output {i} shape mismatch: expected {exp.shape}, got {act.shape}", + ) + + if exp.dtype != act.dtype: + # Convert both to float32 for comparison + exp = exp.float() + act = act.float() + + # For bool tensors, use exact comparison + if exp.dtype == torch.bool: + if not torch.equal(exp, act): + mismatches = (exp != act).sum().item() + total = exp.numel() + return False, ( + f"Output {i} values do not match:\n" + f" {mismatches}/{total} elements differ\n" + f" expected[:5]={exp.flatten()[:5].tolist()}\n" + f" actual[:5]={act.flatten()[:5].tolist()}" + ) + elif not torch.allclose(exp, act, rtol=rtol, atol=atol): + diff = (exp - act).abs() + max_diff = diff.max().item() + mean_diff = diff.float().mean().item() + return False, ( + f"Output {i} values do not match:\n" + f" max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}\n" + f" rtol={rtol}, atol={atol}\n" + f" expected[:5]={exp.flatten()[:5].tolist()}\n" + f" actual[:5]={act.flatten()[:5].tolist()}" + ) + + return True, "All outputs match" + + +def find_executorch_root() -> Path: # noqa: C901 + """Find the executorch root directory.""" + test_dir = Path(__file__).parent + + # Walk up to find the executorch root (has CMakeLists.txt and backends dir at root) + executorch_root = test_dir + for _ in range(10): # Max 10 levels up + if (executorch_root / "CMakeLists.txt").exists() and ( + executorch_root / "backends" + ).exists(): + # Check if we're in src/executorch (editable install) + if ( + executorch_root.name == "executorch" + and executorch_root.parent.name == "src" + ): + executorch_root = executorch_root.parent.parent + break + executorch_root = executorch_root.parent + + # If we didn't find a valid root (e.g. running from a pip-installed + # site-packages), fall back to cwd which is typically the repo root. + if not (executorch_root / "CMakeLists.txt").exists(): + cwd = Path.cwd() + if (cwd / "CMakeLists.txt").exists() and (cwd / "backends").exists(): + executorch_root = cwd + + return executorch_root + + +def find_build_dir(): + """Find the cmake build directory containing op_test_runner.""" + executorch_root = find_executorch_root() + + # Check common build locations + candidates = [ + executorch_root / "cmake-out-mlx", + executorch_root / "cmake-out", + executorch_root / "build", + ] + + for candidate in candidates: + runner_path = candidate / "backends" / "mlx" / "test" / "op_test_runner" + if runner_path.exists(): + return candidate + + # Return first candidate that exists as a directory (for rebuild) + for candidate in candidates: + if candidate.is_dir(): + return candidate + + return None + + +def find_op_test_runner() -> Path: + """Find the op_test_runner binary.""" + executorch_root = find_executorch_root() + + # Check common build locations + candidates = [ + executorch_root + / "cmake-out-mlx" + / "backends" + / "mlx" + / "test" + / "op_test_runner", + executorch_root / "cmake-out" / "backends" / "mlx" / "test" / "op_test_runner", + executorch_root / "build" / "backends" / "mlx" / "test" / "op_test_runner", + ] + + for candidate in candidates: + if candidate.exists(): + return candidate + + raise FileNotFoundError( + "Could not find op_test_runner binary. Tried:\n" + + "\n".join(f" - {c}" for c in candidates) + + "\n\nBuild with:\n" + + " cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON\n" + + " cmake --build cmake-out --target op_test_runner" + ) + + +def rebuild_op_test_runner(verbose: bool = False) -> bool: + """ + Rebuild the op_test_runner binary using cmake. + + Args: + verbose: Whether to print build output. + + Returns: + True if build succeeded, False otherwise. + """ + build_dir = find_build_dir() + if build_dir is None: + print("Error: Could not find cmake build directory.") + print("Make sure you have run cmake configuration first.") + return False + + print(f"Rebuilding op_test_runner in {build_dir}...") + + cmd = ["cmake", "--build", str(build_dir), "--target", "op_test_runner", "-j8"] + + if verbose: + print(f"Running: {' '.join(cmd)}") + + result = subprocess.run( + cmd, + capture_output=not verbose, + text=True, + ) + + if result.returncode != 0: + print(f"Build failed with exit code {result.returncode}") + if not verbose and result.stderr: + print(f"stderr: {result.stderr}") + if not verbose and result.stdout: + print(f"stdout: {result.stdout}") + return False + + print("Build succeeded.") + return True + + +def run_cpp_test_runner( + pte_path: Path, + input_path: Path, + output_path: Path, + verbose: bool = False, + timeout: Optional[int] = None, +) -> bool: + """ + Run the C++ op_test_runner binary. + + Args: + pte_path: Path to the .pte model file. + input_path: Path to input .bin file. + output_path: Path to write output .bin file. + verbose: Whether to print verbose output. + timeout: Timeout in seconds. None means use DEFAULT_TEST_TIMEOUT. + + Returns: + True if execution succeeded, False otherwise. + """ + if timeout is None: + timeout = DEFAULT_TEST_TIMEOUT + + runner = find_op_test_runner() + + cmd = [ + str(runner), + "--pte", + str(pte_path), + "--input", + str(input_path), + "--output", + str(output_path), + ] + if verbose: + cmd.append("--verbose") + + print(f"Running: {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + print(f"TIMEOUT: C++ runner exceeded {timeout}s timeout") + return False + + if result.returncode != 0: + print(f"FAILED: {result.stderr}") + print(f"stdout: {result.stdout}") + return False + + print(f"C++ binary output: {result.stdout.strip()}") + return True + + +# Files that are generated during tests and can be safely cleaned up +GENERATED_TEST_FILES = [ + "model.pte", + "input.bin", + "expected_output.bin", + "actual_output.bin", +] + + +def clean_test_outputs( + test_names: Optional[List[str]] = None, verbose: bool = False +) -> int: + """ + Clean up generated test output files. + + Args: + test_names: Optional list of test names to clean. If None, cleans all tests. + verbose: Whether to print verbose output. + + Returns: + Number of files removed. + """ + test_dir = Path(__file__).parent / "op_tests" + if not test_dir.exists(): + if verbose: + print(f"Test directory does not exist: {test_dir}") + return 0 + + files_removed = 0 + + # Get directories to clean + if test_names: + dirs_to_clean = [ + test_dir / name for name in test_names if (test_dir / name).exists() + ] + else: + dirs_to_clean = [d for d in test_dir.iterdir() if d.is_dir()] + + for subdir in dirs_to_clean: + for filename in GENERATED_TEST_FILES: + filepath = subdir / filename + if filepath.exists(): + if verbose: + print(f"Removing: {filepath}") + filepath.unlink() + files_removed += 1 + + # Remove empty directories + if subdir.exists() and not any(subdir.iterdir()): + if verbose: + print(f"Removing empty directory: {subdir}") + subdir.rmdir() + + return files_removed + + +def get_test_output_size(test_names: Optional[List[str]] = None) -> int: + """ + Get total size of generated test output files in bytes. + + Args: + test_names: Optional list of test names to check. If None, checks all tests. + + Returns: + Total size in bytes. + """ + test_dir = Path(__file__).parent / "op_tests" + if not test_dir.exists(): + return 0 + + total_size = 0 + + # Get directories to check + if test_names: + dirs_to_check = [ + test_dir / name for name in test_names if (test_dir / name).exists() + ] + else: + dirs_to_check = [d for d in test_dir.iterdir() if d.is_dir()] + + for subdir in dirs_to_check: + for filename in GENERATED_TEST_FILES: + filepath = subdir / filename + if filepath.exists(): + total_size += filepath.stat().st_size + + return total_size + + +# Global registry: maps base_name -> (test_class, get_test_configs method) +# Tests are instantiated lazily when actually run, not at import time +_TEST_REGISTRY: Dict[str, type] = {} + + +def register_test(test_class: type) -> type: + """ + Class decorator to register a test class. + + The test class must have: + - A class attribute `name` (str) - the base test name + - A class method `get_test_configs()` that returns a list of OpTestCase instances + + Test instances are created LAZILY when tests are actually run, not at import time. + This avoids creating random tensors at import time and keeps memory usage low. + + Example: + @register_test + class AddTest(OpTestCase): + name = "add" + + @classmethod + def get_test_configs(cls) -> List["OpTestCase"]: + return [ + cls(), # default config + cls(scalar=2.5), # scalar variant + ] + """ + if not hasattr(test_class, "name"): + raise ValueError( + f"Test class {test_class.__name__} must have a 'name' attribute" + ) + + base_name = test_class.name + _TEST_REGISTRY[base_name] = test_class + + return test_class + + +def get_registered_tests() -> Dict[str, List[Tuple[str, "OpTestCase"]]]: + """ + Get all registered tests with their configurations. + + Returns dict mapping base_name -> list of (config_name, test_instance). + Test instances are created fresh each time this is called. + """ + result = {} + for base_name, test_class in _TEST_REGISTRY.items(): + if hasattr(test_class, "get_test_configs"): + configs = test_class.get_test_configs() + else: + configs = [test_class()] + result[base_name] = [(cfg.name, cfg) for cfg in configs] + return result + + +def get_test_names() -> List[str]: + """Get list of registered base test names.""" + return list(_TEST_REGISTRY.keys()) + + +def get_all_test_configs() -> List[Tuple[str, "OpTestCase"]]: + """ + Get flat list of all (config_name, test_instance) tuples. + + Test instances are created fresh each time this is called. + """ + result = [] + for _base_name, test_class in _TEST_REGISTRY.items(): + if hasattr(test_class, "get_test_configs"): + configs = test_class.get_test_configs() + else: + configs = [test_class()] + result.extend((cfg.name, cfg) for cfg in configs) + return result + + +class OpTestCase: + """ + Base class for op test cases. + + Subclasses should implement: + - name: str - test name + - create_model() -> nn.Module + - create_inputs() -> Tuple[torch.Tensor, ...] + + Optionally override: + - get_dynamic_shapes() -> Optional[Dict] - for dynamic shape testing + - create_test_inputs() -> Tuple[torch.Tensor, ...] - inputs for testing (may differ from export inputs) + - expected_mlx_segments: int - expected number of MLX delegate segments (default: 1) + """ + + name: str = "base_test" + rtol: float = 1e-5 + atol: float = 1e-5 + seed: int = 42 # Default seed for reproducibility + timeout: int = DEFAULT_TEST_TIMEOUT # Timeout in seconds + skip_comparison: bool = False # Skip output comparison (for pattern-only tests) + skip_comparison_reason: str = "" # Reason for skipping comparison + expected_mlx_segments: int = 1 # Expected number of MLX delegate segments + expected_node_counts: Optional[Dict[str, int]] = ( + None # Expected serialized node counts + ) + + def _set_seed(self) -> None: + """Set random seed for reproducibility.""" + torch.manual_seed(self.seed) + + def create_model(self) -> torch.nn.Module: + raise NotImplementedError + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Create inputs for export (tracing).""" + raise NotImplementedError + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + """Create inputs for testing. Override for dynamic shape tests.""" + return self.create_inputs() + + def get_dynamic_shapes(self) -> Optional[Dict]: + """Return dynamic shapes specification for torch.export, or None for static shapes.""" + return None + + def get_test_dir(self) -> Path: + """Get the directory for this test's files.""" + test_dir = Path(__file__).parent / "op_tests" / self.name + test_dir.mkdir(parents=True, exist_ok=True) + return test_dir + + def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: + """ + Generate .pte, input.bin, and expected_output.bin files. + + Args: + verbose: Whether to print the exported program for debugging. + + Returns: + (pte_path, input_path, expected_output_path) + """ + test_dir = self.get_test_dir() + + pte_path = test_dir / "model.pte" + input_path = test_dir / "input.bin" + expected_path = test_dir / "expected_output.bin" + + # Set seed for reproducibility + self._set_seed() + + # Create model and inputs + model = self.create_model() + export_inputs = self.create_inputs() + + # Set seed again before creating test inputs (in case they differ) + self._set_seed() + test_inputs = self.create_test_inputs() + + # Get expected outputs using test inputs + model.eval() + with torch.no_grad(): + if isinstance(test_inputs, torch.Tensor): + test_inputs = (test_inputs,) + expected_outputs = model(*test_inputs) + if isinstance(expected_outputs, torch.Tensor): + expected_outputs = [expected_outputs] + else: + expected_outputs = list(expected_outputs) + + # Export model with export inputs (and potentially dynamic shapes) + print(f"Exporting model to {pte_path}") + if isinstance(export_inputs, torch.Tensor): + export_inputs = (export_inputs,) + + dynamic_shapes = self.get_dynamic_shapes() + if dynamic_shapes: + print(f" Using dynamic shapes: {dynamic_shapes}") + + export_model_to_pte( + model, + export_inputs, + pte_path, + dynamic_shapes=dynamic_shapes, + verbose=verbose, + ) + + # Save test inputs + print(f"Saving inputs to {input_path}") + if isinstance(test_inputs, torch.Tensor): + test_inputs = [test_inputs] + else: + test_inputs = list(test_inputs) + save_tensors_to_bin(test_inputs, input_path) + + # Save expected outputs + print(f"Saving expected outputs to {expected_path}") + save_tensors_to_bin(expected_outputs, expected_path) + + return pte_path, input_path, expected_path + + def compare_with_actual( + self, actual_output_path: Union[str, Path], use_dtype_tolerances: bool = False + ) -> Tuple[bool, str]: + """ + Compare actual outputs with expected outputs. + + Args: + actual_output_path: Path to the actual output file. + use_dtype_tolerances: If True, uses tolerance presets based on output dtypes + instead of the test's rtol/atol values. + """ + test_dir = self.get_test_dir() + expected_path = test_dir / "expected_output.bin" + + expected = load_tensors_from_bin(expected_path) + actual = load_tensors_from_bin(actual_output_path) + + # Determine tolerances + if use_dtype_tolerances: + # Use dtype-based tolerances (loosest tolerance across all output dtypes) + output_dtypes = [t.dtype for t in expected] + rtol, atol = get_tolerance_for_dtypes(output_dtypes) + else: + rtol, atol = self.rtol, self.atol + + return compare_outputs(expected, actual, rtol=rtol, atol=atol) + + def run_test(self, verbose: bool = False, timeout: Optional[int] = None) -> bool: + """ + Run the full test: generate files, run C++, compare outputs. + + Args: + verbose: Whether to print verbose output. + timeout: Timeout in seconds. None means use self.timeout. + + Returns: + True if test passed, False otherwise. + """ + if timeout is None: + timeout = self.timeout + + print(f"\n{'='*60}") + print(f"Running test: {self.name}") + print(f"{'='*60}\n") + + # Generate test files + print("Step 1: Generating test files...") + pte_path, input_path, expected_path = self.generate_test_files(verbose=verbose) + + # Print MLX graph summary + print_mlx_graph_summary(pte_path) + + # Verify expected number of MLX delegate segments + print("\nStep 2: Verifying MLX delegation...") + actual_segments = count_mlx_delegate_segments(pte_path) + print(f" Expected MLX segments: {self.expected_mlx_segments}") + print(f" Actual MLX segments: {actual_segments}") + + if actual_segments != self.expected_mlx_segments: + print("✗ FAILED: MLX delegation mismatch!") + print( + f" Expected {self.expected_mlx_segments} segment(s), but found {actual_segments}" + ) + return False + print("✓ MLX delegation verified") + + # Verify expected node counts if specified + if self.expected_node_counts is not None: + print("\n Verifying serialized node counts...") + actual_counts = get_mlx_node_counts(pte_path) + for node_name, expected_count in self.expected_node_counts.items(): + actual_count = actual_counts.get(node_name, 0) + if actual_count != expected_count: + print(f"✗ FAILED: Node count mismatch for {node_name}!") + print(f" Expected {expected_count}, got {actual_count}") + print(f" All node counts: {actual_counts}") + return False + print(f" ✓ {node_name}: {actual_count}") + print(" ✓ All node counts verified") + + # Run C++ binary + print("\nStep 3: Running C++ binary...") + actual_path = self.get_test_dir() / "actual_output.bin" + if not run_cpp_test_runner( + pte_path, input_path, actual_path, verbose=verbose, timeout=timeout + ): + return False + + # Compare outputs (or skip if configured) + print("\nStep 4: Comparing outputs...") + if self.skip_comparison: + reason = self.skip_comparison_reason or "skip_comparison=True" + print(f"NOTE: Output comparison skipped ({reason})") + print("✓ PASSED (runtime execution succeeded)") + return True + + passed, message = self.compare_with_actual(actual_path) + + if passed: + print(f"✓ PASSED: {message}") + else: + print(f"✗ FAILED: {message}") + + return passed + + +def run_op_test_main( + test_factory, + description: str, + add_args_fn=None, +): + """ + Common main() function for op tests. + + This handles the common argparse setup, rebuild logic, and generate/compare/run + action handling that is shared across all op tests. + + Args: + test_factory: A callable that takes parsed args (argparse.Namespace) and + returns an OpTestCase instance. + description: Description for the argparse help message. + add_args_fn: Optional callable that takes a parser and adds test-specific + arguments. Signature: add_args_fn(parser) -> None + """ + import argparse + import sys + + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "action", + choices=["generate", "compare", "run"], + help="Action to perform: generate (create test files), compare (compare outputs), run (full test)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + parser.add_argument( + "--rebuild", + action="store_true", + help="Rebuild the C++ test runner before running", + ) + + # Add test-specific arguments + if add_args_fn is not None: + add_args_fn(parser) + + args = parser.parse_args() + + # Rebuild if requested + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + # Create test case from factory + test = test_factory(args) + + if args.action == "generate": + pte_path, input_path, expected_path = test.generate_test_files( + verbose=args.verbose + ) + print("\nGenerated files:") + print(f" PTE: {pte_path}") + print(f" Input: {input_path}") + print(f" Expected: {expected_path}") + print_mlx_graph_summary(pte_path) + + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + if not actual_path.exists(): + print(f"Error: {actual_path} not found. Run the C++ binary first.") + sys.exit(1) + + passed, message = test.compare_with_actual(actual_path) + if passed: + print(f"✓ PASSED: {message}") + else: + print(f"✗ FAILED: {message}") + sys.exit(0 if passed else 1) + + elif args.action == "run": + passed = test.run_test(verbose=args.verbose) + sys.exit(0 if passed else 1) diff --git a/backends/mlx/test/tester.py b/backends/mlx/test/tester.py new file mode 100644 index 00000000000..7a929ea7c3b --- /dev/null +++ b/backends/mlx/test/tester.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, List, Optional, Tuple + +import executorch +import executorch.backends.test.harness.stages as BaseStages +import torch + +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import StageType +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.partitioner import Partitioner + + +def _create_default_partitioner( + compile_specs: List[CompileSpec] | None = None, +) -> MLXPartitioner: + return MLXPartitioner(compile_specs=compile_specs) + + +class Partition(BaseStages.Partition): + def __init__( + self, + partitioner: Optional[Partitioner] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + super().__init__( + partitioner=partitioner or _create_default_partitioner(compile_specs), + ) + + +class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): + def __init__( + self, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + super().__init__( + default_partitioner_cls=lambda: _create_default_partitioner(compile_specs), + partitioners=partitioners, + edge_compile_config=edge_compile_config, + ) + + +class MLXTester(TesterBase): + def __init__( + self, + module: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], + dynamic_shapes: Optional[Tuple[Any]] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + stage_classes = ( + executorch.backends.test.harness.Tester.default_stage_classes() + | { + StageType.PARTITION: functools.partial( + Partition, compile_specs=compile_specs + ), + StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial( + ToEdgeTransformAndLower, compile_specs=compile_specs + ), + } + ) + + super().__init__( + module=module, + stage_classes=stage_classes, + example_inputs=example_inputs, + dynamic_shapes=dynamic_shapes, + ) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx new file mode 160000 index 00000000000..72e94c81e16 --- /dev/null +++ b/backends/mlx/third-party/mlx @@ -0,0 +1 @@ +Subproject commit 72e94c81e1685c90679ef03532c4b8897010abf9 diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index ec6eca0490c..676903ec780 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -54,6 +54,7 @@ def __str__(self): return self.name +<<<<<<< HEAD def _register_flow( import_fn: Callable[[], list[TestFlow]], backend_name: str ) -> list[TestFlow]: @@ -62,6 +63,10 @@ def _register_flow( except Exception as e: logger.info(f"Skipping {backend_name} flow registration: {e}") return [] +======= +def all_flows() -> dict[str, TestFlow]: # noqa: C901 + flows = [] +>>>>>>> af6810bfa4 (up) def _load_xnnpack() -> list[TestFlow]: @@ -164,4 +169,13 @@ def all_flows() -> dict[str, TestFlow]: + _register_flow(_load_arm, "ARM") ) + try: + from executorch.backends.test.suite.flows.mlx import MLX_TEST_FLOW + + flows += [ + MLX_TEST_FLOW, + ] + except Exception as e: + logger.info(f"Skipping MLX flow registration: {e}") + return {f.name: f for f in flows if f is not None} diff --git a/backends/test/suite/flows/mlx.py b/backends/test/suite/flows/mlx.py new file mode 100644 index 00000000000..d70db46b73c --- /dev/null +++ b/backends/test/suite/flows/mlx.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.mlx.test.tester import MLXTester +from executorch.backends.test.suite.flow import TestFlow + +MLX_TEST_FLOW = TestFlow( + name="mlx", + backend="mlx", + tester_factory=MLXTester, +) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index c0a4f3b795a..4ab2a3572b4 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -768,3 +768,70 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile: ) return PTEFile(program=program, mutable_data=None, named_data=None) + + +def _extract_delegate_payload( + pte_data: bytes, backend_id: str, delegate_index: int = 0 +) -> Optional[bytes]: + """Extract a delegate payload from a serialized PTE file. + + Parses the PTE file structure, finds the delegate matching the given + backend ID, and returns its raw payload bytes. Handles both inline + delegate data and segment-based storage. + + Args: + pte_data: Raw bytes of the PTE file. + backend_id: ID substring to match (case-insensitive). + For example, 'mlx' matches 'MLXBackend'. + delegate_index: Which matching delegate to extract (0-based). + Defaults to 0 (first match). + + Returns: + Delegate payload bytes, or None if not found. + """ + # Parse the extended header + extended_header = _get_extended_header(pte_data) + + # Determine program size from header or use full data + if extended_header is not None: + program_size = extended_header.program_size + else: + program_size = len(pte_data) + + # Parse the program flatbuffer + program: Program = _json_to_program( + _program_flatbuffer_to_json(pte_data[:program_size]) + ) + + # Search for the matching delegate + match_count = 0 + for plan in program.execution_plan: + for delegate in plan.delegates: + if backend_id.lower() not in delegate.id.lower(): + continue + if match_count != delegate_index: + match_count += 1 + continue + + processed = delegate.processed + + # Inline data + if processed.location == DataLocation.INLINE: + inline_data = program.backend_delegate_data[processed.index] + if inline_data.data: + return bytes(inline_data.data) + return None + + # Segment data + if processed.location == DataLocation.SEGMENT: + if extended_header is None: + return None + + segment = program.segments[processed.index] + offset = extended_header.segment_base_offset + segment.offset + size = segment.size + return pte_data[offset : offset + size] + + return None + + return None diff --git a/setup.py b/setup.py index f05951012e3..d07736128c8 100644 --- a/setup.py +++ b/setup.py @@ -624,6 +624,26 @@ def run(self): # the input file is read-only. self.copy_file(src, dst, preserve_mode=False) + # Copy CMake-generated Python directories that setuptools missed. + # Setuptools discovers packages at configuration time, before CMake + # runs. Directories created by CMake during the build (e.g. by + # generate.py) are not in the package list and must be copied manually. + generated_dirs = [ + "backends/mlx/serialization/_generated", + ] + for rel_dir in generated_dirs: + src_dir = os.path.join("src/executorch", rel_dir) + if not os.path.isdir(src_dir): + continue + dst_dir = os.path.join(dst_root, rel_dir) + for dirpath, _dirnames, filenames in os.walk(src_dir): + for filename in filenames: + src_file = os.path.join(dirpath, filename) + rel_path = os.path.relpath(src_file, src_dir) + dst_file = os.path.join(dst_dir, rel_path) + self.mkpath(os.path.dirname(dst_file)) + self.copy_file(src_file, dst_file, preserve_mode=False) + class Buck2EnvironmentFixer(contextlib.AbstractContextManager): """Removes HOME from the environment when running as root. @@ -786,6 +806,9 @@ def run(self): # noqa C901 if cmake_cache.is_enabled("EXECUTORCH_BUILD_COREML"): cmake_build_args += ["--target", "executorchcoreml"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_MLX"): + cmake_build_args += ["--target", "mlxdelegate"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_KERNELS_LLM_AOT"): cmake_build_args += ["--target", "custom_ops_aot_lib"] cmake_build_args += ["--target", "quantized_ops_aot_lib"] @@ -846,6 +869,16 @@ def run(self): # noqa C901 modpath="executorch.extension.pybindings.data_loader", dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"], ), + # MLX metallib (Metal GPU kernels) must be colocated with _portable_lib.so + # because MLX uses dladdr() to find the directory containing the library, + # then looks for mlx.metallib in that directory at runtime. + # After submodule migration, the path is backends/mlx/mlx/... + BuiltFile( + src_dir="%CMAKE_CACHE_DIR%/backends/mlx/mlx/mlx/backend/metal/kernels/", + src_name="mlx.metallib", + dst="executorch/extension/pybindings/", + dependent_cmake_flags=["EXECUTORCH_BUILD_MLX"], + ), BuiltExtension( src="extension/training/_training_lib.*", # @lint-ignore https://github.com/pytorch/executorch/blob/cb3eba0d7f630bc8cec0a9cc1df8ae2f17af3f7a/scripts/lint_xrefs.sh modpath="executorch.extension.training.pybindings._training_lib", diff --git a/tools/cmake/Utils.cmake b/tools/cmake/Utils.cmake index 74f2be78804..3295036663c 100644 --- a/tools/cmake/Utils.cmake +++ b/tools/cmake/Utils.cmake @@ -178,3 +178,36 @@ function(executorch_add_prefix_to_public_headers targetName prefix) TARGET "${targetName}" PROPERTY PUBLIC_HEADER ${FIXED_PUBLIC_HEADERS} ) endfunction() + +# ----------------------------------------------------------------------------- +# MLX metallib distribution helper +# ----------------------------------------------------------------------------- +# Copies mlx.metallib next to the target executable so MLX can find it at +# runtime. +# +# MLX uses dladdr() to find the directory containing the binary with MLX code, +# then looks for mlx.metallib in that directory. When MLX is statically linked +# into an executable or shared library, this function ensures the metallib is +# colocated with that binary. +# +# Usage: executorch_target_copy_mlx_metallib(my_executable) +# +function(executorch_target_copy_mlx_metallib target) + if(EXECUTORCH_BUILD_MLX) + if(DEFINED MLX_METALLIB_PATH AND EXISTS "${MLX_METALLIB_PATH}") + add_custom_command( + TARGET ${target} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${MLX_METALLIB_PATH}" + "$/mlx.metallib" + COMMENT "Copying mlx.metallib for ${target}" + ) + elseif(DEFINED MLX_METALLIB_PATH) + message( + WARNING + "MLX_METALLIB_PATH is set to ${MLX_METALLIB_PATH} but file does not exist. " + "metallib will not be copied for ${target}." + ) + endif() + endif() +endfunction() diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index dc4d34d8701..524e1be36ec 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -63,6 +63,8 @@ set(optional_lib_list coreml_inmemoryfs coremldelegate mpsdelegate + mlxdelegate + mlx metal_backend neuron_backend qnn_executorch_backend @@ -118,3 +120,46 @@ set_property( TARGET executorch_core PROPERTY INTERFACE_LINK_LIBRARIES ${FIXED_EXECUTORCH_CORE_LINK_LIBRARIES} ) + +# Expose MLX library and metallib path for downstream consumers +if(TARGET mlxdelegate) + # Create imported target for mlx library if not already defined (mlx is built + # by MLX's CMake but we need to expose it for linking) + if(NOT TARGET mlx) + find_library( + _mlx_library mlx + HINTS "${_root}/lib" + CMAKE_FIND_ROOT_PATH_BOTH + ) + if(_mlx_library) + add_library(mlx STATIC IMPORTED) + set_target_properties(mlx PROPERTIES IMPORTED_LOCATION "${_mlx_library}") + # MLX requires Metal and Foundation frameworks on Apple platforms + if(APPLE) + find_library(METAL_FRAMEWORK Metal) + find_library(FOUNDATION_FRAMEWORK Foundation) + if(METAL_FRAMEWORK AND FOUNDATION_FRAMEWORK) + set_target_properties( + mlx PROPERTIES INTERFACE_LINK_LIBRARIES + "${METAL_FRAMEWORK};${FOUNDATION_FRAMEWORK}" + ) + endif() + endif() + message(STATUS "Found mlx library at: ${_mlx_library}") + endif() + endif() + + # Find metallib for runtime distribution + find_file( + _mlx_metallib mlx.metallib + HINTS "${_root}/lib" + CMAKE_FIND_ROOT_PATH_BOTH + ) + if(_mlx_metallib) + set(MLX_METALLIB_PATH + "${_mlx_metallib}" + CACHE FILEPATH "Path to mlx.metallib for runtime distribution" + ) + message(STATUS "Found mlx.metallib at: ${MLX_METALLIB_PATH}") + endif() +endif() diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index 37443d44c2b..423194776bc 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -121,6 +121,7 @@ define_overridable_option( EXECUTORCH_BUILD_EXTENSION_APPLE "Build the Apple extension" BOOL OFF ) define_overridable_option(EXECUTORCH_BUILD_MPS "Build the MPS backend" BOOL OFF) +define_overridable_option(EXECUTORCH_BUILD_MLX "Build the MLX backend" BOOL OFF) define_overridable_option( EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" BOOL OFF ) diff --git a/tools/cmake/preset/pybind.cmake b/tools/cmake/preset/pybind.cmake index 1a7e08a9d60..2c5c2edc506 100644 --- a/tools/cmake/preset/pybind.cmake +++ b/tools/cmake/preset/pybind.cmake @@ -31,6 +31,24 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) + # MLX requires Apple Silicon (ARM64) and the Metal compiler (xcrun -sdk macosx + # metal) which is only available with Xcode, not Command Line Tools + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + execute_process( + COMMAND xcrun -sdk macosx --find metal + RESULT_VARIABLE _metal_compiler_result + OUTPUT_QUIET ERROR_QUIET + ) + if(_metal_compiler_result EQUAL 0) + set_overridable_option(EXECUTORCH_BUILD_MLX ON) + set_overridable_option(ET_MLX_ENABLE_OP_LOGGING ON) + else() + message( + STATUS + "Metal compiler not found, disabling MLX backend. Install Xcode to enable MLX." + ) + endif() + endif() elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") set_overridable_option(EXECUTORCH_BUILD_COREML ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) From cbb43c731761e62006f8620a5daa095addc40188 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:06:21 -0800 Subject: [PATCH 23/34] up --- tools/cmake/preset/mlx.cmake | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tools/cmake/preset/mlx.cmake diff --git a/tools/cmake/preset/mlx.cmake b/tools/cmake/preset/mlx.cmake new file mode 100644 index 00000000000..d8ea7fe237f --- /dev/null +++ b/tools/cmake/preset/mlx.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# MLX delegate preset - builds ExecuTorch with MLX backend for Apple Silicon + +# Core ExecuTorch options +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_MODULE ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TENSOR ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER ON) +set_overridable_option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED ON) + +# Build the MLX delegate +set_overridable_option(EXECUTORCH_BUILD_MLX ON) From 7e5abd4e4285f72787b8c06fb1f9ccaa80c770ae Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:21:54 -0800 Subject: [PATCH 24/34] up --- backends/mlx/runtime/MLXInterpreter.h | 8 ++++++++ backends/mlx/serialization/schema.fbs | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index f3b6e9b720f..bfd593c162b 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -98,6 +98,11 @@ inline std::vector infer_shape_with_minus_one( inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} +inline void +exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { + st.set_tensor(n.out, st.const_tensor_ref(n.x)); +} + inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { const auto& mat1 = st.const_tensor_ref(n.mat1); @@ -154,6 +159,9 @@ class Interpreter { case OpCode::NOOP: ops::exec_noop(std::get(instr.node), st, s); break; + case OpCode::ID_COPY: + ops::exec_id_copy(std::get(instr.node), st, s); + break; case OpCode::ADDMM: ops::exec_addmm(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 945186ebef8..8b159314760 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -72,6 +72,11 @@ table IntOrVidOrTid { table NoopNode {} +table IdCopyNode { + x: Tid (required); + out: Tid (required); +} + table AddmmNode { mat1: Tid (required); // First matrix mat2: Tid (required); // Second matrix @@ -89,6 +94,7 @@ table AddmmNode { // Reordering or removing members changes numeric type IDs and breaks existing .pte files. union OpNode { NoopNode, + IdCopyNode, AddmmNode // BC: Add new op nodes here (append only) } From c534395d44ba1e7ecabd407d2398cda038df5cf8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:50:55 -0800 Subject: [PATCH 25/34] up --- backends/mlx/ops.py | 6 +++ backends/mlx/test/test_ops.py | 91 ----------------------------------- 2 files changed, 6 insertions(+), 91 deletions(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 6e8516e86b1..4c9e0d6f796 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -219,6 +219,12 @@ def normalize_reduction_dim( return dim, keepdim +@REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) +def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: + """No-op handler for nodes that don't emit any MLX instructions.""" + return None + + @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 01286f75f16..0ba98b532ad 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -83,94 +83,3 @@ def create_model(self) -> nn.Module: def create_inputs(self) -> Tuple[torch.Tensor, ...]: x = torch.randn(self.batch_size, self.n, self.m) return (x,) - - -class AddmmModel(nn.Module): - """Model that performs addmm: bias + (mat1 @ mat2).""" - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - alpha: float = 1.0, - beta: float = 1.0, - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(out_features, in_features)) - if bias: - self.bias = nn.Parameter(torch.randn(out_features)) - else: - self.bias = None - self.alpha = alpha - self.beta = beta - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm( - self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha - ) - else: - return torch.mm(x, self.weight.t()) - - -@register_test -class AddmmTest(OpTestCase): - """Test case for addmm.""" - - name = "addmm" - rtol = 1e-4 - atol = 1e-4 - - def __init__( - self, - batch_size: int = 2, - in_features: int = 64, - out_features: int = 32, - bias: bool = True, - alpha: float = 1.0, - beta: float = 1.0, - ): - self.batch_size = batch_size - self.in_features = in_features - self.out_features = out_features - self.bias = bias - self.alpha = alpha - self.beta = beta - - # Build unique test name - if not bias: - name = f"addmm_{in_features}x{out_features}_no_bias" - elif alpha != 1.0 or beta != 1.0: - name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" - else: - name = f"addmm_{in_features}x{out_features}" - self.name = name - - @classmethod - def get_test_configs(cls) -> List["AddmmTest"]: - return [ - cls( - batch_size=2, in_features=64, out_features=32 - ), # with bias, default alpha/beta - cls( - batch_size=2, in_features=64, out_features=32, bias=False - ), # without bias - cls(batch_size=4, in_features=128, out_features=64), # larger size - cls( - batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 - ), # custom alpha/beta - ] - - def create_model(self) -> nn.Module: - return AddmmModel( - self.in_features, - self.out_features, - bias=self.bias, - alpha=self.alpha, - beta=self.beta, - ) - - def create_inputs(self) -> Tuple[torch.Tensor, ...]: - x = torch.randn(self.batch_size, self.in_features) - return (x,) From da81e8664569d9d6a9ab80447abe13413c1a01a8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:32:00 -0800 Subject: [PATCH 26/34] up --- backends/mlx/builder/program_builder.py | 19 +++++++++++++----- backends/mlx/runtime/MLXBackend.cpp | 26 ++++++++++++++++++++++--- backends/mlx/runtime/MLXExecutor.h | 20 ++++++++++++++++++- backends/mlx/test/test_utils.py | 10 +++++++++- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 60d5ebbdbfe..2add4f1b7a3 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -27,7 +27,6 @@ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union import torch - from executorch.backends.mlx._logging import logger from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type from executorch.backends.mlx.builder.op_registry import ( @@ -132,7 +131,9 @@ class MLXProgramBuilder: def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): self.ep: ExportedProgram = ep - self._instrs: List[Instruction] = [] + self._chains: List[List[Instruction]] = [[]] # chain 0 = main + self._current_chain: int = 0 + self.init_chain_idx: int = -1 self.extra_constants: Dict[str, torch.Tensor] = {} self.slot_manager = SlotManager() self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) @@ -163,7 +164,13 @@ def _prefix_key(self, name: str) -> str: return name def emit(self, op: OpNodeUnion) -> None: - self._instrs.append(Instruction(op=op)) + self._chains[self._current_chain].append(Instruction(op=op)) + + def emit_init(self, op: OpNodeUnion) -> None: + if self.init_chain_idx == -1: + self.init_chain_idx = len(self._chains) + self._chains.append([]) + self._chains[self.init_chain_idx].append(Instruction(op=op)) def args(self, node: Node) -> Tuple[Any, ...]: return self.slot_map(node.args) @@ -934,9 +941,11 @@ def _build_mlx_graph(self) -> MLXGraph: num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], num_temp_tensors=num_temp_tensors, num_values=num_values_count, - instruction_chains=[InstructionChain(instructions=self._instrs)], + instruction_chains=[ + InstructionChain(instructions=chain) for chain in self._chains + ], main_chain_idx=0, - init_chain_idx=-1, + init_chain_idx=self.init_chain_idx, input_map=input_map, output_map=output_map, mutable_buffer_map=mutable_buffer_map, diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 38dff189935..99e20114ea7 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -219,10 +219,24 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { static_cast(processed->data()), processed->size()); // Validate schema version - if (handle->program.version != "1") { + int schema_version = 1; + if (!handle->program.version.empty()) { + try { + schema_version = std::stoi(handle->program.version); + } catch (...) { + throw std::runtime_error( + "Invalid MLX schema version '" + handle->program.version + + "' (expected integer)"); + } + } + constexpr int kMaxSupportedVersion = 1; + if (schema_version > kMaxSupportedVersion) { throw std::runtime_error( - "Unsupported MLX schema version '" + handle->program.version + - "' (expected '1'). Rebuild the .pte with a matching SDK version."); + "This .pte requires ExecuTorch MLX runtime version " + + std::to_string(schema_version) + + " but this runtime only supports up to version " + + std::to_string(kMaxSupportedVersion) + + ". Upgrade ExecuTorch to a newer version."); } // Load constants from named_data_map @@ -251,11 +265,17 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the // static_cast cannot produce UINT32_MAX from a -1 sentinel. if (handle->program.init_chain_idx >= 0) { + handle->state.is_init_chain = true; handle->interpreter.run_chain( handle->program, static_cast(handle->program.init_chain_idx), handle->state, handle->stream); + handle->state.is_init_chain = false; + + // Evaluate any constants written by the init chain so the first + // execute() doesn't pay the cost of materializing them. + eval(handle->constants.tensors); } } catch (const std::exception& e) { diff --git a/backends/mlx/runtime/MLXExecutor.h b/backends/mlx/runtime/MLXExecutor.h index 32d623790ab..978eaadabba 100644 --- a/backends/mlx/runtime/MLXExecutor.h +++ b/backends/mlx/runtime/MLXExecutor.h @@ -97,6 +97,13 @@ struct ConstantData { return tensors[id.idx]; } + inline void set(Tid id, Tensor t) { + if (id.idx >= tensors.size()) { + throw std::out_of_range("ConstantData::set: id out of range"); + } + tensors[id.idx] = std::move(t); + } + inline void add(Tensor t) { tensors.push_back(std::move(t)); } @@ -153,6 +160,9 @@ struct ExecutionState { // Non-constant values (SymInt, etc.) std::vector> values; + // Init chain flag: when true, set_tensor allows writing to constants + bool is_init_chain{false}; + // Logging context size_t current_op_idx{0}; const char* current_op_name{nullptr}; @@ -478,7 +488,15 @@ struct ExecutionState { throw std::runtime_error("set_tensor: Program not bound"); } if (id.idx < program->num_constant_tensors) { - throw std::runtime_error("set_tensor: cannot write to constant tensor"); + if (!is_init_chain) { + throw std::runtime_error("set_tensor: cannot write to constant tensor"); + } + // Init chain can write over constants + if (!constants) { + throw std::runtime_error("set_tensor: constants not bound"); + } + const_cast(constants)->set(id, std::move(arr)); + return; } // Route to mutable buffers or per-execution tensors if (is_mutable_buffer(id)) { diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 090bceabf08..660968195b7 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import executorch.exir as exir import numpy as np import torch @@ -268,6 +269,7 @@ def export_model_to_pte( output_path: Union[str, Path], dynamic_shapes: Optional[Dict] = None, verbose: bool = False, + edge_compile_config: Optional[exir.EdgeCompileConfig] = None, ) -> None: """ Export a PyTorch model to a .pte file using the MLX delegate. @@ -281,7 +283,6 @@ def export_model_to_pte( Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. verbose: Whether to print the exported program for debugging. """ - import executorch.exir as exir from executorch.backends.mlx import MLXPartitioner from executorch.exir.capture._config import ExecutorchBackendConfig from torch.export import export @@ -301,9 +302,11 @@ def export_model_to_pte( print(exported_program) # Lower to edge and delegate to MLX + compile_config = edge_compile_config or exir.EdgeCompileConfig() edge_program = exir.to_edge_transform_and_lower( exported_program, partitioner=[MLXPartitioner()], + compile_config=compile_config, ) # Print edge program if verbose @@ -865,6 +868,10 @@ def get_dynamic_shapes(self) -> Optional[Dict]: """Return dynamic shapes specification for torch.export, or None for static shapes.""" return None + def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]: + """Return EdgeCompileConfig for export, or None for default.""" + return None + def get_test_dir(self) -> Path: """Get the directory for this test's files.""" test_dir = Path(__file__).parent / "op_tests" / self.name @@ -924,6 +931,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: pte_path, dynamic_shapes=dynamic_shapes, verbose=verbose, + edge_compile_config=self.get_edge_compile_config(), ) # Save test inputs From 74a3f1bef5241782ff4d3f1dd0130f442d5ef376 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:33:28 -0800 Subject: [PATCH 27/34] up --- backends/mlx/third-party/mlx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx index 72e94c81e16..365d6f29b47 160000 --- a/backends/mlx/third-party/mlx +++ b/backends/mlx/third-party/mlx @@ -1 +1 @@ -Subproject commit 72e94c81e1685c90679ef03532c4b8897010abf9 +Subproject commit 365d6f29b47686a9f5401f6a9ec5825fee162d69 From 0b8b0af96ed2398e6188b4d9d5576995440ab227 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:44:36 -0800 Subject: [PATCH 28/34] up --- .github/workflows/mlx.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 2e8ca7aa3b7..ea0bce96e1a 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,11 +9,13 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** + - extension/llm/export/** + - extension/audio/** + - examples/models/parakeet/** + - examples/models/voxtral_realtime/** workflow_dispatch: -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true +permissions: {} jobs: test-mlx: From afa912e9deb6aec43a027a5244c5a02d66848022 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:50:26 -0800 Subject: [PATCH 29/34] up --- .github/workflows/mlx.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index ea0bce96e1a..cc83c90e23e 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,10 +9,6 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** - - extension/llm/export/** - - extension/audio/** - - examples/models/parakeet/** - - examples/models/voxtral_realtime/** workflow_dispatch: permissions: {} From 6e924fedeef3b29c9b7bb75dabd7f4d02b444dfc Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 11:23:22 -0800 Subject: [PATCH 30/34] up --- backends/mlx/README.md | 123 ++++++++++++++++--------- backends/mlx/serialization/generate.py | 2 +- 2 files changed, 82 insertions(+), 43 deletions(-) diff --git a/backends/mlx/README.md b/backends/mlx/README.md index ebab893385a..eea60fe2d00 100644 --- a/backends/mlx/README.md +++ b/backends/mlx/README.md @@ -193,7 +193,7 @@ ExportedProgram (subgraph) ## How to Add a New Op -This section walks through adding a new op end-to-end, using **`aten.linear`** +This section walks through adding a new op end-to-end, using **`aten.addmm`** as an example. ### Step 1: Add the Node to `schema.fbs` @@ -201,15 +201,15 @@ as an example. Add a new table in the "Op nodes" section and add it to the `OpNode` union: ```fbs -table LinearNode { - x: Tid (required); - weight: Tid (required); +table AddmmNode { + mat1: Tid (required); + mat2: Tid (required); out: Tid (required); bias: Tid; // optional } ``` -Then add `LinearNode` to the `union OpNode { ... }` list. +Then add `AddmmNode` to the `union OpNode { ... }` list. ### Step 2: Run the Code Generator @@ -219,34 +219,40 @@ python backends/mlx/serialization/generate.py This regenerates: -- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass -- `_generated_serializers.py` — adds `_build_LinearNode` serializer -- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader -- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization +- `mlx_graph_schema.py` — adds `AddmmNode` Python dataclass +- `_generated_serializers.py` — adds `_build_AddmmNode` serializer +- `runtime/MLXLoader.h` — adds `AddmmNode` C++ struct, `OpCode::ADDMM`, loader +- `runtime/MLXLoader.cpp` — adds FlatBuffer → `AddmmNode` deserialization - `runtime/schema_generated.h` — FlatBuffer C++ bindings ### Step 3: Add the Python Op Handler (`ops.py`) Register a handler that converts the ATen op to your new node. Make sure to -import `LinearNode` from `mlx_graph_schema`: +import `AddmmNode` from `mlx_graph_schema`: ```python -from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode +from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode -@REGISTRY.register(target=[torch.ops.aten.linear.default]) -def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: +@REGISTRY.register(target=[torch.ops.aten.addmm.default]) +def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: args = P.args(n) - require_args(args, 2, 3, "aten.linear") - require_kwargs(P.kwargs(n), set(), "aten.linear") - x, w = args[0], args[1] - b = args[2] if len(args) > 2 else None + kwargs = P.kwargs(n) + require_args(args, 3, 3, "aten.addmm") + require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm") + bias, mat1, mat2 = args[0], args[1], args[2] + + beta = kwargs.get("beta", 1) + alpha = kwargs.get("alpha", 1) + out = P.make_or_get_slot(n) P.emit( - LinearNode( - x=P.slot_to_tid(x), - weight=P.slot_to_tid(w), + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), out=P.slot_to_tid(out), - bias=P.slot_to_tid(b) if b else None, + bias=P.slot_to_tid(bias), + alpha=float(alpha), + beta=float(beta), ) ) return out @@ -263,21 +269,28 @@ Key APIs: Add an `exec_*` function in the `ops` namespace: ```cpp -inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { - const auto& X = st.const_tensor_ref(n.x); - auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s); - array Y = n.bias - ? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s) - : matmul(X, W, s); +inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + st.set_tensor(n.out, std::move(Y)); } ``` -Then add the dispatch case in `Interpreter::execute_instruction()`: +Then add the dispatch case in `Interpreter::dispatch()`: ```cpp -case OpCode::LINEAR: - ops::exec_linear(std::get(instr.node), st, s); +case OpCode::ADDMM: + ops::exec_addmm(std::get(instr.node), st, s); break; ``` @@ -290,34 +303,60 @@ Each test follows a standard pattern: 3. **Decorate with `@register_test`** to register it with the test runner. ```python -class LinearModel(nn.Module): - def __init__(self, in_features=64, out_features=128, bias=True): +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0): super().__init__() - self.linear = nn.Linear(in_features, out_features, bias=bias) + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) @register_test -class LinearTest(OpTestCase): - name = "linear" +class AddmmTest(OpTestCase): + name = "addmm" rtol = 1e-4 atol = 1e-4 - def __init__(self, in_features=64, out_features=128, bias=True): + def __init__(self, batch_size=2, in_features=64, out_features=32, + bias=True, alpha=1.0, beta=1.0): + self.batch_size = batch_size self.in_features = in_features self.out_features = out_features self.bias = bias + self.alpha = alpha + self.beta = beta + self.name = f"addmm_{in_features}x{out_features}" @classmethod def get_test_configs(cls): - return [cls(), cls(bias=False)] + return [ + cls(batch_size=2, in_features=64, out_features=32), + cls(batch_size=2, in_features=64, out_features=32, bias=False), + cls(batch_size=4, in_features=128, out_features=64), + cls(batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5), + ] def create_model(self): - return LinearModel(self.in_features, self.out_features, bias=self.bias) + return AddmmModel( + self.in_features, self.out_features, + bias=self.bias, alpha=self.alpha, beta=self.beta, + ) def create_inputs(self): - return (torch.randn(2, 16, self.in_features),) + return (torch.randn(self.batch_size, self.in_features),) ``` ### Step 6: Run Tests @@ -327,7 +366,7 @@ outputs against PyTorch reference. Since adding a new op always involves C++ changes, use `--rebuild` to recompile the runtime: ```bash -python -m executorch.backends.mlx.test.run_all_tests --rebuild linear +python -m executorch.backends.mlx.test.run_all_tests --rebuild addmm ``` Run all tests in parallel: @@ -356,7 +395,7 @@ architecture, prerequisites, and the `OpTestCase` API. - [ ] Run `python backends/mlx/serialization/generate.py` - [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class) - [ ] Add `exec_*` function in `runtime/MLXInterpreter.h` -- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()` +- [ ] Add `case OpCode::*` in `Interpreter::dispatch()` - [ ] Add test model + `OpTestCase` in `test/test_ops.py` - [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild ` diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index d12743906db..6f6ee11fe41 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -1006,7 +1006,7 @@ def _fbs_type_to_cpp( def _table_name_to_opcode(name: str) -> str: - """Convert table name like 'LinearNode' to opcode like 'LINEAR'. + """Convert table name like 'AddNode' to opcode like 'ADD'. Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small override dict for names whose conventional opcode doesn't follow the From 8f047dd40e075a580c4c5f74adddd2c070770ef8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 7 Apr 2026 14:34:40 -0700 Subject: [PATCH 31/34] up --- examples/models/parakeet/export_parakeet_tdt.py | 10 +++++++++- examples/models/voxtral_realtime/model.py | 13 +++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 587d7674bc4..c35e17eed59 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -640,7 +640,15 @@ def main(): "--backend", type=str, default="xnnpack", - choices=["portable", "xnnpack", "metal", "mlx", "cuda", "cuda-windows", "vulkan"], + choices=[ + "portable", + "xnnpack", + "metal", + "mlx", + "cuda", + "cuda-windows", + "vulkan", + ], help="Backend for acceleration (default: xnnpack)", ) parser.add_argument( diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 06df2be0880..1227e9e8bea 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -700,7 +700,10 @@ def __init__(self, config: VoxtralRealtimeConfig): if self.backend == "mlx": cache_dtype = self.wq.weight.dtype self.kv_cache = MLXKVCache( - config.sliding_window, self.n_kv_heads, self.head_dim, dtype=cache_dtype + config.sliding_window, + self.n_kv_heads, + self.head_dim, + dtype=cache_dtype, ) self.sdpa = MLXSDPA(self.n_heads, self.head_dim) elif self.backend == "metal": @@ -1170,7 +1173,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): elif config.backend == "metal": self.kv_caches = nn.ModuleList( [ - StandardEncoderRingKVCache( + StandardRingKVCache( max_enc_len, config.enc_n_heads, config.enc_head_dim ) for _ in range(config.enc_n_layers) @@ -1184,7 +1187,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): elif config.backend == "cuda": self.kv_caches = nn.ModuleList( [ - StandardEncoderRingKVCache( + StandardRingKVCache( max_enc_len, config.enc_n_heads, config.enc_head_dim ) for _ in range(config.enc_n_layers) @@ -1198,9 +1201,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): else: self.kv_caches = nn.ModuleList( [ - EncoderRingKVCache( - max_enc_len, config.enc_n_heads, config.enc_head_dim - ) + RingKVCache(max_enc_len, config.enc_n_heads, config.enc_head_dim) for _ in range(config.enc_n_layers) ] ) From 984de055ce9b69bdaf668ceb5a03c61719fb2809 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 7 Apr 2026 15:49:23 -0700 Subject: [PATCH 32/34] up --- .github/workflows/mlx.yml | 10 ++--- backends/mlx/examples/llm/README.md | 10 ++--- backends/mlx/examples/whisper/README.md | 4 ++ examples/models/voxtral_realtime/model.py | 48 +++++++++++++++-------- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index fb68d1739af..78087738d4e 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -181,7 +181,7 @@ jobs: echo "::group::Install Voxtral requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" ${CONDA_RUN} pip install mistral_common librosa soundfile datasets OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" @@ -240,13 +240,13 @@ jobs: echo "::group::Install Voxtral Realtime requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]" safetensors - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" echo "::endgroup::" ${CONDA_RUN} pip list echo "::group::Download model" - ${CONDA_RUN} huggingface-cli download mistralai/Voxtral-Mini-4B-Realtime-2602 + ${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602')" MODEL_PATH=$(${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))") echo "Model path: ${MODEL_PATH}" echo "::endgroup::" @@ -313,7 +313,7 @@ jobs: echo "::group::Install Whisper requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" ${CONDA_RUN} pip install transformers soundfile datasets librosa echo "::endgroup::" @@ -447,7 +447,7 @@ jobs: echo "::group::Install LLM requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" echo "::endgroup::" diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index 7346efcef69..f860c4f1ce0 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -44,14 +44,14 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --use-custom-sdpa \ --use-custom-kv-cache -# With INT4 quantization +# With 4-bit quantization python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id "unsloth/Llama-3.2-1B-Instruct" \ --output llama_hf_int4.pte \ --use-custom-sdpa \ --use-custom-kv-cache \ - --quantize-linear int4 \ - --quantize-embeddings int4 + --qlinear 4w \ + --qembedding 4w ``` ### Options @@ -62,8 +62,8 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ | `--output` | *(required)* | Output .pte file path | | `--max-seq-len` | `1024` | Maximum sequence length for KV cache | | `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | -| `--quantize-linear` | None | Quantization for linear layers (`int4`, `int8`) | -| `--quantize-embeddings` | None | Quantization for embedding layers (`int4`, `int8`) | +| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) | +| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) | | `--no-tie-word-embeddings` | `False` | Disable re-tying lm_head to embedding after quantization | | `--use-custom-sdpa` | `False` | Use MLX custom SDPA (`mlx::custom_sdpa`) | | `--use-custom-kv-cache` | `False` | Use MLX custom KV cache (`mlx::kv_cache_update`) | diff --git a/backends/mlx/examples/whisper/README.md b/backends/mlx/examples/whisper/README.md index 3e7749a3957..6487a22a3a5 100644 --- a/backends/mlx/examples/whisper/README.md +++ b/backends/mlx/examples/whisper/README.md @@ -34,6 +34,10 @@ python -m executorch.backends.mlx.examples.whisper.export_whisper \ | `--output-dir` | `whisper_mlx` | Output directory for `.pte` files | | `--max-decoder-seq-len` | `256` | Maximum decoder sequence length | | `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) | +| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) | +| `--qlinear-group-size` | auto | Group size for linear quantization | +| `--qembedding-group-size` | auto | Group size for embedding quantization | ## Run diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 1227e9e8bea..e591445cc56 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -538,11 +538,12 @@ def forward( return y.view(bsz, seqlen, self.dim) -class MLXKVCache(nn.Module): - """Wrapper that adapts MLX BHSD KV cache for model's BSHD convention. +class MLXStaticKVCache(nn.Module): + """Wrapper that adapts MLX static KV cache for model's BSHD convention. - The model's QKV projections produce [B, S, H, D] tensors, but MLX's - KVCache expects [B, H, S, D]. This wrapper transposes on the way in. + For offline (non-streaming) mode. The model's QKV projections produce + [B, S, H, D] tensors, but MLX's KVCache expects [B, H, S, D]. + This wrapper transposes on the way in. """ def __init__( @@ -569,12 +570,13 @@ def update( return self.cache.update(input_pos, k_val, v_val) -class MLXEncoderRingKVCache(nn.Module): - """Wrapper that adapts MLX RingBufferKVCache for the encoder's BSHD convention. +class MLXRingKVCache(nn.Module): + """Wrapper that adapts MLX RingBufferKVCache for model's BSHD convention. - The encoder's QKV projections produce [B, S, H, D] tensors, but MLX's - RingBufferKVCache expects [B, H, S, D]. This wrapper transposes on the - way in and delegates ring buffer semantics to the MLX implementation. + For streaming mode (both encoder and decoder). The model's QKV projections + produce [B, S, H, D] tensors, but MLX's RingBufferKVCache expects + [B, H, S, D]. This wrapper transposes on the way in and delegates + ring buffer semantics to the MLX implementation. """ def __init__( @@ -603,7 +605,9 @@ def update( v_val = v_val.transpose(1, 2) return self.ring_cache.update(input_pos, k_val, v_val) - def create_causal_mask(self, start_pos, seq_len, bool_mask=False) -> torch.Tensor: + def create_causal_mask( + self, start_pos, seq_len, bool_mask=False, **kwargs + ) -> torch.Tensor: return self.ring_cache.create_sliding_window_mask(start_pos, seq_len) @@ -637,9 +641,10 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) -class MLXEncoderSDPA(nn.Module): - """SDPA for streaming encoder with MLX ring buffer KV cache. +class MLXMaskedSDPA(nn.Module): + """SDPA with explicit mask for MLX ring buffer KV cache. + Used with MLXRingKVCache for streaming mode (both encoder and decoder). Uses F.scaled_dot_product_attention with explicit attn_mask from the ring buffer. KV cache is in BHSD layout, queries are in BSHD. """ @@ -662,7 +667,7 @@ def forward( Args: input_pos: (seq_len,) position indices (unused, kept for interface). q: (B, seq_len, n_heads, head_dim) in BSHD layout. - k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXEncoderRingKVCache. + k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXRingKVCache. bsz, seqlen: batch size and query length. mask: (1, 1, seq_len, buf_size) additive attention mask from ring buffer. """ @@ -699,7 +704,7 @@ def __init__(self, config: VoxtralRealtimeConfig): # Ring buffer KV cache for unlimited streaming. if self.backend == "mlx": cache_dtype = self.wq.weight.dtype - self.kv_cache = MLXKVCache( + self.kv_cache = MLXRingKVCache( config.sliding_window, self.n_kv_heads, self.head_dim, @@ -723,7 +728,16 @@ def __init__(self, config: VoxtralRealtimeConfig): self.sdpa = SDPA(self.n_heads, self.head_dim) else: # Flat KV cache for offline mode (capped at max_seq_len). - if self.backend == "metal": + if self.backend == "mlx": + cache_dtype = self.wq.weight.dtype + self.kv_cache = MLXStaticKVCache( + config.max_seq_len, + self.n_kv_heads, + self.head_dim, + dtype=cache_dtype, + ) + self.sdpa = MLXSDPA(self.n_heads, self.head_dim) + elif self.backend == "metal": self.kv_cache = StaticKVCache( config.max_seq_len, self.n_kv_heads, self.head_dim ) @@ -1160,7 +1174,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): cache_dtype = self.layers[0].attention.wq.weight.dtype self.kv_caches = nn.ModuleList( [ - MLXEncoderRingKVCache( + MLXRingKVCache( max_enc_len, config.enc_n_heads, config.enc_head_dim, @@ -1169,7 +1183,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): for _ in range(config.enc_n_layers) ] ) - self.sdpa = MLXEncoderSDPA(config.enc_n_heads, config.enc_head_dim) + self.sdpa = MLXMaskedSDPA(config.enc_n_heads, config.enc_head_dim) elif config.backend == "metal": self.kv_caches = nn.ModuleList( [ From 6e5f3235b3795b1a5cd36d9f0fc6d4d4fc0983cc Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 7 Apr 2026 16:18:50 -0700 Subject: [PATCH 33/34] up --- .github/workflows/mlx.yml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 78087738d4e..10632bf1650 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -181,8 +181,7 @@ jobs: echo "::group::Install Voxtral requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" - ${CONDA_RUN} pip install mistral_common librosa soundfile datasets + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" echo "::endgroup::" @@ -239,14 +238,14 @@ jobs: echo "::endgroup::" echo "::group::Install Voxtral Realtime requirements" - ${CONDA_RUN} pip install -U "huggingface_hub[cli]" safetensors - ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" safetensors + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN echo "::endgroup::" ${CONDA_RUN} pip list echo "::group::Download model" - ${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602')" + ${CONDA_RUN} huggingface-cli download mistralai/Voxtral-Mini-4B-Realtime-2602 MODEL_PATH=$(${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))") echo "Model path: ${MODEL_PATH}" echo "::endgroup::" @@ -313,7 +312,7 @@ jobs: echo "::group::Install Whisper requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN ${CONDA_RUN} pip install transformers soundfile datasets librosa echo "::endgroup::" @@ -447,7 +446,7 @@ jobs: echo "::group::Install LLM requirements" ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" echo "::endgroup::" From 0d3a549fe8a7df06e6bfcc6be9f3b1525d1fc548 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 7 Apr 2026 16:40:55 -0700 Subject: [PATCH 34/34] up --- .github/workflows/mlx.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 10632bf1650..e62e93b3a20 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -180,8 +180,7 @@ jobs: echo "::endgroup::" echo "::group::Install Voxtral requirements" - ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} pip install mistral_common librosa soundfile datasets OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" echo "::endgroup::" @@ -238,15 +237,14 @@ jobs: echo "::endgroup::" echo "::group::Install Voxtral Realtime requirements" - ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" safetensors - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} pip install safetensors echo "::endgroup::" ${CONDA_RUN} pip list echo "::group::Download model" - ${CONDA_RUN} huggingface-cli download mistralai/Voxtral-Mini-4B-Realtime-2602 - MODEL_PATH=$(${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))") + HF_TOKEN=$SECRET_EXECUTORCH_HF_TOKEN ${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602')" + MODEL_PATH=$(HF_TOKEN=$SECRET_EXECUTORCH_HF_TOKEN ${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))") echo "Model path: ${MODEL_PATH}" echo "::endgroup::"