Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cortex_m/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
)

# Generate C++ bindings to register kernels into Executorch
Expand Down
9 changes: 9 additions & 0 deletions backends/cortex_m/ops/op_quantized_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ Tensor& quantized_add_out(
output_mult,
output_shift_val);

// Note 1: The CMSIS-NN kernel implementation uses offsets which are always
// added to the data, whereas zero_points are subtracted when dequantizing
// (for the inputs) and added when quantizing (for the output). Hence the
// negative signs required here.

// Note 2: It is not possible to perform the same rewrite as for mul for
// addition. To preserve precision when rescaling the inputs, they are first
// upscaled as much as possible, Hence the left_shift parameter required here.

// Call CMSIS-NN kernel with precomputed parameters
arm_cmsis_nn_status status = arm_elementwise_add_s8(
input1_int8.const_data_ptr<int8_t>(),
Expand Down
102 changes: 102 additions & 0 deletions backends/cortex_m/ops/op_quantized_mul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright 2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "cortex_m_ops_common.h"

// Include CMSIS-NN headers with C linkage
extern "C" {
#include "arm_nnfunctions.h"
}

namespace cortex_m {
namespace native {
namespace {

constexpr int32_t kInt8ActivationMin = std::numeric_limits<int8_t>::min();
constexpr int32_t kInt8ActivationMax = std::numeric_limits<int8_t>::max();

} // namespace

using KernelRuntimeContext = torch::executor::KernelRuntimeContext;

Tensor& quantized_mul_out(
KernelRuntimeContext& context,
const Tensor& input1_int8,
const Scalar& input1_zero_point,
const Tensor& input2_int8,
const Scalar& input2_zero_point,
const Scalar& output_zero_point,
const Scalar& output_multiplier,
const Scalar& output_shift,
Tensor& out) {
// Validate tensor types and quantization parameters
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);

const Scalar kIdentityMultiplier(/*value=*/1);
const Scalar kZeroShift(/*value=*/0);
validate_quantization_params(
input1_zero_point,
kIdentityMultiplier,
kZeroShift,
input2_zero_point,
kIdentityMultiplier,
kZeroShift,
output_zero_point,
output_multiplier,
output_shift,
out);

// Extract quantization parameters
const int32_t zp1 = extractScalarToInt32(input1_zero_point);
const int32_t zp2 = extractScalarToInt32(input2_zero_point);
const int32_t out_zp = extractScalarToInt32(output_zero_point);
const int32_t output_mult = extractScalarToInt32(output_multiplier);
const int32_t output_shift_val = extractScalarToInt32(output_shift);

// Note 1: The CMSIS-NN kernel implementation uses offsets which are always
// added to the data, whereas zero_points are subtracted when dequantizing
// (for the inputs) and added when quantizing (for the output). Hence the
// negative signs required here.

// Note 2: The following rewrite is used
// yq = y / scale_out + zp_out
// y = x_1*x_2
// x_i = scale_in_i * (xq_i - xq_i), i = 1, 2
// ==>
// yq = (xq_1 - zp_in1) * (xq_2 - zp_in_2) * effective_scale + zp_out
// where
// effective_scale = (scale_in1 * scale_in2 / scale_out)
// Hence no input quantization params required here.

// Call CMSIS-NN elementwise multiply kernel
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
input1_int8.const_data_ptr<int8_t>(),
input2_int8.const_data_ptr<int8_t>(),
-static_cast<int32_t>(zp1),
-static_cast<int32_t>(zp2),
out.mutable_data_ptr<int8_t>(),
static_cast<int32_t>(out_zp),
output_mult,
output_shift_val,
kInt8ActivationMin,
kInt8ActivationMax,
static_cast<int32_t>(out.numel()));

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
status);
context.fail(Error::Internal);
return out;
}

return out;
}

} // namespace native
} // namespace cortex_m
70 changes: 70 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def quantized_add_meta(
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)

Expand All @@ -156,6 +160,10 @@ def quantized_add_impl(
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8
self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift)

Expand All @@ -168,6 +176,68 @@ def quantized_add_impl(
return result


# ===================================================================
# QUANTIZED MUL OPERATION DEFINITION
# ===================================================================
lib.define(
"quantized_mul("
"Tensor self, Scalar self_zero_point, "
"Tensor other, Scalar other_zero_point, "
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
)
lib.define(
"quantized_mul.out("
"Tensor self, Scalar self_zero_point, "
"Tensor other, Scalar other_zero_point, "
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
"*, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::quantized_mul")
def quantized_mul_meta(
self: torch.Tensor,
self_zero_point: int,
other: torch.Tensor,
other_zero_point: int,
output_zero_point: int,
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
# Broadcast to output shape
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)


@impl(lib, "quantized_mul", "CompositeExplicitAutograd")
def quantized_mul_impl(
self: torch.Tensor,
self_zero_point: int,
other: torch.Tensor,
other_zero_point: int,
output_zero_point: int,
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
# CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and
# only uses the output multiplier/shift for rescaling. Mirror that here to
# keep the composite implementation numerically aligned with the backend.
assert self.shape == other.shape, (
"Cortex-M quantized_mul: broadcasting is not yet supported — "
f"got self.shape={self.shape}, other.shape={other.shape}"
)
self_int = self.to(torch.int32) - self_zero_point
other_int = other.to(torch.int32) - other_zero_point
result_fp = self_int * other_int
result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift)
result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8)
return result


