From aadf80e1d9a8ff43af089806a103eacb69eb1e66 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 18 Mar 2026 10:54:09 -0400 Subject: [PATCH] Voxtral Realtime: enable bf16 for Metal backend with quantization (#17845) The Metal AOTI backend already handles bf16 correctly (fp32 attention masks, fp32 RoPE upcast, dtype-agnostic KV caches and SDPA). Enable --dtype bf16 as the default recipe for Metal CI and update all documentation to recommend bf16 with fpa4w quantization. (cherry picked from commit 202c6af42128ea040e9dfffac20441e98ba7277e) --- .ci/scripts/export_model_artifact.sh | 1 + examples/models/voxtral_realtime/README.md | 1 + .../voxtral_realtime/export_voxtral_rt.py | 2 +- examples/models/voxtral_realtime/model.md | 24 +++++++++++-------- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 220af45a904..f60efeec888 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -308,6 +308,7 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w" elif [ "$QUANT_NAME" = "quantized-int4-metal" ]; then VR_QUANT_ARGS="--qlinear-encoder fpa4w --qlinear fpa4w" + VR_DTYPE_ARGS="--dtype bf16" elif [ "$QUANT_NAME" = "quantized-int4-tile-packed" ]; then VR_QUANT_ARGS="--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d --qlinear 4w --qlinear-packing-format tile_packed_to_4d --qembedding 8w" VR_DTYPE_ARGS="--dtype bf16" diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index 45699133ee4..e2a16f60192 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -84,6 +84,7 @@ python export_voxtral_rt.py \ python export_voxtral_rt.py \ --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ --backend metal \ + --dtype bf16 \ --streaming \ --output-dir ./voxtral_rt_exports \ --qlinear-encoder fpa4w \ diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index 951f1f606d5..68190a1c8c0 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -30,7 +30,7 @@ Usage: python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming - python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal + python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal --dtype bf16 --qlinear-encoder fpa4w --qlinear fpa4w python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend cuda --qlinear 4w """ diff --git a/examples/models/voxtral_realtime/model.md b/examples/models/voxtral_realtime/model.md index 987a18bbcb5..856524cc083 100644 --- a/examples/models/voxtral_realtime/model.md +++ b/examples/models/voxtral_realtime/model.md @@ -74,13 +74,15 @@ or masked-scatter like the original non-realtime Voxtral). ## Memory Footprint -Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × 4 bytes -≈ 832 MB. Encoder KV caches (streaming): 32 layers × 2 × 1500 × 32 × -64 × 4 bytes ≈ 786 MB. +Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem. +fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming): +32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB, +bf16: ≈ 393 MB. Runtime memory = model weights (from `.pte`) + KV caches + working -memory. Weight sizes depend on quantization: ~16 GB (fp32), ~4 GB -(8w), ~2 GB (4w/8da4w). +memory. Weight sizes depend on quantization: ~16 GB (fp32), ~8 GB +(bf16), ~4 GB (8w), ~2 GB (4w/8da4w). Metal and CUDA backends are recommended to use +bf16 (`--dtype bf16`) when quantization is enabled. ## Class Hierarchy @@ -163,8 +165,9 @@ fused kernel with causal masking via `start_pos` + `is_causal=True`. Handles GQA expansion internally and upcasts to float32. **Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps` -which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth -overhead of `repeat_interleave`. Uses explicit additive attention masks +which handles GQA natively (the kernel infers the group ratio from differing +Q vs K/V head counts), avoiding the memory bandwidth overhead of +`repeat_interleave`. Uses explicit additive attention masks that must match the Q/K/V dtype (the kernel reads masks as `device T*`). Used for both decoder (GQA, `transpose_kv=False`) and streaming encoder (no GQA, `transpose_kv=True`). @@ -280,7 +283,7 @@ enabling streaming of arbitrary length audio. 5-8, giving query 5 full access to its window. - Default `max_enc_len=750` (matching the model's trained sliding window). Configurable via `--max-enc-len`. -- Memory: 32 layers × 2 × 1500 × 32 × 64 × 4 bytes ≈ 786 MB (fp32) +- Memory: 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem ≈ 786 MB (fp32), 393 MB (bf16) - Duration: unlimited (ring buffer overwrites old entries, RoPE computed on-the-fly) **Naming note:** `max_enc_len` in `StreamingAudioEncoderExport` (default @@ -370,7 +373,7 @@ Parakeet pattern), allowing different configs for encoder vs decoder: --qlinear 8da4w # decoder linear layers --qembedding 8w # embedding layer -# Metal +# Metal (use --dtype bf16 for reduced memory and improved throughput) --qlinear-encoder fpa4w # encoder linear layers --qlinear fpa4w # decoder linear layers @@ -428,7 +431,8 @@ of ~34 GB for the full-size model): 1. **Meta device construction** — `with torch.device("meta"):` builds the model with zero-storage parameter tensors (shape/dtype metadata only). 2. **safetensors lazy access** — `safe_open` loads tensors on demand, cast - to the configured dtype (`--dtype`, default fp32; CUDA uses bf16). + to the configured dtype (`--dtype`, default fp32; bf16 recommended for + Metal and CUDA with quantization). 3. **`assign=True` state dict loading** — replaces meta tensors by reference instead of copying into pre-allocated storage. No duplication. 4. **Post-load fixups** — re-tie `output.weight = tok_embeddings.weight`