Skip to content

mobilenetv2 doesn't work with Vulkan backend #6516

@sternezsl

Description

@sternezsl

🐛 Describe the bug

I can successfully export the vulkan pte. When I run the model with

./backends/vulkan/vulkan_executor_runner --model_path /scratch/models/vulkan_mobilenetv2.pte

I get the error:

I 00:00:00.001717 executorch:executor_runner.cpp:82] Model file /scratch/models/vulkan_mobilenetv2.pte is loaded.
I 00:00:00.001730 executorch:executor_runner.cpp:91] Using method forward
I 00:00:00.001732 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 606112.
libc++abi: terminating due to uncaught exception of type vkcompute::vkapi::Error: Exception raised from check_conv_args at /scratch/code/meta/executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp:225: (check_packed_dim_is(in, WHCN::kChannelsDim)) is false!
[1] 544559 IOT instruction (core dumped) ./backends/vulkan/vulkan_executor_runner --model_path

I inspect the code and find that the input tensor's packed_dim is 0 rather than 2(WHCN::kChannelsDim). If I comment out the check_conv_args function, then I run into another problem:

I 00:00:00.000351 executorch:executor_runner.cpp:82] Model file /scratch/models/vulkan_mobilenetv2.pte is loaded.
I 00:00:00.000358 executorch:executor_runner.cpp:91] Using method forward
I 00:00:00.000360 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 606112.
libc++abi: terminating due to uncaught exception of type vkcompute::vkapi::Error: Exception raised from mean at /scratch/code/meta/executorch/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp:112: (dims_list->size() == 1) is false!

I noticed that @SS-JIA replaced the mobilenet demo with a Add one in the tutorial a few days ago. I guess you know the problem. I try to fix the problem. unfortunately, presently I do not know much about Vulkan backend, could you please give me some hints and I'll try to fix it.

The following is the model conversion script:

import torch
import torchvision.models as models

from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.backend_api import to_backend
from executorch.backends.vulkan import VulkanPartitioner

mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
edge: EdgeProgramManager = to_edge_transform_and_lower(
    exported_program,
    partitioner=[VulkanPartitioner()],
)

# print(edge.exported_program().graph_module)

exec_prog = edge.to_executorch()

with open("vulkan_mobilenetv2.pte", "wb") as file:
    exec_prog.write_to_file(file)

Versions

PyTorch version: 2.5.0+git8a0ce38
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
VULKAN used to build PyTorch: True

OS: Fedora Linux Asahi Remix 40 (KDE Plasma) (aarch64)
GCC version: (GCC) 14.2.1 20240912 (Red Hat 14.2.1-3)
Clang version: 18.1.7 (https://github.com/llvm/llvm-project 768118d1ad38bf13c545828f67bd6b474d61fc55)
CMake version: version 3.20.0
Libc version: glibc-2.39
Vulkan Driver Version: Mesa 24.3.0-devel (git-f05157f591),
Vulkan Instance Version: 1.3.290

Python version: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 19:56:21) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-6.11.3-dimilar-4-1-edge-ARCH+-aarch64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU: Apple M1 Max

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==23.9.16
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.25.0
[pip3] pytorch-sphinx-theme==0.0.19
[pip3] torchao==0.7.0+git6b529961
[pip3] torchvision==0.20.0a0+f851df1
[conda] numpy 1.25.0 pypi_0 pypi
[conda] pytorch-sphinx-theme 0.0.19 pypi_0 pypi
[conda] torchao 0.7.0+git6b529961 pypi_0 pypi
[conda] torchfix 0.5.0 pypi_0 pypi
[conda] torchvision 0.20.0a0+f851df1 pypi_0 pypi

cc @SS-JIA @manuelcandales

Metadata

Metadata

Assignees

Labels

module: vulkanIssues related to the Vulkan delegate and code under backends/vulkan/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions