From 9514e2837ba460c0e10718b345f6f45e73477a6b Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Tue, 18 Nov 2025 11:14:24 +0100 Subject: [PATCH] Arm backend: Move support_extension to base class Move support_extension to TosaSpecification base class to avoid having to check whether the TosaSpecification is an instance of TosaSpecification_1_00. Signed-off-by: Oscar Andersson Change-Id: Iccd28879ad4156d70fc5113d6f7e21870a30efa4 --- .../decompose_int16_activation_conv2d_pass.py | 6 ++---- backends/arm/tosa/specification.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index d12904bbcb9..2f160474c5b 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00 +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -40,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta): if args[0].data.dtype == torch.int8: return super().call_operator(op, args, kwargs, meta) elif args[0].data.dtype == torch.int16: - if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension( - "int16" - ): + if not tosa_spec.support_extension("int16"): raise ValueError( "int16 activation for convolution requires TOSA int16 extension" ) diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 6fca2163d41..c6c79f9ad9a 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -105,6 +105,18 @@ def support_float(self) -> bool: """Return True if floating-point operations are supported.""" raise NotImplementedError + def support_extension(self, extension: str) -> bool: + """Return True if an extension is supported and enabled. + + Args: + extension (str): Extension name (for example ``int4``, ``bf16``). + + Returns: + bool: True if the extension is valid for the active profiles and selected. + + """ + raise NotImplementedError + def __init__(self, version: Version, extras: List[str]): """Initialize the base specification.