diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 53c7b9360cd..e62e93b3a20 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -10,6 +10,9 @@ on: - .github/workflows/mlx.yml - backends/mlx/** - extension/llm/export/** + - extension/audio/** + - examples/models/parakeet/** + - examples/models/voxtral_realtime/** workflow_dispatch: permissions: {} @@ -104,3 +107,370 @@ jobs: echo "::error::Too many test failures: $FAILED > $MAX_FAILURES" exit 1 fi + + test-mlx-parakeet: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-parakeet + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Parakeet requirements" + ${CONDA_RUN} pip install -r examples/models/parakeet/install_requirements.txt + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Parakeet" + ${CONDA_RUN} python -m executorch.examples.models.parakeet.export_parakeet_tdt \ + --backend mlx \ + --dtype bf16 \ + --qlinear_encoder 4w \ + --qlinear_encoder_group_size 128 \ + --qlinear 4w \ + --qlinear_group_size 128 \ + --output-dir /tmp/parakeet_mlx + echo "::endgroup::" + + echo "::group::Build Parakeet MLX runner" + ${CONDA_RUN} make parakeet-mlx + echo "::endgroup::" + + echo "::group::Run Parakeet MLX runner" + curl -L https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav -o /tmp/test_audio.wav + OUTPUT=$(./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path /tmp/parakeet_mlx/model.pte \ + --audio_path /tmp/test_audio.wav \ + --tokenizer_path /tmp/parakeet_mlx/tokenizer.model 2>&1) + echo "Runner output:" + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Phoebe"; then + echo "Success: 'Phoebe' found in output" + else + echo "Failed: Expected 'Phoebe' not found in output" + exit 1 + fi + echo "::endgroup::" + + test-mlx-voxtral: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-voxtral + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Voxtral requirements" + ${CONDA_RUN} pip install mistral_common librosa soundfile datasets + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + ${CONDA_RUN} pip install "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Voxtral" + ${CONDA_RUN} python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --output-dir /tmp/voxtral_mlx \ + --dtype bf16 \ + --qlinear 4w + echo "::endgroup::" + + echo "::group::Build Voxtral MLX runner" + ${CONDA_RUN} make voxtral-mlx + echo "::endgroup::" + + echo "::group::Run Voxtral MLX runner" + curl -L https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main/tekken.json -o /tmp/tekken.json + curl -L https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav -o /tmp/test_audio.wav + OUTPUT=$(./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path /tmp/voxtral_mlx/model.pte \ + --tokenizer_path /tmp/tekken.json \ + --audio_path /tmp/test_audio.wav \ + --processor_path /tmp/voxtral_mlx/preprocessor.pte \ + --prompt "What is happening in this audio?" \ + --temperature 0 2>&1) + echo "Runner output:" + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "poem"; then + echo "Success: 'poem' found in output" + else + echo "Failed: Expected 'poem' not found in output" + exit 1 + fi + echo "::endgroup::" + + test-mlx-voxtral-realtime: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-voxtral-realtime + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Voxtral Realtime requirements" + ${CONDA_RUN} pip install safetensors + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Download model" + HF_TOKEN=$SECRET_EXECUTORCH_HF_TOKEN ${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602')" + MODEL_PATH=$(HF_TOKEN=$SECRET_EXECUTORCH_HF_TOKEN ${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))") + echo "Model path: ${MODEL_PATH}" + echo "::endgroup::" + + echo "::group::Export preprocessor" + ${CONDA_RUN} python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --streaming \ + --backend mlx \ + --output_file /tmp/voxtral_rt_mlx/preprocessor.pte + echo "::endgroup::" + + echo "::group::Export Voxtral Realtime (streaming)" + ${CONDA_RUN} python -m executorch.examples.models.voxtral_realtime.export_voxtral_rt \ + --model-path "${MODEL_PATH}" \ + --backend mlx \ + --streaming \ + --output-dir /tmp/voxtral_rt_mlx \ + --qlinear-encoder 4w \ + --qlinear 4w \ + --qembedding 8w \ + --qembedding-group-size 128 + echo "::endgroup::" + + echo "::group::Build Voxtral Realtime MLX runner" + ${CONDA_RUN} make voxtral_realtime-mlx + echo "::endgroup::" + + echo "::group::Run Voxtral Realtime MLX runner" + curl -L https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav -o /tmp/test_audio.wav + OUTPUT=$(./cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ + --model_path /tmp/voxtral_rt_mlx/model.pte \ + --tokenizer_path "${MODEL_PATH}/tekken.json" \ + --preprocessor_path /tmp/voxtral_rt_mlx/preprocessor.pte \ + --audio_path /tmp/test_audio.wav \ + --streaming 2>&1) + echo "Runner output:" + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Phoebe"; then + echo "Success: 'Phoebe' found in output" + else + echo "Failed: Expected 'Phoebe' not found in output" + exit 1 + fi + echo "::endgroup::" + + test-mlx-whisper: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-whisper + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch and configure MLX build" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Whisper requirements" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} pip install transformers soundfile datasets librosa + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export Whisper" + ${CONDA_RUN} python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id "openai/whisper-tiny" \ + --output-dir /tmp/whisper_mlx \ + --dtype bf16 \ + --qlinear 4w + echo "::endgroup::" + + echo "::group::Run Whisper inference" + OUTPUT=$( ${CONDA_RUN} python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio 2>&1) + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Mr. Quilter"; then + echo "Success: 'Mr. Quilter' found in transcription" + else + echo "Failed: Expected 'Mr. Quilter' not found in transcription" + exit 1 + fi + echo "::endgroup::" + + + test-mlx-stories110m: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-stories110m + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + echo "::group::Install Llama requirements" + ${CONDA_RUN} sh examples/models/llama/install_requirements.sh + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Build ExecuTorch with MLX delegate" + ${CONDA_RUN} cmake --workflow --preset mlx-release + echo "::endgroup::" + + echo "::group::Build Llama runner with MLX" + pushd examples/models/llama + ${CONDA_RUN} cmake --workflow --preset llama-release + popd + echo "::endgroup::" + + echo "::group::Download stories110M artifacts" + curl -Ls "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" --output stories110M.pt + curl -Ls "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" --output tokenizer.model + echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json + echo "::endgroup::" + + echo "::group::Create tokenizer.bin" + ${CONDA_RUN} python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin + echo "::endgroup::" + + echo "::group::Export stories110M with MLX backend via export_llama_lib" + ${CONDA_RUN} python -m extension.llm.export.export_llm \ + base.checkpoint=stories110M.pt \ + base.params=params.json \ + model.use_kv_cache=true \ + model.dtype_override=fp32 \ + backend.mlx.enabled=true \ + quantization.qmode=4w \ + quantization.group_size=32 \ + export.output_name=/tmp/stories110m_mlx.pte + echo "::endgroup::" + + echo "::group::Run inference with C++ llama runner" + ./cmake-out/examples/models/llama/llama_main \ + --model_path=/tmp/stories110m_mlx.pte \ + --tokenizer_path=tokenizer.bin \ + --prompt="Once upon a time," \ + --temperature=0 \ + --seq_len=10 + echo "::endgroup::" + + test-mlx-llm: + strategy: + fail-fast: false + matrix: + model: + - id: "unsloth/Llama-3.2-1B-Instruct" + name: "llama-1b" + - id: "unsloth/Qwen3-0.6B" + name: "qwen3-0.6b" + - id: "unsloth/gemma-3-1b-it" + name: "gemma3-1b" + use-custom: [false, true] + qconfig: ["4w", "nvfp4"] + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + with: + job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }}-${{ matrix.qconfig }} + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + secrets-env: EXECUTORCH_HF_TOKEN + timeout: 90 + script: | + set -eux + + MODEL_ID="${{ matrix.model.id }}" + MODEL_NAME="${{ matrix.model.name }}" + USE_CUSTOM="${{ matrix.use-custom }}" + QCONFIG="${{ matrix.qconfig }}" + + CUSTOM_ARGS="" + if [ "${USE_CUSTOM}" = "true" ]; then + CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache" + fi + + echo "::group::Install ExecuTorch and configure MLX build" + ${CONDA_RUN} python install_executorch.py > /dev/null + ${CONDA_RUN} cmake --preset mlx-release + echo "::endgroup::" + + echo "::group::Install LLM requirements" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Export ${MODEL_NAME}" + ${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "${MODEL_ID}" \ + --output /tmp/${MODEL_NAME}.pte \ + --qlinear ${QCONFIG} \ + --qembedding ${QCONFIG} \ + ${CUSTOM_ARGS} + echo "::endgroup::" + + echo "::group::Run ${MODEL_NAME} inference" + OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte /tmp/${MODEL_NAME}.pte \ + --model-id "${MODEL_ID}" \ + --prompt "What is the capital of France?" \ + --max-new-tokens 50 2>&1) + echo "$OUTPUT" + if echo "$OUTPUT" | grep -iq "Paris"; then + echo "Success: 'Paris' found in output" + else + echo "Failed: Expected 'Paris' not found in output" + exit 1 + fi + echo "::endgroup::" diff --git a/Makefile b/Makefile index 33f8d9bff4e..00ae1c4c0de 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,10 @@ # # SUPPORTED MODELS: # ----------------- -# - voxtral: Multimodal voice + text model (CPU, CUDA, Metal) -# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal) +# - voxtral: Multimodal voice + text model (CPU, CUDA, Metal, MLX) +# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal, MLX) # - whisper: Speech recognition model (CPU, CUDA, Metal) -# - parakeet: Speech recognition model (CPU, CUDA, Metal) +# - parakeet: Speech recognition model (CPU, CUDA, Metal, MLX) # - sortformer: Speaker diarization model (CPU, CUDA) # - silero_vad: Voice activity detection model (CPU) # - llama: Text generation model (CPU) @@ -91,16 +91,18 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" @echo " voxtral-cpu - Build Voxtral runner with CPU backend" @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " voxtral-mlx - Build Voxtral runner with MLX backend" @echo " voxtral_realtime-cuda - Build Voxtral Realtime runner with CUDA backend" @echo " voxtral_realtime-cpu - Build Voxtral Realtime runner with CPU backend" @echo " voxtral_realtime-metal - Build Voxtral Realtime runner with Metal backend (macOS only)" + @echo " voxtral_realtime-mlx - Build Voxtral Realtime runner with MLX backend" @echo " whisper-cuda - Build Whisper runner with CUDA backend" @echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)" @echo " whisper-cpu - Build Whisper runner with CPU backend" @@ -109,6 +111,7 @@ help: @echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)" @echo " parakeet-cpu - Build Parakeet runner with CPU backend" @echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)" + @echo " parakeet-mlx - Build Parakeet runner with MLX backend" @echo " parakeet-vulkan - Build Parakeet runner with Vulkan backend" @echo " dinov2-cuda - Build DINOv2 runner with CUDA backend" @echo " dinov2-cuda-debug - Build DINOv2 runner with CUDA backend (debug mode)" @@ -151,6 +154,15 @@ voxtral-metal: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" +voxtral-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Voxtral runner with MLX..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + whisper-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." cmake --workflow --preset llm-release-cuda @@ -223,6 +235,15 @@ parakeet-metal: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner" +parakeet-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Parakeet runner with MLX..." + cd examples/models/parakeet && cmake --workflow --preset parakeet-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner" + parakeet-vulkan: @echo "==> Building and installing ExecuTorch with Vulkan..." cmake --workflow --preset llm-debug-vulkan @@ -258,7 +279,6 @@ sortformer-cuda: @echo "" @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/sortformer/sortformer_runner" - sortformer-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release @@ -295,6 +315,15 @@ voxtral_realtime-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" +voxtral_realtime-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Voxtral Realtime runner with MLX..." + cd examples/models/voxtral_realtime && cmake --workflow --preset voxtral-realtime-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" + silero-vad-cpu: @echo "==> Configuring and installing ExecuTorch (without LLM runner)..." cmake --preset llm-release -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=OFF diff --git a/backends/mlx/examples/__init__.py b/backends/mlx/examples/__init__.py new file mode 100644 index 00000000000..f557ef26c5b --- /dev/null +++ b/backends/mlx/examples/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md new file mode 100644 index 00000000000..f860c4f1ce0 --- /dev/null +++ b/backends/mlx/examples/llm/README.md @@ -0,0 +1,100 @@ +# LLM MLX Example + +This example demonstrates how to export and run LLMs using the MLX delegate for Apple Silicon. + +## Features + +- **Export**: Convert HuggingFace LLMs to ExecuTorch format with MLX delegate +- **Quantization**: Optional INT4/INT8 weight quantization via TorchAO +- **KV Cache**: Efficient KV cache implementation for autoregressive generation +- **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX +- **Pybindings**: Run inference using ExecuTorch Python bindings + +## Requirements + +```bash +pip install transformers optimum-executorch +``` + +## Scripts Overview + +| Script | Description | +|--------|-------------| +| `export_llm_hf` | Export LLMs using optimum-executorch pipeline, with optional custom MLX SDPA/KV cache | +| `run_llm_hf` | Run exported models with token-by-token generation | + +For exporting via the ExecuTorch LLM pipeline (e.g. `examples/models/llama`), use `--mlx` to enable the MLX delegate. + +--- + +## `export_llm_hf` + +Uses optimum-executorch's `CausalLMExportableModule` by default. Optional flags enable custom MLX-optimized components (custom SDPA and/or KV cache). + +```bash +# Baseline export using optimum-executorch +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "unsloth/Llama-3.2-1B-Instruct" \ + --output llama_hf.pte + +# With custom MLX components +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "unsloth/Llama-3.2-1B-Instruct" \ + --output llama_hf_mlx.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache + +# With 4-bit quantization +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "unsloth/Llama-3.2-1B-Instruct" \ + --output llama_hf_int4.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache \ + --qlinear 4w \ + --qembedding 4w +``` + +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID | +| `--output` | *(required)* | Output .pte file path | +| `--max-seq-len` | `1024` | Maximum sequence length for KV cache | +| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) | +| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) | +| `--no-tie-word-embeddings` | `False` | Disable re-tying lm_head to embedding after quantization | +| `--use-custom-sdpa` | `False` | Use MLX custom SDPA (`mlx::custom_sdpa`) | +| `--use-custom-kv-cache` | `False` | Use MLX custom KV cache (`mlx::kv_cache_update`) | + +--- + +## `run_llm_hf` + +Run models exported with `export_llm_hf`. Supports both full-prompt prefill (dynamic seq len exports) and token-by-token prefill (fixed seq len exports). + +```bash +python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte llama_hf.pte \ + --model-id unsloth/Llama-3.2-1B-Instruct \ + --prompt "Explain quantum computing in simple terms" +``` + +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--pte` | `llama_hf.pte` | Path to .pte file | +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) | +| `--prompt` | `The quick brown fox` | Input prompt | +| `--max-new-tokens` | `50` | Maximum tokens to generate | + +--- + +## Architecture + +The `export_llm_hf` script uses optimum-executorch's `CausalLMExportableModule` by default. When custom flags are enabled, it uses `TorchExportableModuleWithStaticCache` from HuggingFace transformers, with optional MLX-specific replacements: + +- `--use-custom-sdpa`: Registers `mlx::custom_sdpa` attention implementation +- `--use-custom-kv-cache`: Replaces HF's `StaticCache` with `HFStaticCache` using `mlx::kv_cache_update` diff --git a/backends/mlx/examples/llm/__init__.py b/backends/mlx/examples/llm/__init__.py new file mode 100644 index 00000000000..f557ef26c5b --- /dev/null +++ b/backends/mlx/examples/llm/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py new file mode 100644 index 00000000000..39f13e434be --- /dev/null +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export LLM model from HuggingFace to MLX backend. + +By default, uses optimum-executorch's CausalLMExportableModule which provides +a proven export pipeline. Optional flags enable custom MLX-optimized components: + + --use-custom-sdpa Register MLX attention (mlx::custom_sdpa) which handles + K/V slicing and causal masking internally. + --use-custom-kv-cache Replace HF's StaticCache with HFStaticCache that uses + mlx::kv_cache_update for optimized cache updates. + +When neither flag is set, the script behaves identically to the original +optimum-executorch export pipeline. + +Usage: + # Baseline (optimum-executorch pipeline): + python -m executorch.backends.mlx.examples.llm.export_llm_hf \\ + --model-id "unsloth/Llama-3.2-1B-Instruct" \\ + --output llama_hf.pte + + # With custom MLX components: + python -m executorch.backends.mlx.examples.llm.export_llm_hf \\ + --model-id "unsloth/Llama-3.2-1B-Instruct" \\ + --output llama_hf_mlx.pte \\ + --use-custom-sdpa \\ + --use-custom-kv-cache + +Requirements: + pip install transformers torch optimum-executorch +""" + +import argparse +import logging +import os +from typing import Optional + +import torch + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def _export_with_optimum( + model_id: str, + output_path: str, + max_seq_len: int, + dtype: str, + qlinear: Optional[str], + qembedding: Optional[str], + no_tie_word_embeddings: bool = False, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, +) -> None: + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + from optimum.exporters.executorch.tasks.causal_lm import load_causal_lm_model + + dtype_map = {"fp32": "float32", "fp16": "float16", "bf16": "bfloat16"} + dtype_str = dtype_map.get(dtype, "bfloat16") + + logger.info(f"Loading model using optimum-executorch: {model_id}") + exportable = load_causal_lm_model( + model_id, + dtype=dtype_str, + max_seq_len=max_seq_len, + ) + + from executorch.backends.mlx.llm.quantization import quantize_model_ + + quantize_model_( + exportable.model, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, + tie_word_embeddings=getattr( + exportable.model.config, "tie_word_embeddings", False + ) + and not no_tie_word_embeddings, + ) + + logger.info("Exporting model with torch.export...") + exported_progs = exportable.export() + + logger.info("Delegating to MLX backend...") + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + if len(exported_progs) == 1: + exported_progs = {"forward": next(iter(exported_progs.values()))} + + edge_program = exir.to_edge_transform_and_lower( + exported_progs, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + constant_methods=exportable.metadata, + ) + + logger.info("Exporting to ExecuTorch...") + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + _save_program(executorch_program, output_path) + + +def _export_with_custom_components( + model_id: str, + output_path: str, + max_seq_len: int, + dtype: str, + qlinear: Optional[str], + qembedding: Optional[str], + use_custom_sdpa: bool, + use_custom_kv_cache: bool, + no_tie_word_embeddings: bool = False, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, +) -> None: + """ + Export using direct HF model with custom MLX components. + + Used when --use-custom-sdpa and/or --use-custom-kv-cache are set. + """ + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + from transformers import AutoModelForCausalLM + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + + torch_dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + torch_dtype = torch_dtype_map.get(dtype, torch.bfloat16) + + if use_custom_sdpa: + from executorch.backends.mlx.llm.hf_attention import register_mlx_attention + + register_mlx_attention() + logger.info("Registered MLX custom SDPA attention") + + attn_implementation = "mlx" if use_custom_sdpa else None + + # Detect sliding window models (e.g., gemma) + sliding_window = None + + logger.info(f"Loading HuggingFace model: {model_id}") + load_kwargs = { + "torch_dtype": torch_dtype, + "low_cpu_mem_usage": True, + } + if attn_implementation: + load_kwargs["attn_implementation"] = attn_implementation + model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) + + # Check if model uses sliding window attention + sliding_window = getattr(model.config, "sliding_window", None) + if sliding_window is not None: + logger.info(f"Model has sliding_window={sliding_window}") + # Cap max_seq_len to sliding window size for cache allocation + effective_cache_len = min(max_seq_len, sliding_window) + logger.info(f" Capping cache length to sliding window: {effective_cache_len}") + else: + effective_cache_len = max_seq_len + + model.generation_config.cache_implementation = "static" + model.generation_config.cache_config = { + "batch_size": 1, + "max_cache_len": effective_cache_len, + } + model.eval() + + # Use HybridCache wrapper for sliding window models (stores cache as .cache), + # StaticCache wrapper for non-sliding-window models (stores cache as .static_cache). + # This matters because the sliding window SDPA closure looks up the cache via + # exportable_module.cache, matching the optimum-executorch convention. + if sliding_window is not None: + from transformers.integrations.executorch import ( + TorchExportableModuleWithHybridCache, + ) + + logger.info("Creating TorchExportableModuleWithHybridCache wrapper...") + exportable = TorchExportableModuleWithHybridCache( + model=model, + batch_size=1, + max_cache_len=effective_cache_len, + ) + else: + logger.info("Creating TorchExportableModuleWithStaticCache wrapper...") + exportable = TorchExportableModuleWithStaticCache( + model=model, + batch_size=1, + max_cache_len=effective_cache_len, + ) + + if use_custom_kv_cache: + if sliding_window is not None: + # Use ring buffer cache for sliding window models + from executorch.backends.mlx.llm.source_transformation import ( + replace_hf_cache_with_mlx_ring_buffer, + ) + + logger.info( + f"Replacing StaticCache with RingBuffer KV cache " + f"(window_size={effective_cache_len})..." + ) + replace_hf_cache_with_mlx_ring_buffer( + exportable, + model.config, + max_batch_size=1, + window_size=effective_cache_len, + dtype=torch_dtype, + ) + + if use_custom_sdpa: + # Re-register attention with sliding window closure + from executorch.backends.mlx.llm.hf_attention import ( + register_mlx_sliding_window_attention, + ) + + register_mlx_sliding_window_attention(exportable) + model.config._attn_implementation = "mlx_sliding_window" + logger.info( + " Registered sliding window attention (mlx_sliding_window)" + ) + + logger.info(" RingBuffer KV cache installed successfully") + else: + # Use standard linear cache for non-sliding-window models + from executorch.backends.mlx.llm.source_transformation import ( + replace_hf_cache_with_mlx, + ) + + logger.info("Replacing HuggingFace StaticCache with HFStaticCache...") + replace_hf_cache_with_mlx( + exportable, + model.config, + max_batch_size=1, + max_cache_len=effective_cache_len, + dtype=torch_dtype, + ) + logger.info(" HFStaticCache installed successfully") + + from executorch.backends.mlx.llm.quantization import quantize_model_ + + quantize_model_( + exportable.model, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, + tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False) + and not no_tie_word_embeddings, + ) + + logger.info("Exporting model with torch.export...") + seq_length = 3 + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + + seq_len_dim = torch.export.Dim("seq_length_dim", max=effective_cache_len - 1) + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + with torch.no_grad(): + exported_program = torch.export.export( + exportable, + args=(), + kwargs={ + "input_ids": example_input_ids, + "cache_position": example_cache_position, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + logger.info("Export completed successfully") + for sym, constraint in exported_program.range_constraints.items(): + logger.info(f" Range constraint: {sym}: {constraint}") + + logger.info("Delegating to MLX backend...") + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + edge_program = exir.to_edge_transform_and_lower( + {"forward": exported_program}, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + ) + + logger.info("Exporting to ExecuTorch...") + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ) + ) + + _save_program(executorch_program, output_path) + + +def _save_program(executorch_program, output_path: str) -> None: + """Save the ExecuTorch program to disk.""" + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info(f"Saved model to: {output_path}") + logger.info(f"Program size: {len(executorch_program.buffer) / 1024 / 1024:.2f} MB") + + +def export_llama_hf( + model_id: str, + output_path: str, + max_seq_len: int = 1024, + dtype: str = "bf16", + qlinear: Optional[str] = None, + qembedding: Optional[str] = None, + use_custom_sdpa: bool = False, + use_custom_kv_cache: bool = False, + no_tie_word_embeddings: bool = False, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, +) -> None: + """ + Export a HuggingFace Llama model to ExecuTorch with MLX backend. + + Args: + model_id: HuggingFace model ID + output_path: Path to save the .pte file + max_seq_len: Maximum sequence length for KV cache + dtype: Model dtype ("fp32", "fp16", "bf16") + qlinear: Quantization for linear layers ("4w", "8w", "nvfp4", or None) + qembedding: Quantization for embeddings ("4w", "8w", "nvfp4", or None) + use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa) + use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update) + """ + if use_custom_sdpa or use_custom_kv_cache: + logger.info( + f"Using custom components: sdpa={use_custom_sdpa}, " + f"kv_cache={use_custom_kv_cache}" + ) + _export_with_custom_components( + model_id=model_id, + output_path=output_path, + max_seq_len=max_seq_len, + dtype=dtype, + qlinear=qlinear, + qembedding=qembedding, + use_custom_sdpa=use_custom_sdpa, + use_custom_kv_cache=use_custom_kv_cache, + no_tie_word_embeddings=no_tie_word_embeddings, + qlinear_group_size=qlinear_group_size, + qembedding_group_size=qembedding_group_size, + ) + else: + logger.info("Using optimum-executorch pipeline (no custom components)") + _export_with_optimum( + model_id=model_id, + output_path=output_path, + max_seq_len=max_seq_len, + dtype=dtype, + qlinear=qlinear, + qembedding=qembedding, + no_tie_word_embeddings=no_tie_word_embeddings, + qlinear_group_size=qlinear_group_size, + qembedding_group_size=qembedding_group_size, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Export HuggingFace Llama model to MLX backend" + ) + parser.add_argument( + "--model-id", + type=str, + default="unsloth/Llama-3.2-1B-Instruct", + help="HuggingFace model ID", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output .pte file path", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=1024, + help="Maximum sequence length for KV cache", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Model dtype", + ) + from executorch.backends.mlx.llm.quantization import add_quantization_args + + add_quantization_args(parser) + parser.add_argument( + "--use-custom-sdpa", + action="store_true", + default=False, + help="Use MLX custom SDPA (mlx::custom_sdpa) for attention", + ) + parser.add_argument( + "--use-custom-kv-cache", + action="store_true", + default=False, + help="Use MLX custom KV cache (mlx::kv_cache_update)", + ) + + args = parser.parse_args() + + export_llama_hf( + model_id=args.model_id, + output_path=args.output, + max_seq_len=args.max_seq_len, + dtype=args.dtype, + qlinear=args.qlinear, + qembedding=args.qembedding, + use_custom_sdpa=args.use_custom_sdpa, + use_custom_kv_cache=args.use_custom_kv_cache, + no_tie_word_embeddings=args.no_tie_word_embeddings, + qlinear_group_size=args.qlinear_group_size, + qembedding_group_size=args.qembedding_group_size, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py new file mode 100644 index 00000000000..ca3d0468114 --- /dev/null +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run exported Llama model (from HuggingFace) using ExecuTorch pybindings. + +This script runs models exported using export_llm_hf.py. It loads the tokenizer +directly from HuggingFace using the same model ID used during export. + +Usage: + python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte llama_hf.pte \ + --model-id unsloth/Llama-3.2-1B-Instruct \ + --prompt "Hello, world!" +""" + +import argparse +import logging +import time + +import torch +from executorch.runtime import Runtime, Verification +from transformers import AutoTokenizer + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def _get_max_input_seq_len(program) -> int: + """Inspect the .pte program metadata to determine the max input_ids seq len. + + Returns the static seq-len dimension of the first input tensor (input_ids). + For models exported with dynamic shapes this will be the upper-bound; for + models exported with a fixed (1,1) shape it will be 1. + """ + meta = program.metadata("forward") + input_ids_info = meta.input_tensor_meta(0) + sizes = input_ids_info.sizes() + # sizes is (batch, seq_len) + return sizes[1] if len(sizes) >= 2 else 1 + + +def run_inference( + pte_path: str, + model_id: str, + prompt: str, + max_new_tokens: int = 50, +) -> str: + """Run inference on the exported HuggingFace model.""" + logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + logger.info(f"Loading model from {pte_path}...") + et_runtime = Runtime.get() + program = et_runtime.load_program(pte_path, verification=Verification.Minimal) + + max_seq_len = _get_max_input_seq_len(program) + logger.info(f"Model input_ids max seq len: {max_seq_len}") + + forward = program.load_method("forward") + + logger.info(f"Encoding prompt: {prompt!r}") + messages = [{"role": "user", "content": prompt}] + formatted_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") + logger.info(f"Input shape: {input_ids.shape}") + + generated_tokens = input_ids[0].tolist() + seq_len = input_ids.shape[1] + + start_time = time.time() + + if max_seq_len == 1: + # Model was exported with fixed (1,1) input — token-by-token prefill + logger.info(f"Running token-by-token prefill ({seq_len} tokens)...") + for i in range(seq_len): + token_input = input_ids[:, i : i + 1] + cache_position = torch.tensor([i], dtype=torch.long) + outputs = forward.execute([token_input, cache_position]) + logits = outputs[0] + else: + # Model was exported with dynamic seq len — full-prompt prefill + logger.info(f"Running full-prompt prefill ({seq_len} tokens)...") + cache_position = torch.arange(seq_len, dtype=torch.long) + outputs = forward.execute([input_ids, cache_position]) + logits = outputs[0] + + prefill_time = time.time() - start_time + logger.info( + f"Prefill time: {prefill_time:.3f}s " + f"({seq_len / prefill_time:.1f} tokens/sec)" + ) + + # Get the next token from the last position + next_token_logits = logits[0, -1, :] + next_token = torch.argmax(next_token_logits).item() + generated_tokens.append(next_token) + + # Decode: generate tokens one at a time + logger.info(f"Generating up to {max_new_tokens} tokens...") + decode_start = time.time() + + for i in range(max_new_tokens - 1): + pos = len(generated_tokens) - 1 + cache_position = torch.tensor([pos], dtype=torch.long) + token_input = torch.tensor([[next_token]], dtype=torch.long) + + outputs = forward.execute([token_input, cache_position]) + logits = outputs[0] + + next_token_logits = logits[0, -1, :] + next_token = torch.argmax(next_token_logits).item() + generated_tokens.append(next_token) + + if next_token == tokenizer.eos_token_id: + logger.info(f"EOS token reached at position {i + 1}") + break + + decode_time = time.time() - decode_start + num_generated = len(generated_tokens) - seq_len + tokens_per_sec = num_generated / decode_time if decode_time > 0 else 0 + + print(f"\nPrefill time: {prefill_time:.3f}s ({seq_len / prefill_time:.1f} tok/s)") + print( + f"Decode time: {decode_time:.3f}s ({num_generated} tokens, {tokens_per_sec:.1f} tok/s)" + ) + + # Decode only the newly generated tokens (not the input prompt) + new_tokens = generated_tokens[seq_len:] + generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + return generated_text + + +def main(): + parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model") + parser.add_argument( + "--pte", + type=str, + default="llama_hf.pte", + help="Path to the .pte file", + ) + parser.add_argument( + "--model-id", + type=str, + default="unsloth/Llama-3.2-1B-Instruct", + help="HuggingFace model ID (used to load tokenizer)", + ) + parser.add_argument( + "--prompt", + type=str, + default="The quick brown fox", + help="Input prompt", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=50, + help="Maximum number of new tokens to generate", + ) + + args = parser.parse_args() + + generated_text = run_inference( + pte_path=args.pte, + model_id=args.model_id, + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + ) + + print("\n" + "=" * 60) + print("Generated text:") + print("=" * 60) + print(generated_text) + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/voxtral/README.md b/backends/mlx/examples/voxtral/README.md new file mode 100644 index 00000000000..16d2384ed42 --- /dev/null +++ b/backends/mlx/examples/voxtral/README.md @@ -0,0 +1,69 @@ +# Voxtral MLX Export + +Export [mistralai/Voxtral-Mini-3B-2507](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507) +multimodal audio-language model to ExecuTorch with the MLX backend. + +Uses [optimum-executorch](https://github.com/huggingface/optimum-executorch) for +the export pipeline. + +## Prerequisites + +```bash +pip install transformers torch optimum-executorch mistral-common librosa +``` + +## Export + +Export with int4 weight quantization (recommended): + +```bash +python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --output-dir voxtral_mlx \ + --dtype bf16 \ + --qlinear 4w +``` + +This produces: +- `model.pte` — the main model (audio_encoder, token_embedding, text_decoder) +- `preprocessor.pte` — mel spectrogram preprocessor for raw audio + +### Export Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--model-id` | `mistralai/Voxtral-Mini-3B-2507` | HuggingFace model ID | +| `--output-dir` | `voxtral_mlx` | Output directory | +| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--max-seq-len` | `1024` | Maximum sequence length for KV cache | +| `--max-audio-len` | `300` | Maximum audio length in seconds | +| `--qlinear` | `4w` | Linear layer quantization (`4w`, `8w`, `nvfp4`, or None) | +| `--qlinear-group-size` | auto | Group size for linear quantization | + +### Quantization + +The `4w` config uses int4 weight-only quantization with the HQQ algorithm for +optimal scale selection. This typically reduces model size by ~4x with minimal +quality loss. + +## Run + +Requires the C++ voxtral runner. Build with: + +```bash +make voxtral-mlx +``` + +Run inference: + +```bash +./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path voxtral_mlx/model.pte \ + --processor_path voxtral_mlx/preprocessor.pte \ + --tokenizer_path /path/to/tekken.json \ + --audio_path /path/to/audio.wav \ + --prompt "What is happening in this audio?" \ + --temperature 0 +``` + +The `tekken.json` tokenizer is included in the model weights directory +downloaded from HuggingFace. diff --git a/backends/mlx/examples/voxtral/__init__.py b/backends/mlx/examples/voxtral/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/examples/voxtral/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/examples/voxtral/export_voxtral_hf.py b/backends/mlx/examples/voxtral/export_voxtral_hf.py new file mode 100644 index 00000000000..b9ed2bccf1c --- /dev/null +++ b/backends/mlx/examples/voxtral/export_voxtral_hf.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export Voxtral model from HuggingFace using optimum-executorch, delegated to +the MLX backend. + +Voxtral is a multimodal audio-language model (mistralai/Voxtral-Mini-3B-2507). +The exported .pte contains three methods: + - audio_encoder : mel-spectrogram features → audio embeddings + - token_embedding : token ids → text embeddings + - text_decoder : embeddings + cache_position → next-token logits + +Usage: + python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --model-id "mistralai/Voxtral-Mini-3B-2507" \ + --output voxtral_mlx.pte + +Requirements: + pip install transformers torch optimum-executorch mistral-common librosa +""" + +import argparse +import logging +import os +from typing import Optional + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def export_voxtral_hf( + model_id: str, + output_dir: str, + max_seq_len: int = 1024, + dtype: str = "bf16", + qlinear: Optional[str] = None, + qembedding: Optional[str] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, + max_audio_len: int = 300, +) -> None: + """ + Export a HuggingFace Voxtral model using optimum-executorch, delegated to + the MLX backend. Outputs two files: + - model.pte: the main model (audio_encoder, token_embedding, text_decoder) + - preprocessor.pte: mel spectrogram preprocessor for raw audio + + Args: + model_id: HuggingFace model ID (e.g., "mistralai/Voxtral-Mini-3B-2507") + output_dir: Directory to save the .pte files + max_seq_len: Maximum sequence length for KV cache + dtype: Model dtype ("fp32", "fp16", "bf16") + qlinear: Quantization for linear layers ("4w", "8w", "nvfp4", or None) + qembedding: Quantization for embeddings ("4w", "8w", "nvfp4", or None) + qlinear_group_size: Group size for linear quantization (default: auto) + qembedding_group_size: Group size for embedding quantization (default: auto) + max_audio_len: Maximum audio length in seconds for preprocessor + """ + from optimum.exporters.executorch.tasks.multimodal_text_to_text import ( + load_multimodal_text_to_text_model, + ) + + os.makedirs(output_dir, exist_ok=True) + + # --- Export preprocessor --- + from executorch.extension.audio.mel_spectrogram import export_processor + + export_processor( + output_file=os.path.join(output_dir, "preprocessor.pte"), + backend="mlx", + feature_size=128, + max_audio_len=max_audio_len, + stack_output=True, + ) + + # --- Export model --- + logger.info(f"Loading model using optimum-executorch: {model_id}") + + dtype_map = {"fp32": "float32", "fp16": "float16", "bf16": "bfloat16"} + dtype_str = dtype_map.get(dtype, "bfloat16") + + exportable = load_multimodal_text_to_text_model( + model_id, + dtype=dtype_str, + max_seq_len=max_seq_len, + ) + + # Apply quantization if requested + from executorch.backends.mlx.llm.quantization import quantize_model_ + + quantize_model_( + exportable.model, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, + ) + + logger.info("Exporting model with torch.export...") + exported_progs = exportable.export() + logger.info(f"Exported methods: {list(exported_progs.keys())}") + + logger.info("Delegating to MLX backend...") + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + from executorch.exir.passes import MemoryPlanningPass + + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + edge_program = exir.to_edge_transform_and_lower( + exported_progs, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + constant_methods=exportable.metadata, + ) + + logger.info("Exporting to ExecuTorch...") + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + model_path = os.path.join(output_dir, "model.pte") + with open(model_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info(f"Saved model to: {model_path}") + logger.info(f"Model size: {len(executorch_program.buffer) / 1024 / 1024:.2f} MB") + + +def main(): + parser = argparse.ArgumentParser( + description="Export HuggingFace Voxtral model using optimum-executorch to MLX" + ) + parser.add_argument( + "--model-id", + type=str, + default="mistralai/Voxtral-Mini-3B-2507", + help="HuggingFace model ID", + ) + parser.add_argument( + "--output-dir", + type=str, + default="voxtral_mlx", + help="Output directory for model.pte and preprocessor.pte", + ) + parser.add_argument( + "--max-seq-len", + type=int, + default=1024, + help="Maximum sequence length for KV cache", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Model dtype", + ) + from executorch.backends.mlx.llm.quantization import add_quantization_args + + add_quantization_args(parser) + parser.add_argument( + "--max-audio-len", + type=int, + default=300, + help="Maximum audio length in seconds for preprocessor", + ) + + args = parser.parse_args() + + export_voxtral_hf( + model_id=args.model_id, + output_dir=args.output_dir, + max_seq_len=args.max_seq_len, + dtype=args.dtype, + qlinear=args.qlinear, + qembedding=args.qembedding, + qlinear_group_size=args.qlinear_group_size, + qembedding_group_size=args.qembedding_group_size, + max_audio_len=args.max_audio_len, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/whisper/README.md b/backends/mlx/examples/whisper/README.md new file mode 100644 index 00000000000..6487a22a3a5 --- /dev/null +++ b/backends/mlx/examples/whisper/README.md @@ -0,0 +1,70 @@ +# Whisper MLX Export + +Export and run [OpenAI Whisper](https://huggingface.co/openai/whisper-tiny) speech-to-text models on the MLX backend. + +## Prerequisites + +```bash +pip install transformers torchao soundfile datasets +``` + +## Export + +The export script splits the model into three programs: + +- **encoder.pte** — audio features → encoder hidden states +- **cross_kv.pte** — encoder hidden states → per-layer cross-attention K/V +- **decoder.pte** — token-by-token generation with self-attention KV cache + +Export with int4 weight quantization: + +```bash +python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id openai/whisper-tiny \ + --output-dir /tmp/whisper_mlx \ + --dtype bf16 \ + --qlinear 4w +``` + +### Export Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-id` | `openai/whisper-tiny` | HuggingFace model ID | +| `--output-dir` | `whisper_mlx` | Output directory for `.pte` files | +| `--max-decoder-seq-len` | `256` | Maximum decoder sequence length | +| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) | +| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) | +| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) | +| `--qlinear-group-size` | auto | Group size for linear quantization | +| `--qembedding-group-size` | auto | Group size for embedding quantization | + + +## Run + +```bash +python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio +``` + +Or with a custom audio file: + +```bash +python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --audio-file /path/to/audio.wav +``` + +### Run Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-dir` | `/tmp/whisper_mlx` | Directory containing exported `.pte` files | +| `--model-id` | `openai/whisper-tiny` | HuggingFace model ID (used to load processor) | +| `--audio-file` | None | Path to audio file (WAV, MP3, etc.) | +| `--use-sample-audio` | off | Use sample audio from HuggingFace datasets | +| `--max-new-tokens` | `256` | Maximum tokens to generate | +| `--language` | `en` | Language code | +| `--task` | `transcribe` | `transcribe` or `translate` | +| `--dtype` | `bf16` | Input dtype (must match export dtype) | diff --git a/backends/mlx/examples/whisper/__init__.py b/backends/mlx/examples/whisper/__init__.py new file mode 100644 index 00000000000..0adc14c3f18 --- /dev/null +++ b/backends/mlx/examples/whisper/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# diff --git a/backends/mlx/examples/whisper/args.py b/backends/mlx/examples/whisper/args.py new file mode 100644 index 00000000000..82ed7371926 --- /dev/null +++ b/backends/mlx/examples/whisper/args.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared argument definitions for Whisper export and run scripts. +""" + +import argparse +import logging +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + + +def add_export_args(parser: argparse.ArgumentParser) -> None: + """Add common export arguments for Whisper scripts.""" + parser.add_argument( + "--model-id", + type=str, + default="openai/whisper-tiny", + help="HuggingFace model ID", + ) + parser.add_argument( + "--max-decoder-seq-len", + type=int, + default=256, + help="Maximum decoder sequence length", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Model dtype", + ) + from executorch.backends.mlx.llm.quantization import add_quantization_args + + add_quantization_args(parser) + + +def add_run_args(parser: argparse.ArgumentParser) -> None: + """Add common runtime arguments for Whisper scripts.""" + parser.add_argument( + "--model-id", + type=str, + default="openai/whisper-tiny", + help="HuggingFace model ID (used to load processor)", + ) + parser.add_argument( + "--audio-file", + type=str, + default=None, + help="Path to audio file (WAV, MP3, etc.)", + ) + parser.add_argument( + "--use-sample-audio", + action="store_true", + help="Use sample audio from HuggingFace datasets", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=256, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--language", + type=str, + default="en", + help="Language code for transcription", + ) + parser.add_argument( + "--task", + type=str, + choices=["transcribe", "translate"], + default="transcribe", + help="Task: transcribe or translate", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp32", "fp16", "bf16"], + default="bf16", + help="Input dtype (must match the dtype used during export)", + ) + + +def load_audio( + audio_path: Optional[str], + use_sample_audio: bool, + processor, +) -> torch.Tensor: + """Load and preprocess audio input. + + Returns: + input_features: [1, n_mels, n_frames] tensor + """ + if use_sample_audio: + logger.info("Loading sample audio from HuggingFace datasets...") + try: + from datasets import load_dataset + except ImportError: + logger.error("datasets not installed. Run: pip install datasets") + raise + + dataset = load_dataset( + "distil-whisper/librispeech_long", + "clean", + split="validation", + ) + sample = dataset[0]["audio"] + audio_array = sample["array"] + sampling_rate = sample["sampling_rate"] + else: + if audio_path is None: + raise ValueError( + "Either --audio-file or --use-sample-audio must be provided" + ) + + logger.info(f"Loading audio from: {audio_path}") + try: + import soundfile as sf + except ImportError: + logger.error("soundfile not installed. Run: pip install soundfile") + raise + + audio_array, sampling_rate = sf.read(audio_path) + + input_features = processor( + audio_array, + return_tensors="pt", + truncation=False, + sampling_rate=sampling_rate, + ).input_features + + # Truncate to 30 seconds (3000 frames at 100 frames/sec) + max_frames = 3000 + if input_features.shape[2] > max_frames: + logger.info( + f"Truncating audio from {input_features.shape[2]} to {max_frames} frames" + ) + input_features = input_features[:, :, :max_frames].contiguous() + + return input_features diff --git a/backends/mlx/examples/whisper/export_whisper.py b/backends/mlx/examples/whisper/export_whisper.py new file mode 100644 index 00000000000..97d3a22bc79 --- /dev/null +++ b/backends/mlx/examples/whisper/export_whisper.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export Whisper model to MLX delegate using ExecuTorch. + +Exports three separate programs: +- encoder.pte: Audio features → encoder hidden states +- cross_kv.pte: Encoder hidden states → per-layer cross-attention K/V +- decoder.pte: Token-by-token generation with self-attention KV cache + +The decoder uses: +- llama.update_cache for self-attention KV cache updates +- Pre-computed cross-attention K/V passed as inputs + +Usage: + python -m executorch.backends.mlx.examples.whisper.export_whisper \ + --model-id "openai/whisper-tiny" \ + --output-dir /tmp/whisper_mlx \ + --quantize-linear int4 + +Requirements: + pip install transformers torchao +""" + +import argparse +import logging +import os +from typing import Optional, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import WhisperForConditionalGeneration + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Import shared KV cache module +from executorch.backends.mlx.llm.cache import KVCache +from executorch.backends.mlx.passes import get_default_passes + +# Import custom ops +from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +class WhisperEncoderExportable(nn.Module): + """ + Wrapper around Whisper's encoder for export. + + forward(input_features) -> encoder_hidden_states + """ + + def __init__(self, encoder: nn.Module): + super().__init__() + self.encoder = encoder + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + return self.encoder(input_features=input_features).last_hidden_state + + +class WhisperSelfAttentionWithCache(nn.Module): + """ + Whisper self-attention layer with static KV cache. + + Uses llama.update_cache pattern for cache updates. + """ + + def __init__( + self, + attn_module: nn.Module, + max_cache_len: int, + dtype: torch.dtype, + ): + super().__init__() + self.q_proj = attn_module.q_proj + self.k_proj = attn_module.k_proj + self.v_proj = attn_module.v_proj + self.out_proj = attn_module.out_proj + + self.num_heads = attn_module.num_heads + self.head_dim = attn_module.head_dim + self.scale = self.head_dim**-0.5 + self.max_cache_len = max_cache_len + + # Initialize KV cache module + self.kv_cache = KVCache( + max_batch_size=1, + max_context_length=max_cache_len, + n_heads=self.num_heads, + head_dim=self.head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, # [B, T, H*D] + pos_int: int, # Position as SymInt + ) -> torch.Tensor: + B, T, _ = hidden_states.shape + H, D = self.num_heads, self.head_dim + + # Linear projections + q = self.q_proj(hidden_states) # [B, T, H*D] + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to [B, H, T, D] + q = q.view(B, T, H, D).transpose(1, 2) + k = k.view(B, T, H, D).transpose(1, 2) + v = v.view(B, T, H, D).transpose(1, 2) + + # Update KV cache + k_cache, v_cache = self.kv_cache.update(pos_int, k, v) + + # Explicit windowing: slice cache to valid positions + end_pos = pos_int + T + k_win = k_cache[:, :, :end_pos, :] + v_win = v_cache[:, :, :end_pos, :] + + # SDPA with causal mask + attn_out = F.scaled_dot_product_attention( + q, k_win, v_win, attn_mask=None, is_causal=True, scale=self.scale + ) + + # Reshape back + attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, H * D) + return self.out_proj(attn_out) + + +class WhisperCrossAttention(nn.Module): + """ + Whisper cross-attention layer. + + K/V are pre-computed from encoder output and passed as inputs. + No cache update needed - just uses the pre-computed K/V directly. + """ + + def __init__(self, attn_module: nn.Module): + super().__init__() + self.q_proj = attn_module.q_proj + self.out_proj = attn_module.out_proj + + self.num_heads = attn_module.num_heads + self.head_dim = attn_module.head_dim + self.scale = self.head_dim**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, # [B, T_dec, H*D] + cross_k: torch.Tensor, # [B, H, T_enc, D] - pre-computed + cross_v: torch.Tensor, # [B, H, T_enc, D] - pre-computed + ) -> torch.Tensor: + B, T, _ = hidden_states.shape + H, D = self.num_heads, self.head_dim + + # Query projection + q = self.q_proj(hidden_states) + q = q.view(B, T, H, D).transpose(1, 2) # [B, H, T_dec, D] + + # SDPA with pre-computed K/V (no causal mask for cross-attention) + attn_out = F.scaled_dot_product_attention( + q, cross_k, cross_v, attn_mask=None, is_causal=False, scale=self.scale + ) + + # Reshape back + attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, H * D) + return self.out_proj(attn_out) + + +class WhisperDecoderLayerWithCache(nn.Module): + """ + Wrapper for a single Whisper decoder layer with KV cache. + """ + + def __init__( + self, + layer: nn.Module, + max_cache_len: int, + dtype: torch.dtype, + ): + super().__init__() + # Self-attention with cache + self.self_attn = WhisperSelfAttentionWithCache( + layer.self_attn, max_cache_len, dtype + ) + self.self_attn_layer_norm = layer.self_attn_layer_norm + + # Cross-attention (K/V passed as inputs) + self.encoder_attn = WhisperCrossAttention(layer.encoder_attn) + self.encoder_attn_layer_norm = layer.encoder_attn_layer_norm + + # FFN + self.fc1 = layer.fc1 + self.fc2 = layer.fc2 + self.final_layer_norm = layer.final_layer_norm + self.activation_fn = layer.activation_fn + + def forward( + self, + hidden_states: torch.Tensor, + pos_int: int, + cross_k: torch.Tensor, + cross_v: torch.Tensor, + ) -> torch.Tensor: + # Self-attention + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states, pos_int) + hidden_states = residual + hidden_states + + # Cross-attention + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn(hidden_states, cross_k, cross_v) + hidden_states = residual + hidden_states + + # FFN + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperDecoderWithCache(nn.Module): + """ + Whisper decoder wrapper with static KV cache. + + Takes: + - decoder_input_ids: [B, T_dec] token IDs + - cache_position: [1] tensor with start position + - cross_k_tuple: tuple of num_layers tensors [B, H, T_enc, D] - pre-computed cross K + - cross_v_tuple: tuple of num_layers tensors [B, H, T_enc, D] - pre-computed cross V + + Returns: + - logits: [B, T_dec, vocab_size] + """ + + def __init__( + self, + model: "WhisperForConditionalGeneration", + max_decoder_seq_len: int, + ): + super().__init__() + + decoder = model.get_decoder() + dtype = decoder.embed_tokens.weight.dtype + + self.embed_tokens = decoder.embed_tokens + self.embed_positions = decoder.embed_positions + self.layer_norm = decoder.layer_norm + self.proj_out = model.proj_out + + # Wrap decoder layers with cache + self.layers = nn.ModuleList( + [ + WhisperDecoderLayerWithCache(layer, max_decoder_seq_len, dtype) + for layer in decoder.layers + ] + ) + + self.num_layers = len(self.layers) + self.max_decoder_seq_len = max_decoder_seq_len + + def forward( + self, + decoder_input_ids: torch.Tensor, # [B, T_dec] + cache_position: torch.Tensor, # [1] tensor + cross_k_tuple: Tuple[torch.Tensor, ...], # num_layers x [B, H, T_enc, D] + cross_v_tuple: Tuple[torch.Tensor, ...], # num_layers x [B, H, T_enc, D] + ) -> torch.Tensor: + B, T = decoder_input_ids.shape + + # Get position as SymInt + torch._check(cache_position.numel() == 1) + pos_int = cache_position.item() + torch._check(pos_int >= 0) + torch._check(pos_int + T <= self.max_decoder_seq_len) + + # Token + positional embeddings + # Whisper uses absolute positions [pos_int, pos_int + T) + # Use F.embedding to ensure proper lowering (not aten.index.Tensor) + positions = torch.arange( + pos_int, pos_int + T, device=decoder_input_ids.device, dtype=torch.long + ) + hidden_states = self.embed_tokens(decoder_input_ids) + pos_embed = F.embedding(positions, self.embed_positions.weight) + hidden_states = hidden_states + pos_embed + + # Decoder layers + for i, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, pos_int, cross_k_tuple[i], cross_v_tuple[i] + ) + + hidden_states = self.layer_norm(hidden_states) + logits = self.proj_out(hidden_states) + return logits + + +class WhisperCrossKVProjection(nn.Module): + """ + Compute cross-attention K/V projections from encoder hidden states. + + forward(encoder_hidden_states) -> (k_tuple, v_tuple) + """ + + def __init__(self, model: "WhisperForConditionalGeneration"): + super().__init__() + decoder = model.get_decoder() + + # Store K/V projections for each layer + self.k_projs = nn.ModuleList() + self.v_projs = nn.ModuleList() + self.num_heads_list = [] + self.head_dim_list = [] + + for layer in decoder.layers: + self.k_projs.append(layer.encoder_attn.k_proj) + self.v_projs.append(layer.encoder_attn.v_proj) + self.num_heads_list.append(layer.encoder_attn.num_heads) + self.head_dim_list.append(layer.encoder_attn.head_dim) + + def forward( + self, encoder_hidden_states: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: + """ + Returns: + (k_tuple, v_tuple) where each is a tuple of num_layers tensors, + each with shape [B, H, T_enc, D] + """ + B, T_enc, _ = encoder_hidden_states.shape + + k_list = [] + v_list = [] + + for i, (k_proj, v_proj) in enumerate(zip(self.k_projs, self.v_projs)): + H = self.num_heads_list[i] + D = self.head_dim_list[i] + + k = k_proj(encoder_hidden_states) # [B, T_enc, H*D] + v = v_proj(encoder_hidden_states) + + # Reshape to [B, H, T_enc, D] + k = k.view(B, T_enc, H, D).transpose(1, 2) + v = v.view(B, T_enc, H, D).transpose(1, 2) + + k_list.append(k) + v_list.append(v) + + return tuple(k_list), tuple(v_list) + + +def export_whisper_to_mlx( + model_id: str, + output_dir: str, + max_decoder_seq_len: int = 256, + dtype: str = "bf16", + qlinear: Optional[str] = None, + qembedding: Optional[str] = None, + qlinear_group_size: Optional[int] = None, + qembedding_group_size: Optional[int] = None, +) -> None: + """ + Export Whisper model components to MLX delegate. + + Exports: + - encoder.pte: Audio encoder + - cross_kv.pte: Cross-attention K/V projection + - decoder.pte: Decoder with self-attention KV cache + + Args: + model_id: HuggingFace model ID + output_dir: Directory to save .pte files + max_decoder_seq_len: Maximum decoder sequence length + dtype: Model dtype ("fp32", "fp16", "bf16") + qlinear: Quantization config for linear layers ("4w", "8w", "nvfp4", or None) + qembedding: Quantization config for embedding layers ("4w", "8w", "nvfp4", or None) + qlinear_group_size: Group size for linear quantization (default: auto) + qembedding_group_size: Group size for embedding quantization (default: auto) + """ + from transformers import AutoProcessor, WhisperForConditionalGeneration + + # Map dtype string to torch dtype + dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + torch_dtype = dtype_map.get(dtype, torch.float32) + + logger.info(f"Loading model: {model_id} (dtype={dtype})") + processor = AutoProcessor.from_pretrained(model_id) + model = WhisperForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype + ) + model.eval() + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Get feature extractor info + fe = processor.feature_extractor + batch_size = 1 + + # Create example encoder input + encoder_input = torch.zeros( + (batch_size, fe.feature_size, fe.nb_max_frames), dtype=torch_dtype + ) + + # Create wrappers + logger.info("Creating model wrappers...") + encoder_wrapper = WhisperEncoderExportable(model.get_encoder()).eval() + cross_kv_wrapper = WhisperCrossKVProjection(model).eval() + + # Get encoder output shape for decoder + with torch.no_grad(): + encoder_hidden_states = encoder_wrapper(encoder_input) + encoder_seq_len = encoder_hidden_states.shape[1] + + decoder_wrapper = WhisperDecoderWithCache(model, max_decoder_seq_len).eval() + + # Apply quantization if requested + # Whisper has 3 separate wrappers to quantize, and embed_positions must be + # excluded from embedding quantization (accessed via indexing). + if qlinear or qembedding: + from executorch.extension.llm.export.quantize import quantize_model_ + + if qlinear: + logger.info(f"Quantizing linear layers with {qlinear}...") + for module in [encoder_wrapper, cross_kv_wrapper, decoder_wrapper]: + quantize_model_( + module, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + skip_incompatible_shapes=True, + ) + + if qembedding: + # Custom filter: only embed_tokens, not embed_positions + from executorch.extension.llm.export.quantize import ( + _default_group_size, + _make_embedding_config, + ) + from torchao.quantization.quant_api import quantize_ + + gs = ( + qembedding_group_size + if qembedding_group_size is not None + else _default_group_size(qembedding) + ) + embed_config = _make_embedding_config(qembedding, gs) + logger.info( + f"Quantizing embedding layers with {qembedding} " + f"(group size {gs})..." + ) + quantize_( + decoder_wrapper, + embed_config, + lambda m, fqn: isinstance(m, nn.Embedding) and "embed_tokens" in fqn, + ) + + logger.info("Applied quantization successfully") + + logger.info("Exporting encoder...") + + with torch.no_grad(): + encoder_ep = torch.export.export( + encoder_wrapper, (encoder_input,), dynamic_shapes=None, strict=True + ) + encoder_ep = encoder_ep.run_decompositions({}) + + _save_to_pte(encoder_ep, os.path.join(output_dir, "encoder.pte"), "encoder") + + logger.info("Exporting cross-KV projection...") + + with torch.no_grad(): + example_cross_k, example_cross_v = cross_kv_wrapper(encoder_hidden_states) + example_cross_k = tuple(k.contiguous() for k in example_cross_k) + example_cross_v = tuple(v.contiguous() for v in example_cross_v) + + cross_kv_ep = torch.export.export( + cross_kv_wrapper, + (encoder_hidden_states,), + dynamic_shapes=None, + strict=True, + ) + cross_kv_ep = cross_kv_ep.run_decompositions({}) + + _save_to_pte(cross_kv_ep, os.path.join(output_dir, "cross_kv.pte"), "cross_kv") + + logger.info("Exporting decoder...") + + # Example inputs for decoder + start_id = getattr(model.config, "decoder_start_token_id", 0) + decoder_input_ids = torch.tensor([[start_id]], dtype=torch.long) + cache_position = torch.tensor([0], dtype=torch.long) + + with torch.no_grad(): + # Build dynamic shapes for all inputs + # decoder_input_ids: [B, T_dec] - T_dec is dynamic + # cache_position: [1] - static + # cross_k_tuple: tuple of num_layers tensors - static + # cross_v_tuple: tuple of num_layers tensors - static + seq_dim = torch.export.Dim.AUTO(min=1, max=max_decoder_seq_len) + num_layers = decoder_wrapper.num_layers + dynamic_shapes = ( + {1: seq_dim}, # decoder_input_ids + None, # cache_position + tuple(None for _ in range(num_layers)), # cross_k_tuple + tuple(None for _ in range(num_layers)), # cross_v_tuple + ) + + decoder_ep = torch.export.export( + decoder_wrapper, + (decoder_input_ids, cache_position, example_cross_k, example_cross_v), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + decoder_ep = decoder_ep.run_decompositions({}) + + _save_to_pte(decoder_ep, os.path.join(output_dir, "decoder.pte"), "decoder") + + # Save processor for inference + processor_path = os.path.join(output_dir, "processor") + processor.save_pretrained(processor_path) + logger.info(f"Saved processor to: {processor_path}") + + # Save metadata + metadata = { + "model_id": model_id, + "dtype": dtype, + "quantize_linear": qlinear, + "quantize_embeddings": qembedding, + "max_decoder_seq_len": max_decoder_seq_len, + "encoder_seq_len": encoder_seq_len, + "num_decoder_layers": decoder_wrapper.num_layers, + } + import json + + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2) + logger.info(f"Saved metadata to: {os.path.join(output_dir, 'metadata.json')}") + + +def _save_to_pte(ep, output_path: str, name: str) -> None: + """Lower and save an ExportedProgram to a .pte file.""" + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.exir import EdgeCompileConfig + from executorch.exir.capture._config import ExecutorchBackendConfig + + edge_config = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ) + + edge_program = exir.to_edge_transform_and_lower( + ep, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=edge_config, + ) + + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + + with open(output_path, "wb") as f: + f.write(executorch_program.buffer) + + logger.info( + f"Saved {name} to: {output_path} " + f"({len(executorch_program.buffer) / 1024 / 1024:.2f} MB)" + ) + + +def main(): + parser = argparse.ArgumentParser(description="Export Whisper model to MLX delegate") + from executorch.backends.mlx.examples.whisper.args import add_export_args + + add_export_args(parser) + parser.add_argument( + "--output-dir", + type=str, + default="whisper_mlx", + help="Output directory for .pte files", + ) + + args = parser.parse_args() + + export_whisper_to_mlx( + model_id=args.model_id, + output_dir=args.output_dir, + max_decoder_seq_len=args.max_decoder_seq_len, + dtype=args.dtype, + qlinear=args.qlinear, + qembedding=args.qembedding, + qlinear_group_size=args.qlinear_group_size, + qembedding_group_size=args.qembedding_group_size, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/examples/whisper/run_whisper.py b/backends/mlx/examples/whisper/run_whisper.py new file mode 100644 index 00000000000..e20e7db6e2b --- /dev/null +++ b/backends/mlx/examples/whisper/run_whisper.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run exported Whisper model using ExecuTorch pybindings. + +This script loads the three exported programs (encoder, cross_kv, decoder) +and performs speech-to-text transcription. + +Usage: + python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --audio-file /path/to/audio.wav + + # Or use sample audio from HuggingFace: + python -m executorch.backends.mlx.examples.whisper.run_whisper \ + --model-dir /tmp/whisper_mlx \ + --use-sample-audio + +Requirements: + pip install transformers soundfile datasets +""" + +import argparse +import json +import logging +import os +import time +from typing import List, Optional + +import torch + +from executorch.backends.mlx.examples.whisper.args import load_audio + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def run_whisper_inference( # noqa: C901 + model_dir: str, + audio_path: Optional[str] = None, + use_sample_audio: bool = False, + max_new_tokens: int = 256, + language: str = "en", + task: str = "transcribe", + dtype: str = "bf16", +) -> str: + """ + Run Whisper inference using exported ExecuTorch models. + + Args: + model_dir: Directory containing encoder.pte, cross_kv.pte, decoder.pte + audio_path: Path to audio file (WAV, MP3, etc.) + use_sample_audio: If True, use sample audio from HuggingFace + max_new_tokens: Maximum number of tokens to generate + language: Language code for transcription + task: "transcribe" or "translate" + dtype: Input dtype (must match the dtype used during export) + + Returns: + Transcribed text + """ + from executorch.runtime import Runtime, Verification + from transformers import AutoProcessor + + # Load metadata (for structural info like num_decoder_layers) + metadata_path = os.path.join(model_dir, "metadata.json") + with open(metadata_path, "r") as f: + metadata = json.load(f) + + num_layers = metadata["num_decoder_layers"] + + # Load processor + processor_path = os.path.join(model_dir, "processor") + logger.info(f"Loading processor from: {processor_path}") + processor = AutoProcessor.from_pretrained(processor_path) + + # Load audio + input_features = load_audio(audio_path, use_sample_audio, processor) + logger.info(f"Input features shape: {input_features.shape}") + + # Cast to model dtype + dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + model_dtype = dtype_map.get(dtype, torch.float32) + input_features = input_features.to(model_dtype) + logger.info(f"Input dtype: {input_features.dtype}") + + # Load ExecuTorch programs + et_runtime = Runtime.get() + + logger.info("Loading encoder...") + encoder_path = os.path.join(model_dir, "encoder.pte") + encoder_program = et_runtime.load_program( + encoder_path, verification=Verification.Minimal + ) + encoder_forward = encoder_program.load_method("forward") + + logger.info("Loading cross_kv...") + cross_kv_path = os.path.join(model_dir, "cross_kv.pte") + cross_kv_program = et_runtime.load_program( + cross_kv_path, verification=Verification.Minimal + ) + cross_kv_forward = cross_kv_program.load_method("forward") + + logger.info("Loading decoder...") + decoder_path = os.path.join(model_dir, "decoder.pte") + decoder_program = et_runtime.load_program( + decoder_path, verification=Verification.Minimal + ) + decoder_forward = decoder_program.load_method("forward") + + logger.info("Running encoder...") + overall_start = time.time() + start_time = time.time() + + encoder_outputs = encoder_forward.execute([input_features]) + encoder_hidden_states = encoder_outputs[0] + + encoder_time = time.time() - start_time + logger.info(f"Encoder time: {encoder_time:.3f}s") + logger.info(f"Encoder output shape: {encoder_hidden_states.shape}") + + logger.info("Computing cross-attention K/V...") + start_time = time.time() + + cross_kv_outputs = cross_kv_forward.execute([encoder_hidden_states]) + # Output is (k_tuple, v_tuple) flattened: [k0, k1, ..., v0, v1, ...] + # Each k_i, v_i has shape [B, H, T_enc, D] + cross_k_tuple = tuple(cross_kv_outputs[:num_layers]) + cross_v_tuple = tuple(cross_kv_outputs[num_layers:]) + + cross_kv_time = time.time() - start_time + logger.info(f"Cross-KV time: {cross_kv_time:.3f}s") + logger.info(f"Cross K/V: {num_layers} layers, each shape {cross_k_tuple[0].shape}") + + # Get forced decoder IDs for language/task + forced_decoder_ids = processor.get_decoder_prompt_ids( + language=language, + task=task, + ) + # Build forced tokens dict: position -> token_id + forced_tokens_dict = {} + if forced_decoder_ids is not None: + for item in forced_decoder_ids: + if isinstance(item, (list, tuple)) and len(item) == 2: + pos, tok_id = item + if tok_id is not None: + forced_tokens_dict[pos] = int(tok_id) + + # Start with decoder_start_token_id (start-of-transcript) + # Get from processor.tokenizer if available, otherwise use common ID + try: + sot_id = processor.tokenizer.convert_tokens_to_ids("<|startoftranscript|>") + except Exception: + sot_id = 50258 # Common Whisper SOT token ID + + # Also get EOS token ID + try: + eos_id = processor.tokenizer.convert_tokens_to_ids("<|endoftext|>") + except Exception: + eos_id = 50257 # Common Whisper EOS token ID + + generated_tokens: List[int] = [sot_id] + + logger.info(f"Generating up to {max_new_tokens} tokens...") + decode_start = time.time() + + # Initial decoder input + decoder_input_ids = torch.tensor([[sot_id]], dtype=torch.long) + cache_position = torch.tensor([0], dtype=torch.long) + + # Prefill with initial token + decoder_inputs = ( + [decoder_input_ids, cache_position] + list(cross_k_tuple) + list(cross_v_tuple) + ) + decoder_outputs = decoder_forward.execute(decoder_inputs) + logits = decoder_outputs[0] + + # Update cache position + cache_position = cache_position + decoder_input_ids.shape[1] + + # Generation loop + for _step in range(max_new_tokens): + current_pos = cache_position.item() + + # Check for forced token at this position + if current_pos in forced_tokens_dict: + next_token_id = forced_tokens_dict[current_pos] + else: + next_token_id = torch.argmax(logits[0, -1, :]).item() + + generated_tokens.append(next_token_id) + + # Check for EOS + if next_token_id == eos_id: + break + + # Prepare next decoder input + decoder_input_ids = torch.tensor([[next_token_id]], dtype=torch.long) + + # Run decoder + decoder_inputs = ( + [decoder_input_ids, cache_position] + + list(cross_k_tuple) + + list(cross_v_tuple) + ) + decoder_outputs = decoder_forward.execute(decoder_inputs) + logits = decoder_outputs[0] + + # Update cache position + cache_position = cache_position + 1 + + decode_time = time.time() - decode_start + total_time = time.time() - overall_start + tokens_generated = len(generated_tokens) - 1 # Exclude initial SOT + tokens_per_sec = tokens_generated / decode_time if decode_time > 0 else 0 + + print(f"\nEncoder time: {encoder_time:.3f}s") + print(f"Cross-KV time: {cross_kv_time:.3f}s") + print( + f"Decode time: {decode_time:.3f}s ({tokens_generated} tokens, {tokens_per_sec:.1f} tok/s)" + ) + print(f"Total time: {total_time:.3f}s") + + # Decode to text + transcript = processor.tokenizer.decode( + generated_tokens, + skip_special_tokens=True, + ) + + return transcript + + +def main(): + parser = argparse.ArgumentParser(description="Run exported Whisper model") + from executorch.backends.mlx.examples.whisper.args import add_run_args + + add_run_args(parser) + parser.add_argument( + "--model-dir", + type=str, + default="/tmp/whisper_mlx", + help="Directory containing exported .pte files", + ) + + args = parser.parse_args() + + if not args.audio_file and not args.use_sample_audio: + logger.warning("No audio specified. Using --use-sample-audio") + args.use_sample_audio = True + + transcript = run_whisper_inference( + model_dir=args.model_dir, + audio_path=args.audio_file, + use_sample_audio=args.use_sample_audio, + max_new_tokens=args.max_new_tokens, + language=args.language, + task=args.task, + dtype=args.dtype, + ) + + print("\n" + "=" * 60) + print("Transcript:") + print("=" * 60) + print(transcript) + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/llm/et_attention.py b/backends/mlx/llm/et_attention.py new file mode 100644 index 00000000000..10c758f94fe --- /dev/null +++ b/backends/mlx/llm/et_attention.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MLX-optimized attention for ExecutorTorch's Llama attention registry. + +Registers an "mlx" attention type that uses mlx::kv_cache_update and +mlx::custom_sdpa for efficient execution on Apple Silicon. + +Usage: + import executorch.backends.mlx.llm.et_attention # noqa: F401 + + model_args = ModelArgs(attention_type="mlx", ...) + transformer = construct_transformer(model_args) +""" + +from typing import Any, Optional, Tuple, TYPE_CHECKING + +import executorch.backends.mlx.custom_ops as _mlx_custom_ops # noqa: F401 + +import torch +import torch.nn as nn +from executorch.backends.mlx.llm.cache import KVCache +from executorch.examples.models.llama.attention import ( + Attention, + ForwardOptions, + register_attention, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.rope import Rope + +if TYPE_CHECKING: + from executorch.examples.models.llama.attention import AttentionMHA + + +@register_attention("mlx") +class MLXAttentionMHA(Attention): + """ + MLX-optimized attention using mlx::kv_cache_update and mlx::custom_sdpa. + + Supports MHA, GQA, KV caching, and optional QK normalization. + Follows the same interface as AttentionMHA. + """ + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + if not args.use_kv_cache: + raise ValueError("MLXAttention requires use_kv_cache=True") + + self.use_kv_cache = True + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_batch_size = args.max_batch_size + self.max_context_len = args.max_context_len + self.dim = args.dim + self.attention_qkv_bias = args.attention_qkv_bias + self.use_qk_norm = args.use_qk_norm + self.qk_norm_before_rope = args.qk_norm_before_rope + self.enable_dynamic_shape = args.enable_dynamic_shape + + if self.use_qk_norm: + self.q_norm_fn = RMSNorm(self.head_dim, eps=args.norm_eps) + self.k_norm_fn = RMSNorm(self.head_dim, eps=args.norm_eps) + + self.wq = nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wk = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wv = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.layer_id = layer_id + self.rope = rope + self.rope_base = rope.params.rope_freq_base + self.use_fused_rope = self._can_use_fused_rope(rope.params) + self.rope_traditional = not rope.params.use_hf_rope + self.rope_dims = int(self.head_dim * rope.params.partial_rotary_factor) + + self.kv_cache = KVCache( + max_batch_size=args.max_batch_size, + max_context_length=args.max_context_len, + n_heads=self.n_kv_heads, + head_dim=self.head_dim, + enable_dynamic_shape=args.enable_dynamic_shape, + ) + + @staticmethod + def _can_use_fused_rope(params: ModelArgs) -> bool: + if params.no_rope_layer_interval is not None: + return False + return True + + @classmethod + def from_attention_mha( + cls, other: "AttentionMHA", dtype: Optional[torch.dtype] = None + ) -> "MLXAttentionMHA": + """ + Create an MLXAttentionMHA from an existing AttentionMHA. + + Shares weight references (wq, wk, wv, wo, rope, norm) and creates + a fresh KVCache. + """ + from executorch.examples.models.llama.attention import AttentionMHA + + assert isinstance(other, AttentionMHA) + + instance = cls.__new__(cls) + Attention.__init__(instance) + + # Copy all config attributes + instance.use_kv_cache = True + instance.n_heads = other.n_heads + instance.n_kv_heads = other.n_kv_heads + instance.n_local_heads = other.n_local_heads + instance.n_local_kv_heads = other.n_local_kv_heads + instance.n_rep = other.n_rep + instance.head_dim = other.head_dim + instance.max_batch_size = other.max_batch_size + instance.max_context_len = other.max_context_len + instance.dim = other.dim + instance.attention_qkv_bias = other.attention_qkv_bias + instance.use_qk_norm = other.use_qk_norm + instance.qk_norm_before_rope = other.qk_norm_before_rope + instance.enable_dynamic_shape = other.enable_dynamic_shape + + # Share weight references + instance.wq = other.wq + instance.wk = other.wk + instance.wv = other.wv + instance.wo = other.wo + instance.layer_id = other.layer_id + instance.rope = other.rope + instance.rope_base = other.rope.params.rope_freq_base + instance.use_fused_rope = cls._can_use_fused_rope(other.rope.params) + instance.rope_traditional = not other.rope.params.use_hf_rope + instance.rope_dims = int( + instance.head_dim * other.rope.params.partial_rotary_factor + ) + + if other.use_qk_norm: + instance.q_norm_fn = other.q_norm_fn + instance.k_norm_fn = other.k_norm_fn + + # Create fresh MLX KV cache + cache_dtype = dtype if dtype is not None else torch.float32 + if hasattr(other, "kv_cache") and hasattr(other.kv_cache, "k_cache"): + cache_dtype = dtype if dtype is not None else other.kv_cache.k_cache.dtype + instance.kv_cache = KVCache( + max_batch_size=other.max_batch_size, + max_context_length=other.max_context_len, + n_heads=instance.n_kv_heads, + head_dim=instance.head_dim, + enable_dynamic_shape=other.enable_dynamic_shape, + dtype=cache_dtype, + ) + + return instance + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + input_pos = kwargs.get("input_pos") + assert input_pos is not None + bsz, seqlen, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + if self.use_qk_norm and self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + + if "start_pos" in kwargs: + start_pos = kwargs["start_pos"] + else: + start_pos = input_pos[0].item() + + if self.use_fused_rope: + # Transpose to BHSD first (mlx::rope expects BHSD) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + q = torch.ops.mlx.rope( + q, + self.rope_dims, + start_pos, + self.rope_traditional, + self.rope_base, + 1.0, + None, + ) + k = torch.ops.mlx.rope( + k, + self.rope_dims, + start_pos, + self.rope_traditional, + self.rope_base, + 1.0, + None, + ) + else: + # Fallback: upstream rope (handles scaled rope, partial rotary, etc.) + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + k, v = self.kv_cache.update(start_pos, k, v) + + output = torch.ops.mlx.custom_sdpa( + q, + k, + v, + start_pos=start_pos, + is_causal=True, + scale=self.head_dim**-0.5, + ) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output), None diff --git a/backends/mlx/llm/hf_attention.py b/backends/mlx/llm/hf_attention.py new file mode 100644 index 00000000000..9e3c864dce6 --- /dev/null +++ b/backends/mlx/llm/hf_attention.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MLX-optimized attention for HuggingFace models. + +Registers a custom attention implementation ("mlx") with HuggingFace's +attention interface, following the same pattern as optimum-executorch's +custom_sdpa: + +1. Mask function returns None (custom op handles causal masking internally) +2. Attention function extracts start_pos from position_ids[0][0] +3. mlx::custom_sdpa receives full K/V cache + start_pos, slices K/V internally +4. MLX pattern handler serializes custom_sdpa as SliceNode(K), SliceNode(V), SdpaNode + +Usage: + from executorch.backends.mlx.llm.hf_attention import register_mlx_attention + + register_mlx_attention() + + model = AutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="mlx", + ) +""" + +from typing import Callable, Optional, Tuple, Union + +import executorch.backends.mlx.custom_ops as _mlx_custom_ops # noqa: F401 + +import torch + + +def mlx_sdpa_with_start_pos_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD + key: torch.Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (full cache) + value: torch.Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (full cache) + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa: F821 + position_ids: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + """ + MLX-optimized SDPA following optimum-executorch's custom_sdpa pattern. + + Extracts start_pos from position_ids, then delegates to mlx::custom_sdpa + which handles K/V cache slicing, GQA expansion, and causal masking. + + Returns (output, None) where output is [B, seq_len, num_heads, head_dim] (BSHD). + """ + kwargs.pop("is_causal", None) + is_causal = getattr(module, "is_causal", True) + + if is_causal: + assert ( + position_ids is not None + ), "position_ids must be provided to find start position for causal attention" + start_pos = position_ids[0][0].item() + seq_len = query.shape[2] + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= key.shape[2]) + attn_mask = None + else: + start_pos = 0 + attn_mask = attention_mask + + output = torch.ops.mlx.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scaling, + ) + + # Transpose BHSD → BSHD for HF + return output.transpose(1, 2).contiguous(), None + + +def sdpa_mask_passthrough( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Optional[Callable] = None, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + allow_torch_fix: bool = True, + **kwargs, +) -> Optional[torch.Tensor]: + """Returns None — custom SDPA handles causal masking, avoiding bounded mask tensors.""" + return None + + +def register_mlx_attention(name: str = "mlx") -> None: + """ + Register MLX attention with HuggingFace's attention interfaces. + + After registration, models can use MLX attention via: + model = AutoModelForCausalLM.from_pretrained(..., attn_implementation="mlx") + """ + try: + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + ALL_ATTENTION_FUNCTIONS.register(name, mlx_sdpa_with_start_pos_forward) + ALL_MASK_ATTENTION_FUNCTIONS.register(name, sdpa_mask_passthrough) + + except ImportError: + raise ImportError( + "transformers is not installed. Please install it: pip install transformers" + ) + + +def get_mlx_sliding_window_sdpa(exportable_module) -> Callable: + """ + Create a closure-based SDPA function for sliding window attention. + + Following optimum-executorch's pattern, the returned function captures + the model reference so it can access ring buffer caches at runtime to + create attention masks lazily — avoiding torch.export tracing issues. + + Args: + exportable_module: The model module containing .cache (HFStaticCache + or similar) with ring buffer layers accessible via .kv_cache[layer_idx]. + + Returns: + Attention function compatible with HuggingFace's attention interface. + """ + + def _sliding_window_sdpa_forward( + module: torch.nn.Module, + query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD + key: torch.Tensor, # [B, num_kv_heads, window_size, head_dim] - BHSD + value: torch.Tensor, # [B, num_kv_heads, window_size, head_dim] - BHSD + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa: F821 + position_ids: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + """ + MLX sliding window SDPA using ring buffer KV cache. + + Creates the attention mask lazily by reaching into the ring buffer + cache via the captured model reference. This keeps mask creation + in Python (not in the traced graph). + + Uses is_causal=False since the mask handles both causality and windowing. + """ + from executorch.backends.mlx.llm.cache import RingBufferKVCache + + layer_idx = getattr(module, "layer_idx", None) + seq_len = query.shape[2] + attn_mask = None + start_pos = 0 + + if layer_idx is not None and position_ids is not None: + start_pos = position_ids[0][0].item() + + # Reach into the model's cache to find the ring buffer for this layer. + # TorchExportableModuleWithHybridCache stores .cache (standard path). + cache = getattr(exportable_module, "cache", None) + + if cache is not None: + layer_cache = cache.kv_cache[layer_idx] + if isinstance(layer_cache, RingBufferKVCache): + attn_mask = layer_cache.create_sliding_window_mask( + start_pos, seq_len + ) + # Override start_pos so custom_sdpa slices the full buffer: + # stop_pos = start_pos + seq_len = buffer_size + start_pos = layer_cache.buffer_size - seq_len + + if attn_mask is None: + raise RuntimeError( + f"Sliding window attention at layer {layer_idx} requires a " + f"RingBufferKVCache, but none was found. Ensure the model's " + f"cache is set up with RingBufferKVCache for sliding window layers." + ) + + output = torch.ops.mlx.custom_sdpa( + query, + key, + value, + start_pos=start_pos, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + scale=scaling, + ) + + # Transpose BHSD → BSHD for HF + return output.transpose(1, 2).contiguous(), None + + return _sliding_window_sdpa_forward + + +def register_mlx_sliding_window_attention( + exportable_module, name: str = "mlx_sliding_window" +) -> None: + """Register MLX sliding window attention with HuggingFace's attention interfaces.""" + try: + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + sdpa_fn = get_mlx_sliding_window_sdpa(exportable_module) + ALL_ATTENTION_FUNCTIONS.register(name, sdpa_fn) + ALL_MASK_ATTENTION_FUNCTIONS.register(name, sdpa_mask_passthrough) + + except ImportError: + raise ImportError( + "transformers is not installed. Please install it: pip install transformers" + ) diff --git a/backends/mlx/llm/quantization.py b/backends/mlx/llm/quantization.py new file mode 100644 index 00000000000..196e4a9ac1f --- /dev/null +++ b/backends/mlx/llm/quantization.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Quantization argument helpers for MLX LLM export scripts. + +Re-exports quantize_model_ from the shared ExecuTorch LLM export library +and provides add_quantization_args for MLX export CLI scripts. +""" + +import argparse + +from executorch.extension.llm.export.quantize import quantize_model_ + +__all__ = ["add_quantization_args", "quantize_model_"] + + +def add_quantization_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--qlinear", + type=str, + choices=["4w", "8w", "nvfp4"], + default=None, + help="Quantization config for linear layers", + ) + parser.add_argument( + "--qembedding", + type=str, + choices=["4w", "8w", "nvfp4"], + default=None, + help="Quantization config for embedding layers", + ) + parser.add_argument( + "--qlinear-group-size", + type=int, + choices=[32, 64, 128], + default=None, + help="Group size for linear layer quantization (default: 32)", + ) + parser.add_argument( + "--qembedding-group-size", + type=int, + choices=[32, 64, 128], + default=None, + help="Group size for embedding layer quantization (default: 128)", + ) + parser.add_argument( + "--no-tie-word-embeddings", + action="store_true", + default=False, + help="Disable tying lm_head weights to embedding after quantization, " + "even if the model config has tie_word_embeddings=True", + ) diff --git a/backends/mlx/llm/source_transformation.py b/backends/mlx/llm/source_transformation.py new file mode 100644 index 00000000000..d90073c633e --- /dev/null +++ b/backends/mlx/llm/source_transformation.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Source transformations for MLX backend export. + +Provides transforms that replace standard model components with MLX-optimized +versions +""" + +import logging +from typing import Callable + +import torch +import torch.nn as nn + +from executorch.backends.mlx.llm.cache import HFStaticCache, KVCache, RingBufferKVCache + +logger = logging.getLogger(__name__) + + +def _replace_modules( + module: nn.Module, + target_type: type, + factory: Callable[[nn.Module], nn.Module], + label: str, +) -> nn.Module: + """Recursively replace all instances of target_type using factory.""" + + def _recurse(parent: nn.Module) -> int: + count = 0 + for name, child in list(parent.named_children()): + if isinstance(child, target_type): + setattr(parent, name, factory(child)) + count += 1 + else: + count += _recurse(child) + return count + + count = _recurse(module) + if count > 0: + logger.info(f"Replaced {count} {label}") + return module + + +def replace_et_kv_cache_with_mlx( + module: nn.Module, dtype: torch.dtype = None +) -> nn.Module: + """ + Replace ET's KVCache with MLX-optimized KVCache. + + Recursively finds all KVCache instances (from examples/models/llama/attention.py) + and replaces them with KVCache, which uses mlx::kv_cache_update instead of + unsupported index_put operations. + + Args: + module: Model to modify (in place) + dtype: Optional dtype for cache tensors. If None, uses original cache dtype. + """ + try: + from executorch.examples.models.llama.attention import ( + KVCache as ETKVCache_Original, + ) + except ImportError: + return module + + def _make_mlx_cache(child): + cache_dtype = dtype if dtype is not None else child.k_cache.dtype + return KVCache( + max_batch_size=child.max_batch_size, + max_context_length=child.max_context_length, + n_heads=child.n_heads, + head_dim=child.head_dim, + enable_dynamic_shape=child.enable_dynamic_shape, + dtype=cache_dtype, + ) + + return _replace_modules( + module, + ETKVCache_Original, + _make_mlx_cache, + f"KVCache → KVCache (dtype={dtype})", + ) + + +def replace_hf_cache_with_mlx( + module: nn.Module, + config, + max_batch_size: int = 1, + max_cache_len: int | None = None, + dtype: torch.dtype = torch.float32, +) -> nn.Module: + """ + Replace HuggingFace's StaticCache with MLX-optimized HFStaticCache. + + Should be called on TorchExportableModuleWithStaticCache (from + transformers.integrations.executorch), NOT on CausalLMExportableModule + (from optimum-executorch). + + Args: + module: HF exportable module with static_cache or cache attribute + config: HF model config + max_batch_size: Maximum batch size (default: 1) + max_cache_len: Maximum cache length. If None, uses config.max_position_embeddings + dtype: Cache tensor dtype (default: torch.float32) + + Raises: + ValueError: If module has no recognized cache attribute + """ + from transformers.cache_utils import StaticCache + + mlx_cache = HFStaticCache( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + dtype=dtype, + ) + + def _install_cache(attr_name): + setattr(module, attr_name, mlx_cache) + for i, layer_cache in enumerate(mlx_cache.kv_cache): + setattr(module, f"key_cache_{i}", layer_cache.k_cache) + setattr(module, f"value_cache_{i}", layer_cache.v_cache) + + if hasattr(module, "static_cache"): + assert isinstance( + module.static_cache, StaticCache + ), f"Expected StaticCache, got {type(module.static_cache)}" + _install_cache("static_cache") + elif hasattr(module, "cache"): + if isinstance(module.cache, StaticCache): + _install_cache("cache") + else: + raise ValueError( + f"module.cache is not a StaticCache, got {type(module.cache)}" + ) + else: + raise ValueError("Module must have 'static_cache' or 'cache' attribute") + + return module + + +def replace_hf_cache_with_mlx_ring_buffer( + module: nn.Module, + config, + max_batch_size: int = 1, + window_size: int = 512, + dtype: torch.dtype = torch.float32, +) -> nn.Module: + """ + Replace HuggingFace's StaticCache with RingBufferKVCache for sliding window models. + + Creates a HFStaticCache-like structure where each layer uses a RingBufferKVCache + instead of a linear KVCache. This enables infinite-length generation for models + with sliding window attention (e.g., gemma). + + Args: + module: HF exportable module with static_cache or cache attribute + config: HF model config + max_batch_size: Maximum batch size (default: 1) + window_size: Sliding window size (cache capacity per layer) + dtype: Cache tensor dtype + + Raises: + ValueError: If module has no recognized cache attribute + """ + from transformers.cache_utils import StaticCache + + num_layers = config.num_hidden_layers + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + + # Create HFStaticCache with ring buffer layers + mlx_cache = HFStaticCache( + config=config, + max_batch_size=max_batch_size, + max_cache_len=window_size, + dtype=dtype, + ) + + # Replace each layer's KVCache with RingBufferKVCache + for i in range(num_layers): + ring_cache = RingBufferKVCache( + max_batch_size=max_batch_size, + max_context_length=window_size, + n_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) + mlx_cache.kv_cache[i] = ring_cache + + def _install_cache(attr_name): + setattr(module, attr_name, mlx_cache) + for i, layer_cache in enumerate(mlx_cache.kv_cache): + setattr(module, f"key_cache_{i}", layer_cache.k_cache) + setattr(module, f"value_cache_{i}", layer_cache.v_cache) + + if hasattr(module, "static_cache"): + assert isinstance( + module.static_cache, StaticCache + ), f"Expected StaticCache, got {type(module.static_cache)}" + _install_cache("static_cache") + elif hasattr(module, "cache"): + if isinstance(module.cache, StaticCache): + _install_cache("cache") + else: + raise ValueError( + f"module.cache is not a StaticCache, got {type(module.cache)}" + ) + else: + raise ValueError("Module must have 'static_cache' or 'cache' attribute") + + logger.info( + f"Installed RingBufferKVCache: {num_layers} layers, " + f"window_size={window_size}, heads={num_kv_heads}, head_dim={head_dim}" + ) + + return module + + +class MLXRope(nn.Module): + """ + MLX-optimized Rotary Position Embedding. + + Wraps ET's Rope, currently delegating to the original implementation. + Can be extended to use torch.ops.mlx.rope. + """ + + def __init__(self, original_rope: nn.Module): + super().__init__() + self.params = original_rope.params + self.precompute_freqs_cis = original_rope.precompute_freqs_cis + self.apply_rotary_emb = original_rope.apply_rotary_emb + self.register_buffer("freqs_cos", original_rope.freqs_cos, persistent=False) + self.register_buffer("freqs_sin", original_rope.freqs_sin, persistent=False) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos, seq_len: int): + if self.params.use_kv_cache: + assert input_pos is not None + if self.params.enable_dynamic_shape: + input_pos_item = input_pos[-1].item() + torch._check(input_pos_item >= 0) + torch._check(input_pos_item < self.params.max_context_len) + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + else: + assert input_pos is None + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + return freqs_cos, freqs_sin + + +def transform_attention_mha_to_mlx( + module: nn.Module, dtype: torch.dtype = None +) -> nn.Module: + """ + Replace AttentionMHA with MLXAttentionMHA throughout the model. + + Shares weight references (wq, wk, wv, wo, rope, norm) from the original + and creates a fresh KVCache for each attention layer. + + Args: + module: Model to modify (in place) + dtype: Optional dtype for KV cache. If None, inferred from original. + """ + from executorch.backends.mlx.llm.et_attention import MLXAttentionMHA + from executorch.examples.models.llama.attention import AttentionMHA + + _replace_modules( + module, + AttentionMHA, + lambda child: MLXAttentionMHA.from_attention_mha(child, dtype=dtype), + f"AttentionMHA → MLXAttentionMHA (cache dtype={dtype})", + ) + return module diff --git a/backends/mlx/passes.py b/backends/mlx/passes.py index c7efdf561de..ef4c768a2f8 100644 --- a/backends/mlx/passes.py +++ b/backends/mlx/passes.py @@ -8,13 +8,501 @@ Graph transformation passes for the MLX backend. """ -from typing import List +from dataclasses import dataclass +from typing import List, Optional -from executorch.exir.pass_base import ExportPass +import torch +from executorch.backends.mlx.pattern_utils import ( + extract_lifted_tensor_constant, + match_target, + OpStep, + PatternMatch, + walk_back, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes.cse_pass import CSEPass +from torch.fx import GraphModule, Node def get_default_passes() -> List[ExportPass]: """ Returns a list of passes that are enabled by default for the MLX backend. """ - return [] + return [ + FuseRMSNormPass(), + CanonicalizePermutePass(), + CollapseViewCopyPass(), + CollapsePermutePass(), + CollapseDtypeConversionPass(), + RemoveNoOpsPass(), + CSEPass(), + ] + + +@dataclass +class RMSNormMatch(PatternMatch): + """ + Matched RMSNorm pattern. + + HuggingFace Llama's RMSNorm decomposes into: + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + return weight * hidden_states.to(input_dtype) + + Graph pattern: + _to_copy (to f32) [optional] + pow(x, 2) + mean_dim(pow_out, [-1], keepdim=True) + add(mean_out, eps_tensor) + rsqrt(add_out) + mul(to_copy_out, rsqrt_out) + _to_copy (back to original dtype) [optional] + mul(weight, to_copy_out) + """ + + input_node: Node = None # type: ignore[assignment] + weight_node: Node = None # type: ignore[assignment] + eps: float = 0.0 + + @classmethod + def maybe_create(cls, head: Node, **context) -> Optional["RMSNormMatch"]: + """Match RMSNorm pattern starting from final mul(weight, normalized).""" + # Head must be mul + if not match_target(head, torch.ops.aten.mul.Tensor): + return None + + if len(head.args) < 2: + return None + + # Try both orderings: mul(weight, normalized) or mul(normalized, weight) + for weight_idx, norm_idx in [(0, 1), (1, 0)]: + weight_node = head.args[weight_idx] + norm_node = head.args[norm_idx] + + if not isinstance(norm_node, Node): + continue + + # Match entire chain with single walk_back: + # [_to_copy] -> mul(input, rsqrt) -> rsqrt -> add -> mean -> pow -> [_to_copy] + # The mul follows arg_index=1 to get rsqrt (not input) + result = walk_back( + norm_node, + [ + OpStep( + op=torch.ops.aten._to_copy.default, + optional=True, + kwargs={ + "dtype", + "layout", + "device", + "pin_memory", + "non_blocking", + "memory_format", + }, + ), + OpStep(op=torch.ops.aten.mul.Tensor, nargs=2, arg_index=1), + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.add.Tensor, nargs=2), + OpStep(op=torch.ops.aten.mean.dim, nargs=(2, 3), kwargs={"dtype"}), + OpStep(op=torch.ops.aten.pow.Tensor_Scalar, nargs=2), + OpStep( + op=torch.ops.aten._to_copy.default, + optional=True, + require_single_user=False, # _to_copy output used by both pow and mul + kwargs={ + "dtype", + "layout", + "device", + "pin_memory", + "non_blocking", + "memory_format", + }, + ), + ], + ) + if result is None: + continue + + original_input, entries = result + to_copy_out, mul, rsqrt, add, mean, pow, to_copy_in = entries + + # If input _to_copy matched, verify it has exactly 2 users: pow and mul + if to_copy_in is not None: + users = set(to_copy_in.users.keys()) + expected_users = {pow, mul} + if users != expected_users: + continue + + # Validate pow exponent is 2 + if pow.args[1] != 2: + continue + + # Extract epsilon from add node (it's a lifted tensor constant) + eps_value = None + for arg in add.args: + eps_value = extract_lifted_tensor_constant(arg) + if eps_value is not None: + break + + if eps_value is None: + continue + + # Build body from non-None entries + body = [n for n in entries if n is not None] + + return cls( + head=head, + body=body, + input_node=original_input, + weight_node=weight_node, + eps=eps_value, + ) + + return None + + +class FuseRMSNormPass(ExportPass): + """ + Fuses decomposed RMSNorm operations into aten.rms_norm. + + This reduces ~7 ops to 1 fused op per RMSNorm layer. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + + for node in list(graph.nodes): + match = RMSNormMatch.maybe_create(node) + if match is None: + continue + + # Get input shape for normalized_shape + input_meta = match.input_node.meta.get("val") + if input_meta is None: + continue + + # Create fused rms_norm node + with graph.inserting_before(node): + normalized_shape = [input_meta.shape[-1]] + rms_norm_node = graph.call_function( + torch.ops.aten.rms_norm.default, + args=( + match.input_node, + normalized_shape, + match.weight_node, + match.eps, + ), + ) + rms_norm_node.meta = node.meta.copy() + + node.replace_all_uses_with(rms_norm_node) + match.remove_body_nodes(graph) + graph.erase_node(node) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +class CanonicalizePermutePass(ExportPass): + """ + Converts transpose_copy to permute_copy in the edge dialect graph. + + transpose_copy(x, dim0, dim1) is equivalent to permute_copy(x, perm) + where perm is the identity permutation with dim0 and dim1 swapped. + This lets the backend handle a single permute op instead of both + transpose and permute. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + + for node in list(graph.nodes): + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.transpose_copy.int + ): + continue + + input_node = node.args[0] + input_val = ( + input_node.meta.get("val") if isinstance(input_node, Node) else None + ) + if input_val is None: + continue + + ndim = input_val.dim() + dim0 = node.args[1] + dim1 = node.args[2] + + # Normalize negative dims + if dim0 < 0: + dim0 += ndim + if dim1 < 0: + dim1 += ndim + + # Build permutation: identity with dim0 and dim1 swapped + perm = list(range(ndim)) + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + + node.target = exir_ops.edge.aten.permute_copy.default + node.args = (input_node, perm) + modified = True + + if modified: + graph.lint() + + return PassResult(graph_module, modified) + + +class CollapseViewCopyPass(ExportPass): + """ + Collapses consecutive view_copy nodes into a single view_copy. + + view_copy(view_copy(x, shape1), shape2) → view_copy(x, shape2) + + Only the final shape matters, so intermediate view_copys can be removed. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + view_copy_target = exir_ops.edge.aten.view_copy.default + + for node in list(graph.nodes): + if node.op != "call_function" or node.target != view_copy_target: + continue + + parent = node.args[0] + if ( + isinstance(parent, Node) + and parent.op == "call_function" + and parent.target == view_copy_target + and len(parent.users) == 1 + ): + original_input = parent.args[0] + target_shape = node.args[1] + + # Check if final shape matches original input shape (identity). + # Compare meta shapes (not args) so SymInt dims are handled. + # Use try/except because shapes may contain unbacked SymInts + # (e.g. from .item() calls) that can't be guarded on. + original_val = ( + original_input.meta.get("val") + if isinstance(original_input, Node) + else None + ) + output_val = node.meta.get("val") + is_identity = False + if original_val is not None and output_val is not None: + try: + is_identity = original_val.shape == output_val.shape + except Exception: + is_identity = False + if is_identity: + # Identity — remove both view_copys + node.replace_all_uses_with(original_input) + graph.erase_node(node) + graph.erase_node(parent) + else: + # Collapse: view_copy(view_copy(x, s1), s2) → view_copy(x, s2) + node.args = (original_input, target_shape) + graph.erase_node(parent) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +class CollapsePermutePass(ExportPass): + """ + Collapses consecutive permute_copy nodes into a single permute_copy. + + permute(permute(x, p1), p2) → permute(x, composed) + where composed[i] = p1[p2[i]]. + + If the composed permutation is the identity, the permute is removed entirely. + Must run after CanonicalizePermutePass so all transpose_copy nodes are permute_copy. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + permute_target = exir_ops.edge.aten.permute_copy.default + + for node in list(graph.nodes): + if node.op != "call_function" or node.target != permute_target: + continue + + parent = node.args[0] + if ( + isinstance(parent, Node) + and parent.op == "call_function" + and parent.target == permute_target + and len(parent.users) == 1 + ): + p1 = parent.args[1] + p2 = node.args[1] + composed = [p1[p2[i]] for i in range(len(p2))] + + if composed == list(range(len(composed))): + # Identity permutation — remove both permutes + node.replace_all_uses_with(parent.args[0]) + graph.erase_node(node) + graph.erase_node(parent) + else: + node.args = (parent.args[0], composed) + graph.erase_node(parent) + + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +def _is_pure_dtype_cast(kwargs: dict) -> bool: + """Check that _to_copy kwargs only specify dtype (no device/layout/memory_format).""" + for k, v in kwargs.items(): + if k == "dtype": + continue + if v is not None: + return False + return "dtype" in kwargs + + +class CollapseDtypeConversionPass(ExportPass): + """ + Collapses consecutive _to_copy (dtype conversion) nodes into a single one. + + _to_copy(dtype=bf16)(_to_copy(dtype=f32)(x)) → _to_copy(dtype=bf16)(x) + + Only the final dtype matters. Only collapses when both nodes are pure dtype + conversions (no device/layout/memory_format changes). + """ + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + to_copy_target = exir_ops.edge.aten._to_copy.default + + for node in list(graph.nodes): + if node.op != "call_function" or node.target != to_copy_target: + continue + + parent = node.args[0] + if not ( + isinstance(parent, Node) + and parent.op == "call_function" + and parent.target == to_copy_target + and len(parent.users) == 1 + ): + continue + + # Only collapse pure dtype conversions + node_kw = node.kwargs + parent_kw = parent.kwargs + if not _is_pure_dtype_cast(node_kw) or not _is_pure_dtype_cast(parent_kw): + continue + + # Rewrite: to_copy(to_copy(x, dtype=d1), dtype=d2) → to_copy(x, dtype=d2) + node.args = (parent.args[0],) + graph.erase_node(parent) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + +class RemoveNoOpsPass(ExportPass): + """ + Removes ops that are no-ops in the MLX backend. + + - alias_copy(x): always a no-op + - clone(x): only when memory_format is contiguous or absent + - _to_copy(x, dtype=d): when x already has dtype d + - view_copy(x, shape): when shape matches input shape + - permute_copy(x, [0,1,...,n-1]): identity permutation + - slice_copy(x, ...): when output shape matches input shape (full slice) + """ + + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 + graph = graph_module.graph + modified = False + + for node in list(graph.nodes): + if node.op != "call_function": + continue + + input_node = ( + node.args[0] if node.args and isinstance(node.args[0], Node) else None + ) + if input_node is None: + continue + + remove = False + + if node.target == exir_ops.edge.aten.alias_copy.default: + remove = True + + elif node.target == exir_ops.edge.aten.clone.default: + mem_fmt = node.kwargs.get("memory_format") + if mem_fmt is None or mem_fmt == torch.contiguous_format: + remove = True + + elif node.target == exir_ops.edge.aten._to_copy.default: + if _is_pure_dtype_cast(node.kwargs): + input_val = input_node.meta.get("val") + target_dtype = node.kwargs.get("dtype") + if input_val is not None and input_val.dtype == target_dtype: + remove = True + + elif node.target == exir_ops.edge.aten.view_copy.default: + input_val = input_node.meta.get("val") + output_val = node.meta.get("val") + if input_val is not None and output_val is not None: + try: + if input_val.shape == output_val.shape: + remove = True + except Exception: + pass + + elif node.target == exir_ops.edge.aten.permute_copy.default: + perm = node.args[1] + if list(perm) == list(range(len(perm))): + remove = True + + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: + input_val = input_node.meta.get("val") + output_val = node.meta.get("val") + if input_val is not None and output_val is not None: + try: + if input_val.shape == output_val.shape: + remove = True + except Exception: + pass + + if remove: + node.replace_all_uses_with(input_node) + graph.erase_node(node) + modified = True + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) diff --git a/backends/mlx/test/test_passes.py b/backends/mlx/test/test_passes.py index a9fdb3b996b..97172c1411a 100644 --- a/backends/mlx/test/test_passes.py +++ b/backends/mlx/test/test_passes.py @@ -4,3 +4,662 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +""" +Tests for graph transformation passes in the MLX backend. +""" + +import unittest + +import executorch.exir as exir +import torch +import torch.nn as nn +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.backends.mlx.passes import ( + _is_pure_dtype_cast, + CanonicalizePermutePass, + CollapseDtypeConversionPass, + CollapsePermutePass, + CollapseViewCopyPass, + FuseRMSNormPass, + RemoveNoOpsPass, +) +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.partitioner import PartitionResult +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export + + +class _PreserveOpsPartitioner(MLXPartitioner): + """MLXPartitioner that preserves ops (via ops_to_not_decompose) but skips delegation. + + This gives tests a real edge-dialect graph with MLX-relevant ops like + ``item`` preserved, without delegating nodes to the MLX backend. + """ + + def partition(self, edge_program): + return PartitionResult( + tagged_exported_program=edge_program, + partition_tags={}, + ) + + +def _to_edge_gm(module, example_inputs, dynamic_shapes=None): + """Export module and lower to edge dialect, returning the GraphModule.""" + ep = export(module, example_inputs, dynamic_shapes=dynamic_shapes, strict=False) + edge = exir.to_edge_transform_and_lower( + ep, + partitioner=[_PreserveOpsPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + return edge.exported_program().graph_module + + +def _count_ops(gm, target): + return sum( + 1 for n in gm.graph.nodes if n.op == "call_function" and n.target == target + ) + + +def _find_nodes(gm, target): + return [n for n in gm.graph.nodes if n.op == "call_function" and n.target == target] + + +def _has_op(gm, target): + return _count_ops(gm, target) > 0 + + +class TestIsPureDtypeCast(unittest.TestCase): + + def test_pure_dtype_only(self): + self.assertTrue(_is_pure_dtype_cast({"dtype": torch.float16})) + + def test_dtype_with_none_kwargs(self): + self.assertTrue( + _is_pure_dtype_cast( + { + "dtype": torch.float16, + "device": None, + "layout": None, + } + ) + ) + + def test_dtype_with_non_none_memory_format(self): + self.assertFalse( + _is_pure_dtype_cast( + { + "dtype": torch.float16, + "memory_format": torch.contiguous_format, + } + ) + ) + + def test_dtype_with_non_none_device(self): + self.assertFalse( + _is_pure_dtype_cast( + { + "dtype": torch.float16, + "device": torch.device("cpu"), + } + ) + ) + + def test_no_dtype_key(self): + self.assertFalse(_is_pure_dtype_cast({"device": None})) + + def test_empty_kwargs(self): + self.assertFalse(_is_pure_dtype_cast({})) + + +class TestCanonicalizePermutePass(unittest.TestCase): + + def test_transpose_becomes_permute(self): + class M(nn.Module): + def forward(self, x): + return x.transpose(0, 1) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + transpose_target = exir_ops.edge.aten.transpose_copy.int + + if not _has_op(gm, transpose_target): + self.skipTest("Edge lowering did not produce transpose_copy") + + result = CanonicalizePermutePass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, transpose_target)) + self.assertTrue( + _has_op(result.graph_module, exir_ops.edge.aten.permute_copy.default) + ) + + nodes = _find_nodes( + result.graph_module, exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(nodes[0].args[1], [1, 0]) + + def test_negative_dims_normalized(self): + class M(nn.Module): + def forward(self, x): + return x.transpose(-2, -1) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + result = CanonicalizePermutePass()(gm) + + nodes = _find_nodes( + result.graph_module, exir_ops.edge.aten.permute_copy.default + ) + self.assertEqual(len(nodes), 1) + # transpose(-2, -1) on 3D → [0, 2, 1] + self.assertEqual(nodes[0].args[1], [0, 2, 1]) + + def test_noop_when_no_transpose(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + result = CanonicalizePermutePass()(gm) + self.assertFalse(result.modified) + + +class TestCollapseViewCopyPass(unittest.TestCase): + + def test_consecutive_view_copys_collapsed(self): + """view_copy(view_copy(x, s1), s2) → view_copy(x, s2).""" + + class M(nn.Module): + def forward(self, x): + return x.view(2, 6).view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(12),)) + + target = exir_ops.edge.aten.view_copy.default + before = _count_ops(gm, target) + self.assertGreaterEqual(before, 2) + + result = CollapseViewCopyPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + def test_identity_view_copy_chain_removed(self): + """view_copy(view_copy(x, s1), original_shape) → removes both.""" + + class M(nn.Module): + def forward(self, x): + return x.view(12).view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + + result = CollapseViewCopyPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual( + _count_ops(result.graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + def test_single_view_copy_unchanged(self): + class M(nn.Module): + def forward(self, x): + return x.view(12) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + + result = CollapseViewCopyPass()(gm) + self.assertFalse(result.modified) + + def test_collapse_with_dynamic_batch(self): + """Consecutive view_copys with a dynamic leading dim should collapse.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 3, 4).view(-1, 2, 6) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 12),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + before = _count_ops(gm, target) + self.assertGreaterEqual(before, 2) + + result = CollapseViewCopyPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + def test_identity_chain_with_dynamic_batch(self): + """view_copy(view_copy(x, s1), original_shape) with dynamic dim → both removed.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 3, 4).view(-1, 12) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 12),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + before = _count_ops(gm, target) + self.assertGreaterEqual(before, 2) + + result = CollapseViewCopyPass()(gm) + self.assertTrue(result.modified) + # Meta-shape comparison resolves SymInt identity → both view_copys removed + self.assertEqual(_count_ops(result.graph_module, target), 0) + + +class TestCollapsePermutePass(unittest.TestCase): + + def test_inverse_permutations_removed(self): + """permute(permute(x, p), inverse(p)) → identity → removed.""" + + class M(nn.Module): + def forward(self, x): + return x.permute(2, 0, 1).permute(1, 2, 0) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + self.assertEqual(_count_ops(gm, target), 2) + + result = CollapsePermutePass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 0) + + def test_non_identity_composed(self): + """Non-identity composition yields a single permute.""" + + class M(nn.Module): + def forward(self, x): + return x.permute(1, 0, 2).permute(0, 2, 1) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + self.assertEqual(_count_ops(gm, target), 2) + + result = CollapsePermutePass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + # composed[i] = p1[p2[i]] where p1=[1,0,2], p2=[0,2,1] + # → [1, 2, 0] + nodes = _find_nodes(result.graph_module, target) + self.assertEqual(nodes[0].args[1], [1, 2, 0]) + + def test_single_permute_unchanged(self): + class M(nn.Module): + def forward(self, x): + return x.permute(1, 0, 2) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + result = CollapsePermutePass()(gm) + self.assertFalse(result.modified) + + def test_multi_user_parent_not_collapsed(self): + """Don't collapse when the parent permute has multiple users.""" + + class M(nn.Module): + def forward(self, x): + y = x.permute(1, 0, 2) + a = y.permute(1, 0, 2) + b = y.sum() + return a + b + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + result = CollapsePermutePass()(gm) + # Parent permute has 2 users → should not be collapsed + self.assertFalse(result.modified) + + +class TestCollapseDtypeConversionPass(unittest.TestCase): + + def test_consecutive_casts_collapsed(self): + """_to_copy(f32→bf16→f16) → _to_copy(f32→f16).""" + + class M(nn.Module): + def forward(self, x): + return x.to(torch.bfloat16).to(torch.float16) + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten._to_copy.default + before = _count_ops(gm, target) + + if before < 2: + self.skipTest("Export optimized away double cast") + + result = CollapseDtypeConversionPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 1) + + # Remaining cast should be to float16 + nodes = _find_nodes(result.graph_module, target) + self.assertEqual(nodes[0].kwargs.get("dtype"), torch.float16) + + def test_single_cast_unchanged(self): + class M(nn.Module): + def forward(self, x): + return x.to(torch.float16) + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + result = CollapseDtypeConversionPass()(gm) + self.assertFalse(result.modified) + + +class TestRemoveNoOpsPass(unittest.TestCase): + + def test_remove_clone(self): + class M(nn.Module): + def forward(self, x): + return x.clone() + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten.clone.default + + if not _has_op(gm, target): + self.skipTest("Export did not produce a clone op") + + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_remove_identity_view_copy(self): + """view_copy(x, same_shape) → removed.""" + + class M(nn.Module): + def forward(self, x): + return x.view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + target = exir_ops.edge.aten.view_copy.default + + if not _has_op(gm, target): + self.skipTest("Export optimized away identity view_copy") + + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_remove_identity_permute(self): + """permute_copy(x, [0, 1, ..., n-1]) → removed.""" + + class M(nn.Module): + def forward(self, x): + return x.permute(0, 1, 2) + + gm = _to_edge_gm(M(), (torch.randn(2, 3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + + if not _has_op(gm, target): + self.skipTest("Export optimized away identity permute") + + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_identity_dtype_cast_removed_after_collapse(self): + """Chain: f32→f16→f32 collapses to f32→f32, then RemoveNoOps removes it.""" + + class M(nn.Module): + def forward(self, x): + return x.to(torch.float16).to(torch.float32) + + gm = _to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten._to_copy.default + + if _count_ops(gm, target) < 2: + self.skipTest("Export optimized away double cast") + + CollapseDtypeConversionPass()(gm) + result = RemoveNoOpsPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(_count_ops(result.graph_module, target), 0) + + def test_to_copy_with_memory_format_not_removed(self): + """_is_pure_dtype_cast rejects kwargs with non-None memory_format.""" + # Can't easily produce this through export, so test the guard directly + self.assertFalse( + _is_pure_dtype_cast( + { + "dtype": torch.float32, + "memory_format": torch.contiguous_format, + } + ) + ) + + def test_non_identity_view_copy_kept(self): + """view_copy to a different shape should NOT be removed.""" + + class M(nn.Module): + def forward(self, x): + return x.view(6, 2) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + + result = RemoveNoOpsPass()(gm) + self.assertFalse(result.modified) + + def test_noop_when_nothing_to_remove(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + result = RemoveNoOpsPass()(gm) + self.assertFalse(result.modified) + + def test_identity_view_copy_with_dynamic_batch(self): + """view_copy(x, same_shape) with a dynamic dim → removed via meta-shape comparison.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 4) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 4),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + if not _has_op(gm, target): + self.skipTest("Export optimized away identity view_copy") + + result = RemoveNoOpsPass()(gm) + self.assertTrue(result.modified) + self.assertFalse(_has_op(result.graph_module, target)) + + def test_non_identity_view_copy_with_dynamic_batch(self): + """view_copy(x, different_shape) with dynamic dim should be kept.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + return x.view(-1, 2, 2) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 4),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.view_copy.default + if not _has_op(gm, target): + self.skipTest("Export did not produce view_copy") + + result = RemoveNoOpsPass()(gm) + # Shape changes, so view_copy should be kept + self.assertFalse(result.modified) + + def test_full_slice_with_dynamic_batch(self): + """slice_copy shape comparison with dynamic dim should not crash.""" + from torch.export import Dim + + class M(nn.Module): + def forward(self, x): + a = x[:, :4] + b = x[:, 4:] + return torch.cat([b, a], dim=1) + + batch = Dim("batch", min=1, max=128) + gm = _to_edge_gm( + M(), + (torch.randn(4, 8),), + dynamic_shapes={"x": {0: batch}}, + ) + + target = exir_ops.edge.aten.slice_copy.Tensor + self.assertTrue(_has_op(gm, target), "Expected slice_copy in the graph") + + # Must not crash with symbolic shapes (input_val.shape has SymInt) + RemoveNoOpsPass()(gm) + + +class TestFuseRMSNormPass(unittest.TestCase): + + def test_rms_norm_fused(self): + """Decomposed RMSNorm should be fused into a single aten.rms_norm op.""" + + class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return self.weight * x + + model = RMSNorm(16) + model.eval() + gm = _to_edge_gm(model, (torch.randn(1, 4, 16),)) + + result = FuseRMSNormPass()(gm) + + self.assertTrue( + result.modified, "FuseRMSNormPass should fuse the RMSNorm pattern" + ) + + has_rms_norm = any( + n.op == "call_function" and "rms_norm" in str(n.target) + for n in result.graph_module.graph.nodes + ) + self.assertTrue(has_rms_norm) + + # Intermediate ops (pow, rsqrt, mean) should be removed + has_rsqrt = any( + n.op == "call_function" and "rsqrt" in str(n.target) + for n in result.graph_module.graph.nodes + ) + self.assertFalse(has_rsqrt) + + def test_noop_on_non_rms_norm(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + ep = export(M(), (torch.randn(4, 4),), strict=False) + result = FuseRMSNormPass()(ep.graph_module) + self.assertFalse(result.modified) + + +class TestPassComposition(unittest.TestCase): + + def test_collapse_view_copy(self): + class M(nn.Module): + def forward(self, x): + return x.view(2, 6).view(3, 4) + + gm = _to_edge_gm(M(), (torch.randn(12),)) + target = exir_ops.edge.aten.view_copy.default + + self.assertGreaterEqual(_count_ops(gm, target), 2) + + CollapseViewCopyPass()(gm) + self.assertEqual(_count_ops(gm, target), 1) + + def test_canonicalize_then_collapse_permute_identity(self): + """Double transpose = identity → both removed.""" + + class M(nn.Module): + def forward(self, x): + return x.transpose(0, 1).transpose(0, 1) + + gm = _to_edge_gm(M(), (torch.randn(3, 4),)) + target = exir_ops.edge.aten.permute_copy.default + + CanonicalizePermutePass()(gm) + self.assertEqual(_count_ops(gm, target), 2) + + CollapsePermutePass()(gm) + self.assertEqual(_count_ops(gm, target), 0) + + def test_full_pipeline_does_not_crash(self): + """Running the full default pass list should not crash.""" + from executorch.backends.mlx.passes import get_default_passes + + class M(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(16, 16) + + def forward(self, x): + return self.linear(x).to(torch.float16) + + gm = _to_edge_gm(M(), (torch.randn(1, 16),)) + + for p in get_default_passes(): + p(gm) + + gm.graph.lint() + + def test_correctness_after_all_passes(self): + """Output values should be preserved after running all passes.""" + from executorch.backends.mlx.passes import get_default_passes + + class M(nn.Module): + def forward(self, x): + y = x.reshape(12).reshape(3, 4) + return y.transpose(0, 1) + + module = M() + module.eval() + x = torch.randn(3, 4) + expected = module(x) + + gm = _to_edge_gm(module, (x,)) + + for p in get_default_passes(): + p(gm) + + actual = gm(x) + # Edge graph modules may return a tuple + if isinstance(actual, tuple): + actual = actual[0] + torch.testing.assert_close(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index a61d43f626e..6d5b5cc2566 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -107,8 +107,13 @@ else() endif() # quantized_ops_lib: Register quantized op kernels into the runtime -executorch_target_link_options_shared_lib(quantized_ops_lib) -list(APPEND link_libraries quantized_kernels quantized_ops_lib) +if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_kernels quantized_ops_lib) + get_target_property(_quantized_imported quantized_ops_lib IMPORTED) + if(NOT _quantized_imported) + executorch_target_link_options_shared_lib(quantized_ops_lib) + endif() +endif() if(TARGET custom_ops) executorch_target_link_options_shared_lib(custom_ops) @@ -198,6 +203,12 @@ if(TARGET mpsdelegate) executorch_target_link_options_shared_lib(mpsdelegate) endif() +# MLX backend +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + # Openvino backend if(TARGET openvino_backend) find_package(OpenVINO REQUIRED) @@ -226,6 +237,11 @@ endif() add_executable(llama_main ${_srcs}) +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(llama_main) +endif() + # Only strip symbols for Release and MinSizeRel builds. if(CMAKE_BUILD_TYPE STREQUAL "Release" OR CMAKE_BUILD_TYPE STREQUAL "MinSizeRel" diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7d6371add44..557a6f490e0 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -502,6 +502,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="Specify the device for Openvino (CPU, GPU or NPU).", ) + parser.add_argument( + "--mlx", + action="store_true", + help="Delegate to MLX backend (Apple Silicon). Use with --use_kv_cache=True.", + ) + parser.add_argument( "--expand_rope_table", default=False, @@ -774,6 +780,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: coreml=llm_config.backend.coreml.enabled, coreml_ios=llm_config.backend.coreml.ios, vulkan=llm_config.backend.vulkan.enabled, + mlx=llm_config.backend.mlx.enabled, use_qat=llm_config.quantization.use_qat, use_lora=llm_config.base.use_lora, preq_mode=( @@ -1052,6 +1059,34 @@ def _to_edge_and_lower_llama_arm( return builder.to_executorch(passes=additional_passes) +def _to_edge_and_lower_llama_mlx( + builder_exported, + modelname, + quantizers, + additional_passes, + verbose: bool = False, +) -> LLMEdgeManager: + """ + Lower Llama model to MLX backend using to_edge_transform_and_lower. + """ + logging.info("Lowering model using MLX partitioner") + + from executorch.backends.mlx.partitioner import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + + partitioners = [MLXPartitioner()] + + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners, + transform_passes=get_default_passes(), + ) + + if verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + return builder.to_executorch(passes=additional_passes) + + def _to_edge_and_lower_llama( # noqa: C901 builder_exported, modelname, @@ -1428,6 +1463,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config, verbose=llm_config.debug.verbose, ) + elif llm_config.backend.mlx.enabled: + builder = _to_edge_and_lower_llama_mlx( + builder_exported, + modelname, + quantizers, + additional_passes, + verbose=llm_config.debug.verbose, + ) else: builder = _to_edge_and_lower_llama( builder_exported, @@ -1612,6 +1655,7 @@ def _get_source_transforms( # noqa coreml: bool = False, coreml_ios: int = 15, vulkan: bool = False, + mlx: bool = False, use_qat: bool = False, use_lora: int = 0, preq_mode: Optional[str] = None, @@ -1789,6 +1833,19 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) + elif mlx: + from executorch.backends.mlx.llm.source_transformation import ( + replace_et_kv_cache_with_mlx, + transform_attention_mha_to_mlx, + ) + from executorch.examples.models.llama.source_transformation.rms_norm import ( + replace_rms_norm_with_native_rms_norm, + ) + + transforms.append(transform_attention_mha_to_mlx) + transforms.append(replace_et_kv_cache_with_mlx) + transforms.append(replace_rms_norm_with_native_rms_norm) + if local_global_attention: transforms.append( partial( diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index 9354afe5f86..143c4ebe77d 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -23,6 +23,7 @@ find_package(gflags REQUIRED) # Find executorch libraries list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) + find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) get_target_property(_executorch_imported executorch IMPORTED) if(NOT _executorch_imported) @@ -42,9 +43,14 @@ endif() # CPU-only builds need quantized and custom ops if(NOT EXECUTORCH_BUILD_CUDA) - list(APPEND link_libraries quantized_ops_lib custom_ops) - executorch_target_link_options_shared_lib(quantized_ops_lib) - executorch_target_link_options_shared_lib(custom_ops) + if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + endif() + if(TARGET custom_ops) + list(APPEND link_libraries custom_ops) + executorch_target_link_options_shared_lib(custom_ops) + endif() endif() # XNNPACK @@ -91,6 +97,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() +# Link MLX delegate +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + if(EXECUTORCH_BUILD_VULKAN) list(APPEND link_libraries vulkan_backend) executorch_target_link_options_shared_lib(vulkan_backend) @@ -104,6 +116,11 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") endif() endif() +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(parakeet_runner) +endif() + target_include_directories( parakeet_runner PUBLIC ${_common_include_directories} ) diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json index afcfd99491c..87ace61e315 100644 --- a/examples/models/parakeet/CMakePresets.json +++ b/examples/models/parakeet/CMakePresets.json @@ -56,6 +56,19 @@ "rhs": "Darwin" } }, + { + "name": "parakeet-mlx", + "displayName": "Parakeet runner (MLX)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, { "name": "parakeet-vulkan", "displayName": "Parakeet runner (Vulkan)", @@ -99,6 +112,13 @@ "configuration": "Release", "targets": ["parakeet_runner"] }, + { + "name": "parakeet-mlx", + "displayName": "Build Parakeet runner (MLX)", + "configurePreset": "parakeet-mlx", + "configuration": "Release", + "targets": ["parakeet_runner"] + }, { "name": "parakeet-vulkan", "displayName": "Build Parakeet runner (Vulkan)", @@ -164,6 +184,20 @@ } ] }, + { + "name": "parakeet-mlx", + "displayName": "Configure and build Parakeet runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "parakeet-mlx" + }, + { + "type": "build", + "name": "parakeet-mlx" + } + ] + }, { "name": "parakeet-vulkan", "displayName": "Configure and build Parakeet runner (Vulkan)", diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index d893b8e18b0..512e2796e63 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -25,7 +25,7 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | -| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `cuda`, `cuda-windows` (default: `xnnpack`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `mlx`, `cuda`, `cuda-windows` (default: `xnnpack`) | | `--dtype` | Data type: `fp32`, `bf16`, `fp16` (default: `fp32`). Metal backend supports `fp32` and `bf16` only (no `fp16`). | | `--audio` | Path to audio file for transcription test | @@ -39,24 +39,25 @@ The export script supports quantizing encoder and decoder linear layers using [t | Argument | Description | |----------|-------------| -| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | -| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) | +| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4` | +| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: auto) | | `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` | -| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | -| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) | +| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4` | +| `--qlinear_group_size` | Group size for decoder linear quantization (default: auto) | | `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` | -| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` | -| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) | +| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w`, `nvfp4` | +| `--qembedding_group_size` | Group size for embedding quantization (default: auto) | #### Quantization Configs | Config | Description | Backends | |--------|-------------|----------| -| `4w` | 4-bit weight only quantization | CUDA | -| `8w` | 8-bit weight only quantization | CUDA | -| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA | -| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA | +| `4w` | 4-bit weight only quantization | CUDA, MLX, XNNPACK (embedding only) | +| `8w` | 8-bit weight only quantization | CUDA, MLX, XNNPACK (embedding only) | +| `8da4w` | 8-bit dynamic activation, 4-bit weight | XNNPACK | +| `8da8w` | 8-bit dynamic activation, 8-bit weight | XNNPACK | | `fpa4w` | Floating point activation, 4-bit weight | Metal | +| `nvfp4` | 4-bit weight only quantization using NVIDIA's FP4 dtype | MLX | #### Example: Dynamic Quantization for XNNPACK @@ -171,6 +172,36 @@ python export_parakeet_tdt.py --backend cuda-windows --output-dir ./parakeet_cud This generates: - `model.pte` - The compiled Parakeet TDT model - `aoti_cuda_blob.ptd` - CUDA kernel blob required at runtime + +### MLX Export + +Export with MLX backend (bf16, int4 quantized, group size 128): +```bash +python export_parakeet_tdt.py \ + --backend mlx \ + --dtype bf16 \ + --qlinear_encoder 4w \ + --qlinear_encoder_group_size 128 \ + --qlinear 4w \ + --qlinear_group_size 128 \ + --output-dir ./parakeet_mlx_4w +``` + +Export with MLX backend (bf16, NVFP4 quantized): +```bash +python export_parakeet_tdt.py \ + --backend mlx \ + --dtype bf16 \ + --qlinear_encoder nvfp4 \ + --qlinear nvfp4 \ + --qembedding 4w \ + --output-dir ./parakeet_mlx_nvfp4 +``` + +> **Note:** Although MLX supports NVFP4 embedding quantization, Parakeet's embedding layer has dimensions not divisible by 16, which is incompatible with NVFP4. Use `4w` for embeddings instead. + +This generates: +- `model.pte` - The compiled model with MLX delegate (~470 MB) - `tokenizer.model` - SentencePiece tokenizer ## C++ Runner @@ -188,6 +219,9 @@ make parakeet-metal # CUDA build (Linux) make parakeet-cuda + +# MLX build (macOS) +make parakeet-mlx ``` On Windows (PowerShell), use CMake workflow presets directly: @@ -222,6 +256,12 @@ DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner --data_path examples/models/parakeet/parakeet_cuda/aoti_cuda_blob.ptd \ --audio_path /path/to/audio.wav \ --tokenizer_path examples/models/parakeet/parakeet_cuda/tokenizer.model + +# MLX +./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path examples/models/parakeet/parakeet_mlx_4w/model.pte \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/parakeet_mlx_4w/tokenizer.model ``` Windows (PowerShell): diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 8dd9accd866..c35e17eed59 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -303,15 +303,15 @@ def export_all( backend: Optional[str] = None, # Encoder quantization args qlinear_encoder: Optional[str] = None, - qlinear_encoder_group_size: int = 32, + qlinear_encoder_group_size: Optional[int] = None, qlinear_encoder_packing_format: Optional[str] = None, # Decoder quantization args qlinear: Optional[str] = None, - qlinear_group_size: int = 32, + qlinear_group_size: Optional[int] = None, qlinear_packing_format: Optional[str] = None, # Embedding quantization args (decoder has the embedding layer) qembedding: Optional[str] = None, - qembedding_group_size: int = 0, + qembedding_group_size: Optional[int] = None, ): """Export all model components. @@ -388,7 +388,6 @@ def export_all( qlinear_group_size=qlinear_encoder_group_size, qlinear_packing_format=qlinear_encoder_packing_format, ) - programs["encoder"] = export( encoder_with_proj, (), @@ -557,6 +556,19 @@ def _create_cuda_partitioners(programs, is_windows=False): return partitioner, updated_programs +def _create_mlx_partitioners(programs): + """Create MLX partitioners for all programs.""" + from executorch.backends.mlx.partitioner import MLXPartitioner + + print("\nLowering to ExecuTorch with MLX...") + + partitioner = {} + for key in programs.keys(): + partitioner[key] = [MLXPartitioner()] + + return partitioner, programs + + def _create_vulkan_partitioners(programs, vulkan_force_fp16=False): """Create Vulkan partitioners for all programs except preprocessor.""" from executorch.backends.vulkan.partitioner.vulkan_partitioner import ( @@ -580,6 +592,8 @@ def lower_to_executorch( partitioner, programs = _create_xnnpack_partitioners(programs) elif backend == "metal": partitioner, programs = _create_metal_partitioners(programs) + elif backend == "mlx": + partitioner, programs = _create_mlx_partitioners(programs) elif backend in ("cuda", "cuda-windows"): partitioner, programs = _create_cuda_partitioners( programs, is_windows=(backend == "cuda-windows") @@ -626,7 +640,15 @@ def main(): "--backend", type=str, default="xnnpack", - choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows", "vulkan"], + choices=[ + "portable", + "xnnpack", + "metal", + "mlx", + "cuda", + "cuda-windows", + "vulkan", + ], help="Backend for acceleration (default: xnnpack)", ) parser.add_argument( @@ -641,14 +663,14 @@ def main(): parser.add_argument( "--qlinear", type=str, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantization config for decoder linear layers", ) parser.add_argument( "--qlinear_group_size", type=int, - default=32, - help="Group size for decoder linear quantization (default: 32)", + default=None, + help="Group size for decoder linear quantization", ) parser.add_argument( "--qlinear_packing_format", @@ -661,14 +683,14 @@ def main(): parser.add_argument( "--qlinear_encoder", type=str, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantization config for encoder linear layers", ) parser.add_argument( "--qlinear_encoder_group_size", type=int, - default=32, - help="Group size for encoder linear quantization (default: 32)", + default=None, + help="Group size for encoder linear quantization", ) parser.add_argument( "--qlinear_encoder_packing_format", @@ -681,14 +703,14 @@ def main(): parser.add_argument( "--qembedding", type=str, - choices=["4w", "8w"], + choices=["4w", "8w", "nvfp4"], help="Quantization config for decoder embedding layer", ) parser.add_argument( "--qembedding_group_size", type=int, - default=0, - help="Group size for embedding quantization (default: 0 = per-axis)", + default=None, + help="Group size for embedding quantization", ) parser.add_argument("--vulkan_force_fp16", action="store_true") diff --git a/examples/models/voxtral/CMakeLists.txt b/examples/models/voxtral/CMakeLists.txt index 036e6454efe..2e8ebb5c5e9 100644 --- a/examples/models/voxtral/CMakeLists.txt +++ b/examples/models/voxtral/CMakeLists.txt @@ -45,9 +45,14 @@ executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) # CPU-only builds need quantized and custom ops if(NOT EXECUTORCH_BUILD_CUDA) - list(APPEND link_libraries quantized_ops_lib custom_ops) - executorch_target_link_options_shared_lib(quantized_ops_lib) - executorch_target_link_options_shared_lib(custom_ops) + if(TARGET quantized_ops_lib) + list(APPEND link_libraries quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + endif() + if(TARGET custom_ops) + list(APPEND link_libraries custom_ops) + executorch_target_link_options_shared_lib(custom_ops) + endif() endif() # XNNPACK @@ -99,6 +104,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() +# Link MLX delegate +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + # Add tokenizers list(APPEND link_libraries tokenizers::tokenizers) @@ -120,6 +131,11 @@ if(WIN32) target_link_options(voxtral_runner PRIVATE "/STACK:8388608") endif() +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(voxtral_runner) +endif() + # On Windows, copy required DLLs to the executable directory if(MSVC AND EXECUTORCH_BUILD_CUDA) add_custom_command( diff --git a/examples/models/voxtral/CMakePresets.json b/examples/models/voxtral/CMakePresets.json index d9e0ba6af19..e853604c1a1 100644 --- a/examples/models/voxtral/CMakePresets.json +++ b/examples/models/voxtral/CMakePresets.json @@ -41,6 +41,19 @@ "type": "equals", "rhs": "Darwin" } + }, + { + "name": "voxtral-mlx", + "displayName": "Voxtral runner (MLX)", + "inherits": ["voxtral-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -61,6 +74,12 @@ "displayName": "Build Voxtral runner (Metal)", "configurePreset": "voxtral-metal", "targets": ["voxtral_runner"] + }, + { + "name": "voxtral-mlx", + "displayName": "Build Voxtral runner (MLX)", + "configurePreset": "voxtral-mlx", + "targets": ["voxtral_runner"] } ], "workflowPresets": [ @@ -105,6 +124,20 @@ "name": "voxtral-metal" } ] + }, + { + "name": "voxtral-mlx", + "displayName": "Configure and build Voxtral runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "voxtral-mlx" + }, + { + "type": "build", + "name": "voxtral-mlx" + } + ] } ] } diff --git a/examples/models/voxtral/README.md b/examples/models/voxtral/README.md index 72d21425648..1ffae876a99 100644 --- a/examples/models/voxtral/README.md +++ b/examples/models/voxtral/README.md @@ -140,6 +140,51 @@ This will generate: See the "Building the multimodal runner" section below for instructions on building with Metal support, and the "Running the model" section for runtime instructions. +## MLX Support (macOS) +On Apple Silicon, you can export and run Voxtral using the [MLX backend](../../../backends/mlx), which provides accelerated inference via Apple's MLX framework. + +### Exporting with MLX +The MLX export script produces two `.pte` files — the model and the audio preprocessor — both delegated to MLX: +``` +python -m executorch.backends.mlx.examples.voxtral.export_voxtral_hf \ + --output-dir mlx_voxtral_int4_bf16 \ + --dtype bf16 \ + --quantize-linear int4 +``` + +This will generate: +- `model.pte` - The exported model with MLX delegate (audio_encoder, token_embedding, text_decoder) +- `preprocessor.pte` - The mel spectrogram audio preprocessor with MLX delegate + +#### Export arguments + +| Argument | Description | +|----------|-------------| +| `--model-id` | HuggingFace model ID (default: `mistralai/Voxtral-Mini-3B-2507`) | +| `--output-dir` | Output directory for `.pte` files (default: `voxtral_mlx`) | +| `--dtype` | Model dtype: `fp32`, `fp16`, `bf16` (default: `bf16`) | +| `--max-seq-len` | Maximum sequence length for KV cache (default: `1024`) | +| `--quantize-linear` | Quantization for linear layers: `int4`, `int8` (default: none) | +| `--quantize-linear-group-size` | Group size for linear quantization (default: `32`) | +| `--max-audio-len` | Maximum audio length in seconds for preprocessor (default: `300`) | + +### Building for MLX +From the ExecuTorch root directory: +``` +make voxtral-mlx +``` + +### Running with MLX +``` +./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path mlx_voxtral_int4_bf16/model.pte \ + --tokenizer_path path/to/tekken.json \ + --prompt "What is happening in this audio?" \ + --audio_path path/to/audio.wav \ + --processor_path mlx_voxtral_int4_bf16/preprocessor.pte \ + --temperature 0 +``` + # Running the model To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's MultiModal runner API. The Voxtral runner will do the following things: diff --git a/examples/models/voxtral_realtime/CMakeLists.txt b/examples/models/voxtral_realtime/CMakeLists.txt index 5d047df51c0..28545f407ca 100644 --- a/examples/models/voxtral_realtime/CMakeLists.txt +++ b/examples/models/voxtral_realtime/CMakeLists.txt @@ -33,7 +33,7 @@ list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) # CPU-only builds need quantized and custom ops -if(NOT EXECUTORCH_BUILD_CUDA) +if(NOT EXECUTORCH_BUILD_CUDA AND NOT EXECUTORCH_BUILD_MLX) list(APPEND link_libraries quantized_ops_lib custom_ops) executorch_target_link_options_shared_lib(quantized_ops_lib) executorch_target_link_options_shared_lib(custom_ops) @@ -87,6 +87,12 @@ if(EXECUTORCH_BUILD_METAL) executorch_target_link_options_shared_lib(metal_backend) endif() +# Link MLX delegate +if(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) +endif() + # Tokenizer list(APPEND link_libraries tokenizers::tokenizers) @@ -106,6 +112,11 @@ target_compile_options( voxtral_realtime_runner PUBLIC ${_common_compile_options} ) +# Copy MLX metallib for runtime if MLX delegate is enabled +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(voxtral_realtime_runner) +endif() + # On Windows, copy required DLLs to the executable directory if(MSVC AND EXECUTORCH_BUILD_CUDA) add_custom_command( diff --git a/examples/models/voxtral_realtime/CMakePresets.json b/examples/models/voxtral_realtime/CMakePresets.json index 63411010994..12ba6f0e0b9 100644 --- a/examples/models/voxtral_realtime/CMakePresets.json +++ b/examples/models/voxtral_realtime/CMakePresets.json @@ -50,6 +50,19 @@ "Windows" ] } + }, + { + "name": "voxtral-realtime-mlx", + "displayName": "Voxtral Realtime runner (MLX)", + "inherits": ["voxtral-realtime-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_MLX": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -79,6 +92,15 @@ "targets": [ "voxtral_realtime_runner" ] + }, + { + "name": "voxtral-realtime-mlx", + "displayName": "Build Voxtral Realtime runner (MLX)", + "configurePreset": "voxtral-realtime-mlx", + "configuration": "Release", + "targets": [ + "voxtral_realtime_runner" + ] } ], "workflowPresets": [ @@ -123,6 +145,20 @@ "name": "voxtral-realtime-cuda" } ] + }, + { + "name": "voxtral-realtime-mlx", + "displayName": "Configure and build Voxtral Realtime runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "voxtral-realtime-mlx" + }, + { + "type": "build", + "name": "voxtral-realtime-mlx" + } + ] } ] } \ No newline at end of file diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index 7dca9c2c978..c6a09ea4c83 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -57,6 +57,36 @@ python -m executorch.extension.audio.mel_spectrogram \ --output_file ./voxtral_rt_exports/preprocessor.pte ``` +For MLX backend, use `--backend mlx`: + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --max_audio_len 300 \ + --backend mlx \ + --output_file ./voxtral_rt_exports/preprocessor.pte +``` + +For streaming, use a separate preprocessor with `--streaming` (no audio +length limit): + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --streaming \ + --output_file ./voxtral_streaming_exports/preprocessor.pte +``` + +For streaming with MLX backend: + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --streaming \ + --backend mlx \ + --output_file ./voxtral_streaming_exports/preprocessor.pte +``` + ## Export Export produces a single `.pte` containing the audio encoder, text decoder, @@ -79,8 +109,66 @@ python export_voxtral_rt.py \ --qembedding 8w ``` -
-Metal +### Backend support + +| Backend | Offline | Streaming | Quantization | +|---------|---------|-----------|--------------| +| `xnnpack` | ✓ | ✓ | `4w`, `8w`, `8da4w`, `8da8w` | +| `metal` | ✓ | ✓ | none (fp32) or `fpa4w` (Metal-specific 4-bit) | +| `mlx` | ✓ | ✓ | `4w`, `8w`, `nvfp4` (NVIDIA FP4 dtype) | +| `cuda` | ✓ | ✓ | `4w`, `8w` | +| `cuda-windows` | ✓ | ✓ | `4w`, `8w` | + + +MLX and Metal backends provide Apple GPU acceleration. CUDA backend provides NVIDIA GPU acceleration via AOTInductor. + +#### CUDA export examples + +Offline with int4 quantization: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend cuda \ + --dtype bf16 \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear-encoder-packing-format tile_packed_to_4d \ + --qlinear 4w \ + --qlinear-packing-format tile_packed_to_4d \ + --qembedding 8w +``` + +Streaming with int4 quantization: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend cuda \ + --dtype bf16 \ + --streaming \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear-encoder-packing-format tile_packed_to_4d \ + --qlinear 4w \ + --qlinear-packing-format tile_packed_to_4d \ + --qembedding 8w +``` + +#### Metal export examples + +Offline: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend metal \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder fpa4w \ + --qlinear fpa4w +``` + +Streaming: ```bash python export_voxtral_rt.py \ @@ -104,30 +192,65 @@ USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh ``` -
+#### MLX export examples -
-CUDA +MLX backend uses the MLX delegate for Apple Silicon GPU acceleration. +NVFP4 quantizes weights using NVIDIA's FP4 data type. + +Offline (NVFP4): ```bash python export_voxtral_rt.py \ --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ - --backend cuda \ - --dtype bf16 \ + --backend mlx \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder nvfp4 \ + --qlinear nvfp4 \ + --qembedding nvfp4 +``` + +Streaming (NVFP4): + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ --streaming \ - --sliding-window 2048 \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder nvfp4 \ + --qlinear nvfp4 \ + --qembedding nvfp4 +``` + +Offline (int4 linear + int8 embedding): + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ --output-dir ./voxtral_rt_exports \ --qlinear-encoder 4w \ - --qlinear-encoder-packing-format tile_packed_to_4d \ --qlinear 4w \ - --qlinear-packing-format tile_packed_to_4d \ - --qembedding 8w + --qembedding 8w \ + --qembedding-group-size 128 ``` -
+Streaming (int4 linear + int8 embedding): + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend mlx \ + --streaming \ + --sliding-window 2048 \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear 4w \ + --qembedding 8w \ + --qembedding-group-size 128 +``` -
-CUDA-Windows +#### CUDA-Windows export examples Requires `x86_64-w64-mingw32-g++` on `PATH` (mingw-w64 cross-compiler) and `WINDOWS_CUDA_HOME` pointing to the extracted Windows CUDA package directory. @@ -150,8 +273,6 @@ python export_voxtral_rt.py \ --qembedding 8w ``` -
- > [!NOTE] > Omit `--streaming` from any export command above for offline mode. > CUDA and CUDA-Windows exports also produce an `aoti_cuda_blob.ptd` file alongside `model.pte`. @@ -161,18 +282,20 @@ python export_voxtral_rt.py \ | Flag | Default | Description | |------|---------|-------------| | `--model-path` | (required) | Directory with `params.json` + `consolidated.safetensors` | -| `--backend` | `xnnpack` | `xnnpack`, `metal`, `cuda`, `cuda-windows`, or `portable` | +| `--backend` | `xnnpack` | `xnnpack`, `mlx`, `metal`, `cuda`, `cuda-windows`, or `portable` | | `--dtype` | `fp32` | Model dtype: `fp32` or `bf16` | | `--output-dir` | `./voxtral_rt_exports` | Output directory | | `--max-seq-len` | `4096` | KV cache length (offline mode only; ignored with `--streaming`) | | `--delay-tokens` | `6` | Transcription delay in tokens (6 = 480ms) | -| `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) | -| `--qlinear-group-size` | `32` | Group size for decoder linear quantization | + +| `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4`) | +| `--qlinear-group-size` | auto | Group size for decoder linear quantization | | `--qlinear-packing-format` | (none) | Packing format for decoder 4w quantization (`tile_packed_to_4d` for CUDA) | -| `--qlinear-encoder` | (none) | Encoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) | -| `--qlinear-encoder-group-size` | `32` | Group size for encoder linear quantization | +| `--qlinear-encoder` | (none) | Encoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`, `nvfp4`) | +| `--qlinear-encoder-group-size` | auto | Group size for encoder linear quantization | | `--qlinear-encoder-packing-format` | (none) | Packing format for encoder 4w quantization (`tile_packed_to_4d` for CUDA) | -| `--qembedding` | (none) | Embedding layer quantization (`8w`) | +| `--qembedding` | (none) | Embedding layer quantization (`4w`, `8w`, `nvfp4`) | +| `--qembedding-group-size` | auto | Group size for embedding quantization | | `--streaming` | off | Export streaming model with ring buffer KV caches (unlimited duration) | | `--max-enc-len` | `750` | Encoder sliding window size (streaming only) | | `--sliding-window` | from `params.json` | Decoder sliding window size (streaming only; ignored in offline mode). Smaller values reduce memory and improve decode speed but limit context | @@ -211,6 +334,27 @@ cmake --workflow --preset voxtral-realtime-cuda Pop-Location ``` +This builds ExecuTorch with CUDA backend support. The runner binary is at +the same path as above. Requires NVIDIA GPU with CUDA toolkit installed. + +### Metal (Apple GPU) + +```bash +make voxtral_realtime-metal +``` + +This builds ExecuTorch with Metal backend support. The runner binary is at +the same path as above. Metal exports can only run on macOS with Apple Silicon. + +### MLX (Apple GPU) + +```bash +make voxtral_realtime-mlx +``` + +This builds ExecuTorch with MLX backend support. MLX provides GPU acceleration +on Apple Silicon via the MLX delegate. + ## Run The runner requires: diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index 2d47656ab2c..824c4485662 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -39,9 +39,7 @@ import torch import torch.nn as nn - from executorch.examples.models.voxtral_realtime.model import load_model - from executorch.exir import ( EdgeCompileConfig, ExecutorchBackendConfig, @@ -112,6 +110,7 @@ def _export_decoder_and_embedding( qlinear_group_size, qlinear_packing_format, qembedding, + qembedding_group_size, device="cpu", ): """Export text_decoder and token_embedding into programs dict.""" @@ -157,6 +156,7 @@ def _export_decoder_and_embedding( quantize_model_( tok_emb, qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, ) tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len) @@ -174,12 +174,13 @@ def export_all( model, max_seq_len, qlinear_encoder=None, - qlinear_encoder_group_size=32, + qlinear_encoder_group_size=None, qlinear_encoder_packing_format=None, qlinear=None, - qlinear_group_size=32, + qlinear_group_size=None, qlinear_packing_format=None, qembedding=None, + qembedding_group_size=None, backend="xnnpack", ): """Export all three model components with per-component quantization.""" @@ -236,6 +237,7 @@ def export_all( qlinear_group_size, qlinear_packing_format, qembedding, + qembedding_group_size, device, ) @@ -259,12 +261,13 @@ def export_streaming( max_seq_len, max_enc_len=750, qlinear_encoder=None, - qlinear_encoder_group_size=32, + qlinear_encoder_group_size=None, qlinear_encoder_packing_format=None, qlinear=None, - qlinear_group_size=32, + qlinear_group_size=None, qlinear_packing_format=None, qembedding=None, + qembedding_group_size=None, backend="xnnpack", ): """Export streaming model components with per-component quantization.""" @@ -317,6 +320,7 @@ def export_streaming( qlinear_group_size, qlinear_packing_format, qembedding, + qembedding_group_size, device, ) @@ -377,6 +381,8 @@ def _linear_bias_decomposition(input, weight, bias=None): def lower_to_executorch(programs, metadata, backend="xnnpack"): """Lower exported programs to ExecuTorch.""" + transform_passes = None + if backend == "xnnpack": from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, @@ -430,12 +436,20 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"): if backend == "cuda-windows": compile_specs.append(CompileSpec("platform", b"windows")) partitioner[key] = [CudaPartitioner(compile_specs)] + elif backend == "mlx": + from executorch.backends.mlx.partitioner import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + + print("\nLowering to ExecuTorch with MLX...") + partitioner = {key: [MLXPartitioner()] for key in programs} + transform_passes = get_default_passes() else: print("\nLowering to ExecuTorch (portable)...") partitioner = [] et_prog = to_edge_transform_and_lower( programs, + transform_passes=transform_passes, partitioner=partitioner, compile_config=EdgeCompileConfig( _check_ir_validity=False, @@ -470,7 +484,7 @@ def main(): parser.add_argument( "--backend", default="xnnpack", - choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows"], + choices=["portable", "xnnpack", "mlx", "metal", "cuda", "cuda-windows"], help="Backend for acceleration (default: xnnpack)", ) parser.add_argument( @@ -493,13 +507,13 @@ def main(): parser.add_argument( "--qlinear", default=None, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantize decoder linear layers.", ) parser.add_argument( "--qlinear-group-size", type=int, - default=32, + default=None, help="Group size for decoder linear quantization (default: 32).", ) parser.add_argument( @@ -511,13 +525,13 @@ def main(): parser.add_argument( "--qlinear-encoder", default=None, - choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"], help="Quantize encoder linear layers (separate from decoder).", ) parser.add_argument( "--qlinear-encoder-group-size", type=int, - default=32, + default=None, help="Group size for encoder linear quantization (default: 32).", ) parser.add_argument( @@ -529,8 +543,14 @@ def main(): parser.add_argument( "--qembedding", default=None, - choices=["8w"], - help="Quantize embedding layers (8-bit weight-only).", + choices=["4w", "8w", "nvfp4"], + help="Quantize embedding layers.", + ) + parser.add_argument( + "--qembedding-group-size", + type=int, + default=None, + help="Group size for embedding quantization (default: 0 = per-channel).", ) parser.add_argument( "--streaming", @@ -605,6 +625,7 @@ def main(): "qlinear_group_size": args.qlinear_group_size, "qlinear_packing_format": args.qlinear_packing_format, "qembedding": args.qembedding, + "qembedding_group_size": args.qembedding_group_size, "backend": backend_for_export, } if args.streaming: diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index fdab67409c6..e591445cc56 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from executorch.extension.llm.custom_ops import custom_ops as _custom_ops # noqa: F401 @@ -52,7 +51,7 @@ class VoxtralRealtimeConfig: max_seq_len: int = 4096 sliding_window: int = 8192 streaming: bool = False - backend: str = "xnnpack" # "xnnpack", "metal", "cuda", or "portable" + backend: str = "xnnpack" # "xnnpack", "mlx", "metal", "cuda", or "portable" @staticmethod def from_params_json(path: str) -> "VoxtralRealtimeConfig": @@ -156,12 +155,14 @@ class EncoderAttention(nn.Module): """Multi-head attention with RoPE for the causal whisper encoder. Biases: wq yes, wk no, wv yes, wo yes. + Supports MLX backend for Apple Silicon GPU acceleration. """ - def __init__(self, dim: int, n_heads: int, head_dim: int): + def __init__(self, dim: int, n_heads: int, head_dim: int, backend: str = "xnnpack"): super().__init__() self.n_heads = n_heads self.head_dim = head_dim + self.backend = backend attn_dim = n_heads * head_dim self.wq = nn.Linear(dim, attn_dim, bias=True) self.wk = nn.Linear(dim, attn_dim, bias=False) @@ -180,7 +181,15 @@ def forward( v = self.wv(x).view(B, T, self.n_heads, self.head_dim) q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) q, k, v = (t.transpose(1, 2) for t in (q, k, v)) - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + if self.backend == "mlx": + # Use MLX custom SDPA for Apple Silicon GPU + start_pos = 0 # Offline encoder always starts at 0 + scale = self.head_dim**-0.5 + y = torch.ops.mlx.custom_sdpa( + q, k, v, start_pos=start_pos, is_causal=True, scale=scale + ) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.wo(y.transpose(1, 2).contiguous().view(B, T, -1)) @@ -202,7 +211,10 @@ def __init__(self, config: VoxtralRealtimeConfig): super().__init__() self.attention_norm = RMSNorm(config.enc_dim, config.enc_norm_eps) self.attention = EncoderAttention( - config.enc_dim, config.enc_n_heads, config.enc_head_dim + config.enc_dim, + config.enc_n_heads, + config.enc_head_dim, + backend=config.backend, ) self.ffn_norm = RMSNorm(config.enc_dim, config.enc_norm_eps) self.feed_forward = EncoderSwiGLU(config.enc_dim, config.enc_hidden_dim) @@ -526,10 +538,151 @@ def forward( return y.view(bsz, seqlen, self.dim) +class MLXStaticKVCache(nn.Module): + """Wrapper that adapts MLX static KV cache for model's BSHD convention. + + For offline (non-streaming) mode. The model's QKV projections produce + [B, S, H, D] tensors, but MLX's KVCache expects [B, H, S, D]. + This wrapper transposes on the way in. + """ + + def __init__( + self, max_seq_len: int, n_kv_heads: int, head_dim: int, dtype: torch.dtype + ): + super().__init__() + from executorch.backends.mlx.llm.cache import KVCache as MLXKVCacheImpl + + self.cache = MLXKVCacheImpl( + max_batch_size=1, + max_context_length=max_seq_len, + n_heads=n_kv_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # Transpose BSHD -> BHSD for MLX cache + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) + return self.cache.update(input_pos, k_val, v_val) + + +class MLXRingKVCache(nn.Module): + """Wrapper that adapts MLX RingBufferKVCache for model's BSHD convention. + + For streaming mode (both encoder and decoder). The model's QKV projections + produce [B, S, H, D] tensors, but MLX's RingBufferKVCache expects + [B, H, S, D]. This wrapper transposes on the way in and delegates + ring buffer semantics to the MLX implementation. + """ + + def __init__( + self, + window_size: int, + n_heads: int, + head_dim: int, + dtype: torch.dtype, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import RingBufferKVCache + + self.ring_cache = RingBufferKVCache( + max_batch_size=1, + max_context_length=window_size, + n_heads=n_heads, + head_dim=head_dim, + dtype=dtype, + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # Transpose BSHD -> BHSD for MLX ring buffer + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) + return self.ring_cache.update(input_pos, k_val, v_val) + + def create_causal_mask( + self, start_pos, seq_len, bool_mask=False, **kwargs + ) -> torch.Tensor: + return self.ring_cache.create_sliding_window_mask(start_pos, seq_len) + + +class MLXSDPA(nn.Module): + """SDPA using MLX custom op for Apple Silicon GPU acceleration. + + Uses torch.ops.mlx.custom_sdpa which handles GQA expansion and causal + masking internally. KV cache is in BHSD layout, queries are in BSHD. + """ + + def __init__(self, n_heads: int, head_dim: int): + super().__init__() + self.dim = n_heads * head_dim + self.scale = head_dim**-0.5 + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + start_pos = input_pos[0].item() + q = q.transpose(1, 2) # BSHD -> BHSD + y = torch.ops.mlx.custom_sdpa( + q, k, v, start_pos=start_pos, is_causal=True, scale=self.scale + ) + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + +class MLXMaskedSDPA(nn.Module): + """SDPA with explicit mask for MLX ring buffer KV cache. + + Used with MLXRingKVCache for streaming mode (both encoder and decoder). + Uses F.scaled_dot_product_attention with explicit attn_mask from the + ring buffer. KV cache is in BHSD layout, queries are in BSHD. + """ + + def __init__(self, n_heads: int, head_dim: int): + super().__init__() + self.dim = n_heads * head_dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + input_pos: (seq_len,) position indices (unused, kept for interface). + q: (B, seq_len, n_heads, head_dim) in BSHD layout. + k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXRingKVCache. + bsz, seqlen: batch size and query length. + mask: (1, 1, seq_len, buf_size) additive attention mask from ring buffer. + """ + q = q.transpose(1, 2) # BSHD -> BHSD + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False) + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + class LMAttention(nn.Module): """GQA with RoPE, KV cache, and SDPA. No biases. - Supports both custom ops (for XNNPACK) and standard PyTorch ops (for Metal/AOTI). + Supports custom ops (for XNNPACK), standard PyTorch ops (for Metal/AOTI), + and MLX backend ops (for Apple Silicon GPU acceleration via MLX delegate). """ def __init__(self, config: VoxtralRealtimeConfig): @@ -539,17 +692,26 @@ def __init__(self, config: VoxtralRealtimeConfig): self.head_dim = config.head_dim self.dim = config.dim self.backend = config.backend + self.rope_theta = config.rope_theta self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False) + # Choose KV cache and SDPA based on backend and streaming mode if config.streaming: # Ring buffer KV cache for unlimited streaming. - if self.backend == "metal": - # StandardRingKVCache uses [B, H, S, D] — same layout as - # StaticKVCache, so no transpose needed in the SDPA. + if self.backend == "mlx": + cache_dtype = self.wq.weight.dtype + self.kv_cache = MLXRingKVCache( + config.sliding_window, + self.n_kv_heads, + self.head_dim, + dtype=cache_dtype, + ) + self.sdpa = MLXSDPA(self.n_heads, self.head_dim) + elif self.backend == "metal": self.kv_cache = StandardRingKVCache( config.sliding_window, self.n_kv_heads, self.head_dim ) @@ -566,7 +728,16 @@ def __init__(self, config: VoxtralRealtimeConfig): self.sdpa = SDPA(self.n_heads, self.head_dim) else: # Flat KV cache for offline mode (capped at max_seq_len). - if self.backend == "metal": + if self.backend == "mlx": + cache_dtype = self.wq.weight.dtype + self.kv_cache = MLXStaticKVCache( + config.max_seq_len, + self.n_kv_heads, + self.head_dim, + dtype=cache_dtype, + ) + self.sdpa = MLXSDPA(self.n_heads, self.head_dim) + elif self.backend == "metal": self.kv_cache = StaticKVCache( config.max_seq_len, self.n_kv_heads, self.head_dim ) @@ -595,7 +766,24 @@ def forward( k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) - q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + if self.backend == "mlx": + start_pos = input_pos[0].item() + q = torch.ops.mlx.rope( + q.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.rope_theta, + ).transpose(1, 2) + k = torch.ops.mlx.rope( + k.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.rope_theta, + ).transpose(1, 2) + else: + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) k, v = self.kv_cache.update(input_pos, k, v) @@ -971,6 +1159,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): self.n_heads = config.enc_n_heads self.head_dim = config.enc_head_dim self.bool_mask = config.backend == "cuda" + self.enc_rope_theta = config.enc_rope_theta # Register conv states as buffers (mutable state for streaming) self.register_buffer("conv1_state", torch.zeros(1, config.num_mel_bins, 2)) @@ -979,32 +1168,57 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): # Ring buffer KV caches for unlimited streaming. # Window size = max_enc_len (encoder sliding window from params.json). # Buffer is 2x internally for safe wraparound. - # Choose cache implementation based on backend - cache_class = ( - StandardRingKVCache if config.backend in ("metal", "cuda") else RingKVCache - ) - self.kv_caches = nn.ModuleList( - [ - cache_class(max_enc_len, config.enc_n_heads, config.enc_head_dim) - for _ in range(config.enc_n_layers) - ] - ) - - # Choose SDPA based on backend - # StandardRingKVCache returns [B, H, S, D] — no transpose needed. - if config.backend == "metal": + # Choose cache and SDPA implementation based on backend + self.backend = config.backend + if config.backend == "mlx": + cache_dtype = self.layers[0].attention.wq.weight.dtype + self.kv_caches = nn.ModuleList( + [ + MLXRingKVCache( + max_enc_len, + config.enc_n_heads, + config.enc_head_dim, + dtype=cache_dtype, + ) + for _ in range(config.enc_n_layers) + ] + ) + self.sdpa = MLXMaskedSDPA(config.enc_n_heads, config.enc_head_dim) + elif config.backend == "metal": + self.kv_caches = nn.ModuleList( + [ + StandardRingKVCache( + max_enc_len, config.enc_n_heads, config.enc_head_dim + ) + for _ in range(config.enc_n_layers) + ] + ) self.sdpa = MetalSDPA( config.enc_n_heads, config.enc_n_heads, config.enc_head_dim, ) elif config.backend == "cuda": + self.kv_caches = nn.ModuleList( + [ + StandardRingKVCache( + max_enc_len, config.enc_n_heads, config.enc_head_dim + ) + for _ in range(config.enc_n_layers) + ] + ) self.sdpa = StandardSDPA( config.enc_n_heads, config.enc_n_heads, config.enc_head_dim, ) else: + self.kv_caches = nn.ModuleList( + [ + RingKVCache(max_enc_len, config.enc_n_heads, config.enc_head_dim) + for _ in range(config.enc_n_layers) + ] + ) self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim) # RoPE inverse frequencies for on-the-fly computation. @@ -1034,7 +1248,25 @@ def _streaming_encoder_layer( k = attn.wk(h).view(B, T, self.n_heads, self.head_dim) v = attn.wv(h).view(B, T, self.n_heads, self.head_dim) - q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + if self.backend == "mlx": + start_pos = input_pos[0].item() + q = torch.ops.mlx.rope( + q.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.enc_rope_theta, + ).transpose(1, 2) + k = torch.ops.mlx.rope( + k.transpose(1, 2), + self.head_dim, + start_pos, + traditional=True, + base=self.enc_rope_theta, + ).transpose(1, 2) + else: + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + k, v = self.kv_caches[layer_idx].update(input_pos, k, v) y = self.sdpa(input_pos, q, k, v, B, T, mask) @@ -1166,16 +1398,21 @@ def load_model( max_seq_len: Maximum sequence length for KV cache (offline mode). n_delay_tokens: Transcription delay in tokens (default 6 = 480ms). dtype: Weight dtype (default: float32). - backend: Backend for acceleration ("xnnpack", "metal", "cuda", or "portable"). + backend: Backend for acceleration ("xnnpack", "mlx", "metal", "cuda", or "portable"). streaming: If True, use ring buffer KV cache for unlimited streaming. sliding_window: Override decoder sliding window size (default: from params.json). """ - _VALID_BACKENDS = ("xnnpack", "metal", "cuda", "portable") + _VALID_BACKENDS = ("xnnpack", "mlx", "metal", "cuda", "portable") + if backend not in _VALID_BACKENDS: raise ValueError( f"Unknown backend '{backend}'. Must be one of {_VALID_BACKENDS}." ) + # Import MLX custom ops for mlx backend + if backend == "mlx": + import executorch.backends.mlx.custom_ops as _mlx_custom_ops # noqa: F401 + from safetensors import safe_open model_dir = Path(model_path) diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index 4d7180854f1..a0ff0a3f020 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -6,11 +6,11 @@ import argparse import logging +import os import torch import torch.nn as nn import torch.nn.functional as F - from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import ( EdgeCompileConfig, @@ -188,9 +188,9 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: return log_spec.unsqueeze(0) -def export_processor(model=None, output_file="whisper_preprocess.pte"): - if model is None: - model = WhisperAudioProcessor() +def _export_processor_model( + model, output_file="whisper_preprocess.pte", backend="xnnpack" +): if model.streaming: # Streaming processes small windows per step. 2 seconds gives @@ -209,10 +209,17 @@ def export_processor(model=None, output_file="whisper_preprocess.pte"): ) logging.debug(ep) + if backend == "mlx": + from executorch.backends.mlx.partitioner import MLXPartitioner + + partitioner = [MLXPartitioner()] + else: + partitioner = [XnnpackPartitioner()] + # to edge edge: EdgeProgramManager = to_edge_transform_and_lower( ep, - partitioner=[XnnpackPartitioner()], + partitioner=partitioner, compile_config=EdgeCompileConfig( _check_ir_validity=False, ), @@ -221,12 +228,28 @@ def export_processor(model=None, output_file="whisper_preprocess.pte"): # to executorch exec_prog = edge.to_executorch() + os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True) with open(output_file, "wb") as file: exec_prog.write_to_file(file) logging.debug("Done") +def export_processor( + output_file="whisper_preprocess.pte", backend="xnnpack", **model_kwargs +): + """Export a WhisperAudioProcessor to a .pte file. + + Args: + output_file: Output .pte file path. + backend: Backend for partitioning ("xnnpack" or "mlx"). + **model_kwargs: Passed to WhisperAudioProcessor constructor + (e.g. feature_size, max_audio_len, stack_output, streaming). + """ + model = WhisperAudioProcessor(**model_kwargs) + _export_processor_model(model, output_file, backend) + + def main(): parser = argparse.ArgumentParser( description="Export WhisperAudioProcessor to ExecuTorch" @@ -281,9 +304,19 @@ def main(): help="Streaming mode: skip 30-second chunk padding, produce mel frames proportional to input length. For use with real-time audio input.", ) + parser.add_argument( + "--backend", + type=str, + default="xnnpack", + choices=["xnnpack", "mlx"], + help="Backend for partitioning (default: xnnpack)", + ) + args = parser.parse_args() - model = WhisperAudioProcessor( + export_processor( + output_file=args.output_file, + backend=args.backend, feature_size=args.feature_size, sampling_rate=args.sampling_rate, hop_length=args.hop_length, @@ -294,8 +327,6 @@ def main(): streaming=args.streaming, ) - export_processor(model, args.output_file) - if __name__ == "__main__": main() diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index e126ef54456..f4bdfbf1a0d 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -597,6 +597,15 @@ class VgfConfig: compiler_flags: List[str] = field(default_factory=list) +@dataclass +class MLXConfig: + """ + Configures the MLX backend for Apple Silicon. + """ + + enabled: bool = False + + @dataclass class BackendConfig: """ @@ -614,6 +623,7 @@ class BackendConfig: tosa: TosaConfig = field(default_factory=TosaConfig) ethosu: EthosUConfig = field(default_factory=EthosUConfig) vgf: VgfConfig = field(default_factory=VgfConfig) + mlx: MLXConfig = field(default_factory=MLXConfig) ################################################################################ @@ -784,6 +794,12 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "mps"): llm_config.backend.mps.enabled = args.mps + # MLX - auto-enable use_kv_cache when MLX is enabled + if hasattr(args, "mlx"): + llm_config.backend.mlx.enabled = args.mlx + if args.mlx: + llm_config.model.use_kv_cache = True + # Openvino if hasattr(args, "openvino"): llm_config.backend.openvino.enabled = args.openvino diff --git a/extension/llm/export/quantize.py b/extension/llm/export/quantize.py index b372bbb9db8..fb2678ff60f 100644 --- a/extension/llm/export/quantize.py +++ b/extension/llm/export/quantize.py @@ -4,8 +4,8 @@ torch.export(). This is the source-transform counterpart to quantizer_lib.py (which handles PT2E graph-mode quantization). -Supported linear configs: "4w", "8w", "8da4w", "8da8w", "fpa4w" (Metal). -Supported embedding configs: "4w", "8w". +Supported linear configs: "4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4". +Supported embedding configs: "4w", "8w", "nvfp4". Usage: from executorch.extension.llm.export.quantize import quantize_model_ @@ -18,14 +18,168 @@ from executorch.exir._warnings import experimental +def _make_granularity(group_size: int): + """Create PerAxis(0) or PerGroup granularity.""" + from torchao.quantization.granularity import PerAxis, PerGroup + + return PerAxis(0) if group_size == 0 else PerGroup(group_size) + + +def _make_linear_config(config_name: str, group_size: int, packing_format=None): + """Build a TorchAO config for linear layer quantization.""" + from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ) + + granularity = _make_granularity(group_size) + + if config_name == "nvfp4": + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + + assert group_size == 16, "NVFP4 requires group_size=16" + return ExportableNVFP4Config(use_per_tensor_scale=False) + elif config_name == "4w": + if packing_format: + return Int4WeightOnlyConfig( + group_size=group_size, + int4_packing_format=packing_format, + int4_choose_qparams_algorithm="hqq", + ) + return IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8da4w": + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8da8w": + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int8, + weight_granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + else: + raise ValueError(f"Unsupported qlinear_config: {config_name}") + + +def _make_embedding_config(config_name: str, group_size: int): + """Build a TorchAO config for embedding layer quantization.""" + from torchao.quantization.quant_api import IntxWeightOnlyConfig + + if group_size != 0: + assert group_size % 2 == 0, "Embedding group size must be a multiple of 2." + + granularity = _make_granularity(group_size) + + if config_name == "nvfp4": + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + + return ExportableNVFP4Config(use_per_tensor_scale=False) + elif config_name == "4w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + elif config_name == "8w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=granularity, + intx_choose_qparams_algorithm="hqq_scale_only", + ) + else: + raise ValueError(f"Unsupported qembedding_config: {config_name}") + + +def _check_shape_compatible(m, fqn, config_name, group_size, skip_incompatible_shapes): + """Check shape compatibility. Returns True if compatible, False if skipped. + + Raises RuntimeError if incompatible and skip_incompatible_shapes is False. + """ + shape = m.weight.shape + if config_name == "nvfp4": + compatible = shape[-2] % group_size == 0 and shape[-1] % group_size == 0 + elif group_size != 0: + compatible = shape[-1] % group_size == 0 + else: + compatible = True + + if compatible: + return True + if not skip_incompatible_shapes: + raise RuntimeError( + f"Layer {fqn} has weight shape {shape} " + f"incompatible with {config_name} (group_size={group_size}). " + f"Use skip_incompatible_shapes=True to skip instead of failing." + ) + print( + f" Skipping {fqn}: weight shape {shape} " + f"incompatible with {config_name} (group_size={group_size})" + ) + return False + + +def _make_linear_filter( + config_name: str, group_size: int, skip_incompatible_shapes: bool = False +): + """Create a filter_fn for linear layers, skipping incompatible shapes.""" + + def linear_filter(m, fqn): + if not isinstance(m, torch.nn.Linear): + return False + return _check_shape_compatible( + m, fqn, config_name, group_size, skip_incompatible_shapes + ) + + return linear_filter + + +def _make_embedding_filter( + config_name: str, group_size: int, skip_incompatible_shapes: bool = False +): + """Create a filter_fn for embedding layers, skipping incompatible shapes.""" + + def embedding_filter(m, fqn): + if not isinstance(m, torch.nn.Embedding): + return False + return _check_shape_compatible( + m, fqn, config_name, group_size, skip_incompatible_shapes + ) + + return embedding_filter + + +def _default_group_size(config_name: Optional[str]) -> int: + """Return the default group size for a quantization config.""" + if config_name == "nvfp4": + return 16 + if config_name in ("8w", "8da8w"): + return 0 + return 32 + + @experimental("quantize_model_ is experimental and may change without notice.") -def quantize_model_( # noqa: C901 +def quantize_model_( module: torch.nn.Module, qlinear_config: Optional[str] = None, - qlinear_group_size: int = 32, + qlinear_group_size: Optional[int] = None, qlinear_packing_format: Optional[str] = None, qembedding_config: Optional[str] = None, - qembedding_group_size: int = 0, + qembedding_group_size: Optional[int] = None, + tie_word_embeddings: bool = False, + skip_incompatible_shapes: bool = False, ) -> None: """Quantize linear and embedding layers in a module in-place. @@ -36,20 +190,32 @@ def quantize_model_( # noqa: C901 Args: module: The PyTorch module to quantize. qlinear_config: Quantization config for linear layers - ("4w", "8w", "8da4w", "8da8w", "fpa4w"). - qlinear_group_size: Group size for linear quantization (default: 32). + ("4w", "8w", "8da4w", "8da8w", "fpa4w", "nvfp4"). + qlinear_group_size: Group size for linear quantization. + Defaults to 16 for nvfp4, 32 for 4w/8da4w, 0 (per-axis) for 8w/8da8w. qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). - qembedding_config: Quantization config for embedding layers ("4w", "8w"). - qembedding_group_size: Group size for embedding quantization - (default: 0 = per-axis). + qembedding_config: Quantization config for embedding layers + ("4w", "8w", "nvfp4"). + qembedding_group_size: Group size for embedding quantization. + Defaults to 16 for nvfp4, 32 for 4w, 0 (per-axis) for 8w. + tie_word_embeddings: If True and both linear and embedding use the + same quantization, re-tie lm_head.weight to embed_tokens.weight + after quantization. + skip_incompatible_shapes: If True, silently skip layers with + incompatible weight shapes. If False (default), raise RuntimeError. """ if not qlinear_config and not qembedding_config: return + if qlinear_group_size is None: + qlinear_group_size = _default_group_size(qlinear_config) + if qembedding_group_size is None: + qembedding_group_size = _default_group_size(qembedding_config) + from torchao.quantization.quant_api import quantize_ - # Metal (MPS) quantization uses different API + # Metal (MPS) quantization uses a separate API if qlinear_config == "fpa4w": import torchao.experimental.ops.mps # noqa: F401 from torchao.experimental.quant_api import UIntxWeightOnlyConfig @@ -59,111 +225,59 @@ def quantize_model_( # noqa: C901 bitwidth=4, uintx_choose_qparams_algorithm="hqq", ) - - def linear_filter(m, fqn): - if isinstance(m, torch.nn.Linear): - if m.weight.shape[1] % qlinear_group_size != 0: - raise ValueError( - f"Metal int4 quantization requires weight dimension (K) " - f"to be multiple of group_size. Layer {fqn} has weight " - f"shape {m.weight.shape} (K={m.weight.shape[1]}, " - f"group_size={qlinear_group_size})" - ) - return True - return False - print( f" Applying {qlinear_config} linear quantization " f"(group_size={qlinear_group_size})..." ) - quantize_(module, config, filter_fn=linear_filter) + quantize_( + module, + config, + filter_fn=_make_linear_filter( + "fpa4w", qlinear_group_size, skip_incompatible_shapes + ), + ) return - from torchao.quantization.granularity import PerAxis, PerGroup - from torchao.quantization.quant_api import ( - Int4WeightOnlyConfig, - Int8DynamicActivationIntxWeightConfig, - IntxWeightOnlyConfig, - ) - # Quantize embedding layers first if qembedding_config: - if qembedding_group_size == 0: - embedding_granularity = PerAxis(0) - else: - assert ( - qembedding_group_size % 2 == 0 - ), "Embedding group size must be a multiple of 2." - embedding_granularity = PerGroup(qembedding_group_size) - - embedding_config = IntxWeightOnlyConfig( - weight_dtype=torch.int4 if qembedding_config == "4w" else torch.int8, - granularity=embedding_granularity, - ) - + config = _make_embedding_config(qembedding_config, qembedding_group_size) print( f" Applying {qembedding_config} embedding quantization " f"(group_size={qembedding_group_size})..." ) quantize_( module, - embedding_config, - lambda m, fqn: isinstance(m, torch.nn.Embedding), + config, + filter_fn=_make_embedding_filter( + qembedding_config, qembedding_group_size, skip_incompatible_shapes + ), ) # Quantize linear layers if qlinear_config: - if qlinear_group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(qlinear_group_size) - - if qlinear_config == "4w": - if qlinear_packing_format: - config = Int4WeightOnlyConfig( - group_size=qlinear_group_size, - int4_packing_format=qlinear_packing_format, - int4_choose_qparams_algorithm="hqq", - ) - else: - config = IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=granularity, - ) - elif qlinear_config == "8w": - config = IntxWeightOnlyConfig( - weight_dtype=torch.int8, - granularity=granularity, - ) - elif qlinear_config == "8da4w": - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, - weight_granularity=granularity, - intx_choose_qparams_algorithm="hqq_scale_only", - ) - elif qlinear_config == "8da8w": - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int8, - weight_granularity=PerAxis(0), - ) - else: - raise ValueError(f"Unsupported qlinear_config: {qlinear_config}") - - def linear_filter(m, fqn): - if isinstance(m, torch.nn.Linear): - if qlinear_group_size == 0: - return True - if m.weight.shape[1] % qlinear_group_size != 0: - print( - f" Skipping {fqn}: weight shape {m.weight.shape} " - f"incompatible with group_size={qlinear_group_size}" - ) - return False - return True - return False - + config = _make_linear_config( + qlinear_config, qlinear_group_size, qlinear_packing_format + ) print( f" Applying {qlinear_config} linear quantization " f"(group_size={qlinear_group_size}, packing={qlinear_packing_format})..." ) - quantize_(module, config, filter_fn=linear_filter) + quantize_( + module, + config, + filter_fn=_make_linear_filter( + qlinear_config, qlinear_group_size, skip_incompatible_shapes + ), + ) + + # Re-tie word embeddings after quantization if both use the same config + if ( + tie_word_embeddings + and qlinear_config == qembedding_config + and hasattr(module, "lm_head") + and hasattr(module, "model") + ): + embed = getattr(module.model, "embed_tokens", None) + if embed is not None: + module.lm_head.weight = embed.weight + print(" Re-tied lm_head weights to embedding (tie_word_embeddings=True)")