# ===================================================================
# QUANTIZED LINEAR OPERATION DEFINITION
# ===================================================================
Expand Down
8 changes: 7 additions & 1 deletion backends/cortex_m/ops/operators.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
- arg_meta: null
kernel_name: cortex_m::quantized_add_out

- func: cortex_m::quantized_mul.out(Tensor self, Scalar self_zero_point, Tensor other, Scalar other_zero_point, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: cortex_m::quantized_mul_out

- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: cortex_m::quantized_linear_out
kernel_name: cortex_m::quantized_linear_out
2 changes: 0 additions & 2 deletions backends/cortex_m/passes/cortex_m_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


from executorch.backends.arm._passes import (
DecorateFp32toInt32CastingPass,
FoldAndAnnotateQParamsPass,
ScalarsToAttributePass,
)
Expand All @@ -29,7 +28,6 @@ class CortexMPassManager(XNNPACKPassManager):
ReplaceQuantNodesPass,
QuantizedOpFusionPass,
QuantizedLinearFusionPass,
DecorateFp32toInt32CastingPass,
]

pass_list_transform_for_annotation: list[ExportPass] = [
Expand Down
34 changes: 26 additions & 8 deletions backends/cortex_m/passes/passes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,32 @@ def requantize_cmsis(
multiplier: int,
shift: int,
) -> torch.Tensor:
"""
Simulate CMSIS-NN fixed-point requantization:
result = round(tensor * multiplier / (2 ^ shift))
with double rounding
"""
multiplied = torch.round(tensor.to(torch.int64) * multiplier)
shifted = torch.round(multiplied / (2 ** (31 - shift)))
return shifted.to(torch.int32)
"""Simulate CMSIS-NN's arm_nn_requantize helper."""

tensor_64 = tensor.to(torch.int64)
left_shift = max(shift, 0)
right_shift = max(-shift, 0)

# Equivalent to val * (1 << LEFT_SHIFT(shift))
value = tensor_64 << left_shift

# arm_nn_doubling_high_mult_no_sat(value, multiplier)
product = value * int(multiplier)
product = product + (1 << 30)
result = product >> 31

if right_shift:
remainder_mask = (1 << right_shift) - 1
remainder = torch.bitwise_and(result, remainder_mask)
result = result >> right_shift
threshold = remainder_mask >> 1
threshold_tensor = torch.full_like(result, threshold, dtype=torch.int64)
threshold_tensor = torch.where(
result < 0, threshold_tensor + 1, threshold_tensor
)
result = result + torch.where(remainder > threshold_tensor, 1, 0)

return result.to(torch.int32)


def extract_scalar_value(node_arg) -> float:
Expand Down
27 changes: 27 additions & 0 deletions backends/cortex_m/passes/quantized_op_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ def _get_add_replacement(self, args, meta):

return exir_ops.edge.cortex_m.quantized_add.default, args

def _get_mul_replacement(self, args, meta) -> int:

# Extract values
scale1 = meta["input_qparams"][0].scale
zero_point1 = meta["input_qparams"][0].zp
scale2 = meta["input_qparams"][1].scale
zero_point2 = meta["input_qparams"][1].zp
output_scale = meta["output_qparams"][0].scale
output_zero_point = meta["output_qparams"][0].zp

scale_factor = (scale1 * scale2) / output_scale
output_mult, output_shift = quantize_multiplier_aot(scale_factor)

args = (
args[0],
zero_point1,
args[1],
zero_point2,
output_zero_point,
output_mult,
output_shift,
)

return exir_ops.edge.cortex_m.quantized_mul.default, args

def call_operator(
self,
op: EdgeOpOverload,
Expand All @@ -80,6 +105,8 @@ def call_operator(
match op:
case exir_ops.edge.aten.add.Tensor:
op, args = self._get_add_replacement(args, meta)
case exir_ops.edge.aten.mul.Tensor:
op, args = self._get_mul_replacement(args, meta)
case _:
pass

Expand Down
1 change: 1 addition & 0 deletions backends/cortex_m/quantizer/operator_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# ----------------- OPERATOR PATTERN PRESETS -----------------
BINARY_OP_PATTERNS = [
[torch.ops.aten.add.Tensor],
[torch.ops.aten.mul.Tensor],
]

LINEAR_OP_PATTERNS = [
Expand Down
5 changes: 2 additions & 3 deletions backends/cortex_m/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from typing import Callable, List, Optional

import torch

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor

from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
from executorch.backends.cortex_m.quantizer.operator_configs import (
BINARY_OP_PATTERNS,
INT8_BINARY_OPS_OPERATOR_CONFIG,
INT8_LINEAR_OPERATOR_CONFIG,
)
Expand All @@ -37,7 +36,7 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool:
"""
if node is None:
return False
if node.target not in [torch.ops.aten.add.Tensor]:
if [node.target] not in BINARY_OP_PATTERNS:
return False

if len(node.all_input_nodes) == 2:
Expand Down
Loading
Loading