diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 65140baf540..9d139c68242 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -34,6 +34,7 @@ class MulVisitor_INT(NodeVisitor): tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), ] def define_node( @@ -51,11 +52,11 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], output.tosa_spec, ) - if inputs[0].dtype == ts.DType.INT8: + if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: input_A = inputs[0] input_B = inputs[1] input_qparams = get_input_qparams(node) @@ -80,15 +81,15 @@ def define_node( tosa_spec=self.tosa_spec, ) else: - # input[0].dtype == ts.DType.INT32 + # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 # Non quantized input, natively support by TOSA.MUL input_A_rescaled, input_B_rescaled = inputs[0], inputs[1] - if output.dtype == ts.DType.INT8: + if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16: output_shape = tutils.tosa_shape(output.shape, output.dim_order) mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) else: - # output.dtype == ts.DType.INT32 + # output.dtype == ts.DType.INT32 (non-quantized) mul_output = output # Do the INT32 Mul @@ -110,6 +111,15 @@ def define_node( tqutils.insert_rescale_op_to_int8( tosa_graph, mul_output, output_scale, node, self.tosa_spec ) + elif output.dtype == ts.DType.INT16: + # Scale output back to 16 bit + output_scale = ( + input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] + * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] + ) + tqutils.insert_rescale_op_to_int16( + tosa_graph, mul_output, output_scale, node, self.tosa_spec + ) @register_node_visitor diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index acd6d86c788..b2db55d90fd 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -8,9 +8,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -18,6 +23,9 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification + +from executorch.backends.xnnpack.test.tester import Quantize input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x aten_op = "torch.ops.aten.mul.Tensor" @@ -284,3 +292,105 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor): ) pipeline.pop_stage("check.quant_nodes") pipeline.run() + + +def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="missing int16 mul ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13947" +) +def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1): + """Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Mul(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_mul_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947" +) +def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1): + """Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Mul(), + test_data(), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_mul_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947" +) +def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1): + """Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Mul(), + test_data(), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_mul_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 405f1bbf081..b438e556cca 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -16,6 +16,7 @@ def define_arm_tests(): "ops/test_add.py", "ops/test_avg_pool2d.py", "ops/test_linear.py", + "ops/test_mul.py", "ops/test_slice.py", "ops/test_sigmoid.py", "ops/test_tanh.py", diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index 65b5977358e..86e8e5bad8b 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -5,7 +5,7 @@ # pyre-unsafe -# Utiliy functions for TOSA quantized lowerings +# Utility functions for TOSA quantized lowerings import math @@ -27,11 +27,11 @@ def insert_rescale_ops_to_int32_maxscale( tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None ) -> tuple[list[Any], float]: """For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale)) - compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision + compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision for the computation without overflowing. Returns a list of the rescaled nodes and the scale factor used, - needed by rescale_node_back_to_int8. + needed by insert_rescale_op_to_int8. """ if len(inputs) > 2: @@ -86,7 +86,7 @@ def insert_rescale_ops_to_int32( The scales are adjusted using the smallest scale of all 'nodes'. Returns a list of the rescaled nodes and the scale factor used, - needed by rescale_node_back_to_int8. + needed by insert_rescale_op_to_int8. This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters @@ -134,7 +134,59 @@ def insert_rescale_op_to_int8( Parameters: node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32' + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' + compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict. + """ + _insert_rescale_op_to_dtype( + tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec + ) + + +def insert_rescale_op_to_int16( + tosa_graph: Any, + last_tensor: TosaArg, + scale: float, + node: Node, + compute_rescale=True, + tosa_spec=None, +) -> None: + """Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' + compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict. + """ + _insert_rescale_op_to_dtype( + tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec + ) + + +def _insert_rescale_op_to_dtype( + tosa_graph: Any, + last_tensor: TosaArg, + scale: float, + node: Node, + output_dtype: Any, + compute_rescale=True, + tosa_spec=None, +) -> None: + """Common implementation for rescaling nodes back to a specific dtype. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' + output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16) compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. tosa_graph: the tosa_graph to manipulate. @@ -156,20 +208,21 @@ def insert_rescale_op_to_int8( else: output_rescale_scale = scale - # Rescale Back to INT8 - build_rescale_from_int32( + # Rescale Back to the specified dtype + build_rescale_from_int32_to_dtype( tosa_graph, last_tensor, node.name, qargs_out.get_zp_per_tensor(), output_rescale_scale, + output_dtype, tosa_spec=tosa_spec, ) # TOSA uses the RESCALE operation to scale between values with differing precision. # The RESCALE operator is defined using an integer multiply, add, and shift. -# This utility function is for calculating the multier and shift given a scale. +# This utility function is for calculating the multiplier and shift given a scale. # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling def compute_multiplier_and_shift( scales: list[float], scaleWidth: int = 32 @@ -214,7 +267,7 @@ def compute_multiplier_and_shift( return multipliers, shifts -# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be +# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be # const inputs. Create constant operators from the data already initialized. def create_const_ops_for_rescale( tosa_fb, @@ -335,6 +388,47 @@ def build_rescale_from_int32( per_channel: bool = False, tosa_spec=None, ) -> None: + # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs + # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale + build_rescale_from_int32_to_dtype( + tosa_fb, + input_node, + output_name, + output_zp, + rescale_scale, + ts.DType.INT8, + is_scale32, + is_double_round, + per_channel, + tosa_spec, + ) + + return + + +def build_rescale_from_int32_to_dtype( + tosa_fb: Any, + input_node: TosaArg, + output_name: str, + output_zp: int, + rescale_scale: float, + output_dtype: Any, + is_scale32: bool = True, + is_double_round: bool = False, + per_channel: bool = False, + tosa_spec=None, +) -> None: + """Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16). + + Parameters: + tosa_fb: The TOSA serializer + input_node: Input tensor (should be INT32) + output_name: Name for the output tensor + output_zp: Output zero point + rescale_scale: Rescaling factor + output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16) + Other parameters: Standard rescale parameters + """ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale build_rescale( @@ -342,7 +436,7 @@ def build_rescale_from_int32( [rescale_scale], input_node, output_name=output_name, - output_type=ts.DType.INT8, + output_type=output_dtype, input_zp=[0], output_zp=[output_zp], rounding_mode=RoundingMode.SINGLE_ROUND,