diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index b4bbfdaffce..f607977aa48 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -26,6 +26,8 @@ Tensor& quantized_add_out( const int64_t output_zero_point, const int64_t output_multiplier, const int64_t output_shift, + const int64_t activation_min, + const int64_t activation_max, Tensor& out) { // Validate tensor types and dim order bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8); @@ -69,8 +71,8 @@ Tensor& quantized_add_out( // Left shift to maximize precision const int32_t left_shift = 20; - const int32_t activation_min = std::numeric_limits::min(); - const int32_t activation_max = std::numeric_limits::max(); + const int32_t act_min = static_cast(activation_min); + const int32_t act_max = static_cast(activation_max); ET_LOG( Debug, @@ -121,8 +123,8 @@ Tensor& quantized_add_out( static_cast(out_zp), output_mult, output_shift_val, - activation_min, - activation_max, + act_min, + act_max, adds_per_loop); if (status != ARM_CMSIS_NN_SUCCESS) { diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 623db8cd648..4de702b47a9 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -123,15 +123,16 @@ def dequantize_per_tensor_impl( "quantized_add(" "Tensor self, int self_zero_point, int self_multiplier, int self_shift, " "Tensor other, int other_zero_point, int other_multiplier, int other_shift, " - "int output_zero_point, int output_multiplier, int output_shift) -> Tensor" + "int output_zero_point, int output_multiplier, int output_shift, " + "int activation_min, int activation_max) -> Tensor" ) -# Define the operator schema with multipliers and shifts (11 args + out tensor) lib.define( "quantized_add.out(" "Tensor self, int self_zero_point, int self_multiplier, int self_shift, " "Tensor other, int other_zero_point, int other_multiplier, int other_shift, " "int output_zero_point, int output_multiplier, int output_shift, " + "int activation_min, int activation_max, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -149,6 +150,8 @@ def quantized_add_meta( output_zero_point: int, output_multiplier: int, output_shift: int, + activation_min: int, + activation_max: int, ) -> torch.Tensor: assert self.shape == other.shape or is_channel_broadcast(self, other), ( "Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — " @@ -174,6 +177,8 @@ def quantized_add_impl( output_zero_point: int, output_multiplier: int, output_shift: int, + activation_min: int, + activation_max: int, ) -> torch.Tensor: assert self.shape == other.shape or is_channel_broadcast(self, other), ( "Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — " @@ -187,7 +192,9 @@ def quantized_add_impl( result_fp = self_fp + other_fp result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) - result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + result = torch.clamp( + result_quantized + output_zero_point, activation_min, activation_max + ).to(torch.int8) return result diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 0f8f764c1f3..e0ebbfab868 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -17,7 +17,7 @@ - arg_meta: null kernel_name: cortex_m::dequantize_per_tensor_out -- func: cortex_m::quantized_add.out(Tensor self, int self_zero_point, int self_multiplier, int self_shift, Tensor other, int other_zero_point, int other_multiplier, int other_shift, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_add.out(Tensor self, int self_zero_point, int self_multiplier, int self_shift, Tensor other, int other_zero_point, int other_multiplier, int other_shift, int output_zero_point, int output_multiplier, int output_shift, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py index 864f9e47ec8..a53c065aaa4 100644 --- a/backends/cortex_m/passes/activation_fusion_pass.py +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -40,6 +40,7 @@ class ActivationFusionPass(ExportPass): FUSE_OPS = { exir_ops.edge.aten.linear.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.add.Tensor, } def _get_validated_qparams(self, node, input_node): @@ -85,7 +86,7 @@ def _get_validated_qparams(self, node, input_node): else qmax ) case _: - raise RuntimeError("Unexpected target {node.target}.") + raise RuntimeError(f"Unexpected target {node.target}.") # If the minimal quantized value is larger than the qmin, it means that the quantized range contains # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index e56a266f73e..0eb696fcc03 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -61,6 +61,9 @@ def _get_add_replacement(self, args, meta): max_scale_2x / (output_scale * (1 << SHIFT_INT8)) ) + activation_min = meta["output_qparams"][0].qmin + activation_max = meta["output_qparams"][0].qmax + args = ( args[0], zero_point1, @@ -73,6 +76,8 @@ def _get_add_replacement(self, args, meta): output_zero_point, output_mult, output_shift, + activation_min, + activation_max, ) return exir_ops.edge.cortex_m.quantized_add.default, args diff --git a/backends/cortex_m/quantizer/quantizer_support.py b/backends/cortex_m/quantizer/quantizer_support.py index 348e7bf87f1..2cf0483f74b 100644 --- a/backends/cortex_m/quantizer/quantizer_support.py +++ b/backends/cortex_m/quantizer/quantizer_support.py @@ -17,6 +17,12 @@ BINARY_OP_PATTERNS = { (torch.ops.aten.add.Tensor,): CortexMAddMulCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): CortexMAddMulCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default): CortexMAddMulCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.hardtanh.default): CortexMAddMulCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.hardtanh_.default): CortexMAddMulCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.clamp.default): CortexMAddMulCheck, + (torch.ops.aten.add.Tensor, torch.ops.aten.clamp_.default): CortexMAddMulCheck, (torch.ops.aten.add_.Tensor,): CortexMAddMulCheck, (torch.ops.aten.mul.Tensor,): CortexMAddMulCheck, (torch.ops.aten.mul_.Tensor,): CortexMAddMulCheck, diff --git a/backends/cortex_m/test/models/test_nn_modules.py b/backends/cortex_m/test/models/test_nn_modules.py index 77aa07e04c9..4a92fd578ff 100644 --- a/backends/cortex_m/test/models/test_nn_modules.py +++ b/backends/cortex_m/test/models/test_nn_modules.py @@ -188,9 +188,7 @@ def forward(self, x): ), } -xfails: dict[str, xfail_type] = { - "conv_add_relu": "Activation fusion does not support relu after add", -} +xfails: dict[str, xfail_type] = {} @parametrize("test_case", test_cases, xfails=xfails, strict=False) diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 5adc77ce4aa..43a76149670 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -73,6 +73,50 @@ class CortexMAlphaAdd(ModelAlpha): } +class CortexMAddReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x, y): + return self.relu(x + y) + + +class CortexMAddHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + + def forward(self, x, y): + return self.act(x + y) + + test_cases = { "self_rank_1": McuTestCase( CortexMSelfAdd(), @@ -149,6 +193,34 @@ class CortexMAlphaAdd(ModelAlpha): ramp_tensor(-20, 20, (4, 5)), ), ), + "add_relu": McuTestCase( + CortexMAddReLU(), + ( + ramp_tensor(-5, 5, (2, 4)), + ramp_tensor(-3, 3, (2, 4)), + ), + ), + "add_relu_channels_last": McuTestCase( + CortexMAddReLU(), + ( + ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "add_hardtanh": McuTestCase( + CortexMAddHardtanh(min_val=-0.5, max_val=0.5), + ( + ramp_tensor(-2, 2, (2, 4)), + ramp_tensor(-1, 1, (2, 4)), + ), + ), + "add_hardtanh_channels_last": McuTestCase( + CortexMAddHardtanh(min_val=-1.0, max_val=1.0), + ( + ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ramp_tensor(-2, 2, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), }