diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index c7e6cc8a389..10fad9b6a0b 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -22,6 +22,7 @@ using Tensor = torch::executor::Tensor; using ScalarType = executorch::aten::ScalarType; using Scalar = torch::executor::Scalar; using Error = executorch::runtime::Error; +using IntArrayRef = executorch::aten::ArrayRef; // From arm_nn_math_types.h #define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL)) diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index d1ccb6d0d45..015fa805134 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -1,12 +1,12 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "cmsis_scratch_buffer_context.h" #include "cortex_m_ops_common.h" extern "C" { @@ -20,151 +20,90 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext; Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, - const Scalar& input_zero_point, - const Scalar& input_multiplier, - const Scalar& input_shift, const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features, + const torch::executor::optional& kernel_sum, + const Scalar& input_offset, + const Scalar& filter_offset, + const Scalar& output_offset, + const IntArrayRef requantize_multipliers, + const IntArrayRef requantize_shifts, + const Scalar& activation_max, + const Scalar& activation_min, Tensor& out) { ET_LOG(Info, "quantized_linear_out: called"); - validate_cmsis_nn_tensor_requirements(input, weights, out); - - ET_CHECK_MSG( - scratch_buffer.scalar_type() == ScalarType::Char, - "Scratch buffer must be int8"); - - const int32_t batch_size = input.size(0); - const int32_t in_feat = static_cast(in_features.to()); - const int32_t out_feat = static_cast(out_features.to()); - const int32_t input_zp = static_cast(input_zero_point.to()); - const int32_t output_zp = - static_cast(output_zero_point.to()); - const bool is_per_channel = (weight_zero_point.numel() > 1); const int8_t* input_data = input.const_data_ptr(); const int8_t* weight_data = weights.const_data_ptr(); const int32_t* bias_data = bias.has_value() ? bias.value().const_data_ptr() : nullptr; + int32_t* kernel_sum_data = + kernel_sum.has_value() ? kernel_sum.value().data_ptr() : nullptr; int8_t* output_data = out.mutable_data_ptr(); - const int32_t* weight_zp_data = weight_zero_point.const_data_ptr(); - const int32_t* weight_mult_data = weight_multiplier.const_data_ptr(); - const int32_t* weight_shift_data = weight_shift.const_data_ptr(); - - if (!validate_per_channel_quant_params( - weight_mult_data, weight_shift_data, out_feat)) { - context.fail(Error::InvalidArgument); - return out; - } - - // Initialize scratch buffer context (validates early) - CMSISScratchBufferContext scratch_ctx( - const_cast(scratch_buffer), weights, weight_zero_point, bias); - scratch_ctx.compute_kernel_sums_if_needed(); - cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx(); + cmsis_nn_context ctx; + ctx.size = 0; // Not used in CMSIS-NN + ctx.buf = kernel_sum_data; // Setup CMSIS-NN parameters cmsis_nn_fc_params fc_params; - fc_params.input_offset = -input_zp; - fc_params.output_offset = output_zp; - fc_params.activation.min = std::numeric_limits::min(); - fc_params.activation.max = std::numeric_limits::max(); - - cmsis_nn_dims input_dims = {1, 1, 1, in_feat}; + fc_params.input_offset = static_cast(input_offset.to()); + fc_params.filter_offset = static_cast(filter_offset.to()); + fc_params.output_offset = static_cast(output_offset.to()); + fc_params.activation.min = static_cast(activation_min.to()); + fc_params.activation.max = static_cast(activation_max.to()); + + cmsis_nn_per_tensor_quant_params per_tensor_quant_params; + per_tensor_quant_params.multiplier = + static_cast(requantize_multipliers.at(0)); + per_tensor_quant_params.shift = static_cast(requantize_shifts.at(0)); + + auto in_feat = input.size(input.dim() - 1); + auto out_feat = out.size(out.dim() - 1); + auto batches = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + batches *= input.size(i); + } + ET_LOG( + Info, + "in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d", + in_feat, + out_feat, + batches, + kernel_sum.has_value() ? kernel_sum.value().numel() : 0); + ET_LOG( + Info, + "kernel_sum[0]: %d, kernel_sum[1]: %d", + kernel_sum_data != nullptr ? kernel_sum_data[0] : -1, + kernel_sum_data != nullptr ? kernel_sum_data[1] : -1); + cmsis_nn_dims input_dims = {batches, 1, 1, in_feat}; cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat}; cmsis_nn_dims bias_dims = {1, 1, 1, out_feat}; - cmsis_nn_dims output_dims = {1, 1, 1, out_feat}; - - arm_cmsis_nn_status status; - for (int32_t b = 0; b < batch_size; b++) { - const int8_t* batch_input = input_data + b * in_feat; - int8_t* batch_output = output_data + b * out_feat; - - ET_CHECK_MSG( - batch_input != nullptr && weight_data != nullptr, - "Null input pointers"); - ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions"); - - if (is_per_channel) { - cmsis_nn_per_channel_quant_params per_channel_quant_params; - per_channel_quant_params.multiplier = - const_cast(weight_mult_data); - per_channel_quant_params.shift = const_cast(weight_shift_data); - - status = arm_fully_connected_per_channel_s8( - &ctx, - &fc_params, - &per_channel_quant_params, - &input_dims, - batch_input, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - batch_output); - } else { - fc_params.filter_offset = -weight_zp_data[0]; - cmsis_nn_per_tensor_quant_params per_tensor_quant_params; - per_tensor_quant_params.multiplier = weight_mult_data[0]; - per_tensor_quant_params.shift = weight_shift_data[0]; - - status = arm_fully_connected_s8( - &ctx, - &fc_params, - &per_tensor_quant_params, - &input_dims, - batch_input, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - batch_output); - } - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_linear_out: CMSIS-NN failed with status [%d]", - status); - context.fail(Error::Internal); - return out; - } + cmsis_nn_dims output_dims = {batches, 1, 1, out_feat}; + + arm_cmsis_nn_status status = arm_fully_connected_s8( + &ctx, + &fc_params, + &per_tensor_quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_linear_out: CMSIS-NN failed with status [%d]", + status); + context.fail(Error::Internal); + return out; } - return out; -} -// Functional variant (stub, not used at runtime) -Tensor quantized_linear( - KernelRuntimeContext& context, - const Tensor& input, - const Scalar& input_zero_point, - const Scalar& input_multiplier, - const Scalar& input_shift, - const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, - const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features) { - ET_LOG(Info, "quantized_linear: called"); - assert(false); - return const_cast(input); + return out; } } // namespace native diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 286f938ccc9..b8abfb9bde4 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod + import torch from executorch.backends.cortex_m.passes.passes_utils import ( requantize_cmsis, @@ -170,210 +172,110 @@ def quantized_add_impl( # QUANTIZED LINEAR OPERATION DEFINITION # =================================================================== - -def _check_per_tensor_or_per_channel(param: torch.Tensor, out_channels: int, name: str): - assert param.numel() in [ - 1, - out_channels, - ], f"{name} must be per-tensor (1) or per-channel ({out_channels}), got {param.numel()}" - - lib.define( "quantized_linear.out(" - "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor input, " "Tensor weights, " - "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " - "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " - "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, " - "*, Tensor(a!) out) -> Tensor(a!)" + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" ) # Define functional variant (non-out version) lib.define( "quantized_linear(" - "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor input, " "Tensor weights, " - "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " - "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " - "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features" + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min" ") -> Tensor" ) -# Fake meta function for shape inference (out variant) -@register_fake("cortex_m::quantized_linear.out") -def quantized_linear_out_meta( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, - out: torch.Tensor, -) -> torch.Tensor: - # Validate dimensions - batch_size = input.shape[0] - out_channels = weights.shape[0] - - # Validate weight quantization parameters dimensions - _check_per_tensor_or_per_channel( - weight_zero_point, out_channels, "weight_zero_point" - ) - _check_per_tensor_or_per_channel( - weight_multiplier, out_channels, "weight_multiplier" - ) - _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") - - # Validate output shape - expected_shape = (batch_size, out_channels) - assert ( - out.shape == expected_shape - ), f"Output shape {out.shape} must be {expected_shape}" - - return out - - # Fake meta function for shape inference (functional variant) @register_fake("cortex_m::quantized_linear") def quantized_linear_meta( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, -) -> torch.Tensor: - # Validate dimensions (same as out variant) - batch_size = input.shape[0] - out_channels = weights.shape[0] - - # Validate weight quantization parameters dimensions - _check_per_tensor_or_per_channel( - weight_zero_point, out_channels, "weight_zero_point" - ) - _check_per_tensor_or_per_channel( - weight_multiplier, out_channels, "weight_multiplier" - ) - _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") - - # Calculate output shape for functional variant - output_shape = (batch_size, out_channels) - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - - -@impl(lib, "quantized_linear.out", "CompositeExplicitAutograd") -def quantized_linear_out_impl( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, - *, - out: torch.Tensor, + input, + weights, + bias, + kernel_sum, + input_offset, + filter_offset, + output_offset, + requantize_multipliers, + requantize_shifts, + activation_max, + activation_min, ) -> torch.Tensor: - """ - Fallback implementation for meta/testing - Note: This won't be called at runtime, only during compilation - """ - # Per-channel dequantization - input_scale = input_multiplier * (2.0 ** (-input_shift)) - input_fp = (input.float() - input_zero_point) * input_scale - if weight_zero_point.numel() == 1: - # Per-tensor - weight_scale = weight_multiplier.item() * (2.0 ** (-weight_shift.item())) - weights_fp = (weights.float() - weight_zero_point.item()) * weight_scale - else: - # Per-channel - weight_scales = weight_multiplier.float() * (2.0 ** (-weight_shift.float())) - weights_fp = ( - weights.float() - weight_zero_point.float().unsqueeze(1) - ) * weight_scales.unsqueeze(1) - bias_fp = None - if bias is not None: - bias_scales = bias_multiplier.float() * (2.0 ** (-bias_shift.float())) - bias_fp = bias.float() * bias_scales - - result_fp = torch.nn.functional.linear(input_fp, weights_fp, bias_fp) - else: - result_fp = torch.nn.functional.linear(input_fp, weights_fp) - result_quantized = torch.clamp( - torch.round(result_fp + output_zero_point), -128, 127 - ).to(torch.int8) - out.copy_(result_quantized) - return out + shape = (*input.shape[:-1], weights.shape[0]) + return torch.empty(shape, dtype=input.dtype, device=input.device) # Functional variant implementation @impl(lib, "quantized_linear", "CompositeExplicitAutograd") def quantized_linear_impl( input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, + kernel_sum: torch.Tensor, + input_offset: int, + filter_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_max: int, + activation_min: int, ) -> torch.Tensor: """ Functional variant - creates output tensor and calls out variant """ - # Create output tensor - batch_size = input.shape[0] - output = torch.empty( - (batch_size, out_features), dtype=torch.int8, device=input.device - ) - return quantized_linear_out_impl( - input, - input_zero_point, - input_multiplier, - input_shift, - weights, - weight_zero_point, - weight_multiplier, - weight_shift, - bias, - bias_multiplier, - bias_shift, - scratch_buffer, - output_zero_point, - in_features, - out_features, - out=output, + + # Leaving both implementations for debugging purposes. + compute_using_kernel_sum = True + + if compute_using_kernel_sum: + weights_int32 = weights.to(torch.int32) + + input_int32 = input.to(torch.int32) + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + + lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset + output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + else: + weights_int32 = weights.to(torch.int32) + filter_offset + + input_int32 = input.to(torch.int32) + input_offset + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + + output = torch.mm(input_reshaped, weights_int32.T) + if bias is not None: + output = output + bias + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + + output = requantize_cmsis( + output_reshaped, requantize_multipliers[0], requantize_shifts[0] ) + output += output_offset + output = torch.clamp(output, activation_min, activation_max).to(torch.int8) + return output diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 81ebeafc778..98d8df8797e 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -23,14 +23,8 @@ - arg_meta: null kernel_name: cortex_m::quantized_add_out -- func: cortex_m::quantized_linear(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features) -> Tensor +- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_linear - -- func: cortex_m::quantized_linear.out(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, *, Tensor(a!) out) -> Tensor(a!) - variants: function - kernels: - - arg_meta: null - kernel_name: cortex_m::quantized_linear_out + kernel_name: cortex_m::quantized_linear_out \ No newline at end of file diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 02429cc68e0..10fb358c70e 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -4,7 +4,11 @@ # LICENSE file in the root directory of this source tree. -from executorch.backends.arm._passes import ScalarsToAttributePass +from executorch.backends.arm._passes import ( + DecorateFp32toInt32CastingPass, + FoldAndAnnotateQParamsPass, + ScalarsToAttributePass, +) from executorch.backends.cortex_m.passes import ( QuantizedLinearFusionPass, QuantizedOpFusionPass, @@ -20,10 +24,12 @@ class CortexMPassManager(XNNPACKPassManager): pass_list: list[ExportPass] = [ + FoldAndAnnotateQParamsPass, ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, QuantizedLinearFusionPass, + DecorateFp32toInt32CastingPass, ] pass_list_transform_for_annotation: list[ExportPass] = [ diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py index 11a49beb2f4..f921f5ce621 100644 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_linear_fusion_pass.py @@ -5,642 +5,147 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -from typing import Optional import executorch.backends.cortex_m.ops.operators # noqa + import torch import torch.fx +from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot -from executorch.backends.cortex_m.passes.passes_utils import ( - cleanup_nodes, - is_dequant_node, - quantize_multiplier_aot, - transfer_metadata, +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, ) -from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor - from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx import Node +from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -logger = logging.getLogger("quantized_linear_fusion_pass") -logger.setLevel(logging.INFO) - class QuantizedLinearFusionPass(XNNPACKPass): """ Cortex-M backend pass that fuses quantized linear-like patterns. Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize Into: cortex_m.quantized_linear.default with direct parameters. - """ - - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.addmm.default: exir_ops.edge.cortex_m.quantized_linear.default, - exir_ops.edge.aten.mm.default: exir_ops.edge.cortex_m.quantized_linear.default, - } - - requires_exported_program = True - - def __init__(self, exported_program: ExportedProgram): - super().__init__(exported_program) - self.nodes_to_erase = [] - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - logger.info("Starting QuantizedLinearFusionPass") - assert id(self._exported_program.graph_module.graph) == id( - graph_module.graph - ), "QuantizedLinearFusionPass requires same graph instance" - - try: - fusion_count = self._fuse_quantized_linear_patterns(graph_module) - if fusion_count > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - logger.info(f"Linear fusion completed: {fusion_count} patterns fused") - return PassResult(graph_module, fusion_count > 0) - except Exception as e: - logger.error(f"Error in QuantizedLinearFusionPass: {e}") - raise e - - def _extract_linear_pattern(self, quantize_node: Node): - if not quantize_node.args: - return None - fc_node = quantize_node.args[0] - if not ( - fc_node.op == "call_function" - and fc_node.target in self.SUPPORTED_OPS_MAPPING - ): - return None - - op_name = str(fc_node.target).split(".")[-1] - - if "addmm" in str(fc_node.target): - input_dq_node = fc_node.args[1] - else: - input_dq_node = fc_node.args[0] - if not is_dequant_node(input_dq_node): - logger.info("input_dq_node is not a dequant node") - return None - weight_dq_node, bias_dq_node = self._extract_weight_bias_from_fc_op(fc_node) - if not weight_dq_node: - logger.info("No weight, bias dequantize node found") - return None - return ( - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - op_name, - ) - - def _extract_weight_bias_from_fc_op(self, fc_node: Node): - """Generic extraction for FC-like operations.""" - - if "addmm" in str(fc_node.target): - if len(fc_node.args) >= 3: - bias_arg = fc_node.args[0] - weight_arg = fc_node.args[2] - weight_dq_node = self._trace_to_dequantize(weight_arg) - logger.info( - f"weight_arg: {weight_arg}, traced weight_dq_node: {weight_dq_node}" - ) - - if weight_dq_node is None: - logger.info("No weight dequantize node found ") - - # For bias, try to trace to dequantize but allow None (no-bias case) - bias_dq_node = self._trace_to_dequantize(bias_arg) - if bias_dq_node is None: - logger.info("No bias dequantize node found - likely no-bias linear") - return weight_dq_node, bias_dq_node - elif any(op in str(fc_node.target) for op in ["linear", "mm"]): - if len(fc_node.args) >= 2: - weight_arg = fc_node.args[1] - bias_arg = fc_node.args[2] if len(fc_node.args) > 2 else None - weight_dq_node = self._trace_to_dequantize(weight_arg) - bias_dq_node = self._trace_to_dequantize(bias_arg) if bias_arg else None - return weight_dq_node, bias_dq_node - return None, None - - def _extract_input_quantization_parameters( - self, input_dq_node: Node - ) -> Optional[dict]: - """Extract input quantization parameters from dequantize node.""" - try: - # Find the quantize operation that produces the int8 tensor - input_quantize_node = None - if hasattr(input_dq_node, "args") and input_dq_node.args: - quantize_candidate = input_dq_node.args[0] - if getattr( - quantize_candidate, "op", None - ) == "call_function" and "quantize" in str( - getattr(quantize_candidate, "target", "") - ): - input_quantize_node = quantize_candidate - - if not input_quantize_node: - logger.error("Could not find quantize node for input!") - return None - - # Extract input quantization parameters - input_scale = self._extract_param_value(input_dq_node.args[1]) - input_zero_point = int(self._extract_param_value(input_dq_node.args[2])) - input_multiplier, input_shift = quantize_multiplier_aot(input_scale) - - return { - "input_scale": input_scale, - "input_zero_point": input_zero_point, - "input_multiplier": input_multiplier, - "input_shift": input_shift, - "input_tensor": input_quantize_node, - } - except Exception as e: - logger.error(f"Failed to extract input quantization parameters: {e}") - return None - - def _extract_output_quantization_parameters( - self, quantize_node: Node - ) -> Optional[dict]: - """Extract output quantization parameters from quantize node.""" - try: - output_scale = self._extract_param_value(quantize_node.args[1]) - output_zero_point = int(self._extract_param_value(quantize_node.args[2])) - return { - "output_scale": output_scale, - "output_zero_point": output_zero_point, - } - except Exception as e: - logger.error(f"Failed to extract output quantization parameters: {e}") - return None + Note that the optimzed implementation makes use of the following rewrite: - def _create_constant_parameter_buffer( - self, graph, quantize_node: Node, data: torch.Tensor, name: str - ): - """Create a parameter buffer""" - buffer_name = f"{name}_{id(quantize_node)}" + Let + - yi be the output activations (y1, ... yn) + - xj be the input activations (x1, ... xm) + - wij be the weights (w11, ... wnm) + - a be the input offset + - b be the weight offset + - ci be the bias - setattr(graph.owning_module, buffer_name, data) - - # Create a get_attr node - with graph.inserting_before(quantize_node): - buffer_node = graph.create_node( - op="get_attr", target=buffer_name, name=buffer_name - ) - - # Set metadata - buffer_node.meta["val"] = data - - return buffer_node - - def _extract_weight_parameters(self, weight_dq_node: Node) -> Optional[dict]: - try: - weight_tensor = weight_dq_node.args[0] - weight_scale = weight_dq_node.args[1] - weight_zero_point = ( - weight_dq_node.args[2] if len(weight_dq_node.args) > 2 else None - ) - - weight_scale_data = self._extract_param_value(weight_scale) - weight_zp_data = ( - self._extract_param_value(weight_zero_point) - if weight_zero_point - else None - ) - - # Get actual tensor data to determine output features - weight_tensor_data = get_param_tensor(self._exported_program, weight_tensor) - out_features = weight_tensor_data.shape[0] - - # Handle both per-tensor and per-channel - if ( - isinstance(weight_scale_data, torch.Tensor) - and weight_scale_data.numel() > 1 - ): - # Per-channel: ensure we have the right number of elements - assert ( - weight_scale_data.numel() == out_features - ), f"Scale size {weight_scale_data.numel()} != out_features {out_features}" - - multipliers = [] - shifts = [] - for scale in weight_scale_data: - mult, shift = quantize_multiplier_aot(scale.item()) - multipliers.append(mult) - shifts.append(shift) - - weight_multiplier = torch.tensor(multipliers, dtype=torch.int32) - weight_shift = torch.tensor(shifts, dtype=torch.int32) - weight_zp_tensor = ( - weight_zp_data.int() - if weight_zp_data is not None - else torch.zeros(out_features, dtype=torch.int32) - ) - else: - # Per-tensor: create tensors with correct size for output features - scale_val = ( - weight_scale_data.item() - if isinstance(weight_scale_data, torch.Tensor) - else weight_scale_data - ) - mult, shift = quantize_multiplier_aot(scale_val) - - # Create tensors sized for out_features (not single element) - weight_multiplier = torch.full((out_features,), mult, dtype=torch.int32) - weight_shift = torch.full((out_features,), shift, dtype=torch.int32) - weight_zp_tensor = torch.full( - (out_features,), - weight_zp_data if weight_zp_data else 0, - dtype=torch.int32, - ) - - # Validate multipliers - for i, mult in enumerate(weight_multiplier): - if mult < (1 << 30) or mult > ((1 << 31) - 1): - logger.error( - f"Invalid multiplier[{i}]: {mult}, scale was: {weight_scale_data}" - ) - return None - - return { - "weight_tensor": weight_tensor, - "weight_zero_point_data": weight_zp_tensor, - "weight_multiplier_data": weight_multiplier, - "weight_shift_data": weight_shift, - } - except Exception as e: - logger.error(f"Failed to extract weight parameters: {e}") - return None - - def _extract_bias_parameters(self, bias_dq_node: Optional[Node]) -> Optional[dict]: - """ - Extract bias parameters for quantized linear fusion. - Handles both dequantized bias nodes and constant bias tensors. - Returns a dict with bias_tensor, bias_multiplier, and bias_shift. - """ - if not bias_dq_node: - # No bias present - return None - try: - # Case 1: Bias is a dequantize node - if hasattr(bias_dq_node, "op") and is_dequant_node(bias_dq_node): - bias_tensor = bias_dq_node.args[0] - bias_scale = bias_dq_node.args[1] + Then the linear operation can be written as: + yi = sum_j((xj + a) * (wij + b)) + ci + = sum_j(xj*wij + xj*b + a*wij + a*b) + ci + = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) + = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum - bias_scale_data = self._extract_param_value(bias_scale) - - if ( - isinstance(bias_scale_data, torch.Tensor) - and bias_scale_data.numel() > 1 - ): - # Per-channel bias - bias_multipliers = [] - bias_shifts = [] - for scale_val in bias_scale_data.tolist(): - mult, shift = quantize_multiplier_aot(scale_val) - bias_multipliers.append(mult) - bias_shifts.append(shift) - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multipliers, - "bias_shift": bias_shifts, - } - else: - # Per-tensor bias - bias_scale_val = ( - bias_scale_data.item() - if isinstance(bias_scale_data, torch.Tensor) - else bias_scale_data - ) - bias_multiplier, bias_shift = quantize_multiplier_aot( - bias_scale_val - ) - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multiplier, - "bias_shift": bias_shift, - } - else: - # Case 2: Bias is a constant tensor (not dequantized) - # This can happen if bias is not quantized in the model - bias_tensor = bias_dq_node - # Use default multiplier/shift for unquantized bias - bias_multiplier = 1 - bias_shift = 0 - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multiplier, - "bias_shift": bias_shift, - } - except Exception as e: - logger.error(f"Failed to extract bias parameters: {e}") - return None - - def _prepare_bias_tensors( - self, bias_params: Optional[dict], out_features: int - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Prepare bias multiplier and shift tensors for kernel call. - Returns (bias_multiplier_tensor, bias_shift_tensor) both sized [out_features]. - """ - if bias_params: - bias_multiplier = bias_params["bias_multiplier"] - bias_shift = bias_params["bias_shift"] - - # Convert to tensors of the right size - if isinstance(bias_multiplier, int): - bias_multiplier_tensor = torch.full( - [out_features], bias_multiplier, dtype=torch.int32 - ) - elif isinstance(bias_multiplier, list): - assert ( - len(bias_multiplier) == out_features - ), f"Bias multiplier size {len(bias_multiplier)} != out_features {out_features}" - bias_multiplier_tensor = torch.tensor( - bias_multiplier, dtype=torch.int32 - ) - elif isinstance(bias_multiplier, torch.Tensor): - assert ( - bias_multiplier.numel() == out_features - ), f"Bias multiplier size {bias_multiplier.numel()} != out_features {out_features}" - bias_multiplier_tensor = bias_multiplier - else: - raise TypeError( - f"Unsupported bias_multiplier type: {type(bias_multiplier)}" - ) - - if isinstance(bias_shift, int): - bias_shift_tensor = torch.full( - [out_features], bias_shift, dtype=torch.int32 - ) - elif isinstance(bias_shift, list): - assert ( - len(bias_shift) == out_features - ), f"Bias shift size {len(bias_shift)} != out_features {out_features}" - bias_shift_tensor = torch.tensor(bias_shift, dtype=torch.int32) - elif isinstance(bias_shift, torch.Tensor): - assert ( - bias_shift.numel() == out_features - ), f"Bias shift size {bias_shift.numel()} != out_features {out_features}" - bias_shift_tensor = bias_shift - else: - raise TypeError(f"Unsupported bias_shift type: {type(bias_shift)}") - - return bias_multiplier_tensor, bias_shift_tensor - else: - # No bias: return zero tensors of correct shape - return ( - torch.zeros([out_features], dtype=torch.int32), - torch.zeros([out_features], dtype=torch.int32), - ) + where kernel_sum is precomputed aot. + """ - def _extract_param_value(self, node_or_value): - """ - Extract a scalar value from a Node or a direct float/int. + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ - if isinstance(node_or_value, (float, int)): - return node_or_value - # If it's a tensor, get its scalar value if possible - if isinstance(node_or_value, torch.Tensor): - return node_or_value.item() if node_or_value.numel() == 1 else node_or_value - # If it's a Node, use get_param_tensor - if hasattr(node_or_value, "op"): - tensor = get_param_tensor(self._exported_program, node_or_value) - return tensor.item() if tensor.numel() == 1 else tensor - raise TypeError(f"Unsupported parameter type: {type(node_or_value)}") - - def _calculate_cmsis_scratch_size(self, weight_tensor) -> int: - """Calculate CMSIS-NN scratch buffer size for quantized linear operations. + Computes the precomputed kernel sum term (bias optional) + a * sum_j(wij + b) + ci - Source: CMSIS-NN arm_fully_connected_s8_get_buffer_size() returns filter_dims->w * sizeof(int32_t). - This buffer stores pre-computed kernel sums (weight row sums) - one int32_t per output feature. - Same buffer size applies to both per-tensor and per-channel quantization paths since both use - identical kernel sum optimization in the underlying matrix multiplication. + as defined above, for i = (1, ..., n) where j indexes the input activations. """ - try: - print(f"weight_tensor type: {type(weight_tensor)}, value: {weight_tensor}") - weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape - out_features = weight_shape[0] # filter_dims->w in CMSIS terms - - # CMSIS-NN implementation expects the following size - cmsis_buffer_size = out_features * 4 # sizeof(int32_t) - return cmsis_buffer_size - except Exception as e: - logger.error(f"Failed to calculate CMSIS scratch size: {e}") - return 2048 # Fallback - - def _create_scratch_buffer(self, graph, quantize_node: Node, weight_tensor): - cmsis_scratch = self._calculate_cmsis_scratch_size(weight_tensor) - - kernel_sum_header = 8 # sizeof(KernelSumHeader) - total_size = kernel_sum_header + cmsis_scratch - - logger.info( - f"Kernel sum header: {kernel_sum_header}, CMSIS buffer: {cmsis_scratch}, total: {total_size}" - ) - - return create_mutable_buffer( - self._exported_program, - name=f"b_cmsis_linear_scratch_{id(quantize_node)}", - data=torch.zeros((total_size,), dtype=torch.int8), - ) - - def _create_fused_node( - self, - graph, - quantize_node: Node, - quant_params: dict, - weight_params: dict, - bias_params: Optional[dict], - quantized_target, - ) -> Node: - """Generic fused node creation for any FC-like operation.""" - # Extract all parameters - input_tensor = quant_params["input_tensor"] - input_zp = quant_params["input_zero_point"] - input_multiplier = quant_params["input_multiplier"] - input_shift = quant_params["input_shift"] - weight_tensor = weight_params["weight_tensor"] - - weight_zp_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_zero_point_data"], "weight_zp" - ) - weight_mult_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_multiplier_data"], "weight_mult" - ) - weight_shift_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_shift_data"], "weight_shift" + weights_transposed = weights.T + weights_int32 = weights_transposed.to(torch.int32) + offset_weights = weights_int32 + weight_offset + kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) + kernel_sum_offset = kernel_sum * input_offset + + if bias is not None: + kernel_sum_offset += bias + + return kernel_sum_offset + + def _get_linear_replacement(self, args, meta, node): + input_scale = meta["input_qparams"][0].scale + input_zp = meta["input_qparams"][0].zp + weight_scale = meta["input_qparams"][1].scale + weight_zp = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zp = meta["output_qparams"][0].zp + output_min = meta["output_qparams"][0].qmin + output_max = meta["output_qparams"][0].qmax + + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + (input_scale * weight_scale) / output_scale ) - # Get dimensions - weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape - assert ( - len(weight_shape) == 2 - ), f"Weight tensor must be 2D, got shape {weight_shape}" - in_features = weight_shape[1] - out_features = weight_shape[0] - # Handle bias - bias_tensor = bias_params["bias_tensor"] if bias_params else None - bias_multiplier, bias_shift = self._prepare_bias_tensors( - bias_params, out_features + # TODO: Add support for configuring the backend to support other extensions. + # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, + # so this should be optional. + weights = args[1] + weights_tensor = get_param_tensor(self.exported_program, weights) + bias_tensor = ( + get_param_tensor(self.exported_program, args[2]) if len(args) > 2 else None ) - output_zp = quant_params["output_zero_point"] - - scratch_buffer = self._create_scratch_buffer( - graph, quantize_node, weight_tensor + kernel_sum_tensor = self._compute_kernel_sum( + weights_tensor, bias_tensor, -input_zp, -weight_zp ) - - with graph.inserting_after(quantize_node): - fused = graph.create_node( - "call_function", - target=quantized_target, - args=( - input_tensor, - input_zp, - input_multiplier, - input_shift, - weight_tensor, - weight_zp_node, - weight_mult_node, - weight_shift_node, - bias_tensor, - bias_multiplier, - bias_shift, - scratch_buffer, - output_zp, - in_features, - out_features, - ), - kwargs={}, + with node.graph.inserting_after(weights): + kernel_sum = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_kernel_sum", + InputKind.PARAMETER, + kernel_sum_tensor, ) - transfer_metadata(fused, quantize_node, "QuantizedLinearFusionPass") - return fused - - def _mark_for_cleanup(self, nodes): - for node in nodes: - if node is not None: - self.nodes_to_erase.append(node) - - def _cleanup_nodes(self, graph): - cleanup_nodes(self.nodes_to_erase, graph) - self.nodes_to_erase.clear() - - def _extract_linear_pattern_with_validation(self, quantize_node: Node): - pattern_info = self._extract_linear_pattern(quantize_node) - if not pattern_info: - return None - # Optionally add more validation here if needed - return pattern_info + args = ( + args[0], + weights, + None, + kernel_sum, + -input_zp, + -weight_zp, + output_zp, + [quantized_multiplier], + [quantized_shift], + output_max, + output_min, + ) - def _trace_to_dequantize(self, node: Optional[Node], max_depth=3) -> Optional[Node]: - """Trace through transformations to find dequantize node.""" - current_node = node - depth = 0 - while current_node and depth < max_depth: - if is_dequant_node(current_node): - return current_node - if current_node.op == "call_function" and current_node.target in { - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.view_copy.default, - }: - if current_node.args: - current_node = current_node.args[0] - depth += 1 - continue - break - return None + return args - def _fuse_quantized_linear_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - fusion_count = 0 - graph = graph_module.graph - for node in list(graph.nodes): - if not ( - node.op == "call_function" and "quantize_per_tensor" in str(node.target) - ): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": continue - pattern_info = self._extract_linear_pattern_with_validation(node) - if not pattern_info: + if node.target != exir_ops.edge.aten.linear.default: continue - - ( - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - op_name, - ) = pattern_info - - # Get quantized target for this FC operation - quantized_target = self.SUPPORTED_OPS_MAPPING.get(fc_node.target) - if not quantized_target: - logger.warning(f"No quantized target found for {fc_node.target}") + if ( + node.meta.get("input_qparams", {}) == {} + or node.meta.get("output_qparams", {}) == {} + ): continue - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - - try: - input_params = self._extract_input_quantization_parameters( - input_dq_node - ) - if not input_params: - logger.error( - "Quantization parameter extraction failed for node: %s", node - ) - return None - output_params = self._extract_output_quantization_parameters( - quantize_node + args = self._get_linear_replacement(node.args, node.meta, node) + with graph_module.graph.inserting_before(node): + cortex_m_linear = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.cortex_m.quantized_linear.default, + args=args, + kwargs={}, ) - if not output_params: - logger.error( - "Output quantization parameter extraction failed for node: %s", - node, - ) - return None - quant_params = {**input_params, **output_params} - logger.info(f"Quantization parameters: {quant_params}") - weight_params = self._extract_weight_parameters(weight_dq_node) - if not weight_params: - continue - bias_params = self._extract_bias_parameters(bias_dq_node) - if bias_dq_node and not bias_params: - continue - fused_node = self._create_fused_node( - graph, - quantize_node, - quant_params, - weight_params, - bias_params, - quantized_target, - ) - logger.info(f"Created fused {op_name} node: {fused_node}") + node.replace_all_uses_with(cortex_m_linear) + graph_module.graph.erase_node(node) - quantize_node.replace_all_uses_with(fused_node) - self._mark_for_cleanup( - [ - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - ] - ) - fusion_count += 1 - logger.info(f"✅ Successfully fused {op_name} operation {fusion_count}") - except Exception as e: - logger.error( - f"Failed to fuse {op_name} pattern for {fc_node.name}: {e}" - ) - continue - self._cleanup_nodes(graph) - return fusion_count + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index 888155dcfd0..df35c8d626a 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -5,23 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -from typing import Set - -import executorch.backends.cortex_m.ops.operators # noqa -import torch +from typing import Dict from executorch.backends.cortex_m.passes.passes_utils import ( - extract_scalar_value, quantize_multiplier_aot, SHIFT_INT8, ) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass -from torch.fx.passes.infra.pass_manager import PassResult -logger = logging.getLogger("quant_op_fusion_pass") -logger.setLevel(logging.INFO) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument class QuantizedOpFusionPass(ExportPass): @@ -35,234 +29,58 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ - # Generic operation mapping - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, - # Future binary ops to be added here: - } + def _get_add_replacement(self, args, meta): - def __init__(self): - super().__init__() + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp - def _get_dequant_targets(self) -> Set: - """Support both decomposed and cortex_m dequant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.cortex_m.dequantize_per_tensor.default, - } + # AoT COMPUTATION: Calculate multipliers and shifts + max_scale_2x = 2 * max(scale1, scale2) - def _get_quant_targets(self) -> Set: - """Support both decomposed and cortex_m quant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: - """Check if node is a supported binary operation.""" - is_supported = ( - node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) + input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) + output_mult, output_shift = quantize_multiplier_aot( + max_scale_2x / (output_scale * (1 << SHIFT_INT8)) ) - if not is_supported: - return False - - shape1 = node.args[0].meta["val"].shape - shape2 = node.args[1].meta["val"].shape - is_broadcast = shape1 != shape2 - return not is_broadcast - def _is_dequant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a dequantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_dequant_targets() + args = ( + args[0], + zero_point1, + input1_mult, + input1_shift, + args[1], + zero_point2, + input2_mult, + input2_shift, + output_zero_point, + output_mult, + output_shift, ) - def _is_quant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a quantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_quant_targets() - ) + return exir_ops.edge.cortex_m.quantized_add.default, args - def _transfer_metadata( + def call_operator( self, - new_node: torch.fx.Node, - source_node: torch.fx.Node, - pass_name: str = "QuantizedOpFusionPass", - ) -> None: - """Metadata transfer with proper provenance tracking.""" - if hasattr(source_node, "meta") and source_node.meta: - new_node.meta = source_node.meta.copy() - - if "from_node" in new_node.meta: - from_node_list = new_node.meta.get("from_node", []).copy() - from_node_list.append( - {"source": source_node.name, "pass": pass_name, "op": "fuse"} - ) - new_node.meta["from_node"] = from_node_list - - # Copy essential fields - for field in ["tensor_meta", "stack_trace"]: - if field in source_node.meta: - new_node.meta[field] = source_node.meta[field] - - def _normalize_to_cortex_m_targets(self, graph_module: torch.fx.GraphModule) -> int: - """Convert decomposed targets to cortex_m equivalents for consistent handling.""" - target_mapping = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.cortex_m.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - normalization_count = 0 - for node in list(graph_module.graph.nodes): - if node.op == "call_function" and node.target in target_mapping: - logger.info(f"Normalizing {node.target} to cortex_m equivalent") - node.target = target_mapping[node.target] - normalization_count += 1 - - return normalization_count - - def _fuse_quantized_binary_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - """Generic fusion for quantized binary operation patterns.""" - fusion_count = 0 - nodes_to_erase = [] - - for node in list(graph_module.graph.nodes): - if not self._is_quant_node(node): - continue - - quantize_node = node - if not quantize_node.args: - continue - - binary_op_node = quantize_node.args[0] - if not self._is_supported_binary_op(binary_op_node): - continue - - if len(binary_op_node.args) < 2: - continue - - dequant_node1, dequant_node2 = binary_op_node.args[:2] - if not ( - self._is_dequant_node(dequant_node1) - and self._is_dequant_node(dequant_node2) - ): - continue - - # Get the target quantized operation - quantized_target = self.SUPPORTED_OPS_MAPPING[binary_op_node.target] - # Extract op name (e.g., 'Tensor' -> 'add') - op_name = str(binary_op_node.target).split(".")[-1] - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - - try: - # Extract values - int8_tensor1, scale1, zero_point1 = dequant_node1.args[:3] - int8_tensor2, scale2, zero_point2 = dequant_node2.args[:3] - output_scale, output_zero_point = quantize_node.args[1:3] - - # Convert to Python floats - scale1_val = extract_scalar_value(scale1) - scale2_val = extract_scalar_value(scale2) - output_scale_val = extract_scalar_value(output_scale) - zp1_val = int(extract_scalar_value(zero_point1)) - zp2_val = int(extract_scalar_value(zero_point2)) - output_zp_val = int(extract_scalar_value(output_zero_point)) - - max_scale_2x = 2 * max(scale1_val, scale2_val) - # AoT COMPUTATION: Calculate multipliers and shifts - - input1_mult, input1_shift = quantize_multiplier_aot( - scale1_val / max_scale_2x - ) - input2_mult, input2_shift = quantize_multiplier_aot( - scale2_val / max_scale_2x - ) - output_mult, output_shift = quantize_multiplier_aot( - max_scale_2x / (output_scale_val * (1 << SHIFT_INT8)) - ) - - logger.info("AoT computed parameters:") - logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") - logger.info(f" Input2: mult={input2_mult}, shift={input2_shift}") - logger.info(f" Output: mult={output_mult}, shift={output_shift}") - - with graph_module.graph.inserting_after(quantize_node): - fused = graph_module.graph.create_node( - "call_function", - target=quantized_target, - args=( - int8_tensor1, - zp1_val, - input1_mult, - input1_shift, - int8_tensor2, - zp2_val, - input2_mult, - input2_shift, - output_zp_val, - output_mult, - output_shift, - ), - kwargs={}, - ) - - # metadata transfer - self._transfer_metadata(fused, quantize_node) - - logger.info(f"✅ Created fused quantized_{op_name} node: {fused}") - - # Replace all uses - quantize_node.replace_all_uses_with(fused) - binary_op_node.replace_all_uses_with(fused) - dequant_node1.replace_all_uses_with(fused) - dequant_node2.replace_all_uses_with(fused) - - nodes_to_erase.extend( - [quantize_node, binary_op_node, dequant_node1, dequant_node2] - ) - fusion_count += 1 - logger.info(f"Pattern fused, total so far: {fusion_count}") - - except Exception as e: - logger.info(f"❌ Error during AoT computation: {e}") - logger.info(" Skipping fusion for this pattern") - continue - - for old_node in reversed(nodes_to_erase): - if old_node in graph_module.graph.nodes and len(old_node.users) == 0: - logger.info(f"🗑️ Erasing node: {old_node}") - graph_module.graph.erase_node(old_node) - - return fusion_count - - def call(self, graph_module: torch.fx.GraphModule): - logger.info("QuantizedOpFusionPass.call() started") - - # Normalize targets for flexible pass ordering - normalization_count = self._normalize_to_cortex_m_targets(graph_module) - - # Generic fusion for supported binary operations - fusion_count = self._fuse_quantized_binary_patterns(graph_module) - - total_changes = normalization_count + fusion_count - logger.info(f"Total changes: {total_changes}") - - if total_changes > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - - logger.debug("=== AFTER FUSION: All nodes in the graph ===") - for i, node in enumerate(graph_module.graph.nodes): - logger.debug(f"Node {i}: op={node.op}, target={node.target}") - if "quantized_" in str(node.target) and "add" in str(node.target): - logger.debug(" ⭐ FOUND QUANTIZED BINARY OP NODE! ⭐") - logger.debug("=== END DEBUG ===") - - return PassResult(graph_module, total_changes > 0) + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return super().call_operator(op, args, {}, meta) + + match op: + case exir_ops.edge.aten.add.Tensor: + op, args = self._get_add_replacement(args, meta) + case _: + pass + + return super().call_operator(op, args, {}, meta) diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 4389b463076..458d5361347 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -59,17 +59,6 @@ class CortexMTensorAdd(Model): } -class CortexMTensorAddBroadcast(Model): - # TODO: Quantize and accelerate broadcasted adds - ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - - ops_after_transforms = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - - class CortexMAlphaAdd(ModelAlpha): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, @@ -126,15 +115,15 @@ class CortexMAlphaAdd(ModelAlpha): (torch.rand(2, 2) * 10, torch.rand(2, 2)), ), "broadcast_1": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), ( ramp_tensor(-2, 2, (2, 1, 2, 1)), ramp_tensor(-5, 5, (1, 2, 1, 2)), @@ -183,6 +172,18 @@ def test_dialect_add(test_case): "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, ), + "broadcast_1": ( + " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", + RuntimeError, + ), + "broadcast_2": ( + " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", + RuntimeError, + ), + "broadcast_3": ( + " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", + RuntimeError, + ), "alpha": ( "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError, diff --git a/backends/cortex_m/test/ops/test_linear.py b/backends/cortex_m/test/ops/test_linear.py index 4ab5ca99f15..e81daa7e83e 100644 --- a/backends/cortex_m/test/ops/test_linear.py +++ b/backends/cortex_m/test/ops/test_linear.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. -import pytest import torch +from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import ( CortexMTester, McuTestCase, @@ -13,12 +13,9 @@ ) -class CortexMMm(torch.nn.Module): - def forward(self, x, y): - return torch.mm(x, y) - +class CortexMLinear(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_mm_default": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, } @@ -29,32 +26,45 @@ def forward(self, x, y): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + return self.linear(x) -class CortexMBmm(torch.nn.Module): - def forward(self, x, y): - return torch.bmm(x, y) +class CortexMLinearX3(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_bmm_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_aten_linear_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, } ops_after_transforms = { - "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 3, "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + x = self.linear(x) + x = self.linear(x) + x = self.linear(x) + return x -class CortexMAddmm(torch.nn.Module): - def forward(self, x, y, z, alpha=None, beta=None): - return torch.addmm(beta, x, alpha, y, z) +class CortexMLinearBias(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, } ops_after_transforms = { @@ -63,90 +73,23 @@ def forward(self, x, y, z, alpha=None, beta=None): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } - -class CortexMAt(CortexMMm): - def forward(self, x, y): - return x @ y - - -class CortexMMatmul(CortexMMm): - def forward(self, x, y): - return torch.matmul(x, y) - - -class CortexMLinear(CortexMMatmul): - def __init__(self, *args, **kwargs): - super().__init__() - self.linear = torch.nn.Linear(*args, bias=False) - - def forward(self, x): - return self.linear(x) - - -class CortexMLinearBias(CortexMAddmm): def __init__(self, *args, **kwargs): super().__init__() self.linear = torch.nn.Linear(*args, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): - return self.relu(self.linear(x)) + return self.linear(x) test_cases = { - "mm": McuTestCase( - model=CortexMMm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "bmm": McuTestCase( - model=CortexMBmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16, 16)), - ramp_tensor(0, 10, (1, 16, 16)), - ), - ), - "addmm": McuTestCase( - model=CortexMAddmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ramp_tensor(0, 10, (16, 16)), - 2, - 4, - ), - ), - "addmm_scalars": McuTestCase( - model=CortexMAddmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "@-operator": McuTestCase( - model=CortexMAt(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "matmul": McuTestCase( - model=CortexMMatmul(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), "linear_rank1": McuTestCase( - model=CortexMLinear(2, 3), - example_inputs=(ramp_tensor(-1, 1, (2,)),), + model=CortexMLinear(1, 2), + example_inputs=(torch.Tensor([1]),), ), "linear_rank2_pos": McuTestCase( - model=CortexMLinear(8, 3), - example_inputs=(ramp_tensor(0, 10, (2, 8)),), + model=CortexMLinear(1, 2), + example_inputs=(ramp_tensor(-1, 1, (1, 1)),), ), "linear_rank3_neg": McuTestCase( model=CortexMLinear(5, 3), @@ -164,22 +107,24 @@ def forward(self, x): model=CortexMLinearBias(61, 37), example_inputs=(ramp_tensor(0, 10, (8, 61)),), ), + "linear_x3": McuTestCase( + model=CortexMLinearX3(4, 4), + example_inputs=(ramp_tensor(0, 10, (2, 4)),), + ), } -@pytest.mark.skip( - reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." -) +@parametrize("test_case", test_cases) def test_dialect_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( - test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, ) -@pytest.mark.skip( - reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." -) +@parametrize("test_case", test_cases) def test_implementation_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation() + tester.test_implementation(qtol=1)