From aa00bbd8992b65f5bddc0b1c97fb478a459861bf Mon Sep 17 00:00:00 2001 From: sidart Date: Thu, 17 Jul 2025 16:49:08 -0700 Subject: [PATCH] Summary: Initial CMSS-NN integration for Quantized Add Op Test Plan: a) Setup for Arm FVP and run 'examples/arm/run.sh' (Check no regressions in e2e test scenarios) b) Then add to run.sh another iteration with qadd with only --quantize flag and see that quantized add op is called c) cd backends/cortex_m/test/; python test_quantize_op_fusion_pass.py ---------------------------------------------------------------------- Ran 9 tests in 11.128s OK Reviewers: Subscribers: Tasks: Tags: --- backends/cortex_m/CMakeLists.txt | 51 ++- backends/cortex_m/ops/TARGETS | 1 + backends/cortex_m/ops/cortex_m_ops_common.h | 141 +++++++ backends/cortex_m/ops/op_quantized_add.cpp | 149 +++++++ backends/cortex_m/ops/operators.py | 133 ++++++- backends/cortex_m/ops/operators.yaml | 12 + backends/cortex_m/passes/TARGETS | 22 +- backends/cortex_m/passes/passes_utils.py | 94 +++++ .../passes/quantized_op_fusion_pass.py | 255 ++++++++++++ backends/cortex_m/test/TARGETS | 14 +- .../test/test_helpers_passes_utils.py | 103 +++++ .../test/test_quantize_op_fusion_pass.py | 369 ++++++++++++++++++ .../cortex_m/test/test_replace_quant_nodes.py | 91 +---- examples/arm/aot_arm_compiler.py | 11 +- examples/arm/run.sh | 18 +- 15 files changed, 1348 insertions(+), 116 deletions(-) create mode 100644 backends/cortex_m/ops/cortex_m_ops_common.h create mode 100644 backends/cortex_m/ops/op_quantized_add.cpp create mode 100644 backends/cortex_m/passes/passes_utils.py create mode 100644 backends/cortex_m/passes/quantized_op_fusion_pass.py create mode 100644 backends/cortex_m/test/test_helpers_passes_utils.py create mode 100644 backends/cortex_m/test/test_quantize_op_fusion_pass.py diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index b198be09ee2..1567b8b5e1c 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -5,11 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Kernel library for Cortex-M operators. Please keep this file formatted by -# running: -# ~~~ -# cmake-format -i CMakeLists.txt -# ~~~ cmake_minimum_required(VERSION 3.19) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -24,29 +19,65 @@ endif() include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) +include(FetchContent) + +# CMSIS-NN version to download +set(CMSIS_NN_VERSION + "v4.1.0" + CACHE STRING "CMSIS-NN version to download" +) + +# Declare CMSIS-NN as a FetchContent project +FetchContent_Declare( + cmsis_nn + GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git + GIT_TAG ${CMSIS_NN_VERSION} +) + +# Download and make CMSIS-NN available +FetchContent_MakeAvailable(cmsis_nn) + +# Print paths for debugging +message(STATUS "CMSIS-NN source dir: ${cmsis_nn_SOURCE_DIR}") +message(STATUS "CMSIS-NN binary dir: ${cmsis_nn_BINARY_DIR}") # Cortex-M ops kernel sources set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp ) -# Generate C++ bindings to register kernels into Executorch (for runtime). Here -# select all ops in operators.yaml +# Generate C++ bindings to register kernels into Executorch (for runtime) set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml) gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}") -# Generate bindings for the kernels generate_bindings_for_kernels( LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}" ) message("Generated files ${gen_command_sources}") -# Build a library for _cortex_m_kernels_srcs +# Build a library for cortex_m_kernels add_library(cortex_m_kernels ${_cortex_m_kernels__srcs}) -target_link_libraries(cortex_m_kernels PRIVATE executorch) target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options}) +# Include directories for cortex_m_kernels +target_include_directories( + cortex_m_kernels + PRIVATE ${EXECUTORCH_ROOT}/.. + ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 + ${cmsis_nn_SOURCE_DIR}/Include +) + +# Link directly to the CMSIS-NN static library file +target_link_libraries( + cortex_m_kernels PUBLIC ${cmsis_nn_BINARY_DIR}/libcmsis-nn.a executorch +) + +# Add dependency to ensure CMSIS-NN builds before we try to link. Use the actual +# CMSIS-NN target name (usually 'cmsis-nn') +add_dependencies(cortex_m_kernels cmsis-nn) + # cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime gen_operators_lib( LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch diff --git a/backends/cortex_m/ops/TARGETS b/backends/cortex_m/ops/TARGETS index e02f096fd83..4d12fa196cf 100644 --- a/backends/cortex_m/ops/TARGETS +++ b/backends/cortex_m/ops/TARGETS @@ -16,6 +16,7 @@ python_library( ], deps = [ "fbcode//caffe2:torch", + "//executorch/backends/cortex_m/passes:passes_utils", ], ) diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h new file mode 100644 index 00000000000..0bde2ddff17 --- /dev/null +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +using Tensor = torch::executor::Tensor; +using ScalarType = executorch::aten::ScalarType; +using Scalar = torch::executor::Scalar; +using Error = executorch::runtime::Error; + +// Basic tensor type / layout validation and dimension order checking +inline void validate_cmsis_nn_tensor_requirements( + const Tensor& input1, + const Tensor& input2, + Tensor& output, + ScalarType expected_dtype = ScalarType::Char, + bool require_channels_last = false) { + // Basic dtype validation + ET_CHECK_MSG( + input1.scalar_type() == expected_dtype, + "Input1 dtype must be %hhd", + expected_dtype); + ET_CHECK_MSG( + input2.scalar_type() == expected_dtype, + "Input2 dtype must be %hhd", + expected_dtype); + ET_CHECK_MSG( + output.scalar_type() == expected_dtype, + "Output dtype must be %hhd", + expected_dtype); + + // Dim order consistency + ET_CHECK_MSG( + executorch::runtime::tensors_have_same_dim_order(input1, input2, output), + "Tensors must have same dimension order"); + + // TBD: Validate memory alignment (CMSIS-NN requirement) +} + +inline void validate_single_quant_params( + const Scalar& zero_point, + const Scalar& multiplier, + const Scalar& shift, + const char* param_name) { + int64_t zp_val = zero_point.to(); + int64_t mult_val = multiplier.to(); + int64_t shift_val = shift.to(); + + ET_CHECK_MSG( + zp_val >= std::numeric_limits::min() && + zp_val <= std::numeric_limits::max(), + "%s zero point must be in int8 range [Value: %d]", + param_name, + zp_val); + + ET_CHECK_MSG( + mult_val >= std::numeric_limits::min() && + mult_val <= std::numeric_limits::max(), + "%s multiplier must be in int32 range [Value: %d]", + param_name, + mult_val); + + ET_CHECK_MSG( + shift_val >= -31 && shift_val <= 31, + "%s shift must be in range [-31, 31] [Value: %d]", + param_name, + shift_val); +} + +/** + * Validate quantization parameters for inputs and output. + * + * Checks that zero points fit in int8 range, multipliers fit in int32 range, + * and shifts are within a valid bit-shift range (0-31). + * + * Ensures parameters comply with Ahead-Of-Time (AOT) quantization requirements + * and CMSIS-NN kernel expectations. + * + * Raises errors via ET_KERNEL_CHECK if any check fails. + */ +inline void validate_quantization_params( + const Scalar& zero_point1, + const Scalar& multiplier1, + const Scalar& shift1, + const Scalar& zero_point2, + const Scalar& multiplier2, + const Scalar& shift2, + const Scalar& output_zero_point, + const Scalar& output_multiplier, + const Scalar& output_shift, + Tensor& output) { + validate_single_quant_params( + zero_point1, multiplier1, shift1, "Single quant Input1"); + validate_single_quant_params( + zero_point2, multiplier2, shift2, "Single quant Input2"); + validate_single_quant_params( + output_zero_point, + output_multiplier, + output_shift, + "Single quant Output"); +} + +inline Error resize_to_broadcast_target_size( + const Tensor& input1, + const Tensor& input2, + Tensor& output) { + static constexpr int kTensorDimensionLimit = 5; + Tensor::SizesType expected_output_size[kTensorDimensionLimit]; + size_t expected_output_dim = 0; + auto err = torch::executor::get_broadcast_target_size( + input1, + input2, + expected_output_size, + kTensorDimensionLimit, + &expected_output_dim); + + if (err != Error::Ok) + return err; + + return executorch::runtime::resize_tensor( + output, {expected_output_size, expected_output_dim}); +} diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp new file mode 100644 index 00000000000..47f6df6bfc5 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +namespace cortex_m { +namespace native { +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& quantized_add_out( + KernelRuntimeContext& context, + const Tensor& input1_int8, + const Scalar& input1_zero_point, + const Scalar& input1_multiplier, + const Scalar& input1_shift, + const Tensor& input2_int8, + const Scalar& input2_zero_point, + const Scalar& input2_multiplier, + const Scalar& input2_shift, + const Scalar& output_zero_point, + const Scalar& output_multiplier, + const Scalar& output_shift, + Tensor& out) { + // Validate tensor types and dim order + validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out); + + // Validate quantization parameters + validate_quantization_params( + input1_zero_point, + input1_multiplier, + input1_shift, + input2_zero_point, + input2_multiplier, + input2_shift, + output_zero_point, + output_multiplier, + output_shift, + out); + + // Broadcast if needed + auto result = resize_to_broadcast_target_size(input1_int8, input2_int8, out); + ET_CHECK_MSG( + (result == Error::Ok), + "Failed to resize output tensor. Status: [%d]", + result); + + ET_LOG( + Info, + "quantized_add_out: input1_int8.sizes() = %zu", + input1_int8.sizes().size()); + + // FIX: Use template types that ExecutorTorch definitely provides + // Use to() and to() which are commonly instantiated + int32_t zp1 = static_cast(input1_zero_point.to()); + int32_t input1_mult = static_cast(input1_multiplier.to()); + int input1_shift_val = static_cast(input1_shift.to()); + + int32_t zp2 = static_cast(input2_zero_point.to()); + int32_t input2_mult = static_cast(input2_multiplier.to()); + int input2_shift_val = static_cast(input2_shift.to()); + + int32_t out_zp = static_cast(output_zero_point.to()); + int32_t output_mult = static_cast(output_multiplier.to()); + int output_shift_val = static_cast(output_shift.to()); + + // Left shift to maximize precision (tune as needed) + const int32_t left_shift = 20; + const int32_t activation_min = std::numeric_limits::min(); + const int32_t activation_max = std::numeric_limits::max(); + + ET_LOG( + Info, + "Using AoT-computed parameters: input1[mult=%d, shift=%d], input2[mult=%d, shift=%d], output[mult=%d, shift=%d]", + input1_mult, + input1_shift_val, + input2_mult, + input2_shift_val, + output_mult, + output_shift_val); + + // Call CMSIS-NN kernel with precomputed parameters + arm_cmsis_nn_status status = arm_elementwise_add_s8( + input1_int8.const_data_ptr(), + input2_int8.const_data_ptr(), + static_cast(zp1), + input1_mult, + input1_shift_val, + static_cast(zp2), + input2_mult, + input2_shift_val, + left_shift, + out.mutable_data_ptr(), + static_cast(out_zp), + output_mult, + output_shift_val, + static_cast(out.numel()), + activation_min, + activation_max); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_add_out: arm_elementwise_add_s8 failed with status [%d]", + status); + + context.fail(Error::Internal); // Fail the execution context + return out; + } + ET_LOG( + Info, + "quantized_add_out: Successfully completed with AoT-computed parameters!"); + + return out; +} + +// Stub Implementation: Non-out variant for compatibility (functional variant) +// EXIR/ExecuTorch runs an out-variant pass that converts +// .default operations to .out variants before memory planning. +// In the pass we are calling quantized_add's default variant +// but ExecuTorch's kernel dispatch mechanism will end up calling the out +// variant. This stub is to make sure that compiler doesn't complain. +Tensor quantized_add( + KernelRuntimeContext& context, + const Tensor& input1_int8, + const Scalar& input1_zero_point, + const Scalar& input1_multiplier, + const Scalar& input1_shift, + const Tensor& input2_int8, + const Scalar& input2_zero_point, + const Scalar& input2_multiplier, + const Scalar& input2_shift, + const Scalar& output_zero_point, + const Scalar& output_multiplier, + const Scalar& output_shift) { + ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes()); + + // Crash on Debug builds if invoked + assert(False); + // This is to make sure compiler doesn't complain. + return const_cast(input1_int8); +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index b5ba3c6ddc8..926dcd85e4b 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -5,11 +5,16 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.exir.dialects._ops import ( - ops as exir_ops, -) # To provide the implementation of the operators +from executorch.backends.cortex_m.passes.passes_utils import ( + dequantize_per_tensor_cmsis, + quantize_per_tensor_cmsis, +) +from executorch.exir.dialects._ops import ops as exir_ops + +# To provide the implementation of the operators from torch.library import impl, Library, register_fake + # New operator library with a custom namespace to allow fusion etc. lib = Library("cortex_m", "DEF") @@ -96,3 +101,125 @@ def dequantize_per_tensor_impl( return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( input, scale, zero_point, quant_min, quant_max, dtype ) + + +# Define the operator schema with multipliers and shifts (11 args) +lib.define( + "quantized_add(" + "Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, " + "Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" +) + + +@register_fake("cortex_m::quantized_add") +def quantized_add_meta( + self: torch.Tensor, + self_zero_point: int, + self_multiplier: int, + self_shift: int, + other: torch.Tensor, + other_zero_point: int, + other_multiplier: int, + other_shift: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, +) -> torch.Tensor: + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) + + +@impl(lib, "quantized_add", "CompositeExplicitAutograd") +def quantized_add_impl( + self: torch.Tensor, + self_zero_point: int, + self_multiplier: int, + self_shift: int, + other: torch.Tensor, + other_zero_point: int, + other_multiplier: int, + other_shift: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, +) -> torch.Tensor: + self_fp = dequantize_per_tensor_cmsis( + self, self_zero_point, self_multiplier, self_shift + ) + other_fp = dequantize_per_tensor_cmsis( + other, other_zero_point, other_multiplier, other_shift + ) + result_fp = self_fp + other_fp + result_quantized = quantize_per_tensor_cmsis( + result_fp, output_zero_point, output_multiplier, output_shift + ) + return result_quantized + + +# Define the operator schema with multipliers and shifts (11 args + out tensor) +lib.define( + "quantized_add.out(" + "Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, " + "Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, " + "*, Tensor(a!) out) -> Tensor(a!)" +) + + +# Fake meta function for shape and dtype inference during compilation +@register_fake("cortex_m::quantized_add.out") +def quantized_add_out_meta( + self: torch.Tensor, + self_zero_point: int, + self_multiplier: int, + self_shift: int, + other: torch.Tensor, + other_zero_point: int, + other_multiplier: int, + other_shift: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, + out: torch.Tensor, +) -> torch.Tensor: + # Validate against correct broadcasted shape + expected_shape = torch.broadcast_shapes(self.shape, other.shape) + assert ( + out.shape == expected_shape + ), f"Output shape {out.shape} must match broadcasted shape {expected_shape}" + return out + + +# Actual implementation delegating to backend or custom kernel +@impl(lib, "quantized_add.out", "CompositeExplicitAutograd") +def quantized_add_out_impl( + self: torch.Tensor, + self_zero_point: int, + self_multiplier: int, + self_shift: int, + other: torch.Tensor, + other_zero_point: int, + other_multiplier: int, + other_shift: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, + *, + out: torch.Tensor, +) -> torch.Tensor: + self_fp = dequantize_per_tensor_cmsis( + self, self_zero_point, self_multiplier, self_shift + ) + other_fp = dequantize_per_tensor_cmsis( + other, other_zero_point, other_multiplier, other_shift + ) + result_fp = self_fp + other_fp + result_quantized = quantize_per_tensor_cmsis( + result_fp, output_zero_point, output_multiplier, output_shift + ) + + # Write into the provided output tensor + out.copy_(result_quantized) + + return out diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 0cc248effaa..f2615a1f525 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -15,3 +15,15 @@ kernels: - arg_meta: null kernel_name: cortex_m::dequantize_per_tensor_out + +- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_add + +- func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_add_out diff --git a/backends/cortex_m/passes/TARGETS b/backends/cortex_m/passes/TARGETS index e5d1fc6149d..dde3ad4f068 100644 --- a/backends/cortex_m/passes/TARGETS +++ b/backends/cortex_m/passes/TARGETS @@ -9,13 +9,27 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") oncall("executorch") python_library( - name = "replace_quant_nodes_pass", - srcs = ["replace_quant_nodes_pass.py"], - deps = [ + name="replace_quant_nodes_pass", + srcs=[ + "replace_quant_nodes_pass.py", + "quantized_op_fusion_pass.py", + ], + deps=[ "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/backends/cortex_m/ops:ops", - ] + "//executorch/backends/cortex_m/passes:passes_utils", + ], +) + +python_library( + name="passes_utils", + srcs=[ + "passes_utils.py", + ], + deps=[ + "fbcode//caffe2:torch", + ], ) diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py new file mode 100644 index 00000000000..3f6e05fc4de --- /dev/null +++ b/backends/cortex_m/passes/passes_utils.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + + +def dequantize_per_tensor_cmsis( + qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int +) -> torch.Tensor: + """ + Simulate CMSIS-NN fixed-point dequantization: + result = (qtensor - zero_point) * multiplier * 2^shift / 2^31 + """ + scale = multiplier * (2**shift) / (1 << 31) + return (qtensor.float() - zero_point) * scale + + +def quantize_per_tensor_cmsis( + tensor: torch.Tensor, + zero_point: int, + multiplier: int, + shift: int, + qmin=-128, + qmax=127, +) -> torch.Tensor: + """ + Simulate CMSIS-NN fixed-point quantization: + result = round(tensor / scale) + zero_point, clamped to [qmin, qmax] + """ + scale = multiplier * (2**shift) / (1 << 31) + quantized = torch.round(tensor / scale) + zero_point + return quantized.clamp(qmin, qmax).to(torch.int8) + + +def extract_scalar_value(node_arg) -> float: + """ + Extract scalar value from various PyTorch scalar representations. + """ + if hasattr(node_arg, "op") and node_arg.op == "get_attr": + # Handle case where scalar is a graph attribute + return float(node_arg.target) + elif isinstance(node_arg, (int, float)): + return float(node_arg) + elif hasattr(node_arg, "item"): + return float(node_arg.item()) + else: + # Try to extract from meta if available + if hasattr(node_arg, "meta") and "val" in node_arg.meta: + val = node_arg.meta["val"] + if hasattr(val, "item"): + return float(val.item()) + return float(val) + raise ValueError( + f"Cannot extract scalar value from {type(node_arg)}: {node_arg}" + ) + + +def is_qualified_int8_node(args) -> bool: + try: + if len(args) < 6: + return False + qmin = int(args[3]) + qmax = int(args[4]) + dtype_str = str(args[5]) + is_int8_range = ( + qmin >= torch.iinfo(torch.int8).min and qmax <= torch.iinfo(torch.int8).max + ) + is_int8_dtype = "int8" in dtype_str.lower() + return is_int8_range and is_int8_dtype + except (IndexError, ValueError, TypeError): + return False + + +def quantize_multiplier_aot(scale: float) -> tuple[int, int]: + if scale == 0.0: + return 0, 0 + mantissa, exponent = math.frexp(scale) + shift = -exponent + q_fixed = int(round(mantissa * (1 << 31))) + if q_fixed == (1 << 31): + q_fixed //= 2 + shift -= 1 + multiplier = max(-2147483648, min(2147483647, q_fixed)) + return multiplier, shift + + +def cleanup_erased_nodes(graph_module: torch.fx.GraphModule): + # Placeholder for any additional cleanup if needed + pass diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py new file mode 100644 index 00000000000..ca6d8b97795 --- /dev/null +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Set + +import executorch.backends.cortex_m.ops.operators # noqa +import torch + +from executorch.backends.cortex_m.passes.passes_utils import ( + extract_scalar_value, + quantize_multiplier_aot, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger("quant_op_fusion_pass") +logger.setLevel(logging.INFO) + + +class QuantizedOpFusionPass(ExportPass): + """ + Generic ExportPass that: + 1. Replaces certain ops with cortex_m variants based on qualifiers. + 2. Fuses patterns: dequantize_per_tensor -> [binary_op] -> quantize_per_tensor + into cortex_m.quantized_[op].default with AoT computed multipliers/shifts. + + + Supports multiple binary operations with backward compatibility for add. + """ + + # Generic operation mapping + SUPPORTED_OPS_MAPPING = { + exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, + # Future ops to be added here: + } + + def __init__(self): + super().__init__() + + def _get_dequant_targets(self) -> Set: + """Support both decomposed and cortex_m dequant targets for flexible pass ordering.""" + return { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cortex_m.dequantize_per_tensor.default, + } + + def _get_quant_targets(self) -> Set: + """Support both decomposed and cortex_m quant targets for flexible pass ordering.""" + return { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.cortex_m.quantize_per_tensor.default, + } + + def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: + """Check if node is a supported binary operation.""" + return node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + + def _is_dequant_node(self, node: torch.fx.Node) -> bool: + """Check if node is a dequantize operation.""" + return ( + hasattr(node, "op") + and node.op == "call_function" + and node.target in self._get_dequant_targets() + ) + + def _is_quant_node(self, node: torch.fx.Node) -> bool: + """Check if node is a quantize operation.""" + return ( + hasattr(node, "op") + and node.op == "call_function" + and node.target in self._get_quant_targets() + ) + + def _transfer_metadata( + self, + new_node: torch.fx.Node, + source_node: torch.fx.Node, + pass_name: str = "QuantizedOpFusionPass", + ) -> None: + """Metadata transfer with proper provenance tracking.""" + if hasattr(source_node, "meta") and source_node.meta: + new_node.meta = source_node.meta.copy() + + if "from_node" in new_node.meta: + from_node_list = new_node.meta.get("from_node", []).copy() + from_node_list.append( + {"source": source_node.name, "pass": pass_name, "op": "fuse"} + ) + new_node.meta["from_node"] = from_node_list + + # Copy essential fields + for field in ["tensor_meta", "stack_trace"]: + if field in source_node.meta: + new_node.meta[field] = source_node.meta[field] + + def _normalize_to_cortex_m_targets(self, graph_module: torch.fx.GraphModule) -> int: + """Convert decomposed targets to cortex_m equivalents for consistent handling.""" + target_mapping = { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.cortex_m.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.cortex_m.quantize_per_tensor.default, + } + + normalization_count = 0 + for node in list(graph_module.graph.nodes): + if node.op == "call_function" and node.target in target_mapping: + logger.info(f"Normalizing {node.target} to cortex_m equivalent") + node.target = target_mapping[node.target] + normalization_count += 1 + + return normalization_count + + def _fuse_quantized_binary_patterns( + self, graph_module: torch.fx.GraphModule + ) -> int: + """Generic fusion for quantized binary operation patterns.""" + fusion_count = 0 + nodes_to_erase = [] + + for node in list(graph_module.graph.nodes): + if not self._is_quant_node(node): + continue + + quantize_node = node + if not quantize_node.args: + continue + + binary_op_node = quantize_node.args[0] + if not self._is_supported_binary_op(binary_op_node): + continue + + if len(binary_op_node.args) < 2: + continue + + dequant_node1, dequant_node2 = binary_op_node.args[:2] + if not ( + self._is_dequant_node(dequant_node1) + and self._is_dequant_node(dequant_node2) + ): + continue + + # Get the target quantized operation + quantized_target = self.SUPPORTED_OPS_MAPPING[binary_op_node.target] + # Extract op name (e.g., 'Tensor' -> 'add') + op_name = str(binary_op_node.target).split(".")[-1] + logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") + + try: + # Extract values + int8_tensor1, scale1, zero_point1 = dequant_node1.args[:3] + int8_tensor2, scale2, zero_point2 = dequant_node2.args[:3] + output_scale, output_zero_point = quantize_node.args[1:3] + + # Convert to Python floats + scale1_val = extract_scalar_value(scale1) + scale2_val = extract_scalar_value(scale2) + output_scale_val = extract_scalar_value(output_scale) + zp1_val = int(extract_scalar_value(zero_point1)) + zp2_val = int(extract_scalar_value(zero_point2)) + output_zp_val = int(extract_scalar_value(output_zero_point)) + + # AoT COMPUTATION: Calculate multipliers and shifts + input1_mult, input1_shift = quantize_multiplier_aot( + scale1_val / output_scale_val + ) + input2_mult, input2_shift = quantize_multiplier_aot( + scale2_val / output_scale_val + ) + output_mult, output_shift = quantize_multiplier_aot( + 1.0 + ) # Output multiplier is 1 + + logger.info("AoT computed parameters:") + logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") + logger.info(f" Input2: mult={input2_mult}, shift={input2_shift}") + logger.info(f" Output: mult={output_mult}, shift={output_shift}") + + with graph_module.graph.inserting_after(quantize_node): + fused = graph_module.graph.create_node( + "call_function", + target=quantized_target, + args=( + int8_tensor1, + zp1_val, + input1_mult, + input1_shift, + int8_tensor2, + zp2_val, + input2_mult, + input2_shift, + output_zp_val, + output_mult, + output_shift, + ), + kwargs={}, + ) + + # metadata transfer + self._transfer_metadata(fused, quantize_node) + + logger.info(f"✅ Created fused quantized_{op_name} node: {fused}") + + # Replace all uses + quantize_node.replace_all_uses_with(fused) + binary_op_node.replace_all_uses_with(fused) + dequant_node1.replace_all_uses_with(fused) + dequant_node2.replace_all_uses_with(fused) + + nodes_to_erase.extend( + [quantize_node, binary_op_node, dequant_node1, dequant_node2] + ) + fusion_count += 1 + logger.info(f"Pattern fused, total so far: {fusion_count}") + + except Exception as e: + logger.info(f"❌ Error during AoT computation: {e}") + logger.info(" Skipping fusion for this pattern") + continue + + for old_node in reversed(nodes_to_erase): + if old_node in graph_module.graph.nodes and len(old_node.users) == 0: + logger.info(f"🗑️ Erasing node: {old_node}") + graph_module.graph.erase_node(old_node) + + return fusion_count + + def call(self, graph_module: torch.fx.GraphModule): + logger.info("QuantizedOpFusionPass.call() started") + + # Normalize targets for flexible pass ordering + normalization_count = self._normalize_to_cortex_m_targets(graph_module) + + # Generic fusion for supported binary operations + fusion_count = self._fuse_quantized_binary_patterns(graph_module) + + total_changes = normalization_count + fusion_count + logger.info(f"Total changes: {total_changes}") + + if total_changes > 0: + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + graph_module.recompile() + + logger.debug("=== AFTER FUSION: All nodes in the graph ===") + for i, node in enumerate(graph_module.graph.nodes): + logger.debug(f"Node {i}: op={node.op}, target={node.target}") + if "quantized_" in str(node.target) and "add" in str(node.target): + logger.debug(" ⭐ FOUND QUANTIZED BINARY OP NODE! ⭐") + logger.debug("=== END DEBUG ===") + + return PassResult(graph_module, total_changes > 0) diff --git a/backends/cortex_m/test/TARGETS b/backends/cortex_m/test/TARGETS index d381011b648..b7a04f3efab 100644 --- a/backends/cortex_m/test/TARGETS +++ b/backends/cortex_m/test/TARGETS @@ -10,14 +10,18 @@ load("targets.bzl", "define_common_targets") oncall("executorch") python_unittest( - name = "test_replace_quant_nodes", - srcs = ["test_replace_quant_nodes.py"], - deps = [ + name="test_replace_quant_nodes", + srcs=[ + "test_helpers_passes_utils.py", + "test_replace_quant_nodes.py", + "test_quantize_op_fusion_pass.py", + ], + deps=[ "//pytorch/ao:torchao", # @manual - "//caffe2:torch", + "//caffe2:torch", "//executorch/backends/cortex_m/passes:replace_quant_nodes_pass", "//executorch/backends/cortex_m/ops:ops", ], -) +) define_common_targets() diff --git a/backends/cortex_m/test/test_helpers_passes_utils.py b/backends/cortex_m/test/test_helpers_passes_utils.py new file mode 100644 index 00000000000..ccaf8476867 --- /dev/null +++ b/backends/cortex_m/test/test_helpers_passes_utils.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.fx import GraphModule +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + + +class AddQuantizer(Quantizer): + def __init__(self): + super().__init__() + + @staticmethod + def _get_qspec(): + return QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), + ) + + @staticmethod + def _get_qconfig(): + qspec = AddQuantizer._get_qspec() + return QuantizationConfig( + input_activation=qspec, + output_activation=qspec, + ) + + def annotate(self, model: GraphModule): + config = self._get_qconfig() + annotated_partitions = [] + + for node in model.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + continue + + if Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated: + continue + + input_qspec_map = { + node.args[0]: config.input_activation, + node.args[1]: config.input_activation, + } + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=config.output_activation, + _annotated=True, + ) + annotated_partitions.append([node]) + + return annotated_partitions + + def validate(self, model: GraphModule) -> None: + pass + + +def check_count( + graph_module: GraphModule, op: torch.fx.node.Target, expected_count: int +): + actual_count = sum( + 1 + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == op + ) + + assert ( + actual_count == expected_count + ), f"Expected {expected_count} {op} nodes, got {actual_count}" + + +def get_node_args(graph_module: GraphModule, op: torch.fx.node.Target): + """Helper to get arguments of specific operator nodes""" + nodes = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == op + ] + return [node.args for node in nodes] diff --git a/backends/cortex_m/test/test_quantize_op_fusion_pass.py b/backends/cortex_m/test/test_quantize_op_fusion_pass.py new file mode 100644 index 00000000000..3b3ec40d543 --- /dev/null +++ b/backends/cortex_m/test/test_quantize_op_fusion_pass.py @@ -0,0 +1,369 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import executorch +import executorch.backends.cortex_m.ops.operators # noqa + +import torch + +from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( + QuantizedOpFusionPass, +) +from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( + ReplaceQuantNodesPass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from test_helpers_passes_utils import AddQuantizer, check_count, get_node_args +from torch.export import export, export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class TestQuantizedOpFusionPass(unittest.TestCase): + """ + Test suite for the QuantizedOpFusionPass which fuses dequantize->add->quantize patterns + into a single quantized_add operation with AoT-computed parameters. + """ + + def setUp(self): + """Set up common test fixtures""" + self.example_inputs = (torch.randn(4, 8), torch.randn(4, 8)) + + def _prepare_quantized_model(self, model_class): + """Helper to prepare a quantized model for testing""" + model = model_class() + + # Export and quantize + exported_model = export_for_training( + model.eval(), self.example_inputs, strict=True + ).module() + prepared_model = prepare_pt2e(exported_model, AddQuantizer()) + quantized_model = convert_pt2e(prepared_model) + + # Export to EXIR Edge + exported = export(quantized_model, self.example_inputs, strict=True) + edge_program = executorch.exir.to_edge( + exported, + compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), + ) + return edge_program + + def _apply_passes(self, edge_program): + """Apply both ReplaceQuantNodesPass and QuantizedOpFusionPass""" + passes = [QuantizedOpFusionPass(), ReplaceQuantNodesPass()] + final_program = edge_program.transform(passes) + return final_program + + def test_single_add_fusion(self): + """Single add with full Q/DQ pattern should fuse into one quantized_add node""" + + class SingleAddModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + # Prepare model + edge_program = self._prepare_quantized_model(SingleAddModel) + edge_graph = edge_program.exported_program().graph_module + + # Get reference output + reference_output = edge_graph(*self.example_inputs) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Verify fusion occurred + check_count( + transformed_graph, + exir_ops.edge.cortex_m.quantized_add.default, + 1, # Should have exactly 1 fused quantized_add + ) + + # Verify the following + # Before fusion: + # x --> quantize_per_tensor --> dequantize_per_tensor --> add --> quantize_per_tensor --> + # dequantize_per_tensor --> output y --> quantize_per_tensor --> dequantize_per_tensor --^ + # After fusion: + # x --> quantize_per_tensor --> quantized_add --> dequantize_per_tensor --> output + # y --> quantize_per_tensor --^ + check_count( + transformed_graph, exir_ops.edge.cortex_m.quantize_per_tensor.default, 2 + ) + check_count( + transformed_graph, exir_ops.edge.cortex_m.dequantize_per_tensor.default, 1 + ) + check_count(transformed_graph, exir_ops.edge.cortex_m.quantized_add.default, 1) + + # Verify numerical equivalence + fused_output = transformed_graph(*self.example_inputs) + torch.testing.assert_close(reference_output, fused_output, rtol=1e-3, atol=1e-3) + + def test_multiple_add_fusion(self): + """Multiple independent adds should create multiple quantized_add nodes""" + + class MultipleAddModel(torch.nn.Module): + def forward(self, x, y): + z1 = x + y # First add + z2 = x + z1 # Second add + return z2 + + # Prepare model + edge_program = self._prepare_quantized_model(MultipleAddModel) + edge_graph = edge_program.exported_program().graph_module + + # Get reference output + reference_output = edge_graph(*self.example_inputs) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Verify multiple fusions occurred + check_count( + transformed_graph, + exir_ops.edge.cortex_m.quantized_add.default, + 2, # Should have 2 fused quantized_add nodes + ) + + # Verify numerical equivalence + fused_output = transformed_graph(*self.example_inputs) + torch.testing.assert_close(reference_output, fused_output, rtol=1e-3, atol=1e-3) + + def test_no_fusion_without_pattern(self): + """Add without proper Q/DQ pattern should not be fused""" + + class NonQuantizedAddModel(torch.nn.Module): + def forward(self, x, y): + # This will have add but not the full Q/DQ pattern after quantization + return torch.relu(x + y) # ReLU breaks the pattern + + # For this test, we'll create a model that doesn't have the complete pattern + # We need to manually construct a graph that has add without full Q/DQ + + model = NonQuantizedAddModel() + exported = export(model, self.example_inputs, strict=True) + edge_program = executorch.exir.to_edge( + exported, + compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), + ) + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Verify no fusion occurred + check_count( + transformed_graph, + exir_ops.edge.cortex_m.quantized_add.default, + 0, # Should have no fused quantized_add nodes + ) + + def test_precomputed_parameters(self): + """Fused node should have precomputed multipliers/shifts instead of scales""" + + class SingleAddModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + # Prepare model + edge_program = self._prepare_quantized_model(SingleAddModel) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Get arguments of the fused quantized_add node + quantized_add_args = get_node_args( + transformed_graph, exir_ops.edge.cortex_m.quantized_add.default + ) + + # Should have exactly one quantized_add node + self.assertEqual(len(quantized_add_args), 1) + args = quantized_add_args[0] + + # Verify argument structure: (tensor1, zp1, mult1, shift1, tensor2, zp2, mult2, shift2, out_zp, out_mult, out_shift) + self.assertEqual(len(args), 11, "quantized_add should have 11 arguments") + + # Check that multipliers and shifts are integers (not floats/scales) + # args[2], args[3] = input1 multiplier, shift + # args[6], args[7] = input2 multiplier, shift + # args[9], args[10] = output multiplier, shift + for i in [2, 3, 6, 7, 9, 10]: # multiplier and shift positions + self.assertIsInstance( + args[i], int, f"Argument {i} should be an integer (precomputed)" + ) + + def test_mixed_fusion_pattern(self): + """Mixed pattern (some fusable, some not) should partially fuse""" + + class MixedModel(torch.nn.Module): + def forward(self, x, y): + z1 = x + y # This should fuse + z2 = torch.relu(z1) # ReLU breaks next fusion + z3 = z2 + x # This won't have full Q/DQ pattern + return z3 + + # Prepare model + edge_program = self._prepare_quantized_model(MixedModel) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Should have partial fusion (at least 1, but not necessarily all adds) + quantized_add_count = sum( + 1 + for node in transformed_graph.graph.nodes + if node.op == "call_function" + and node.target == exir_ops.edge.cortex_m.quantized_add.default + ) + + self.assertGreaterEqual( + quantized_add_count, 1, "Should have at least 1 fused operation" + ) + + def test_different_tensor_shapes(self): + """Different tensor shapes should still fuse correctly""" + + class SingleAddModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + # Test with different input shapes + for shape in [(2, 3), (10, 20, 30), (1,)]: + with self.subTest(shape=shape): + inputs = (torch.randn(shape), torch.randn(shape)) + + model = SingleAddModel() + exported_model = export_for_training( + model.eval(), inputs, strict=True + ).module() + prepared_model = prepare_pt2e(exported_model, AddQuantizer()) + quantized_model = convert_pt2e(prepared_model) + + exported = export(quantized_model, inputs, strict=True) + edge_program = executorch.exir.to_edge( + exported, + compile_config=executorch.exir.EdgeCompileConfig( + _check_ir_validity=False + ), + ) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Verify fusion occurred regardless of shape + check_count( + transformed_graph, exir_ops.edge.cortex_m.quantized_add.default, 1 + ) + + def test_aot_parameter_computation_accuracy(self): + """Verify that AoT-computed parameters match runtime computation""" + + class SingleAddModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + # Prepare model + edge_program = self._prepare_quantized_model(SingleAddModel) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + transformed_graph = transformed_program.exported_program().graph_module + + # Get the fused node arguments + quantized_add_args = get_node_args( + transformed_graph, exir_ops.edge.cortex_m.quantized_add.default + )[0] + + # Extract the computed multipliers and shifts + input1_mult, input1_shift = quantized_add_args[2], quantized_add_args[3] + input2_mult, input2_shift = quantized_add_args[6], quantized_add_args[7] + output_mult, output_shift = quantized_add_args[9], quantized_add_args[10] + + # Verify they are reasonable values + # Multipliers should be in int32 range + self.assertTrue(-(2**31) <= input1_mult < 2**31) + self.assertTrue(-(2**31) <= input2_mult < 2**31) + self.assertTrue(-(2**31) <= output_mult < 2**31) + + # Shifts should be reasonable (typically -31 to 31) + self.assertTrue(-50 <= input1_shift <= 50) + self.assertTrue(-50 <= input2_shift <= 50) + self.assertTrue(-50 <= output_shift <= 50) + + # Output multiplier should be close to 2^30 (for 1.0 scale) + self.assertAlmostEqual(output_mult, 2**30, delta=1000) + self.assertEqual(output_shift, -1) + + def test_executorch_program_generation(self): + """Verify ExecuTorch program generation with fused ops""" + + class SingleAddModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + # Prepare model + edge_program = self._prepare_quantized_model(SingleAddModel) + + # Apply passes + transformed_program = self._apply_passes(edge_program) + + # Generate ExecutorTorch program + executorch_program = transformed_program.to_executorch() + + # Verify the program contains the expected fused operator + operator_names = [ + op.name + for op in executorch_program.executorch_program.execution_plan[0].operators + ] + + self.assertIn("cortex_m::quantized_add", operator_names) + self.assertIn("cortex_m::quantize_per_tensor", operator_names) + self.assertIn("cortex_m::dequantize_per_tensor", operator_names) + # quantize_per_tensor --> dequantize_per_tensor --> add --> quantize_per_tensor --> dequantize_per_tensor + # (input quant) (dequant) (fp32 add) (re-quant) (dequant) + # ↓ + # Fusion Pass detects pattern: + # dequantize_per_tensor --> quantized_add (Fused node) --> quantize_per_tensor + + def test_broadcastable_shapes(self): + """Verify that broadcastable shapes are supported""" + + class BroadcastAddModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + # input broadcastable shapes + inputs = (torch.randn(4, 1), torch.randn(4, 8)) + print(inputs) + + # Prepare quantized model + edge_program = self._prepare_quantized_model(BroadcastAddModel) + + # Get unfused output + unfused_graph = edge_program.exported_program().graph_module + unfused_output = unfused_graph(*inputs) + if isinstance(unfused_output, tuple): + unfused_output = unfused_output[0] + + # Apply fusion pass + fused_program = self._apply_passes(edge_program) + fused_graph = fused_program.exported_program().graph_module + fused_output = fused_graph(*inputs) + if isinstance(fused_output, tuple): + fused_output = fused_output[0] + + # Check fusion occurred + check_count(fused_graph, exir_ops.edge.cortex_m.quantized_add.default, 1) + + # Compare fused vs unfused (both quantized) + torch.testing.assert_close(fused_output, unfused_output, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cortex_m/test/test_replace_quant_nodes.py b/backends/cortex_m/test/test_replace_quant_nodes.py index 3853f7b5535..fb489956b6d 100644 --- a/backends/cortex_m/test/test_replace_quant_nodes.py +++ b/backends/cortex_m/test/test_replace_quant_nodes.py @@ -6,8 +6,6 @@ # LICENSE file in the root directory of this source tree. import unittest -from dataclasses import dataclass -from typing import Optional import executorch.backends.cortex_m.ops.operators # noqa @@ -19,91 +17,12 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.program._program import _transform -from torch.export import export -from torch.fx import GraphModule -from torchao.quantization.pt2e.observer import HistogramObserver -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torchao.quantization.pt2e.quantizer import ( - QuantizationAnnotation, - QuantizationSpec, - Quantizer, -) -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY - - -@dataclass(eq=True, frozen=True) -class QuantizationConfig: - input_activation: Optional[QuantizationSpec] - output_activation: Optional[QuantizationSpec] - - -class AddQuantizer(Quantizer): - def __init__(self): - super().__init__() - - @staticmethod - def _get_qspec(): - return QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_symmetric, - is_dynamic=False, - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), - ) - - @staticmethod - def _get_qconfig(): - qspec = AddQuantizer._get_qspec() - return QuantizationConfig( - input_activation=qspec, - output_activation=qspec, - ) - def annotate(self, model: GraphModule): - config = self._get_qconfig() - annotated_partitions = [] +from test_helpers_passes_utils import AddQuantizer, check_count - for node in model.graph.nodes: - if node.op != "call_function" or node.target not in [ - torch.ops.aten.add.Tensor, - torch.ops.aten.add_.Tensor, - ]: - continue - - if Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated: - continue - - input_qspec_map = { - node.args[0]: config.input_activation, - node.args[1]: config.input_activation, - } - - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=config.output_activation, - _annotated=True, - ) - annotated_partitions.append([node]) - - return annotated_partitions - - def validate(self, model: GraphModule) -> None: - pass - - -def check_count( - graph_module: GraphModule, op: torch.fx.node.Target, expected_count: int -): - actual_count = sum( - 1 - for node in graph_module.graph.nodes - if node.op == "call_function" and node.target == op - ) +from torch.export import export - assert ( - actual_count == expected_count - ), f"Expected {expected_count} {op} nodes, got {actual_count}" +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class TestReplaceQuantOps(unittest.TestCase): @@ -207,3 +126,7 @@ def forward(self, x): "cortex_m::quantize_per_tensor", "cortex_m::dequantize_per_tensor", ], f"Unexpected op {op.name}" + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index ec5f63e0590..72e91fc640d 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -42,9 +42,14 @@ from executorch.backends.arm.vgf_partitioner import VgfPartitioner # To use Cortex-M backend +from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( + QuantizedOpFusionPass, +) + from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( ReplaceQuantNodesPass, ) + from executorch.devtools import generate_etrecord from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite @@ -790,7 +795,11 @@ def transform_for_cortex_m_backend(edge): # Let's make sure we are using optimized Cortex M backend # NB: If we can't find and replace ops those are expected to be replaced, # bad things will happen at runtime, like "missing operator" errors! - edge = edge.transform([ReplaceQuantNodesPass()]) + # Instantiate the pass + replace_quant_pass = ReplaceQuantNodesPass() + quantized_op_fusion_pass = QuantizedOpFusionPass() + edge = edge.transform([replace_quant_pass, quantized_op_fusion_pass]) + return edge diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 9d576d97c5e..dcadafb01d0 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -51,7 +51,7 @@ function help() { echo " --no_delegate Do not delegate the model (can't override builtin models)" echo " --no_quantize Do not quantize the model (can't override builtin models)" echo " --portable_kernels= TO BE DEPRECATED: Alias to select_ops_list." - echo " --select_ops_list= Comma separated list of portable (non delagated) kernels to include Default: ${select_ops_list}" + echo " --select_ops_list= Comma separated list of portable (non delagated) kernels to include Default: ${select_ops_list}" echo " NOTE: This is used when select_ops_model is not possible to use, e.g. for semihosting or bundleio." echo " See https://docs.pytorch.org/executorch/stable/kernel-library-selective-build.html for more information." echo " --target= Target to build and run for Default: ${target}" @@ -104,7 +104,7 @@ ethos_u_scratch_dir=$(realpath ${ethos_u_scratch_dir}) setup_path_script=${ethos_u_scratch_dir}/setup_path.sh if [[ ${toolchain} == "arm-none-eabi-gcc" ]]; then toolchain_cmake=${et_root_dir}/examples/arm/ethos-u-setup/${toolchain}.cmake -elif [[ ${toolchain} == "arm-zephyr-eabi-gcc" ]]; then +elif [[ ${toolchain} == "arm-zephyr-eabi-gcc" ]]; then toolchain_cmake=${et_root_dir}/examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake else echo "Error: Invalid toolchain selection, provided: ${toolchain}" @@ -202,13 +202,13 @@ backends/arm/scripts/build_executorch.sh --et_build_root="${et_build_root}" --bu if [[ -z "$model_name" ]]; then # the test models run, and whether to delegate test_model=( - "softmax" # 0 - "add" # 1 - "add3" # 2 - "qadd" # 3 - "qadd2" # 4 - "qops" # 5 - "mv2" # 6 + "softmax" # 0 + "add" # 1 + "add3" # 2 + "qadd" # 3 + "qadd2" # 4 + "qops" # 5 + "mv2" # 6 ) model_compiler_flags=( "" # 0 softmax