diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index a8f0c3fe14d..81b415363ea 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -64,12 +64,18 @@ def define_node( rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( tosa_graph, inputs, node, self.tosa_spec ) + elif inputs[0].dtype == ts.DType.INT16: + rescaled_inputs, scale_back = ( + tqutils.insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + ) else: # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 # Non quantized input, natively support by TOSA.ADD rescaled_inputs = inputs - if output.dtype == ts.DType.INT8: + if output.dtype in [ts.DType.INT8, ts.DType.INT16]: broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: @@ -99,6 +105,15 @@ def define_node( compute_rescale=False, tosa_spec=self.tosa_spec, ) # type: ignore[possibly-undefined] + elif output.dtype == ts.DType.INT16: + tqutils.insert_rescale_op_to_int16( + tosa_graph, + add_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] @register_node_visitor diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index bb690d89f59..19a3ba1a718 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -276,9 +276,6 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1): @common.parametrize("test_data", Add.test_data) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730" -) def test_add_tensor_16a8w_u55_INT16(test_data: input_t1): """Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -304,9 +301,6 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1): @common.parametrize("test_data", Add.test_data) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730" -) def test_add_tensor_16a8w_u85_INT16(test_data: input_t1): """Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 42d136c52c1..1fdc4619131 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -192,20 +192,15 @@ def test_to_vgf_INT(test_data: Tuple): ), } -redundant_xfails_FP = { +redundant_xfails = { "rand_fp16_fp16": "FP16 is not supported", "rand_int8_int8": "Tracing graph with quantized input is not supported.", "rand_int16_int16": "Tracing graph with quantized input is not supported.", } -redundant_xfails_INT = { - "rand_fp16_fp16": "FP16 is not supported", - "rand_int8_int8": "Tracing graph with quantized input is not supported.", -} - @common.parametrize( - "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_FP + "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails ) def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple): test_tensor, new_dtype = test_data() @@ -220,7 +215,7 @@ def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple): @common.parametrize( - "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_INT + "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails ) def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple): test_tensor, new_dtype = test_data() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 00ec87f928e..093268edef7 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -25,6 +25,7 @@ def define_arm_tests(): "ops/test_tanh.py", "ops/test_view.py", "ops/test_cos.py", + "ops/test_to_copy.py", ] # Quantization diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index 68ceec8d97c..562c77e30da 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -77,6 +77,59 @@ def insert_rescale_ops_to_int32_maxscale( return [rescaled_lhs, rescaled_rhs], back_scale +def insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None +) -> tuple[list[Any], float]: + """For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale)) + compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision + for the computation without overflowing. + + Returns a list of the rescaled nodes and the scale factor used, + needed by insert_rescale_op_to_int16. + """ + + if len(inputs) > 2: + raise ValueError("More than two inputs not supported") + + tensors = inputs.copy() + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + input_qparams = get_input_qparams(node) + lhs_qparams, rhs_qparams = input_qparams.values() + lhs_scale = lhs_qparams.get_scale_per_tensor() + rhs_scale = rhs_qparams.get_scale_per_tensor() + # Common scale for the two numbers + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + SHIFT_INT16 = 12 + # We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result. + # We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) + # we are shifting to the left by 11. + lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x + rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x + rescaled_lhs = build_rescale_to_int32( + tosa_graph, + tensors[0], + lhs_qparams.get_zp_per_tensor(), + lhs_factor, + tosa_spec=tosa_spec, + ) + rescaled_rhs = build_rescale_to_int32( + tosa_graph, + tensors[1], + rhs_qparams.get_zp_per_tensor(), + rhs_factor, + tosa_spec=tosa_spec, + ) + out_qparam = get_output_qparams(node)[0] + out_scale = out_qparam.get_scale_per_tensor() + back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16)) + + return [rescaled_lhs, rescaled_rhs], back_scale + + def insert_rescale_ops_to_int32( tosa_graph: Any, inputs: list[TosaArg],