From 4daab85e9c7d5b88a4ffc74de45d71578ea7b458 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:06:30 -0800 Subject: [PATCH] Revert "Cortex_m backend: Simplify add + linear fusion passes (#15526)" This reverts commit 98432224460cf2162ae5b6fdfba398970e6caa1d. --- backends/cortex_m/ops/cortex_m_ops_common.h | 1 - backends/cortex_m/ops/op_quantized_linear.cpp | 195 +++-- backends/cortex_m/ops/operators.py | 254 +++++-- backends/cortex_m/ops/operators.yaml | 10 +- .../cortex_m/passes/cortex_m_pass_manager.py | 8 +- .../passes/quantized_linear_fusion_pass.py | 703 +++++++++++++++--- .../passes/quantized_op_fusion_pass.py | 284 +++++-- backends/cortex_m/test/ops/test_add.py | 29 +- backends/cortex_m/test/ops/test_linear.py | 141 ++-- 9 files changed, 1257 insertions(+), 368 deletions(-) diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index 10fad9b6a0b..c7e6cc8a389 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -22,7 +22,6 @@ using Tensor = torch::executor::Tensor; using ScalarType = executorch::aten::ScalarType; using Scalar = torch::executor::Scalar; using Error = executorch::runtime::Error; -using IntArrayRef = executorch::aten::ArrayRef; // From arm_nn_math_types.h #define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL)) diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index 015fa805134..d1ccb6d0d45 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -1,12 +1,12 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. - * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include "cmsis_scratch_buffer_context.h" #include "cortex_m_ops_common.h" extern "C" { @@ -20,91 +20,152 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext; Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, + const Scalar& input_zero_point, + const Scalar& input_multiplier, + const Scalar& input_shift, const Tensor& weights, + const Tensor& weight_zero_point, + const Tensor& weight_multiplier, + const Tensor& weight_shift, const torch::executor::optional& bias, - const torch::executor::optional& kernel_sum, - const Scalar& input_offset, - const Scalar& filter_offset, - const Scalar& output_offset, - const IntArrayRef requantize_multipliers, - const IntArrayRef requantize_shifts, - const Scalar& activation_max, - const Scalar& activation_min, + const Tensor& bias_multiplier, + const Tensor& bias_shift, + const Tensor& scratch_buffer, + const Scalar& output_zero_point, + const Scalar& in_features, + const Scalar& out_features, Tensor& out) { ET_LOG(Info, "quantized_linear_out: called"); + validate_cmsis_nn_tensor_requirements(input, weights, out); + + ET_CHECK_MSG( + scratch_buffer.scalar_type() == ScalarType::Char, + "Scratch buffer must be int8"); + + const int32_t batch_size = input.size(0); + const int32_t in_feat = static_cast(in_features.to()); + const int32_t out_feat = static_cast(out_features.to()); + const int32_t input_zp = static_cast(input_zero_point.to()); + const int32_t output_zp = + static_cast(output_zero_point.to()); + const bool is_per_channel = (weight_zero_point.numel() > 1); const int8_t* input_data = input.const_data_ptr(); const int8_t* weight_data = weights.const_data_ptr(); const int32_t* bias_data = bias.has_value() ? bias.value().const_data_ptr() : nullptr; - int32_t* kernel_sum_data = - kernel_sum.has_value() ? kernel_sum.value().data_ptr() : nullptr; int8_t* output_data = out.mutable_data_ptr(); + const int32_t* weight_zp_data = weight_zero_point.const_data_ptr(); + const int32_t* weight_mult_data = weight_multiplier.const_data_ptr(); + const int32_t* weight_shift_data = weight_shift.const_data_ptr(); + + if (!validate_per_channel_quant_params( + weight_mult_data, weight_shift_data, out_feat)) { + context.fail(Error::InvalidArgument); + return out; + } + + // Initialize scratch buffer context (validates early) + CMSISScratchBufferContext scratch_ctx( + const_cast(scratch_buffer), weights, weight_zero_point, bias); - cmsis_nn_context ctx; - ctx.size = 0; // Not used in CMSIS-NN - ctx.buf = kernel_sum_data; + scratch_ctx.compute_kernel_sums_if_needed(); + cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx(); // Setup CMSIS-NN parameters cmsis_nn_fc_params fc_params; - fc_params.input_offset = static_cast(input_offset.to()); - fc_params.filter_offset = static_cast(filter_offset.to()); - fc_params.output_offset = static_cast(output_offset.to()); - fc_params.activation.min = static_cast(activation_min.to()); - fc_params.activation.max = static_cast(activation_max.to()); - - cmsis_nn_per_tensor_quant_params per_tensor_quant_params; - per_tensor_quant_params.multiplier = - static_cast(requantize_multipliers.at(0)); - per_tensor_quant_params.shift = static_cast(requantize_shifts.at(0)); - - auto in_feat = input.size(input.dim() - 1); - auto out_feat = out.size(out.dim() - 1); - auto batches = 1; - for (size_t i = 0; i < input.dim() - 1; i++) { - batches *= input.size(i); - } - ET_LOG( - Info, - "in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d", - in_feat, - out_feat, - batches, - kernel_sum.has_value() ? kernel_sum.value().numel() : 0); - ET_LOG( - Info, - "kernel_sum[0]: %d, kernel_sum[1]: %d", - kernel_sum_data != nullptr ? kernel_sum_data[0] : -1, - kernel_sum_data != nullptr ? kernel_sum_data[1] : -1); - cmsis_nn_dims input_dims = {batches, 1, 1, in_feat}; + fc_params.input_offset = -input_zp; + fc_params.output_offset = output_zp; + fc_params.activation.min = std::numeric_limits::min(); + fc_params.activation.max = std::numeric_limits::max(); + + cmsis_nn_dims input_dims = {1, 1, 1, in_feat}; cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat}; cmsis_nn_dims bias_dims = {1, 1, 1, out_feat}; - cmsis_nn_dims output_dims = {batches, 1, 1, out_feat}; - - arm_cmsis_nn_status status = arm_fully_connected_s8( - &ctx, - &fc_params, - &per_tensor_quant_params, - &input_dims, - input_data, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - output_data); - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_linear_out: CMSIS-NN failed with status [%d]", - status); - context.fail(Error::Internal); - return out; - } + cmsis_nn_dims output_dims = {1, 1, 1, out_feat}; + + arm_cmsis_nn_status status; + for (int32_t b = 0; b < batch_size; b++) { + const int8_t* batch_input = input_data + b * in_feat; + int8_t* batch_output = output_data + b * out_feat; + ET_CHECK_MSG( + batch_input != nullptr && weight_data != nullptr, + "Null input pointers"); + ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions"); + + if (is_per_channel) { + cmsis_nn_per_channel_quant_params per_channel_quant_params; + per_channel_quant_params.multiplier = + const_cast(weight_mult_data); + per_channel_quant_params.shift = const_cast(weight_shift_data); + + status = arm_fully_connected_per_channel_s8( + &ctx, + &fc_params, + &per_channel_quant_params, + &input_dims, + batch_input, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + batch_output); + } else { + fc_params.filter_offset = -weight_zp_data[0]; + cmsis_nn_per_tensor_quant_params per_tensor_quant_params; + per_tensor_quant_params.multiplier = weight_mult_data[0]; + per_tensor_quant_params.shift = weight_shift_data[0]; + + status = arm_fully_connected_s8( + &ctx, + &fc_params, + &per_tensor_quant_params, + &input_dims, + batch_input, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + batch_output); + } + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_linear_out: CMSIS-NN failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + } return out; } +// Functional variant (stub, not used at runtime) +Tensor quantized_linear( + KernelRuntimeContext& context, + const Tensor& input, + const Scalar& input_zero_point, + const Scalar& input_multiplier, + const Scalar& input_shift, + const Tensor& weights, + const Tensor& weight_zero_point, + const Tensor& weight_multiplier, + const Tensor& weight_shift, + const torch::executor::optional& bias, + const Tensor& bias_multiplier, + const Tensor& bias_shift, + const Tensor& scratch_buffer, + const Scalar& output_zero_point, + const Scalar& in_features, + const Scalar& out_features) { + ET_LOG(Info, "quantized_linear: called"); + assert(false); + return const_cast(input); +} + } // namespace native } // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index b8abfb9bde4..286f938ccc9 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -5,8 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from math import prod - import torch from executorch.backends.cortex_m.passes.passes_utils import ( requantize_cmsis, @@ -172,110 +170,210 @@ def quantized_add_impl( # QUANTIZED LINEAR OPERATION DEFINITION # =================================================================== + +def _check_per_tensor_or_per_channel(param: torch.Tensor, out_channels: int, name: str): + assert param.numel() in [ + 1, + out_channels, + ], f"{name} must be per-tensor (1) or per-channel ({out_channels}), got {param.numel()}" + + lib.define( "quantized_linear.out(" - "Tensor input, " + "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " "Tensor weights, " - "Tensor? bias, " - "Tensor? kernel_sum, " - "Scalar input_offset, " - "Scalar filter_offset, " - "Scalar output_offset, " - "int[] requantize_multipliers, " - "int[] requantize_shifts, " - "Scalar activation_max, " - "Scalar activation_min, " - "*, Tensor(a!) out" - ") -> Tensor(a!)" + "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " + "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " + "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, " + "*, Tensor(a!) out) -> Tensor(a!)" ) # Define functional variant (non-out version) lib.define( "quantized_linear(" - "Tensor input, " + "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " "Tensor weights, " - "Tensor? bias, " - "Tensor? kernel_sum, " - "Scalar input_offset, " - "Scalar filter_offset, " - "Scalar output_offset, " - "int[] requantize_multipliers, " - "int[] requantize_shifts, " - "Scalar activation_max, " - "Scalar activation_min" + "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " + "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " + "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features" ") -> Tensor" ) +# Fake meta function for shape inference (out variant) +@register_fake("cortex_m::quantized_linear.out") +def quantized_linear_out_meta( + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, + out: torch.Tensor, +) -> torch.Tensor: + # Validate dimensions + batch_size = input.shape[0] + out_channels = weights.shape[0] + + # Validate weight quantization parameters dimensions + _check_per_tensor_or_per_channel( + weight_zero_point, out_channels, "weight_zero_point" + ) + _check_per_tensor_or_per_channel( + weight_multiplier, out_channels, "weight_multiplier" + ) + _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") + + # Validate output shape + expected_shape = (batch_size, out_channels) + assert ( + out.shape == expected_shape + ), f"Output shape {out.shape} must be {expected_shape}" + + return out + + # Fake meta function for shape inference (functional variant) @register_fake("cortex_m::quantized_linear") def quantized_linear_meta( - input, - weights, - bias, - kernel_sum, - input_offset, - filter_offset, - output_offset, - requantize_multipliers, - requantize_shifts, - activation_max, - activation_min, + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, +) -> torch.Tensor: + # Validate dimensions (same as out variant) + batch_size = input.shape[0] + out_channels = weights.shape[0] + + # Validate weight quantization parameters dimensions + _check_per_tensor_or_per_channel( + weight_zero_point, out_channels, "weight_zero_point" + ) + _check_per_tensor_or_per_channel( + weight_multiplier, out_channels, "weight_multiplier" + ) + _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") + + # Calculate output shape for functional variant + output_shape = (batch_size, out_channels) + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +@impl(lib, "quantized_linear.out", "CompositeExplicitAutograd") +def quantized_linear_out_impl( + input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, + weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, + bias: torch.Tensor, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, + *, + out: torch.Tensor, ) -> torch.Tensor: + """ + Fallback implementation for meta/testing + Note: This won't be called at runtime, only during compilation + """ - shape = (*input.shape[:-1], weights.shape[0]) - return torch.empty(shape, dtype=input.dtype, device=input.device) + # Per-channel dequantization + input_scale = input_multiplier * (2.0 ** (-input_shift)) + input_fp = (input.float() - input_zero_point) * input_scale + if weight_zero_point.numel() == 1: + # Per-tensor + weight_scale = weight_multiplier.item() * (2.0 ** (-weight_shift.item())) + weights_fp = (weights.float() - weight_zero_point.item()) * weight_scale + else: + # Per-channel + weight_scales = weight_multiplier.float() * (2.0 ** (-weight_shift.float())) + weights_fp = ( + weights.float() - weight_zero_point.float().unsqueeze(1) + ) * weight_scales.unsqueeze(1) + bias_fp = None + if bias is not None: + bias_scales = bias_multiplier.float() * (2.0 ** (-bias_shift.float())) + bias_fp = bias.float() * bias_scales + + result_fp = torch.nn.functional.linear(input_fp, weights_fp, bias_fp) + else: + result_fp = torch.nn.functional.linear(input_fp, weights_fp) + result_quantized = torch.clamp( + torch.round(result_fp + output_zero_point), -128, 127 + ).to(torch.int8) + out.copy_(result_quantized) + return out # Functional variant implementation @impl(lib, "quantized_linear", "CompositeExplicitAutograd") def quantized_linear_impl( input: torch.Tensor, + input_zero_point: int, + input_multiplier: int, + input_shift: int, weights: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_multiplier: torch.Tensor, + weight_shift: torch.Tensor, bias: torch.Tensor, - kernel_sum: torch.Tensor, - input_offset: int, - filter_offset: int, - output_offset: int, - requantize_multipliers: torch.Tensor, - requantize_shifts: torch.Tensor, - activation_max: int, - activation_min: int, + bias_multiplier: torch.Tensor, + bias_shift: torch.Tensor, + scratch_buffer: torch.Tensor, + output_zero_point: int, + in_features: int, + out_features: int, ) -> torch.Tensor: """ Functional variant - creates output tensor and calls out variant """ - - # Leaving both implementations for debugging purposes. - compute_using_kernel_sum = True - - if compute_using_kernel_sum: - weights_int32 = weights.to(torch.int32) - - input_int32 = input.to(torch.int32) - new_shape = (prod(input.shape[:-1]), input.shape[-1]) - input_reshaped = input_int32.reshape(new_shape) - - lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset - output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum - output_shape = (*input.shape[:-1], output.shape[-1]) - output_reshaped = output.reshape(output_shape) - else: - weights_int32 = weights.to(torch.int32) + filter_offset - - input_int32 = input.to(torch.int32) + input_offset - new_shape = (prod(input.shape[:-1]), input.shape[-1]) - input_reshaped = input_int32.reshape(new_shape) - - output = torch.mm(input_reshaped, weights_int32.T) - if bias is not None: - output = output + bias - output_shape = (*input.shape[:-1], output.shape[-1]) - output_reshaped = output.reshape(output_shape) - - output = requantize_cmsis( - output_reshaped, requantize_multipliers[0], requantize_shifts[0] + # Create output tensor + batch_size = input.shape[0] + output = torch.empty( + (batch_size, out_features), dtype=torch.int8, device=input.device + ) + return quantized_linear_out_impl( + input, + input_zero_point, + input_multiplier, + input_shift, + weights, + weight_zero_point, + weight_multiplier, + weight_shift, + bias, + bias_multiplier, + bias_shift, + scratch_buffer, + output_zero_point, + in_features, + out_features, + out=output, ) - output += output_offset - output = torch.clamp(output, activation_min, activation_max).to(torch.int8) - return output diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 98d8df8797e..81ebeafc778 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -23,8 +23,14 @@ - arg_meta: null kernel_name: cortex_m::quantized_add_out -- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_linear(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features) -> Tensor variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_linear_out \ No newline at end of file + kernel_name: cortex_m::quantized_linear + +- func: cortex_m::quantized_linear.out(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_linear_out diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 10fb358c70e..02429cc68e0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -4,11 +4,7 @@ # LICENSE file in the root directory of this source tree. -from executorch.backends.arm._passes import ( - DecorateFp32toInt32CastingPass, - FoldAndAnnotateQParamsPass, - ScalarsToAttributePass, -) +from executorch.backends.arm._passes import ScalarsToAttributePass from executorch.backends.cortex_m.passes import ( QuantizedLinearFusionPass, QuantizedOpFusionPass, @@ -24,12 +20,10 @@ class CortexMPassManager(XNNPACKPassManager): pass_list: list[ExportPass] = [ - FoldAndAnnotateQParamsPass, ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, QuantizedLinearFusionPass, - DecorateFp32toInt32CastingPass, ] pass_list_transform_for_annotation: list[ExportPass] = [ diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py index f921f5ce621..11a49beb2f4 100644 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_linear_fusion_pass.py @@ -5,147 +5,642 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +from typing import Optional import executorch.backends.cortex_m.ops.operators # noqa - import torch import torch.fx -from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - get_param_tensor, +from executorch.backends.cortex_m.passes.passes_utils import ( + cleanup_nodes, + is_dequant_node, + quantize_multiplier_aot, + transfer_metadata, ) +from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor + from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from torch.export.graph_signature import InputKind +from torch.fx import Node from torch.fx.passes.infra.pass_manager import PassResult +logger = logging.getLogger("quantized_linear_fusion_pass") +logger.setLevel(logging.INFO) + class QuantizedLinearFusionPass(XNNPACKPass): """ Cortex-M backend pass that fuses quantized linear-like patterns. Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize Into: cortex_m.quantized_linear.default with direct parameters. + """ - Note that the optimzed implementation makes use of the following rewrite: + SUPPORTED_OPS_MAPPING = { + exir_ops.edge.aten.addmm.default: exir_ops.edge.cortex_m.quantized_linear.default, + exir_ops.edge.aten.mm.default: exir_ops.edge.cortex_m.quantized_linear.default, + } - Let - - yi be the output activations (y1, ... yn) - - xj be the input activations (x1, ... xm) - - wij be the weights (w11, ... wnm) - - a be the input offset - - b be the weight offset - - ci be the bias + requires_exported_program = True - Then the linear operation can be written as: - yi = sum_j((xj + a) * (wij + b)) + ci - = sum_j(xj*wij + xj*b + a*wij + a*b) + ci - = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) - = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum + def __init__(self, exported_program: ExportedProgram): + super().__init__(exported_program) + self.nodes_to_erase = [] - where kernel_sum is precomputed aot. - """ + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + logger.info("Starting QuantizedLinearFusionPass") + assert id(self._exported_program.graph_module.graph) == id( + graph_module.graph + ), "QuantizedLinearFusionPass requires same graph instance" + + try: + fusion_count = self._fuse_quantized_linear_patterns(graph_module) + if fusion_count > 0: + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + graph_module.recompile() + logger.info(f"Linear fusion completed: {fusion_count} patterns fused") + return PassResult(graph_module, fusion_count > 0) + except Exception as e: + logger.error(f"Error in QuantizedLinearFusionPass: {e}") + raise e + + def _extract_linear_pattern(self, quantize_node: Node): + if not quantize_node.args: + return None + fc_node = quantize_node.args[0] + if not ( + fc_node.op == "call_function" + and fc_node.target in self.SUPPORTED_OPS_MAPPING + ): + return None + + op_name = str(fc_node.target).split(".")[-1] + + if "addmm" in str(fc_node.target): + input_dq_node = fc_node.args[1] + else: + input_dq_node = fc_node.args[0] + if not is_dequant_node(input_dq_node): + logger.info("input_dq_node is not a dequant node") + return None + weight_dq_node, bias_dq_node = self._extract_weight_bias_from_fc_op(fc_node) + if not weight_dq_node: + logger.info("No weight, bias dequantize node found") + return None + return ( + quantize_node, + fc_node, + input_dq_node, + weight_dq_node, + bias_dq_node, + op_name, + ) + + def _extract_weight_bias_from_fc_op(self, fc_node: Node): + """Generic extraction for FC-like operations.""" + + if "addmm" in str(fc_node.target): + if len(fc_node.args) >= 3: + bias_arg = fc_node.args[0] + weight_arg = fc_node.args[2] + weight_dq_node = self._trace_to_dequantize(weight_arg) + logger.info( + f"weight_arg: {weight_arg}, traced weight_dq_node: {weight_dq_node}" + ) + + if weight_dq_node is None: + logger.info("No weight dequantize node found ") + + # For bias, try to trace to dequantize but allow None (no-bias case) + bias_dq_node = self._trace_to_dequantize(bias_arg) + if bias_dq_node is None: + logger.info("No bias dequantize node found - likely no-bias linear") + return weight_dq_node, bias_dq_node + elif any(op in str(fc_node.target) for op in ["linear", "mm"]): + if len(fc_node.args) >= 2: + weight_arg = fc_node.args[1] + bias_arg = fc_node.args[2] if len(fc_node.args) > 2 else None + weight_dq_node = self._trace_to_dequantize(weight_arg) + bias_dq_node = self._trace_to_dequantize(bias_arg) if bias_arg else None + return weight_dq_node, bias_dq_node + return None, None + + def _extract_input_quantization_parameters( + self, input_dq_node: Node + ) -> Optional[dict]: + """Extract input quantization parameters from dequantize node.""" + try: + # Find the quantize operation that produces the int8 tensor + input_quantize_node = None + if hasattr(input_dq_node, "args") and input_dq_node.args: + quantize_candidate = input_dq_node.args[0] + if getattr( + quantize_candidate, "op", None + ) == "call_function" and "quantize" in str( + getattr(quantize_candidate, "target", "") + ): + input_quantize_node = quantize_candidate + + if not input_quantize_node: + logger.error("Could not find quantize node for input!") + return None + + # Extract input quantization parameters + input_scale = self._extract_param_value(input_dq_node.args[1]) + input_zero_point = int(self._extract_param_value(input_dq_node.args[2])) + input_multiplier, input_shift = quantize_multiplier_aot(input_scale) + + return { + "input_scale": input_scale, + "input_zero_point": input_zero_point, + "input_multiplier": input_multiplier, + "input_shift": input_shift, + "input_tensor": input_quantize_node, + } + except Exception as e: + logger.error(f"Failed to extract input quantization parameters: {e}") + return None + + def _extract_output_quantization_parameters( + self, quantize_node: Node + ) -> Optional[dict]: + """Extract output quantization parameters from quantize node.""" + try: + output_scale = self._extract_param_value(quantize_node.args[1]) + output_zero_point = int(self._extract_param_value(quantize_node.args[2])) + + return { + "output_scale": output_scale, + "output_zero_point": output_zero_point, + } + except Exception as e: + logger.error(f"Failed to extract output quantization parameters: {e}") + return None + + def _create_constant_parameter_buffer( + self, graph, quantize_node: Node, data: torch.Tensor, name: str + ): + """Create a parameter buffer""" + buffer_name = f"{name}_{id(quantize_node)}" + + setattr(graph.owning_module, buffer_name, data) - def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): + # Create a get_attr node + with graph.inserting_before(quantize_node): + buffer_node = graph.create_node( + op="get_attr", target=buffer_name, name=buffer_name + ) + + # Set metadata + buffer_node.meta["val"] = data + + return buffer_node + + def _extract_weight_parameters(self, weight_dq_node: Node) -> Optional[dict]: + try: + weight_tensor = weight_dq_node.args[0] + weight_scale = weight_dq_node.args[1] + weight_zero_point = ( + weight_dq_node.args[2] if len(weight_dq_node.args) > 2 else None + ) + + weight_scale_data = self._extract_param_value(weight_scale) + weight_zp_data = ( + self._extract_param_value(weight_zero_point) + if weight_zero_point + else None + ) + + # Get actual tensor data to determine output features + weight_tensor_data = get_param_tensor(self._exported_program, weight_tensor) + out_features = weight_tensor_data.shape[0] + + # Handle both per-tensor and per-channel + if ( + isinstance(weight_scale_data, torch.Tensor) + and weight_scale_data.numel() > 1 + ): + # Per-channel: ensure we have the right number of elements + assert ( + weight_scale_data.numel() == out_features + ), f"Scale size {weight_scale_data.numel()} != out_features {out_features}" + + multipliers = [] + shifts = [] + for scale in weight_scale_data: + mult, shift = quantize_multiplier_aot(scale.item()) + multipliers.append(mult) + shifts.append(shift) + + weight_multiplier = torch.tensor(multipliers, dtype=torch.int32) + weight_shift = torch.tensor(shifts, dtype=torch.int32) + weight_zp_tensor = ( + weight_zp_data.int() + if weight_zp_data is not None + else torch.zeros(out_features, dtype=torch.int32) + ) + else: + # Per-tensor: create tensors with correct size for output features + scale_val = ( + weight_scale_data.item() + if isinstance(weight_scale_data, torch.Tensor) + else weight_scale_data + ) + mult, shift = quantize_multiplier_aot(scale_val) + + # Create tensors sized for out_features (not single element) + weight_multiplier = torch.full((out_features,), mult, dtype=torch.int32) + weight_shift = torch.full((out_features,), shift, dtype=torch.int32) + weight_zp_tensor = torch.full( + (out_features,), + weight_zp_data if weight_zp_data else 0, + dtype=torch.int32, + ) + + # Validate multipliers + for i, mult in enumerate(weight_multiplier): + if mult < (1 << 30) or mult > ((1 << 31) - 1): + logger.error( + f"Invalid multiplier[{i}]: {mult}, scale was: {weight_scale_data}" + ) + return None + + return { + "weight_tensor": weight_tensor, + "weight_zero_point_data": weight_zp_tensor, + "weight_multiplier_data": weight_multiplier, + "weight_shift_data": weight_shift, + } + except Exception as e: + logger.error(f"Failed to extract weight parameters: {e}") + return None + + def _extract_bias_parameters(self, bias_dq_node: Optional[Node]) -> Optional[dict]: + """ + Extract bias parameters for quantized linear fusion. + Handles both dequantized bias nodes and constant bias tensors. + Returns a dict with bias_tensor, bias_multiplier, and bias_shift. """ - Computes the precomputed kernel sum term (bias optional) - a * sum_j(wij + b) + ci + if not bias_dq_node: + # No bias present + return None + try: + # Case 1: Bias is a dequantize node + if hasattr(bias_dq_node, "op") and is_dequant_node(bias_dq_node): + bias_tensor = bias_dq_node.args[0] + bias_scale = bias_dq_node.args[1] + + bias_scale_data = self._extract_param_value(bias_scale) + + if ( + isinstance(bias_scale_data, torch.Tensor) + and bias_scale_data.numel() > 1 + ): + # Per-channel bias + bias_multipliers = [] + bias_shifts = [] + for scale_val in bias_scale_data.tolist(): + mult, shift = quantize_multiplier_aot(scale_val) + bias_multipliers.append(mult) + bias_shifts.append(shift) + return { + "bias_tensor": bias_tensor, + "bias_multiplier": bias_multipliers, + "bias_shift": bias_shifts, + } + else: + # Per-tensor bias + bias_scale_val = ( + bias_scale_data.item() + if isinstance(bias_scale_data, torch.Tensor) + else bias_scale_data + ) + bias_multiplier, bias_shift = quantize_multiplier_aot( + bias_scale_val + ) + return { + "bias_tensor": bias_tensor, + "bias_multiplier": bias_multiplier, + "bias_shift": bias_shift, + } + else: + # Case 2: Bias is a constant tensor (not dequantized) + # This can happen if bias is not quantized in the model + bias_tensor = bias_dq_node + # Use default multiplier/shift for unquantized bias + bias_multiplier = 1 + bias_shift = 0 + return { + "bias_tensor": bias_tensor, + "bias_multiplier": bias_multiplier, + "bias_shift": bias_shift, + } + except Exception as e: + logger.error(f"Failed to extract bias parameters: {e}") + return None - as defined above, for i = (1, ..., n) where j indexes the input activations. + def _prepare_bias_tensors( + self, bias_params: Optional[dict], out_features: int + ) -> tuple[torch.Tensor, torch.Tensor]: """ - weights_transposed = weights.T - weights_int32 = weights_transposed.to(torch.int32) - offset_weights = weights_int32 + weight_offset - kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) - kernel_sum_offset = kernel_sum * input_offset - - if bias is not None: - kernel_sum_offset += bias - - return kernel_sum_offset - - def _get_linear_replacement(self, args, meta, node): - input_scale = meta["input_qparams"][0].scale - input_zp = meta["input_qparams"][0].zp - weight_scale = meta["input_qparams"][1].scale - weight_zp = meta["input_qparams"][1].zp - output_scale = meta["output_qparams"][0].scale - output_zp = meta["output_qparams"][0].zp - output_min = meta["output_qparams"][0].qmin - output_max = meta["output_qparams"][0].qmax - - quantized_multiplier, quantized_shift = quantize_multiplier_aot( - (input_scale * weight_scale) / output_scale + Prepare bias multiplier and shift tensors for kernel call. + Returns (bias_multiplier_tensor, bias_shift_tensor) both sized [out_features]. + """ + if bias_params: + bias_multiplier = bias_params["bias_multiplier"] + bias_shift = bias_params["bias_shift"] + + # Convert to tensors of the right size + if isinstance(bias_multiplier, int): + bias_multiplier_tensor = torch.full( + [out_features], bias_multiplier, dtype=torch.int32 + ) + elif isinstance(bias_multiplier, list): + assert ( + len(bias_multiplier) == out_features + ), f"Bias multiplier size {len(bias_multiplier)} != out_features {out_features}" + bias_multiplier_tensor = torch.tensor( + bias_multiplier, dtype=torch.int32 + ) + elif isinstance(bias_multiplier, torch.Tensor): + assert ( + bias_multiplier.numel() == out_features + ), f"Bias multiplier size {bias_multiplier.numel()} != out_features {out_features}" + bias_multiplier_tensor = bias_multiplier + else: + raise TypeError( + f"Unsupported bias_multiplier type: {type(bias_multiplier)}" + ) + + if isinstance(bias_shift, int): + bias_shift_tensor = torch.full( + [out_features], bias_shift, dtype=torch.int32 + ) + elif isinstance(bias_shift, list): + assert ( + len(bias_shift) == out_features + ), f"Bias shift size {len(bias_shift)} != out_features {out_features}" + bias_shift_tensor = torch.tensor(bias_shift, dtype=torch.int32) + elif isinstance(bias_shift, torch.Tensor): + assert ( + bias_shift.numel() == out_features + ), f"Bias shift size {bias_shift.numel()} != out_features {out_features}" + bias_shift_tensor = bias_shift + else: + raise TypeError(f"Unsupported bias_shift type: {type(bias_shift)}") + + return bias_multiplier_tensor, bias_shift_tensor + else: + # No bias: return zero tensors of correct shape + return ( + torch.zeros([out_features], dtype=torch.int32), + torch.zeros([out_features], dtype=torch.int32), + ) + + def _extract_param_value(self, node_or_value): + """ + Extract a scalar value from a Node or a direct float/int. + """ + if isinstance(node_or_value, (float, int)): + return node_or_value + # If it's a tensor, get its scalar value if possible + if isinstance(node_or_value, torch.Tensor): + return node_or_value.item() if node_or_value.numel() == 1 else node_or_value + # If it's a Node, use get_param_tensor + if hasattr(node_or_value, "op"): + tensor = get_param_tensor(self._exported_program, node_or_value) + return tensor.item() if tensor.numel() == 1 else tensor + raise TypeError(f"Unsupported parameter type: {type(node_or_value)}") + + def _calculate_cmsis_scratch_size(self, weight_tensor) -> int: + """Calculate CMSIS-NN scratch buffer size for quantized linear operations. + + Source: CMSIS-NN arm_fully_connected_s8_get_buffer_size() returns filter_dims->w * sizeof(int32_t). + This buffer stores pre-computed kernel sums (weight row sums) - one int32_t per output feature. + Same buffer size applies to both per-tensor and per-channel quantization paths since both use + identical kernel sum optimization in the underlying matrix multiplication. + """ + try: + print(f"weight_tensor type: {type(weight_tensor)}, value: {weight_tensor}") + weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape + out_features = weight_shape[0] # filter_dims->w in CMSIS terms + + # CMSIS-NN implementation expects the following size + cmsis_buffer_size = out_features * 4 # sizeof(int32_t) + return cmsis_buffer_size + except Exception as e: + logger.error(f"Failed to calculate CMSIS scratch size: {e}") + return 2048 # Fallback + + def _create_scratch_buffer(self, graph, quantize_node: Node, weight_tensor): + cmsis_scratch = self._calculate_cmsis_scratch_size(weight_tensor) + + kernel_sum_header = 8 # sizeof(KernelSumHeader) + total_size = kernel_sum_header + cmsis_scratch + + logger.info( + f"Kernel sum header: {kernel_sum_header}, CMSIS buffer: {cmsis_scratch}, total: {total_size}" ) - # TODO: Add support for configuring the backend to support other extensions. - # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, - # so this should be optional. - weights = args[1] - weights_tensor = get_param_tensor(self.exported_program, weights) - bias_tensor = ( - get_param_tensor(self.exported_program, args[2]) if len(args) > 2 else None + return create_mutable_buffer( + self._exported_program, + name=f"b_cmsis_linear_scratch_{id(quantize_node)}", + data=torch.zeros((total_size,), dtype=torch.int8), ) - kernel_sum_tensor = self._compute_kernel_sum( - weights_tensor, bias_tensor, -input_zp, -weight_zp + + def _create_fused_node( + self, + graph, + quantize_node: Node, + quant_params: dict, + weight_params: dict, + bias_params: Optional[dict], + quantized_target, + ) -> Node: + """Generic fused node creation for any FC-like operation.""" + # Extract all parameters + input_tensor = quant_params["input_tensor"] + input_zp = quant_params["input_zero_point"] + input_multiplier = quant_params["input_multiplier"] + input_shift = quant_params["input_shift"] + weight_tensor = weight_params["weight_tensor"] + + weight_zp_node = self._create_constant_parameter_buffer( + graph, quantize_node, weight_params["weight_zero_point_data"], "weight_zp" ) - with node.graph.inserting_after(weights): - kernel_sum = create_constant_placeholder( - self.exported_program, - node.graph, - node.name + "_kernel_sum", - InputKind.PARAMETER, - kernel_sum_tensor, - ) + weight_mult_node = self._create_constant_parameter_buffer( + graph, quantize_node, weight_params["weight_multiplier_data"], "weight_mult" + ) + weight_shift_node = self._create_constant_parameter_buffer( + graph, quantize_node, weight_params["weight_shift_data"], "weight_shift" + ) + # Get dimensions + weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape + assert ( + len(weight_shape) == 2 + ), f"Weight tensor must be 2D, got shape {weight_shape}" + in_features = weight_shape[1] + out_features = weight_shape[0] - args = ( - args[0], - weights, - None, - kernel_sum, - -input_zp, - -weight_zp, - output_zp, - [quantized_multiplier], - [quantized_shift], - output_max, - output_min, + # Handle bias + bias_tensor = bias_params["bias_tensor"] if bias_params else None + bias_multiplier, bias_shift = self._prepare_bias_tensors( + bias_params, out_features ) + output_zp = quant_params["output_zero_point"] - return args + scratch_buffer = self._create_scratch_buffer( + graph, quantize_node, weight_tensor + ) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - modified = False - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if node.target != exir_ops.edge.aten.linear.default: - continue - if ( - node.meta.get("input_qparams", {}) == {} - or node.meta.get("output_qparams", {}) == {} + with graph.inserting_after(quantize_node): + fused = graph.create_node( + "call_function", + target=quantized_target, + args=( + input_tensor, + input_zp, + input_multiplier, + input_shift, + weight_tensor, + weight_zp_node, + weight_mult_node, + weight_shift_node, + bias_tensor, + bias_multiplier, + bias_shift, + scratch_buffer, + output_zp, + in_features, + out_features, + ), + kwargs={}, + ) + + transfer_metadata(fused, quantize_node, "QuantizedLinearFusionPass") + return fused + + def _mark_for_cleanup(self, nodes): + for node in nodes: + if node is not None: + self.nodes_to_erase.append(node) + + def _cleanup_nodes(self, graph): + cleanup_nodes(self.nodes_to_erase, graph) + self.nodes_to_erase.clear() + + def _extract_linear_pattern_with_validation(self, quantize_node: Node): + pattern_info = self._extract_linear_pattern(quantize_node) + if not pattern_info: + return None + # Optionally add more validation here if needed + return pattern_info + + def _trace_to_dequantize(self, node: Optional[Node], max_depth=3) -> Optional[Node]: + """Trace through transformations to find dequantize node.""" + current_node = node + depth = 0 + while current_node and depth < max_depth: + if is_dequant_node(current_node): + return current_node + if current_node.op == "call_function" and current_node.target in { + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view_copy.default, + }: + if current_node.args: + current_node = current_node.args[0] + depth += 1 + continue + break + return None + + def _fuse_quantized_linear_patterns( + self, graph_module: torch.fx.GraphModule + ) -> int: + fusion_count = 0 + graph = graph_module.graph + for node in list(graph.nodes): + if not ( + node.op == "call_function" and "quantize_per_tensor" in str(node.target) ): continue + pattern_info = self._extract_linear_pattern_with_validation(node) + if not pattern_info: + continue - args = self._get_linear_replacement(node.args, node.meta, node) - with graph_module.graph.inserting_before(node): - cortex_m_linear = graph_module.graph.create_node( - "call_function", - target=exir_ops.edge.cortex_m.quantized_linear.default, - args=args, - kwargs={}, - ) + ( + quantize_node, + fc_node, + input_dq_node, + weight_dq_node, + bias_dq_node, + op_name, + ) = pattern_info - node.replace_all_uses_with(cortex_m_linear) - graph_module.graph.erase_node(node) + # Get quantized target for this FC operation + quantized_target = self.SUPPORTED_OPS_MAPPING.get(fc_node.target) + if not quantized_target: + logger.warning(f"No quantized target found for {fc_node.target}") + continue - modified = True + logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = super().call(graph_module).graph_module + try: + input_params = self._extract_input_quantization_parameters( + input_dq_node + ) + if not input_params: + logger.error( + "Quantization parameter extraction failed for node: %s", node + ) + return None + output_params = self._extract_output_quantization_parameters( + quantize_node + ) + if not output_params: + logger.error( + "Output quantization parameter extraction failed for node: %s", + node, + ) + return None + quant_params = {**input_params, **output_params} + logger.info(f"Quantization parameters: {quant_params}") - return PassResult(graph_module, modified) + weight_params = self._extract_weight_parameters(weight_dq_node) + if not weight_params: + continue + bias_params = self._extract_bias_parameters(bias_dq_node) + if bias_dq_node and not bias_params: + continue + fused_node = self._create_fused_node( + graph, + quantize_node, + quant_params, + weight_params, + bias_params, + quantized_target, + ) + logger.info(f"Created fused {op_name} node: {fused_node}") + + quantize_node.replace_all_uses_with(fused_node) + self._mark_for_cleanup( + [ + quantize_node, + fc_node, + input_dq_node, + weight_dq_node, + bias_dq_node, + ] + ) + fusion_count += 1 + logger.info(f"✅ Successfully fused {op_name} operation {fusion_count}") + except Exception as e: + logger.error( + f"Failed to fuse {op_name} pattern for {fc_node.name}: {e}" + ) + continue + self._cleanup_nodes(graph) + return fusion_count diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index df35c8d626a..888155dcfd0 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -5,17 +5,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict +import logging +from typing import Set + +import executorch.backends.cortex_m.ops.operators # noqa +import torch from executorch.backends.cortex_m.passes.passes_utils import ( + extract_scalar_value, quantize_multiplier_aot, SHIFT_INT8, ) - from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -from torch.fx.node import Argument +from executorch.exir.pass_base import ExportPass +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger("quant_op_fusion_pass") +logger.setLevel(logging.INFO) class QuantizedOpFusionPass(ExportPass): @@ -29,58 +35,234 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ - def _get_add_replacement(self, args, meta): + # Generic operation mapping + SUPPORTED_OPS_MAPPING = { + exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, + # Future binary ops to be added here: + } - # Extract values - scale1 = meta["input_qparams"][0].scale - zero_point1 = meta["input_qparams"][0].zp - scale2 = meta["input_qparams"][1].scale - zero_point2 = meta["input_qparams"][1].zp - output_scale = meta["output_qparams"][0].scale - output_zero_point = meta["output_qparams"][0].zp + def __init__(self): + super().__init__() - # AoT COMPUTATION: Calculate multipliers and shifts - max_scale_2x = 2 * max(scale1, scale2) + def _get_dequant_targets(self) -> Set: + """Support both decomposed and cortex_m dequant targets for flexible pass ordering.""" + return { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cortex_m.dequantize_per_tensor.default, + } - input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) - input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) - output_mult, output_shift = quantize_multiplier_aot( - max_scale_2x / (output_scale * (1 << SHIFT_INT8)) + def _get_quant_targets(self) -> Set: + """Support both decomposed and cortex_m quant targets for flexible pass ordering.""" + return { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.cortex_m.quantize_per_tensor.default, + } + + def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: + """Check if node is a supported binary operation.""" + is_supported = ( + node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING ) + if not is_supported: + return False + + shape1 = node.args[0].meta["val"].shape + shape2 = node.args[1].meta["val"].shape + is_broadcast = shape1 != shape2 + return not is_broadcast - args = ( - args[0], - zero_point1, - input1_mult, - input1_shift, - args[1], - zero_point2, - input2_mult, - input2_shift, - output_zero_point, - output_mult, - output_shift, + def _is_dequant_node(self, node: torch.fx.Node) -> bool: + """Check if node is a dequantize operation.""" + return ( + hasattr(node, "op") + and node.op == "call_function" + and node.target in self._get_dequant_targets() ) - return exir_ops.edge.cortex_m.quantized_add.default, args + def _is_quant_node(self, node: torch.fx.Node) -> bool: + """Check if node is a quantize operation.""" + return ( + hasattr(node, "op") + and node.op == "call_function" + and node.target in self._get_quant_targets() + ) - def call_operator( + def _transfer_metadata( self, - op: EdgeOpOverload, - args: tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if ( - meta.data.get("input_qparams", {}) == {} - or meta.data.get("output_qparams", {}) == {} - ): - return super().call_operator(op, args, {}, meta) - - match op: - case exir_ops.edge.aten.add.Tensor: - op, args = self._get_add_replacement(args, meta) - case _: - pass - - return super().call_operator(op, args, {}, meta) + new_node: torch.fx.Node, + source_node: torch.fx.Node, + pass_name: str = "QuantizedOpFusionPass", + ) -> None: + """Metadata transfer with proper provenance tracking.""" + if hasattr(source_node, "meta") and source_node.meta: + new_node.meta = source_node.meta.copy() + + if "from_node" in new_node.meta: + from_node_list = new_node.meta.get("from_node", []).copy() + from_node_list.append( + {"source": source_node.name, "pass": pass_name, "op": "fuse"} + ) + new_node.meta["from_node"] = from_node_list + + # Copy essential fields + for field in ["tensor_meta", "stack_trace"]: + if field in source_node.meta: + new_node.meta[field] = source_node.meta[field] + + def _normalize_to_cortex_m_targets(self, graph_module: torch.fx.GraphModule) -> int: + """Convert decomposed targets to cortex_m equivalents for consistent handling.""" + target_mapping = { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.cortex_m.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.cortex_m.quantize_per_tensor.default, + } + + normalization_count = 0 + for node in list(graph_module.graph.nodes): + if node.op == "call_function" and node.target in target_mapping: + logger.info(f"Normalizing {node.target} to cortex_m equivalent") + node.target = target_mapping[node.target] + normalization_count += 1 + + return normalization_count + + def _fuse_quantized_binary_patterns( + self, graph_module: torch.fx.GraphModule + ) -> int: + """Generic fusion for quantized binary operation patterns.""" + fusion_count = 0 + nodes_to_erase = [] + + for node in list(graph_module.graph.nodes): + if not self._is_quant_node(node): + continue + + quantize_node = node + if not quantize_node.args: + continue + + binary_op_node = quantize_node.args[0] + if not self._is_supported_binary_op(binary_op_node): + continue + + if len(binary_op_node.args) < 2: + continue + + dequant_node1, dequant_node2 = binary_op_node.args[:2] + if not ( + self._is_dequant_node(dequant_node1) + and self._is_dequant_node(dequant_node2) + ): + continue + + # Get the target quantized operation + quantized_target = self.SUPPORTED_OPS_MAPPING[binary_op_node.target] + # Extract op name (e.g., 'Tensor' -> 'add') + op_name = str(binary_op_node.target).split(".")[-1] + logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") + + try: + # Extract values + int8_tensor1, scale1, zero_point1 = dequant_node1.args[:3] + int8_tensor2, scale2, zero_point2 = dequant_node2.args[:3] + output_scale, output_zero_point = quantize_node.args[1:3] + + # Convert to Python floats + scale1_val = extract_scalar_value(scale1) + scale2_val = extract_scalar_value(scale2) + output_scale_val = extract_scalar_value(output_scale) + zp1_val = int(extract_scalar_value(zero_point1)) + zp2_val = int(extract_scalar_value(zero_point2)) + output_zp_val = int(extract_scalar_value(output_zero_point)) + + max_scale_2x = 2 * max(scale1_val, scale2_val) + # AoT COMPUTATION: Calculate multipliers and shifts + + input1_mult, input1_shift = quantize_multiplier_aot( + scale1_val / max_scale_2x + ) + input2_mult, input2_shift = quantize_multiplier_aot( + scale2_val / max_scale_2x + ) + output_mult, output_shift = quantize_multiplier_aot( + max_scale_2x / (output_scale_val * (1 << SHIFT_INT8)) + ) + + logger.info("AoT computed parameters:") + logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") + logger.info(f" Input2: mult={input2_mult}, shift={input2_shift}") + logger.info(f" Output: mult={output_mult}, shift={output_shift}") + + with graph_module.graph.inserting_after(quantize_node): + fused = graph_module.graph.create_node( + "call_function", + target=quantized_target, + args=( + int8_tensor1, + zp1_val, + input1_mult, + input1_shift, + int8_tensor2, + zp2_val, + input2_mult, + input2_shift, + output_zp_val, + output_mult, + output_shift, + ), + kwargs={}, + ) + + # metadata transfer + self._transfer_metadata(fused, quantize_node) + + logger.info(f"✅ Created fused quantized_{op_name} node: {fused}") + + # Replace all uses + quantize_node.replace_all_uses_with(fused) + binary_op_node.replace_all_uses_with(fused) + dequant_node1.replace_all_uses_with(fused) + dequant_node2.replace_all_uses_with(fused) + + nodes_to_erase.extend( + [quantize_node, binary_op_node, dequant_node1, dequant_node2] + ) + fusion_count += 1 + logger.info(f"Pattern fused, total so far: {fusion_count}") + + except Exception as e: + logger.info(f"❌ Error during AoT computation: {e}") + logger.info(" Skipping fusion for this pattern") + continue + + for old_node in reversed(nodes_to_erase): + if old_node in graph_module.graph.nodes and len(old_node.users) == 0: + logger.info(f"🗑️ Erasing node: {old_node}") + graph_module.graph.erase_node(old_node) + + return fusion_count + + def call(self, graph_module: torch.fx.GraphModule): + logger.info("QuantizedOpFusionPass.call() started") + + # Normalize targets for flexible pass ordering + normalization_count = self._normalize_to_cortex_m_targets(graph_module) + + # Generic fusion for supported binary operations + fusion_count = self._fuse_quantized_binary_patterns(graph_module) + + total_changes = normalization_count + fusion_count + logger.info(f"Total changes: {total_changes}") + + if total_changes > 0: + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + graph_module.recompile() + + logger.debug("=== AFTER FUSION: All nodes in the graph ===") + for i, node in enumerate(graph_module.graph.nodes): + logger.debug(f"Node {i}: op={node.op}, target={node.target}") + if "quantized_" in str(node.target) and "add" in str(node.target): + logger.debug(" ⭐ FOUND QUANTIZED BINARY OP NODE! ⭐") + logger.debug("=== END DEBUG ===") + + return PassResult(graph_module, total_changes > 0) diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 458d5361347..4389b463076 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -59,6 +59,17 @@ class CortexMTensorAdd(Model): } +class CortexMTensorAddBroadcast(Model): + # TODO: Quantize and accelerate broadcasted adds + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + } + + class CortexMAlphaAdd(ModelAlpha): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, @@ -115,15 +126,15 @@ class CortexMAlphaAdd(ModelAlpha): (torch.rand(2, 2) * 10, torch.rand(2, 2)), ), "broadcast_1": McuTestCase( - CortexMTensorAdd(), + CortexMTensorAddBroadcast(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - CortexMTensorAdd(), + CortexMTensorAddBroadcast(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - CortexMTensorAdd(), + CortexMTensorAddBroadcast(), ( ramp_tensor(-2, 2, (2, 1, 2, 1)), ramp_tensor(-5, 5, (1, 2, 1, 2)), @@ -172,18 +183,6 @@ def test_dialect_add(test_case): "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, ), - "broadcast_1": ( - " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", - RuntimeError, - ), - "broadcast_2": ( - " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", - RuntimeError, - ), - "broadcast_3": ( - " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", - RuntimeError, - ), "alpha": ( "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError, diff --git a/backends/cortex_m/test/ops/test_linear.py b/backends/cortex_m/test/ops/test_linear.py index e81daa7e83e..4ab5ca99f15 100644 --- a/backends/cortex_m/test/ops/test_linear.py +++ b/backends/cortex_m/test/ops/test_linear.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. +import pytest import torch -from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import ( CortexMTester, McuTestCase, @@ -13,9 +13,12 @@ ) -class CortexMLinear(torch.nn.Module): +class CortexMMm(torch.nn.Module): + def forward(self, x, y): + return torch.mm(x, y) + ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_mm_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, } @@ -26,45 +29,32 @@ class CortexMLinear(torch.nn.Module): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } - def __init__(self, *args, **kwargs): - super().__init__() - self.linear = torch.nn.Linear(*args, bias=False) - self.linear.weight.data.fill_(1.0) - - def forward(self, x): - return self.linear(x) +class CortexMBmm(torch.nn.Module): + def forward(self, x, y): + return torch.bmm(x, y) -class CortexMLinearX3(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_linear_default": 3, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, + "executorch_exir_dialects_edge__ops_aten_bmm_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, } ops_after_transforms = { - "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } - def __init__(self, *args, **kwargs): - super().__init__() - self.linear = torch.nn.Linear(*args, bias=False) - self.linear.weight.data.fill_(1.0) - - def forward(self, x): - x = self.linear(x) - x = self.linear(x) - x = self.linear(x) - return x +class CortexMAddmm(torch.nn.Module): + def forward(self, x, y, z, alpha=None, beta=None): + return torch.addmm(beta, x, alpha, y, z) -class CortexMLinearBias(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, } ops_after_transforms = { @@ -73,23 +63,90 @@ class CortexMLinearBias(torch.nn.Module): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + +class CortexMAt(CortexMMm): + def forward(self, x, y): + return x @ y + + +class CortexMMatmul(CortexMMm): + def forward(self, x, y): + return torch.matmul(x, y) + + +class CortexMLinear(CortexMMatmul): + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + + def forward(self, x): + return self.linear(x) + + +class CortexMLinearBias(CortexMAddmm): def __init__(self, *args, **kwargs): super().__init__() self.linear = torch.nn.Linear(*args, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): - return self.linear(x) + return self.relu(self.linear(x)) test_cases = { + "mm": McuTestCase( + model=CortexMMm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "bmm": McuTestCase( + model=CortexMBmm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16, 16)), + ramp_tensor(0, 10, (1, 16, 16)), + ), + ), + "addmm": McuTestCase( + model=CortexMAddmm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ramp_tensor(0, 10, (16, 16)), + 2, + 4, + ), + ), + "addmm_scalars": McuTestCase( + model=CortexMAddmm(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "@-operator": McuTestCase( + model=CortexMAt(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), + "matmul": McuTestCase( + model=CortexMMatmul(), + example_inputs=( + ramp_tensor(0, 10, (1, 16)), + ramp_tensor(0, 10, (16, 16)), + ), + ), "linear_rank1": McuTestCase( - model=CortexMLinear(1, 2), - example_inputs=(torch.Tensor([1]),), + model=CortexMLinear(2, 3), + example_inputs=(ramp_tensor(-1, 1, (2,)),), ), "linear_rank2_pos": McuTestCase( - model=CortexMLinear(1, 2), - example_inputs=(ramp_tensor(-1, 1, (1, 1)),), + model=CortexMLinear(8, 3), + example_inputs=(ramp_tensor(0, 10, (2, 8)),), ), "linear_rank3_neg": McuTestCase( model=CortexMLinear(5, 3), @@ -107,24 +164,22 @@ def forward(self, x): model=CortexMLinearBias(61, 37), example_inputs=(ramp_tensor(0, 10, (8, 61)),), ), - "linear_x3": McuTestCase( - model=CortexMLinearX3(4, 4), - example_inputs=(ramp_tensor(0, 10, (2, 4)),), - ), } -@parametrize("test_case", test_cases) +@pytest.mark.skip( + reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." +) def test_dialect_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( - test_case.model.ops_before_transforms, - test_case.model.ops_after_transforms, - qtol=1, + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms ) -@parametrize("test_case", test_cases) +@pytest.mark.skip( + reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." +) def test_implementation_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation(qtol=1) + tester.test_implementation()