Skip to content

call_module is not supported error with _guards_fn durinע quantization with Arm backend #16248

@ofirgo

Description

@ofirgo

🐛 Describe the bug

When attempting to quantize a ViT-B model (from torchvision) using ExecuTorch's ARM backend quantizer, the following error occurs during the prepare_pt2e step:

executorch.exir.pass_base.ExportPassBaseError: call_module is not supported. While executing %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})

The error happens in the DecomposeRoundPass within the ARM pass manager's transform pipeline.

Reproduction Steps

import torch
from executorch.backends.arm.ethosu import EthosUCompileSpec
from executorch.backends.arm.quantizer import EthosUQuantizer, get_symmetric_quantization_config
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

batch_size = 4
example_input = torch.randn(batch_size, 3, 224, 224)

weights = ViT_B_16_Weights.IMAGENET1K_V1
model = vit_b_16(weights=weights)
model.eval()

exported_program = torch.export.export(model, (example_input,))
graph_model = exported_program.module()

compile_spec = EthosUCompileSpec(
    target="ethos-u55-128",
    system_config="Ethos_U55_High_End_Embedded",
    memory_mode="Shared_Sram",
)

quantizer = EthosUQuantizer(compile_spec)
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)

# Error occurs here
prepared_model = prepare_pt2e(graph_model, quantizer)

When attempting to disable guards, by replacing the exported module with:

exported_program = torch.export.export(model, (example_input,))
graph_model = exported_program.module(check_guards=False)

I then encountered a different problem when trying to move the model to eval mode after converting:

from torchao.quantization.pt2e import move_exported_model_to_eval

pared_model = prepare_pt2e(graph_model, quantizer)
with torch.no_grad():
    prepared_model(example_input)

quantized_graph_model = convert_pt2e(prepared_model, fold_quantize=True)
quantized_graph_model = move_exported_model_to_eval(quantized_graph_model)

Error:
KeyError: _guards_fn

Versions

PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] executorch==1.0.1
[pip3] numpy==2.3.5
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.9.0
[pip3] torchao==0.14.0
[pip3] torchvision==0.24.0
[conda] Could not collect

cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai

Metadata

Metadata

Assignees

No one assigned

    Labels

    partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions