diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 1567b8b5e1c..bd12c7d8183 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -12,7 +12,7 @@ if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 17) endif() -# Source root directory for executorch. +# Source root directory for executorch if(NOT EXECUTORCH_ROOT) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) endif() @@ -21,70 +21,90 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) include(FetchContent) -# CMSIS-NN version to download +# CMSIS-NN configuration with dynamic path detection set(CMSIS_NN_VERSION - "v4.1.0" + "v7.0.0" CACHE STRING "CMSIS-NN version to download" ) - -# Declare CMSIS-NN as a FetchContent project -FetchContent_Declare( - cmsis_nn - GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git - GIT_TAG ${CMSIS_NN_VERSION} +set(CMSIS_NN_LOCAL_PATH + "" + CACHE PATH "Path to existing local CMSIS-NN installation" ) -# Download and make CMSIS-NN available -FetchContent_MakeAvailable(cmsis_nn) +# Try to find existing / local CMSIS-NN installation. This is useful for +# debugging and testing with local changes. This is not common, as the CMSIS-NN +# library is downloaded via FetchContent in the default/regular case. +if(CMSIS_NN_LOCAL_PATH AND EXISTS "${CMSIS_NN_LOCAL_PATH}") + message(STATUS "Using CMSIS-NN from specified path: ${CMSIS_NN_LOCAL_PATH}") + add_subdirectory(${CMSIS_NN_LOCAL_PATH} cmsis_nn_build) +else() + # Use FetchContent with automatic fallback + message(STATUS "Using CMSIS-NN via FetchContent") + + FetchContent_Declare( + cmsis_nn + GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git + GIT_TAG ${CMSIS_NN_VERSION} + GIT_SHALLOW TRUE + ) + + FetchContent_GetProperties(cmsis_nn) + if(NOT cmsis_nn_POPULATED) + FetchContent_Populate(cmsis_nn) + add_subdirectory(${cmsis_nn_SOURCE_DIR} ${cmsis_nn_BINARY_DIR}) + endif() +endif() -# Print paths for debugging -message(STATUS "CMSIS-NN source dir: ${cmsis_nn_SOURCE_DIR}") -message(STATUS "CMSIS-NN binary dir: ${cmsis_nn_BINARY_DIR}") +# Add MVEI define to cmsis-nn target +if(TARGET cmsis-nn) + target_compile_definitions(cmsis-nn PUBLIC ARM_MATH_MVEI=1) + get_target_property(CMSIS_NN_INCLUDES cmsis-nn INTERFACE_INCLUDE_DIRECTORIES) + message(STATUS "CMSIS-NN include dirs: ${CMSIS_NN_INCLUDES}") +else() + message( + FATAL_ERROR + "CMSIS-NN target not found. Check your CMSIS_NN_LOCAL_PATH or network connection." + ) +endif() # Cortex-M ops kernel sources set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp ) -# Generate C++ bindings to register kernels into Executorch (for runtime) +# Generate C++ bindings to register kernels into Executorch set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml) gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}") - generate_bindings_for_kernels( LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}" ) -message("Generated files ${gen_command_sources}") -# Build a library for cortex_m_kernels +# Build library for cortex_m_kernels add_library(cortex_m_kernels ${_cortex_m_kernels__srcs}) -target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options}) -# Include directories for cortex_m_kernels -target_include_directories( +# Use PRIVATE for implementation dependencies to avoid INTERFACE pollution +target_link_libraries( cortex_m_kernels - PRIVATE ${EXECUTORCH_ROOT}/.. - ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 - ${cmsis_nn_SOURCE_DIR}/Include + PRIVATE cmsis-nn + PRIVATE executorch ) -# Link directly to the CMSIS-NN static library file -target_link_libraries( - cortex_m_kernels PUBLIC ${cmsis_nn_BINARY_DIR}/libcmsis-nn.a executorch +# Include directories for cortex_m_kernels +target_include_directories( + cortex_m_kernels PRIVATE ${EXECUTORCH_ROOT}/.. + ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 ) -# Add dependency to ensure CMSIS-NN builds before we try to link. Use the actual -# CMSIS-NN target name (usually 'cmsis-nn') -add_dependencies(cortex_m_kernels cmsis-nn) - # cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime gen_operators_lib( LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch ) install( - TARGETS cortex_m_kernels cortex_m_ops_lib + TARGETS cortex_m_kernels cortex_m_ops_lib cmsis-nn EXPORT ExecuTorchTargets DESTINATION lib PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/ diff --git a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h new file mode 100644 index 00000000000..4b9fdaebdf7 --- /dev/null +++ b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "cortex_m_ops_common.h" +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +// During AOT phase, quantized_linear_fusion_pass allocates total buffer +// and passes in as 'Tensor'. (Total buffer = 8-byte header + x bytes) +// ┌─────────────────┬─────────────────────────────────────┐ +// │ KernelSum Header│ CMSIS Workspace │ +// │ (8 bytes) │ (x bytes) │ +// └─────────────────┴─────────────────────────────────────┘ +// │ │ +// │ └─> Passed to CMSIS API +// │ +// └─> State for kernel sum + +// C++ Runtime: +// ┌─────────────────┬─────────────────────────────────────┐ +// │ KernelSum Header│ CMSIS Workspace │ +// │ (8 bytes) │ (x bytes) │ +// └─────────────────┴─────────────────────────────────────┘ +// ^ ^ +// │ │ +// scratch_ptr cmsis_workspace_ptr +// │ │ +// ▼ ▼ +// arm_vector_sum_s8() writes kernel sums (with bias if avail): +// [sum₀+bias₀][sum₁+bias₁][sum₂+bias₂]...[sum_{n-1}+bias_{n-1}] +// (n * 4-byte int32_t values = x bytes) +// +// - n = out_features (number of output features) +// - x = n * 4 bytes (total CMSIS buffer size) +// - Total buffer = 8 + x bytes + +class CMSISScratchBufferContext final { + public: + CMSISScratchBufferContext( + Tensor& scratch_buffer, + const Tensor& weights, + const Tensor& weight_zero_point, + const torch::executor::optional& bias) + : scratch_ptr_(scratch_buffer.mutable_data_ptr()), + total_size_(scratch_buffer.size(0)), + base_ptr_(reinterpret_cast(scratch_ptr_)), + in_features_(weights.size(1)), + out_features_(weights.size(0)), + is_per_channel_(weight_zero_point.numel() > 1), + weight_data_offset_(calculate_offset(weights.const_data_ptr())), + weight_zp_data_offset_( + calculate_offset(weight_zero_point.const_data_ptr())), + bias_data_offset_( + bias.has_value() + ? calculate_offset(bias.value().const_data_ptr()) + : 0), + header_(reinterpret_cast(scratch_ptr_)), + cmsis_workspace_ptr_(scratch_ptr_ + KERNEL_SUM_HEADER_SIZE) { + cmsis_nn_dims filter_dims = {in_features_, 1, 1, out_features_}; + validate_size(filter_dims); + } + + cmsis_nn_context get_cmsis_ctx() const { + cmsis_nn_context ctx; + ET_CHECK_MSG( + reinterpret_cast(cmsis_workspace_ptr_) % 4 == 0, + "CMSIS workspace not 4-byte aligned"); + ctx.buf = cmsis_workspace_ptr_; + ctx.size = get_cmsis_workspace_size(); + return ctx; + } + + bool is_kernel_sum_updated() const { + return header_->updated; + } + + void compute_kernel_sums_if_needed() { + if (!header_->updated) { + arm_vector_sum_s8( + reinterpret_cast(cmsis_workspace_ptr_), + in_features_, + out_features_, + get_weight_data(), + get_weight_zp_data()[0], + 0, + get_bias_data()); + header_->updated = true; + ET_LOG( + Info, + "Computed kernel sums. [required_bytes : %d]", + header_->required_size); + } + } + + const int8_t* get_weight_data() const { + return reinterpret_cast(base_ptr_ + weight_data_offset_); + } + + const int32_t* get_weight_zp_data() const { + return reinterpret_cast(base_ptr_ + weight_zp_data_offset_); + } + + const int32_t* get_bias_data() const { + return bias_data_offset_ == 0 + ? nullptr + : reinterpret_cast(base_ptr_ + bias_data_offset_); + } + + bool is_per_channel_quant() const { + return is_per_channel_; + } + int32_t get_in_features() const { + return in_features_; + } + int32_t get_out_features() const { + return out_features_; + } + + private: + static constexpr size_t KERNEL_SUM_HEADER_SIZE = 8; + + // Header for kernel sum computation state only + struct KernelSumHeader { + bool updated = false; + int32_t required_size = 0; + }; + static_assert( + sizeof(KernelSumHeader) == KERNEL_SUM_HEADER_SIZE, + "KernelSumHeader must be exactly 8 bytes"); + + int8_t* scratch_ptr_; + size_t total_size_; + uint8_t* base_ptr_; + + // Context members + const int32_t in_features_; + const int32_t out_features_; + const bool is_per_channel_; + const uint32_t weight_data_offset_; + const uint32_t weight_zp_data_offset_; + const uint32_t bias_data_offset_; + + KernelSumHeader* header_; + int8_t* cmsis_workspace_ptr_; + + uint32_t calculate_offset(const void* ptr) const { + if (ptr == nullptr) + return 0; + + const uint8_t* ptr_bytes = reinterpret_cast(ptr); + ET_CHECK_MSG(ptr_bytes >= base_ptr_, "Pointer is before base address"); + + const std::ptrdiff_t offset = ptr_bytes - base_ptr_; + ET_CHECK_MSG( + offset >= 0 && offset <= UINT32_MAX, "Offset out of valid range"); + return static_cast(offset); + } + + size_t get_cmsis_workspace_size() const { + return total_size_ - KERNEL_SUM_HEADER_SIZE; + } + + void validate_size(const cmsis_nn_dims& filter_dims) const { + header_->required_size = + arm_fully_connected_s8_get_buffer_size(&filter_dims); + + ET_CHECK_MSG( + get_cmsis_workspace_size() >= + static_cast(header_->required_size), + "Scratch buffer size %zu insufficient for required size %d", + get_cmsis_workspace_size(), + header_->required_size); + } +}; + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index 5ef2d9d4bf9..eaa7027e46c 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -22,6 +22,10 @@ using ScalarType = executorch::aten::ScalarType; using Scalar = torch::executor::Scalar; using Error = executorch::runtime::Error; +// From arm_nn_math_types.h +#define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL)) +#define ARM_NN_Q31_MIN ((int32_t)(0x80000000L)) + // Basic tensor type / layout validation and dimension order checking inline void validate_cmsis_nn_tensor_requirements( const Tensor& input1, @@ -32,16 +36,19 @@ inline void validate_cmsis_nn_tensor_requirements( // Basic dtype validation ET_CHECK_MSG( input1.scalar_type() == expected_dtype, - "Input1 dtype must be %hhd", - expected_dtype); + "Input1 dtype must be %hhd, got %hhd", + expected_dtype, + input1.scalar_type()); ET_CHECK_MSG( input2.scalar_type() == expected_dtype, - "Input2 dtype must be %hhd", - expected_dtype); + "Input2 dtype must be %hhd, got %hhd", + expected_dtype, + input2.scalar_type()); ET_CHECK_MSG( output.scalar_type() == expected_dtype, - "Output dtype must be %hhd", - expected_dtype); + "Output dtype must be %hhd, got %hhd", + expected_dtype, + output.scalar_type()); // Dim order consistency ET_CHECK_MSG( @@ -114,6 +121,33 @@ inline void validate_quantization_params( "Single quant Output"); } +// Refer to CMSIS-NN 'arm_nn_requantize' implementation for details: +// https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625 +// multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX} +// shift : Range {-31, 30} +inline bool validate_per_channel_quant_params( + const int32_t* multipliers, + const int32_t* shifts, + int num_channels) { + for (int i = 0; i < num_channels; ++i) { + // Multiplier: {ARM_NN_Q31_MIN + 1, ARM_NN_Q31_MAX} + if (multipliers[i] <= ARM_NN_Q31_MIN || multipliers[i] > ARM_NN_Q31_MAX) { + ET_LOG( + Error, + "weight_multiplier[%d] out of CMSIS-NN range: %d", + i, + multipliers[i]); + return false; + } + // Shift: {-31, 30} for arm_nn_requantize + if (shifts[i] < -31 || shifts[i] > 30) { + ET_LOG(Error, "weight_shift[%d] out of range: %d", i, shifts[i]); + return false; + } + } + return true; +} + inline Error resize_to_broadcast_target_size( const Tensor& input1, const Tensor& input2, diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp new file mode 100644 index 00000000000..d1ccb6d0d45 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cmsis_scratch_buffer_context.h" +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& quantized_linear_out( + KernelRuntimeContext& context, + const Tensor& input, + const Scalar& input_zero_point, + const Scalar& input_multiplier, + const Scalar& input_shift, + const Tensor& weights, + const Tensor& weight_zero_point, + const Tensor& weight_multiplier, + const Tensor& weight_shift, + const torch::executor::optional& bias, + const Tensor& bias_multiplier, + const Tensor& bias_shift, + const Tensor& scratch_buffer, + const Scalar& output_zero_point, + const Scalar& in_features, + const Scalar& out_features, + Tensor& out) { + ET_LOG(Info, "quantized_linear_out: called"); + validate_cmsis_nn_tensor_requirements(input, weights, out); + + ET_CHECK_MSG( + scratch_buffer.scalar_type() == ScalarType::Char, + "Scratch buffer must be int8"); + + const int32_t batch_size = input.size(0); + const int32_t in_feat = static_cast(in_features.to()); + const int32_t out_feat = static_cast(out_features.to()); + const int32_t input_zp = static_cast(input_zero_point.to()); + const int32_t output_zp = + static_cast(output_zero_point.to()); + const bool is_per_channel = (weight_zero_point.numel() > 1); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weights.const_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + int8_t* output_data = out.mutable_data_ptr(); + const int32_t* weight_zp_data = weight_zero_point.const_data_ptr(); + const int32_t* weight_mult_data = weight_multiplier.const_data_ptr(); + const int32_t* weight_shift_data = weight_shift.const_data_ptr(); + + if (!validate_per_channel_quant_params( + weight_mult_data, weight_shift_data, out_feat)) { + context.fail(Error::InvalidArgument); + return out; + } + + // Initialize scratch buffer context (validates early) + CMSISScratchBufferContext scratch_ctx( + const_cast(scratch_buffer), weights, weight_zero_point, bias); + + scratch_ctx.compute_kernel_sums_if_needed(); + cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx(); + + // Setup CMSIS-NN parameters + cmsis_nn_fc_params fc_params; + fc_params.input_offset = -input_zp; + fc_params.output_offset = output_zp; + fc_params.activation.min = std::numeric_limits::min(); + fc_params.activation.max = std::numeric_limits::max(); + + cmsis_nn_dims input_dims = {1, 1, 1, in_feat}; + cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat}; + cmsis_nn_dims bias_dims = {1, 1, 1, out_feat}; + cmsis_nn_dims output_dims = {1, 1, 1, out_feat}; + + arm_cmsis_nn_status status; + for (int32_t b = 0; b < batch_size; b++) { + const int8_t* batch_input = input_data + b * in_feat; + int8_t* batch_output = output_data + b * out_feat; + + ET_CHECK_MSG( + batch_input != nullptr && weight_data != nullptr, + "Null input pointers"); + ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions"); + + if (is_per_channel) { + cmsis_nn_per_channel_quant_params per_channel_quant_params; + per_channel_quant_params.multiplier = + const_cast(weight_mult_data); + per_channel_quant_params.shift = const_cast(weight_shift_data); + + status = arm_fully_connected_per_channel_s8( + &ctx, + &fc_params, + &per_channel_quant_params, + &input_dims, + batch_input, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + batch_output); + } else { + fc_params.filter_offset = -weight_zp_data[0]; + cmsis_nn_per_tensor_quant_params per_tensor_quant_params; + per_tensor_quant_params.multiplier = weight_mult_data[0]; + per_tensor_quant_params.shift = weight_shift_data[0]; + + status = arm_fully_connected_s8( + &ctx, + &fc_params, + &per_tensor_quant_params, + &input_dims, + batch_input, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + batch_output); + } + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_linear_out: CMSIS-NN failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + } + return out; +} + +// Functional variant (stub, not used at runtime) +Tensor quantized_linear( + KernelRuntimeContext& context, + const Tensor& input, + const Scalar& input_zero_point, + const Scalar& input_multiplier, + const Scalar& input_shift, + const Tensor& weights, + const Tensor& weight_zero_point, + const Tensor& weight_multiplier, + const Tensor& weight_shift, + const torch::executor::optional& bias, + const Tensor& bias_multiplier, + const Tensor& bias_shift, + const Tensor& scratch_buffer, + const Scalar& output_zero_point, + const Scalar& in_features, + const Scalar& out_features) { + ET_LOG(Info, "quantized_linear: called"); + assert(false); + return const_cast(input); +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 926dcd85e4b..d642531e950 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -223,3 +223,216 @@ def quantized_add_out_impl( out.copy_(result_quantized) return out + + +# =================================================================== +# QUANTIZED LINEAR OPERATION DEFINITION +# =================================================================== + + +def _check_per_tensor_or_per_channel(param: torch.Tensor, out_channels: int, name: str): + assert param.numel() in [ + 1, + out_channels, + ], f"{name} must be per-tensor (1) or per-channel ({out_channels}), got {param.numel()}" + + +lib.define( + "quantized_linear.out(" + "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor weights, " + "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " + "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " + "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, " + "*, Tensor(a!) out) -> Tensor(a!)" +) + +# Define functional variant (non-out version) +lib.define( + "quantized_linear(" + "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor weights, " + "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " + "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " + "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features" + ") -> Tensor" +) + + +# Fake meta function for shape inference (out variant) +@register_fake("cortex_m::quantized_linear.out") +def quantized_linear_out_meta( + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, + out: torch.Tensor, +) -> torch.Tensor: + # Validate dimensions + batch_size = input.shape[0] + out_channels = weights.shape[0] + + # Validate weight quantization parameters dimensions + _check_per_tensor_or_per_channel( + weight_zero_point, out_channels, "weight_zero_point" + ) + _check_per_tensor_or_per_channel( + weight_multiplier, out_channels, "weight_multiplier" + ) + _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") + + # Validate output shape + expected_shape = (batch_size, out_channels) + assert ( + out.shape == expected_shape + ), f"Output shape {out.shape} must be {expected_shape}" + + return out + + +# Fake meta function for shape inference (functional variant) +@register_fake("cortex_m::quantized_linear") +def quantized_linear_meta( + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, +) -> torch.Tensor: + # Validate dimensions (same as out variant) + batch_size = input.shape[0] + out_channels = weights.shape[0] + + # Validate weight quantization parameters dimensions + _check_per_tensor_or_per_channel( + weight_zero_point, out_channels, "weight_zero_point" + ) + _check_per_tensor_or_per_channel( + weight_multiplier, out_channels, "weight_multiplier" + ) + _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") + + # Calculate output shape for functional variant + output_shape = (batch_size, out_channels) + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +@impl(lib, "quantized_linear.out", "CompositeExplicitAutograd") +def quantized_linear_out_impl( + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, + *, + out: torch.Tensor, +) -> torch.Tensor: + """ + Fallback implementation for meta/testing + Note: This won't be called at runtime, only during compilation + """ + + # Per-channel dequantization + input_scale = input_multiplier * (2.0 ** (-input_shift)) + input_fp = (input.float() - input_zero_point) * input_scale + if weight_zero_point.numel() == 1: + # Per-tensor + weight_scale = weight_multiplier.item() * (2.0 ** (-weight_shift.item())) + weights_fp = (weights.float() - weight_zero_point.item()) * weight_scale + else: + # Per-channel + weight_scales = weight_multiplier.float() * (2.0 ** (-weight_shift.float())) + weights_fp = ( + weights.float() - weight_zero_point.float().unsqueeze(1) + ) * weight_scales.unsqueeze(1) + bias_fp = None + if bias is not None: + bias_scales = bias_multiplier.float() * (2.0 ** (-bias_shift.float())) + bias_fp = bias.float() * bias_scales + + result_fp = torch.nn.functional.linear(input_fp, weights_fp, bias_fp) + else: + result_fp = torch.nn.functional.linear(input_fp, weights_fp) + result_quantized = torch.clamp( + torch.round(result_fp + output_zero_point), -128, 127 + ).to(torch.int8) + out.copy_(result_quantized) + return out + + +# Functional variant implementation +@impl(lib, "quantized_linear", "CompositeExplicitAutograd") +def quantized_linear_impl( + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, +) -> torch.Tensor: + """ + Functional variant - creates output tensor and calls out variant + """ + # Create output tensor + batch_size = input.shape[0] + output = torch.empty( + (batch_size, out_features), dtype=torch.int8, device=input.device + ) + return quantized_linear_out_impl( + input, + input_zero_point, + input_multiplier, + input_shift, + weights, + weight_zero_point, + weight_multiplier, + weight_shift, + bias, + bias_multiplier, + bias_shift, + scratch_buffer, + output_zero_point, + in_features, + out_features, + out=output, + ) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index f2615a1f525..b41c0c68fa5 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -27,3 +27,15 @@ kernels: - arg_meta: null kernel_name: cortex_m::quantized_add_out + +- func: cortex_m::quantized_linear(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features) -> Tensor + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_linear + +- func: cortex_m::quantized_linear.out(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_linear_out diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index 3f6e05fc4de..7155f997bf4 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -8,6 +8,10 @@ import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import Node + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int @@ -92,3 +96,58 @@ def quantize_multiplier_aot(scale: float) -> tuple[int, int]: def cleanup_erased_nodes(graph_module: torch.fx.GraphModule): # Placeholder for any additional cleanup if needed pass + + +def transfer_metadata( + new_node: Node, source_node: Node, pass_name: str = "QuantizedPass" +) -> None: + """Transfer metadata with proper provenance tracking.""" + if hasattr(source_node, "meta") and source_node.meta: + new_node.meta = source_node.meta.copy() + if "from_node" in new_node.meta: + from_node_list = new_node.meta.get("from_node", []).copy() + from_node_list.append( + {"source": source_node.name, "pass": pass_name, "op": "fuse"} + ) + new_node.meta["from_node"] = from_node_list + for field in ["tensor_meta", "stack_trace"]: + if field in source_node.meta: + new_node.meta[field] = source_node.meta[field] + + +def is_dequant_node(node: Node) -> bool: + """Check if node is a dequantize operation.""" + dequant_targets = { + exir_ops.edge.cortex_m.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + } + return node.op == "call_function" and node.target in dequant_targets + + +def is_quant_node(node: Node) -> bool: + """Check if node is a quantize operation.""" + quant_targets = { + exir_ops.edge.cortex_m.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + } + return node.op == "call_function" and node.target in quant_targets + + +def cleanup_nodes(nodes_to_erase, graph): + """Clean up marked nodes from graph.""" + failed_nodes = [] + + for node in reversed(nodes_to_erase): + if node in graph.nodes and len(node.users) == 0: + try: + graph.erase_node(node) + except Exception as e: + print(f"Warning: Failed to erase node {node}: {e}") + failed_nodes.append(node) + continue + + if failed_nodes: + print(f"Warning: {len(failed_nodes)} nodes could not be erased") + + return failed_nodes diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py new file mode 100644 index 00000000000..8f8a90eec2f --- /dev/null +++ b/backends/cortex_m/passes/quantized_linear_fusion_pass.py @@ -0,0 +1,645 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import executorch.backends.cortex_m.ops.operators # noqa +import torch +import torch.fx + +from executorch.backends.cortex_m.passes.passes_utils import ( + cleanup_nodes, + is_dequant_node, + quantize_multiplier_aot, + transfer_metadata, +) + +from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger("quantized_linear_fusion_pass") +logger.setLevel(logging.INFO) + + +class QuantizedLinearFusionPass(ExportPass): + """ + Cortex-M backend pass that fuses quantized linear-like patterns. + Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize + Into: cortex_m.quantized_linear.default with direct parameters. + """ + + SUPPORTED_OPS_MAPPING = { + exir_ops.edge.aten.addmm.default: exir_ops.edge.cortex_m.quantized_linear.default, + exir_ops.edge.aten.mm.default: exir_ops.edge.cortex_m.quantized_linear.default, + } + + requires_exported_program = True + + def __init__(self, exported_program: ExportedProgram): + super().__init__() + self._exported_program = exported_program + self.nodes_to_erase = [] + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + logger.info("Starting QuantizedLinearFusionPass") + assert id(self._exported_program.graph_module.graph) == id( + graph_module.graph + ), "QuantizedLinearFusionPass requires same graph instance" + + try: + fusion_count = self._fuse_quantized_linear_patterns(graph_module) + if fusion_count > 0: + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + graph_module.recompile() + logger.info(f"Linear fusion completed: {fusion_count} patterns fused") + return PassResult(graph_module, fusion_count > 0) + except Exception as e: + logger.error(f"Error in QuantizedLinearFusionPass: {e}") + raise e + + def _extract_linear_pattern(self, quantize_node: Node): + if not quantize_node.args: + return None + fc_node = quantize_node.args[0] + if not ( + fc_node.op == "call_function" + and fc_node.target in self.SUPPORTED_OPS_MAPPING + ): + return None + + op_name = str(fc_node.target).split(".")[-1] + + if "addmm" in str(fc_node.target): + input_dq_node = fc_node.args[1] + else: + input_dq_node = fc_node.args[0] + if not is_dequant_node(input_dq_node): + logger.info("input_dq_node is not a dequant node") + return None + weight_dq_node, bias_dq_node = self._extract_weight_bias_from_fc_op(fc_node) + if not weight_dq_node: + logger.info("No weight, bias dequantize node found") + return None + return ( + quantize_node, + fc_node, + input_dq_node, + weight_dq_node, + bias_dq_node, + op_name, + ) + + def _extract_weight_bias_from_fc_op(self, fc_node: Node): + """Generic extraction for FC-like operations.""" + + if "addmm" in str(fc_node.target): + if len(fc_node.args) >= 3: + bias_arg = fc_node.args[0] + weight_arg = fc_node.args[2] + weight_dq_node = self._trace_to_dequantize(weight_arg) + logger.info( + f"weight_arg: {weight_arg}, traced weight_dq_node: {weight_dq_node}" + ) + + if weight_dq_node is None: + logger.info("No weight dequantize node found ") + + # For bias, try to trace to dequantize but allow None (no-bias case) + bias_dq_node = self._trace_to_dequantize(bias_arg) + if bias_dq_node is None: + logger.info("No bias dequantize node found - likely no-bias linear") + return weight_dq_node, bias_dq_node + elif any(op in str(fc_node.target) for op in ["linear", "mm"]): + if len(fc_node.args) >= 2: + weight_arg = fc_node.args[1] + bias_arg = fc_node.args[2] if len(fc_node.args) > 2 else None + weight_dq_node = self._trace_to_dequantize(weight_arg) + bias_dq_node = self._trace_to_dequantize(bias_arg) if bias_arg else None + return weight_dq_node, bias_dq_node + return None, None + + def _extract_input_quantization_parameters( + self, input_dq_node: Node + ) -> Optional[dict]: + """Extract input quantization parameters from dequantize node.""" + try: + # Find the quantize operation that produces the int8 tensor + input_quantize_node = None + if hasattr(input_dq_node, "args") and input_dq_node.args: + quantize_candidate = input_dq_node.args[0] + if getattr( + quantize_candidate, "op", None + ) == "call_function" and "quantize" in str( + getattr(quantize_candidate, "target", "") + ): + input_quantize_node = quantize_candidate + + if not input_quantize_node: + logger.error("Could not find quantize node for input!") + return None + + # Extract input quantization parameters + input_scale = self._extract_param_value(input_dq_node.args[1]) + input_zero_point = int(self._extract_param_value(input_dq_node.args[2])) + input_multiplier, input_shift = quantize_multiplier_aot(input_scale) + + return { + "input_scale": input_scale, + "input_zero_point": input_zero_point, + "input_multiplier": input_multiplier, + "input_shift": input_shift, + "input_tensor": input_quantize_node, + } + except Exception as e: + logger.error(f"Failed to extract input quantization parameters: {e}") + return None + + def _extract_output_quantization_parameters( + self, quantize_node: Node + ) -> Optional[dict]: + """Extract output quantization parameters from quantize node.""" + try: + output_scale = self._extract_param_value(quantize_node.args[1]) + output_zero_point = int(self._extract_param_value(quantize_node.args[2])) + + return { + "output_scale": output_scale, + "output_zero_point": output_zero_point, + } + except Exception as e: + logger.error(f"Failed to extract output quantization parameters: {e}") + return None + + def _create_constant_parameter_buffer( + self, graph, quantize_node: Node, data: torch.Tensor, name: str + ): + """Create a parameter buffer""" + buffer_name = f"{name}_{id(quantize_node)}" + + setattr(graph.owning_module, buffer_name, data) + + # Create a get_attr node + with graph.inserting_before(quantize_node): + buffer_node = graph.create_node( + op="get_attr", target=buffer_name, name=buffer_name + ) + + # Set metadata + buffer_node.meta["val"] = data + + return buffer_node + + def _extract_weight_parameters(self, weight_dq_node: Node) -> Optional[dict]: + try: + weight_tensor = weight_dq_node.args[0] + weight_scale = weight_dq_node.args[1] + weight_zero_point = ( + weight_dq_node.args[2] if len(weight_dq_node.args) > 2 else None + ) + + weight_scale_data = self._extract_param_value(weight_scale) + weight_zp_data = ( + self._extract_param_value(weight_zero_point) + if weight_zero_point + else None + ) + + # Get actual tensor data to determine output features + weight_tensor_data = get_param_tensor(self._exported_program, weight_tensor) + out_features = weight_tensor_data.shape[0] + + # Handle both per-tensor and per-channel + if ( + isinstance(weight_scale_data, torch.Tensor) + and weight_scale_data.numel() > 1 + ): + # Per-channel: ensure we have the right number of elements + assert ( + weight_scale_data.numel() == out_features + ), f"Scale size {weight_scale_data.numel()} != out_features {out_features}" + + multipliers = [] + shifts = [] + for scale in weight_scale_data: + mult, shift = quantize_multiplier_aot(scale.item()) + multipliers.append(mult) + shifts.append(shift) + + weight_multiplier = torch.tensor(multipliers, dtype=torch.int32) + weight_shift = torch.tensor(shifts, dtype=torch.int32) + weight_zp_tensor = ( + weight_zp_data.int() + if weight_zp_data is not None + else torch.zeros(out_features, dtype=torch.int32) + ) + else: + # Per-tensor: create tensors with correct size for output features + scale_val = ( + weight_scale_data.item() + if isinstance(weight_scale_data, torch.Tensor) + else weight_scale_data + ) + mult, shift = quantize_multiplier_aot(scale_val) + + # Create tensors sized for out_features (not single element) + weight_multiplier = torch.full((out_features,), mult, dtype=torch.int32) + weight_shift = torch.full((out_features,), shift, dtype=torch.int32) + weight_zp_tensor = torch.full( + (out_features,), + weight_zp_data if weight_zp_data else 0, + dtype=torch.int32, + ) + + # Validate multipliers + for i, mult in enumerate(weight_multiplier): + if mult < (1 << 30) or mult > ((1 << 31) - 1): + logger.error( + f"Invalid multiplier[{i}]: {mult}, scale was: {weight_scale_data}" + ) + return None + + return { + "weight_tensor": weight_tensor, + "weight_zero_point_data": weight_zp_tensor, + "weight_multiplier_data": weight_multiplier, + "weight_shift_data": weight_shift, + } + except Exception as e: + logger.error(f"Failed to extract weight parameters: {e}") + return None + + def _extract_bias_parameters(self, bias_dq_node: Optional[Node]) -> Optional[dict]: + """ + Extract bias parameters for quantized linear fusion. + Handles both dequantized bias nodes and constant bias tensors. + Returns a dict with bias_tensor, bias_multiplier, and bias_shift. + """ + if not bias_dq_node: + # No bias present + return None + try: + # Case 1: Bias is a dequantize node + if hasattr(bias_dq_node, "op") and is_dequant_node(bias_dq_node): + bias_tensor = bias_dq_node.args[0] + bias_scale = bias_dq_node.args[1] + + bias_scale_data = self._extract_param_value(bias_scale) + + if ( + isinstance(bias_scale_data, torch.Tensor) + and bias_scale_data.numel() > 1 + ): + # Per-channel bias + bias_multipliers = [] + bias_shifts = [] + for scale_val in bias_scale_data.tolist(): + mult, shift = quantize_multiplier_aot(scale_val) + bias_multipliers.append(mult) + bias_shifts.append(shift) + return { + "bias_tensor": bias_tensor, + "bias_multiplier": bias_multipliers, + "bias_shift": bias_shifts, + } + else: + # Per-tensor bias + bias_scale_val = ( + bias_scale_data.item() + if isinstance(bias_scale_data, torch.Tensor) + else bias_scale_data + ) + bias_multiplier, bias_shift = quantize_multiplier_aot( + bias_scale_val + ) + return { + "bias_tensor": bias_tensor, + "bias_multiplier": bias_multiplier, + "bias_shift": bias_shift, + } + else: + # Case 2: Bias is a constant tensor (not dequantized) + # This can happen if bias is not quantized in the model + bias_tensor = bias_dq_node + # Use default multiplier/shift for unquantized bias + bias_multiplier = 1 + bias_shift = 0 + return { + "bias_tensor": bias_tensor, + "bias_multiplier": bias_multiplier, + "bias_shift": bias_shift, + } + except Exception as e: + logger.error(f"Failed to extract bias parameters: {e}") + return None + + def _prepare_bias_tensors( + self, bias_params: Optional[dict], out_features: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare bias multiplier and shift tensors for kernel call. + Returns (bias_multiplier_tensor, bias_shift_tensor) both sized [out_features]. + """ + if bias_params: + bias_multiplier = bias_params["bias_multiplier"] + bias_shift = bias_params["bias_shift"] + + # Convert to tensors of the right size + if isinstance(bias_multiplier, int): + bias_multiplier_tensor = torch.full( + [out_features], bias_multiplier, dtype=torch.int32 + ) + elif isinstance(bias_multiplier, list): + assert ( + len(bias_multiplier) == out_features + ), f"Bias multiplier size {len(bias_multiplier)} != out_features {out_features}" + bias_multiplier_tensor = torch.tensor( + bias_multiplier, dtype=torch.int32 + ) + elif isinstance(bias_multiplier, torch.Tensor): + assert ( + bias_multiplier.numel() == out_features + ), f"Bias multiplier size {bias_multiplier.numel()} != out_features {out_features}" + bias_multiplier_tensor = bias_multiplier + else: + raise TypeError( + f"Unsupported bias_multiplier type: {type(bias_multiplier)}" + ) + + if isinstance(bias_shift, int): + bias_shift_tensor = torch.full( + [out_features], bias_shift, dtype=torch.int32 + ) + elif isinstance(bias_shift, list): + assert ( + len(bias_shift) == out_features + ), f"Bias shift size {len(bias_shift)} != out_features {out_features}" + bias_shift_tensor = torch.tensor(bias_shift, dtype=torch.int32) + elif isinstance(bias_shift, torch.Tensor): + assert ( + bias_shift.numel() == out_features + ), f"Bias shift size {bias_shift.numel()} != out_features {out_features}" + bias_shift_tensor = bias_shift + else: + raise TypeError(f"Unsupported bias_shift type: {type(bias_shift)}") + + return bias_multiplier_tensor, bias_shift_tensor + else: + # No bias: return zero tensors of correct shape + return ( + torch.zeros([out_features], dtype=torch.int32), + torch.zeros([out_features], dtype=torch.int32), + ) + + def _extract_param_value(self, node_or_value): + """ + Extract a scalar value from a Node or a direct float/int. + """ + if isinstance(node_or_value, (float, int)): + return node_or_value + # If it's a tensor, get its scalar value if possible + if isinstance(node_or_value, torch.Tensor): + return node_or_value.item() if node_or_value.numel() == 1 else node_or_value + # If it's a Node, use get_param_tensor + if hasattr(node_or_value, "op"): + tensor = get_param_tensor(self._exported_program, node_or_value) + return tensor.item() if tensor.numel() == 1 else tensor + raise TypeError(f"Unsupported parameter type: {type(node_or_value)}") + + def _calculate_cmsis_scratch_size(self, weight_tensor) -> int: + """Calculate CMSIS-NN scratch buffer size for quantized linear operations. + + Source: CMSIS-NN arm_fully_connected_s8_get_buffer_size() returns filter_dims->w * sizeof(int32_t). + This buffer stores pre-computed kernel sums (weight row sums) - one int32_t per output feature. + Same buffer size applies to both per-tensor and per-channel quantization paths since both use + identical kernel sum optimization in the underlying matrix multiplication. + """ + try: + print(f"weight_tensor type: {type(weight_tensor)}, value: {weight_tensor}") + weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape + out_features = weight_shape[0] # filter_dims->w in CMSIS terms + + # CMSIS-NN implementation expects the following size + cmsis_buffer_size = out_features * 4 # sizeof(int32_t) + return cmsis_buffer_size + except Exception as e: + logger.error(f"Failed to calculate CMSIS scratch size: {e}") + return 2048 # Fallback + + def _create_scratch_buffer(self, graph, quantize_node: Node, weight_tensor): + cmsis_scratch = self._calculate_cmsis_scratch_size(weight_tensor) + + kernel_sum_header = 8 # sizeof(KernelSumHeader) + total_size = kernel_sum_header + cmsis_scratch + + logger.info( + f"Kernel sum header: {kernel_sum_header}, CMSIS buffer: {cmsis_scratch}, total: {total_size}" + ) + + return create_mutable_buffer( + self._exported_program, + name=f"b_cmsis_linear_scratch_{id(quantize_node)}", + data=torch.zeros((total_size,), dtype=torch.int8), + ) + + def _create_fused_node( + self, + graph, + quantize_node: Node, + quant_params: dict, + weight_params: dict, + bias_params: Optional[dict], + quantized_target, + ) -> Node: + """Generic fused node creation for any FC-like operation.""" + # Extract all parameters + input_tensor = quant_params["input_tensor"] + input_zp = quant_params["input_zero_point"] + input_multiplier = quant_params["input_multiplier"] + input_shift = quant_params["input_shift"] + weight_tensor = weight_params["weight_tensor"] + + weight_zp_node = self._create_constant_parameter_buffer( + graph, quantize_node, weight_params["weight_zero_point_data"], "weight_zp" + ) + weight_mult_node = self._create_constant_parameter_buffer( + graph, quantize_node, weight_params["weight_multiplier_data"], "weight_mult" + ) + weight_shift_node = self._create_constant_parameter_buffer( + graph, quantize_node, weight_params["weight_shift_data"], "weight_shift" + ) + # Get dimensions + weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape + assert ( + len(weight_shape) == 2 + ), f"Weight tensor must be 2D, got shape {weight_shape}" + in_features = weight_shape[1] + out_features = weight_shape[0] + + # Handle bias + bias_tensor = bias_params["bias_tensor"] if bias_params else None + bias_multiplier, bias_shift = self._prepare_bias_tensors( + bias_params, out_features + ) + output_zp = quant_params["output_zero_point"] + + scratch_buffer = self._create_scratch_buffer( + graph, quantize_node, weight_tensor + ) + + with graph.inserting_after(quantize_node): + fused = graph.create_node( + "call_function", + target=quantized_target, + args=( + input_tensor, + input_zp, + input_multiplier, + input_shift, + weight_tensor, + weight_zp_node, + weight_mult_node, + weight_shift_node, + bias_tensor, + bias_multiplier, + bias_shift, + scratch_buffer, + output_zp, + in_features, + out_features, + ), + kwargs={}, + ) + + transfer_metadata(fused, quantize_node, "QuantizedLinearFusionPass") + return fused + + def _mark_for_cleanup(self, nodes): + for node in nodes: + if node is not None: + self.nodes_to_erase.append(node) + + def _cleanup_nodes(self, graph): + cleanup_nodes(self.nodes_to_erase, graph) + self.nodes_to_erase.clear() + + def _extract_linear_pattern_with_validation(self, quantize_node: Node): + pattern_info = self._extract_linear_pattern(quantize_node) + if not pattern_info: + return None + # Optionally add more validation here if needed + return pattern_info + + def _trace_to_dequantize(self, node: Optional[Node], max_depth=3) -> Optional[Node]: + """Trace through transformations to find dequantize node.""" + current_node = node + depth = 0 + while current_node and depth < max_depth: + if is_dequant_node(current_node): + return current_node + if current_node.op == "call_function" and current_node.target in { + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view_copy.default, + }: + if current_node.args: + current_node = current_node.args[0] + depth += 1 + continue + break + return None + + def _fuse_quantized_linear_patterns( + self, graph_module: torch.fx.GraphModule + ) -> int: + fusion_count = 0 + graph = graph_module.graph + for node in list(graph.nodes): + if not ( + node.op == "call_function" and "quantize_per_tensor" in str(node.target) + ): + continue + pattern_info = self._extract_linear_pattern_with_validation(node) + if not pattern_info: + continue + + ( + quantize_node, + fc_node, + input_dq_node, + weight_dq_node, + bias_dq_node, + op_name, + ) = pattern_info + + # Get quantized target for this FC operation + quantized_target = self.SUPPORTED_OPS_MAPPING.get(fc_node.target) + if not quantized_target: + logger.warning(f"No quantized target found for {fc_node.target}") + continue + + logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") + + try: + input_params = self._extract_input_quantization_parameters( + input_dq_node + ) + if not input_params: + logger.error( + "Quantization parameter extraction failed for node: %s", node + ) + return None + output_params = self._extract_output_quantization_parameters( + quantize_node + ) + if not output_params: + logger.error( + "Output quantization parameter extraction failed for node: %s", + node, + ) + return None + quant_params = {**input_params, **output_params} + logger.info(f"Quantization parameters: {quant_params}") + + weight_params = self._extract_weight_parameters(weight_dq_node) + if not weight_params: + continue + bias_params = self._extract_bias_parameters(bias_dq_node) + if bias_dq_node and not bias_params: + continue + fused_node = self._create_fused_node( + graph, + quantize_node, + quant_params, + weight_params, + bias_params, + quantized_target, + ) + logger.info(f"Created fused {op_name} node: {fused_node}") + + quantize_node.replace_all_uses_with(fused_node) + self._mark_for_cleanup( + [ + quantize_node, + fc_node, + input_dq_node, + weight_dq_node, + bias_dq_node, + ] + ) + fusion_count += 1 + logger.info(f"✅ Successfully fused {op_name} operation {fusion_count}") + except Exception as e: + logger.error( + f"Failed to fuse {op_name} pattern for {fc_node.name}: {e}" + ) + continue + self._cleanup_nodes(graph) + return fusion_count diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index ca6d8b97795..eebf6866d83 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -36,7 +36,7 @@ class QuantizedOpFusionPass(ExportPass): # Generic operation mapping SUPPORTED_OPS_MAPPING = { exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, - # Future ops to be added here: + # Future binary ops to be added here: } def __init__(self): diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 106ab35363c..5513529509e 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -38,6 +38,10 @@ from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner # To use Cortex-M backend +from executorch.backends.cortex_m.passes.quantized_linear_fusion_pass import ( + QuantizedLinearFusionPass, +) + from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( QuantizedOpFusionPass, ) @@ -55,6 +59,7 @@ ExecutorchBackendConfig, to_edge_transform_and_lower, ) + from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate from torch.utils.data import DataLoader @@ -148,7 +153,8 @@ def quantize( evaluator_name: str | None, evaluator_config: Dict[str, Any] | None, ) -> torch.nn.Module: - """This is the official recommended flow for quantization in pytorch 2.0 export""" + """This is the official recommended flow for quantization in pytorch 2.0 + export""" logging.info("Quantizing Model...") logging.debug(f"Original model: {model}") quantizer = None @@ -605,7 +611,7 @@ def get_args(): parser.add_argument( "--enable_qdq_fusion_pass", action="store_true", - help="Enable the QuantizedOpFusionPass fusion step", + help="Enable the Quantized qdq fusion Op passes", ) parser.add_argument( "--enable_debug_mode", @@ -806,22 +812,24 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ return model_int8, edge -def transform_for_cortex_m_backend(edge, args): +def transform_for_cortex_m_backend(edge_program_manager, args): # Let's make sure we are using optimized Cortex M backend # NB: If we can't find and replace ops those are expected to be replaced, # bad things will happen at runtime, like "missing operator" errors! # Instantiate the mandatory ReplaceQuantNodesPass - passes = [ReplaceQuantNodesPass()] - - # Conditionally add the QuantizedOpFusionPass + passes = [ReplaceQuantNodesPass] if args.enable_qdq_fusion_pass: - passes.append(QuantizedOpFusionPass()) - - # Apply the passes - edge = edge.transform(passes) - - return edge + passes += [QuantizedLinearFusionPass, QuantizedOpFusionPass] + current_edge = edge_program_manager + for pass_cls in passes: + transform_pass = ( + pass_cls(current_edge.exported_program()) + if pass_cls.__name__ == "QuantizedLinearFusionPass" + else pass_cls() + ) + current_edge = current_edge.transform([transform_pass]) + return current_edge if __name__ == "__main__": # noqa: C901