-
Notifications
You must be signed in to change notification settings - Fork 739
Description
🐛 Describe the bug
During the export of whisper-tiny.en model to ExecuTorch format, I encountered the following problem: one of the matrix multiplication operations is not being delegated to the XNNPACK backend, causing a significant inference speed slowdown (up to 50% of the whole inference time).
The operation is located within the forward method of Decoder module from openai-whisper package:
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
# ISSUE: the below matrix multiplication is not being delegated
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
Steps to reproduce
I use the following code to export the module with static shapes:
import torch
import whisper
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
model = whisper.load_model("tiny.en").decoder
model.eval()
inputs = (
torch.randint(1, 100, (1, 128), dtype=torch.int),
torch.randn(1, 1500, 384, dtype=torch.float32)
)
# Export
exported_program = torch.export.export(model, inputs)
executorch_program = to_edge_transform_and_lower(
exported_program,
partitioner = [XnnpackPartitioner()]
).to_executorch()
with open("whisper/exported/decoder.pte", "wb") as file:
executorch_program.write_to_file(file)
To profile the exported model, I follow the instructions from ExecuTorch docs.
Actual behavior
A single native_call_mm.out operation not delegated to the XNNPACK backend is responsible for approximately 50% of the inference time.
The profiling results are available here.
What I tried
- Using other matrix multiplication function (
matmul()andmm()) - no effect - Replacing matrix multiplication with an equivalent nn.Linear call with static shapes - resolves the delegation issue, and produces a speedup from approximately 145 ms to 80 ms inference time (~45%).
Versions
Collecting environment information...
PyTorch version: 2.10.0.dev20250916
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.7.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: version 4.1.2
Libc version: N/A
Python version: 3.12.11 (main, Jun 3 2025, 15:41:47) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-15.7.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] executorch==1.0.0.dev20250916
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.4
[pip3] optimum-executorch==0.2.0.dev0
[pip3] pytorch_tokenizers==0.1.0
[pip3] torch==2.10.0.dev20250916
[pip3] torchao==0.14.0.dev20250916+cpu
[pip3] torchaudio==2.8.0.dev20250916
[pip3] torchvision==0.25.0.dev20250916
[pip3] openai-whisper==20250625