Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ def define_node(
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
elif inputs[0].dtype == ts.DType.INT16:
rescaled_inputs, scale_back = (
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
)
else:
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
# Non quantized input, natively support by TOSA.ADD
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
Expand Down Expand Up @@ -99,6 +105,15 @@ def define_node(
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]
elif output.dtype == ts.DType.INT16:
tqutils.insert_rescale_op_to_int16(
tosa_graph,
add_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]


@register_node_visitor
Expand Down
6 changes: 0 additions & 6 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):

@common.parametrize("test_data", Add.test_data)
@common.XfailIfNoCorstone300
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
)
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False
Expand All @@ -304,9 +301,6 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):

@common.parametrize("test_data", Add.test_data)
@common.XfailIfNoCorstone320
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
)
def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False
Expand Down
11 changes: 3 additions & 8 deletions backends/arm/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,15 @@ def test_to_vgf_INT(test_data: Tuple):
),
}

redundant_xfails_FP = {
redundant_xfails = {
"rand_fp16_fp16": "FP16 is not supported",
"rand_int8_int8": "Tracing graph with quantized input is not supported.",
"rand_int16_int16": "Tracing graph with quantized input is not supported.",
}

redundant_xfails_INT = {
"rand_fp16_fp16": "FP16 is not supported",
"rand_int8_int8": "Tracing graph with quantized input is not supported.",
}


@common.parametrize(
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_FP
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails
)
def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple):
test_tensor, new_dtype = test_data()
Expand All @@ -220,7 +215,7 @@ def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple):


@common.parametrize(
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_INT
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails
)
def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple):
test_tensor, new_dtype = test_data()
Expand Down
1 change: 1 addition & 0 deletions backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def define_arm_tests():
"ops/test_tanh.py",
"ops/test_view.py",
"ops/test_cos.py",
"ops/test_to_copy.py",
]

# Quantization
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/tosa/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,59 @@ def insert_rescale_ops_to_int32_maxscale(
return [rescaled_lhs, rescaled_rhs], back_scale


def insert_rescale_ops_int16_to_int32_maxscale(
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
) -> tuple[list[Any], float]:
"""For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale))
compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision
for the computation without overflowing.

Returns a list of the rescaled nodes and the scale factor used,
needed by insert_rescale_op_to_int16.
"""

if len(inputs) > 2:
raise ValueError("More than two inputs not supported")

tensors = inputs.copy()
# Reshape tensor according to TOSA dim order
for tensor in tensors:
dim_order = tensor.dim_order
tensor.shape = [tensor.shape[i] for i in dim_order]

input_qparams = get_input_qparams(node)
lhs_qparams, rhs_qparams = input_qparams.values()
lhs_scale = lhs_qparams.get_scale_per_tensor()
rhs_scale = rhs_qparams.get_scale_per_tensor()
# Common scale for the two numbers
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
SHIFT_INT16 = 12
# We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result.
# We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale)
# we are shifting to the left by 11.
lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x
rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x
rescaled_lhs = build_rescale_to_int32(
tosa_graph,
tensors[0],
lhs_qparams.get_zp_per_tensor(),
lhs_factor,
tosa_spec=tosa_spec,
)
rescaled_rhs = build_rescale_to_int32(
tosa_graph,
tensors[1],
rhs_qparams.get_zp_per_tensor(),
rhs_factor,
tosa_spec=tosa_spec,
)
out_qparam = get_output_qparams(node)[0]
out_scale = out_qparam.get_scale_per_tensor()
back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16))

return [rescaled_lhs, rescaled_rhs], back_scale


def insert_rescale_ops_to_int32(
tosa_graph: Any,
inputs: list[TosaArg],
Expand Down
Loading