From 1348f0c37f4f375dce83c39bb6fc49a0dc0e0872 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 3 Oct 2025 14:41:53 +0200 Subject: [PATCH 1/3] Cortex_m backend: Add mul op Signed-off-by: Adrian Lundell Change-Id: Ic116e5294d9362f3a43655629d2a3c0f338a2fd5 --- backends/cortex_m/CMakeLists.txt | 1 + backends/cortex_m/ops/op_quantized_mul.cpp | 87 +++++++++++++++++++ backends/cortex_m/ops/operators.py | 54 ++++++++++++ backends/cortex_m/ops/operators.yaml | 8 +- .../cortex_m/passes/cortex_m_pass_manager.py | 2 - backends/cortex_m/passes/passes_utils.py | 34 ++++++-- .../passes/quantized_op_fusion_pass.py | 27 ++++++ .../cortex_m/quantizer/operator_configs.py | 1 + backends/cortex_m/quantizer/quantizer.py | 5 +- backends/cortex_m/test/ops/test_mul.py | 39 ++++++--- backends/cortex_m/test/tester.py | 1 + 11 files changed, 233 insertions(+), 26 deletions(-) create mode 100644 backends/cortex_m/ops/op_quantized_mul.cpp diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index a728584e49c..c695ee615d5 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp ) # Generate C++ bindings to register kernels into Executorch diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp new file mode 100644 index 00000000000..db56690e411 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -0,0 +1,87 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { +namespace { + +constexpr int32_t kInt8ActivationMin = std::numeric_limits::min(); +constexpr int32_t kInt8ActivationMax = std::numeric_limits::max(); + +} // namespace + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& quantized_mul_out( + KernelRuntimeContext& context, + const Tensor& input1_int8, + const Scalar& input1_zero_point, + const Tensor& input2_int8, + const Scalar& input2_zero_point, + const Scalar& output_zero_point, + const Scalar& output_multiplier, + const Scalar& output_shift, + Tensor& out) { + // Validate tensor types and quantization parameters + validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out); + + const Scalar kIdentityMultiplier(/*value=*/1); + const Scalar kZeroShift(/*value=*/0); + validate_quantization_params( + input1_zero_point, + kIdentityMultiplier, + kZeroShift, + input2_zero_point, + kIdentityMultiplier, + kZeroShift, + output_zero_point, + output_multiplier, + output_shift, + out); + + // Extract quantization parameters + const int32_t zp1 = extractScalarToInt32(input1_zero_point); + const int32_t zp2 = extractScalarToInt32(input2_zero_point); + const int32_t out_zp = extractScalarToInt32(output_zero_point); + const int32_t output_mult = extractScalarToInt32(output_multiplier); + const int32_t output_shift_val = extractScalarToInt(output_shift); + + // Call CMSIS-NN elementwise multiply kernel + arm_cmsis_nn_status status = arm_elementwise_mul_s8( + input1_int8.const_data_ptr(), + input2_int8.const_data_ptr(), + -static_cast(zp1), + -static_cast(zp2), + out.mutable_data_ptr(), + static_cast(out_zp), + output_mult, + output_shift_val, + kInt8ActivationMin, + kInt8ActivationMax, + static_cast(out.numel())); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index b8abfb9bde4..dea7c848b71 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -168,6 +168,60 @@ def quantized_add_impl( return result +# =================================================================== +# QUANTIZED MUL OPERATION DEFINITION +# =================================================================== +lib.define( + "quantized_mul(" + "Tensor self, Scalar self_zero_point, " + "Tensor other, Scalar other_zero_point, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" +) +lib.define( + "quantized_mul.out(" + "Tensor self, Scalar self_zero_point, " + "Tensor other, Scalar other_zero_point, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, " + "*, Tensor(a!) out) -> Tensor(a!)" +) + + +@register_fake("cortex_m::quantized_mul") +def quantized_mul_meta( + self: torch.Tensor, + self_zero_point: int, + other: torch.Tensor, + other_zero_point: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, +) -> torch.Tensor: + # Broadcast to output shape + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) + + +@impl(lib, "quantized_mul", "CompositeExplicitAutograd") +def quantized_mul_impl( + self: torch.Tensor, + self_zero_point: int, + other: torch.Tensor, + other_zero_point: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, +) -> torch.Tensor: + # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and + # only uses the output multiplier/shift for rescaling. Mirror that here to + # keep the composite implementation numerically aligned with the backend. + self_int = self.to(torch.int32) - self_zero_point + other_int = other.to(torch.int32) - other_zero_point + result_fp = self_int * other_int + result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) + result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + return result + + # =================================================================== # QUANTIZED LINEAR OPERATION DEFINITION # =================================================================== diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 98d8df8797e..08a8e1ad1b5 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -23,8 +23,14 @@ - arg_meta: null kernel_name: cortex_m::quantized_add_out +- func: cortex_m::quantized_mul.out(Tensor self, Scalar self_zero_point, Tensor other, Scalar other_zero_point, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_mul_out + - func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_linear_out \ No newline at end of file + kernel_name: cortex_m::quantized_linear_out diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 10fb358c70e..2b880f5ed05 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -5,7 +5,6 @@ from executorch.backends.arm._passes import ( - DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, ScalarsToAttributePass, ) @@ -29,7 +28,6 @@ class CortexMPassManager(XNNPACKPassManager): ReplaceQuantNodesPass, QuantizedOpFusionPass, QuantizedLinearFusionPass, - DecorateFp32toInt32CastingPass, ] pass_list_transform_for_annotation: list[ExportPass] = [ diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index b045005d34d..de07db2443a 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -50,14 +50,32 @@ def requantize_cmsis( multiplier: int, shift: int, ) -> torch.Tensor: - """ - Simulate CMSIS-NN fixed-point requantization: - result = round(tensor * multiplier / (2 ^ shift)) - with double rounding - """ - multiplied = torch.round(tensor.to(torch.int64) * multiplier) - shifted = torch.round(multiplied / (2 ** (31 - shift))) - return shifted.to(torch.int32) + """Simulate CMSIS-NN's arm_nn_requantize helper.""" + + tensor_64 = tensor.to(torch.int64) + left_shift = max(shift, 0) + right_shift = max(-shift, 0) + + # Equivalent to val * (1 << LEFT_SHIFT(shift)) + value = tensor_64 << left_shift + + # arm_nn_doubling_high_mult_no_sat(value, multiplier) + product = value * int(multiplier) + product = product + (1 << 30) + result = product >> 31 + + if right_shift: + remainder_mask = (1 << right_shift) - 1 + remainder = torch.bitwise_and(result, remainder_mask) + result = result >> right_shift + threshold = remainder_mask >> 1 + threshold_tensor = torch.full_like(result, threshold, dtype=torch.int64) + threshold_tensor = torch.where( + result < 0, threshold_tensor + 1, threshold_tensor + ) + result = result + torch.where(remainder > threshold_tensor, 1, 0) + + return result.to(torch.int32) def extract_scalar_value(node_arg) -> float: diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index df35c8d626a..9d913adc411 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -64,6 +64,31 @@ def _get_add_replacement(self, args, meta): return exir_ops.edge.cortex_m.quantized_add.default, args + def _get_mul_replacement(self, args, meta) -> int: + + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp + + scale_factor = (scale1 * scale2) / output_scale + output_mult, output_shift = quantize_multiplier_aot(scale_factor) + + args = ( + args[0], + zero_point1, + args[1], + zero_point2, + output_zero_point, + output_mult, + output_shift, + ) + + return exir_ops.edge.cortex_m.quantized_mul.default, args + def call_operator( self, op: EdgeOpOverload, @@ -80,6 +105,8 @@ def call_operator( match op: case exir_ops.edge.aten.add.Tensor: op, args = self._get_add_replacement(args, meta) + case exir_ops.edge.aten.mul.Tensor: + op, args = self._get_mul_replacement(args, meta) case _: pass diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py index 6ffc011df27..2936129819a 100644 --- a/backends/cortex_m/quantizer/operator_configs.py +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -17,6 +17,7 @@ # ----------------- OPERATOR PATTERN PRESETS ----------------- BINARY_OP_PATTERNS = [ [torch.ops.aten.add.Tensor], + [torch.ops.aten.mul.Tensor], ] LINEAR_OP_PATTERNS = [ diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 1309c463b75..8fc88fada56 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -6,13 +6,12 @@ from typing import Callable, List, Optional -import torch - from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.operator_configs import ( + BINARY_OP_PATTERNS, INT8_BINARY_OPS_OPERATOR_CONFIG, INT8_LINEAR_OPERATOR_CONFIG, ) @@ -37,7 +36,7 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool: """ if node is None: return False - if node.target not in [torch.ops.aten.add.Tensor]: + if [node.target] not in BINARY_OP_PATTERNS: return False if len(node.all_input_nodes) == 2: diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py index a2f13760bf0..54a43121eaf 100644 --- a/backends/cortex_m/test/ops/test_mul.py +++ b/backends/cortex_m/test/ops/test_mul.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. -import pytest import torch from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import ( @@ -60,6 +59,16 @@ class CortexMTensorMul(Model): } +class CortexMTensorMulBroadCast(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + + test_cases = { "self_scalar": McuTestCase( CortexMSelfMul(), @@ -91,22 +100,22 @@ class CortexMTensorMul(Model): ), "tensor_scalar": McuTestCase( CortexMScalarMul(), - (torch.ones(2, 2), 1.0), + (torch.ones(1), 1.0), ), "scalar_tensor": McuTestCase( CortexMScalarMul(), - (1000.0, torch.ones(2, 2)), + (1000.0, torch.ones(1)), ), "broadcast_1": McuTestCase( - CortexMTensorMul(), + CortexMTensorMulBroadCast(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - CortexMTensorMul(), + CortexMTensorMulBroadCast(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - CortexMTensorMul(), + CortexMTensorMulBroadCast(), ( ramp_tensor(-2, 2, (2, 1, 2, 1)), ramp_tensor(-5, 5, (1, 2, 1, 2)), @@ -115,17 +124,23 @@ class CortexMTensorMul(Model): } -@pytest.mark.skip(reason="Not implemented yet") -@parametrize("test_case", test_cases) +xfail_cases = { + "self_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars", + "scalar_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars", +} + + +@parametrize("test_case", test_cases, xfails=xfail_cases) def test_dialect_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( - test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, ) -@pytest.mark.skip(reason="Not implemented yet") -@parametrize("test_case", test_cases) +@parametrize("test_case", test_cases, xfails=xfail_cases) def test_implementation_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation() + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index 19de71444cd..010cc7e4ace 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -22,6 +22,7 @@ ToEdge, ToExecutorch, ) + from executorch.exir import EdgeCompileConfig From d0eea50d61fdf7d8f9d0cb869b6248a72283348a Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 7 Nov 2025 13:49:33 +0100 Subject: [PATCH 2/3] Update broadcast tests xfails Signed-off-by: Adrian Lundell Change-Id: I6224ca5a890556ca9cf0929440f7aea8da1b5028 --- backends/cortex_m/ops/operators.py | 12 +++++++++ backends/cortex_m/test/ops/test_add.py | 37 +++++--------------------- backends/cortex_m/test/ops/test_mul.py | 29 ++++++++++---------- 3 files changed, 32 insertions(+), 46 deletions(-) diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index dea7c848b71..d6d6a317505 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -138,6 +138,9 @@ def quantized_add_meta( output_multiplier: int, output_shift: int, ) -> torch.Tensor: + assert ( + self.shape == other.shape + ), "Broadcasting not yet supported in Cortex-M backend." broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) @@ -156,6 +159,9 @@ def quantized_add_impl( output_multiplier: int, output_shift: int, ) -> torch.Tensor: + assert ( + self.shape == other.shape + ), "Broadcasting not yet supported in Cortex-M backend." self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift) @@ -197,6 +203,9 @@ def quantized_mul_meta( output_shift: int, ) -> torch.Tensor: # Broadcast to output shape + assert ( + self.shape == other.shape + ), "Broadcasting not yet supported in Cortex-M backend." broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) @@ -214,6 +223,9 @@ def quantized_mul_impl( # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and # only uses the output multiplier/shift for rescaling. Mirror that here to # keep the composite implementation numerically aligned with the backend. + assert ( + self.shape == other.shape + ), "Broadcasting not yet supported in Cortex-M backend." self_int = self.to(torch.int32) - self_zero_point other_int = other.to(torch.int32) - other_zero_point result_fp = self_int * other_int diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 458d5361347..8c355fd2e39 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -139,7 +139,7 @@ class CortexMAlphaAdd(ModelAlpha): } -dialect_xfails = { +xfails = { "self_scalar": ( "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, @@ -152,10 +152,13 @@ class CortexMAlphaAdd(ModelAlpha): "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError, ), + "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", } -@parametrize("test_case", test_cases, xfails=dialect_xfails) +@parametrize("test_case", test_cases, xfails=xfails) def test_dialect_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -163,35 +166,7 @@ def test_dialect_add(test_case): ) -implementation_xfails = { - "self_scalar": ( - "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", - AttributeError, - ), - "scalar_scalar": ( - "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", - AttributeError, - ), - "broadcast_1": ( - " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", - RuntimeError, - ), - "broadcast_2": ( - " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", - RuntimeError, - ), - "broadcast_3": ( - " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", - RuntimeError, - ), - "alpha": ( - "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", - AssertionError, - ), -} - - -@parametrize("test_case", test_cases, xfails=implementation_xfails) +@parametrize("test_case", test_cases, xfails=xfails) def test_implementation_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py index 54a43121eaf..35c958ce8d4 100644 --- a/backends/cortex_m/test/ops/test_mul.py +++ b/backends/cortex_m/test/ops/test_mul.py @@ -59,16 +59,6 @@ class CortexMTensorMul(Model): } -class CortexMTensorMulBroadCast(Model): - ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, - } - - ops_after_transforms = { - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, - } - - test_cases = { "self_scalar": McuTestCase( CortexMSelfMul(), @@ -107,15 +97,15 @@ class CortexMTensorMulBroadCast(Model): (1000.0, torch.ones(1)), ), "broadcast_1": McuTestCase( - CortexMTensorMulBroadCast(), + CortexMTensorMul(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - CortexMTensorMulBroadCast(), + CortexMTensorMul(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - CortexMTensorMulBroadCast(), + CortexMTensorMul(), ( ramp_tensor(-2, 2, (2, 1, 2, 1)), ramp_tensor(-5, 5, (1, 2, 1, 2)), @@ -125,8 +115,17 @@ class CortexMTensorMulBroadCast(Model): xfail_cases = { - "self_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars", - "scalar_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars", + "self_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "scalar_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", } From d11996976d9e661ffc2a3ad8da66e4a6c2193adf Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 10 Nov 2025 15:53:56 +0100 Subject: [PATCH 3/3] Address PR comments Signed-off-by: Adrian Lundell Change-Id: I58fbd5708b2e5866ec87e3b33e098ee21195a9d8 --- backends/cortex_m/ops/op_quantized_add.cpp | 9 +++++++ backends/cortex_m/ops/op_quantized_mul.cpp | 17 ++++++++++++- backends/cortex_m/ops/operators.py | 28 ++++++++++++---------- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index 30be108ffcb..ddc4b4bb869 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -78,6 +78,15 @@ Tensor& quantized_add_out( output_mult, output_shift_val); + // Note 1: The CMSIS-NN kernel implementation uses offsets which are always + // added to the data, whereas zero_points are subtracted when dequantizing + // (for the inputs) and added when quantizing (for the output). Hence the + // negative signs required here. + + // Note 2: It is not possible to perform the same rewrite as for mul for + // addition. To preserve precision when rescaling the inputs, they are first + // upscaled as much as possible, Hence the left_shift parameter required here. + // Call CMSIS-NN kernel with precomputed parameters arm_cmsis_nn_status status = arm_elementwise_add_s8( input1_int8.const_data_ptr(), diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp index db56690e411..28af8406f87 100644 --- a/backends/cortex_m/ops/op_quantized_mul.cpp +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -55,7 +55,22 @@ Tensor& quantized_mul_out( const int32_t zp2 = extractScalarToInt32(input2_zero_point); const int32_t out_zp = extractScalarToInt32(output_zero_point); const int32_t output_mult = extractScalarToInt32(output_multiplier); - const int32_t output_shift_val = extractScalarToInt(output_shift); + const int32_t output_shift_val = extractScalarToInt32(output_shift); + + // Note 1: The CMSIS-NN kernel implementation uses offsets which are always + // added to the data, whereas zero_points are subtracted when dequantizing + // (for the inputs) and added when quantizing (for the output). Hence the + // negative signs required here. + + // Note 2: The following rewrite is used + // yq = y / scale_out + zp_out + // y = x_1*x_2 + // x_i = scale_in_i * (xq_i - xq_i), i = 1, 2 + // ==> + // yq = (xq_1 - zp_in1) * (xq_2 - zp_in_2) * effective_scale + zp_out + // where + // effective_scale = (scale_in1 * scale_in2 / scale_out) + // Hence no input quantization params required here. // Call CMSIS-NN elementwise multiply kernel arm_cmsis_nn_status status = arm_elementwise_mul_s8( diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index d6d6a317505..e5dd12e82e0 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -138,9 +138,10 @@ def quantized_add_meta( output_multiplier: int, output_shift: int, ) -> torch.Tensor: - assert ( - self.shape == other.shape - ), "Broadcasting not yet supported in Cortex-M backend." + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) @@ -159,9 +160,10 @@ def quantized_add_impl( output_multiplier: int, output_shift: int, ) -> torch.Tensor: - assert ( - self.shape == other.shape - ), "Broadcasting not yet supported in Cortex-M backend." + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift) @@ -203,9 +205,10 @@ def quantized_mul_meta( output_shift: int, ) -> torch.Tensor: # Broadcast to output shape - assert ( - self.shape == other.shape - ), "Broadcasting not yet supported in Cortex-M backend." + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) @@ -223,9 +226,10 @@ def quantized_mul_impl( # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and # only uses the output multiplier/shift for rescaling. Mirror that here to # keep the composite implementation numerically aligned with the backend. - assert ( - self.shape == other.shape - ), "Broadcasting not yet supported in Cortex-M backend." + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) self_int = self.to(torch.int32) - self_zero_point other_int = other.to(torch.int32) - other_zero_point result_fp = self_int * other_int