Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class MulVisitor_INT(NodeVisitor):

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
]

def define_node(
Expand All @@ -51,11 +52,11 @@ def define_node(
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
output.tosa_spec,
)

if inputs[0].dtype == ts.DType.INT8:
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
input_A = inputs[0]
input_B = inputs[1]
input_qparams = get_input_qparams(node)
Expand All @@ -80,15 +81,15 @@ def define_node(
tosa_spec=self.tosa_spec,
)
else:
# input[0].dtype == ts.DType.INT32
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
# Non quantized input, natively support by TOSA.MUL
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]

if output.dtype == ts.DType.INT8:
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
# output.dtype == ts.DType.INT32 (non-quantized)
mul_output = output

# Do the INT32 Mul
Expand All @@ -110,6 +111,15 @@ def define_node(
tqutils.insert_rescale_op_to_int8(
tosa_graph, mul_output, output_scale, node, self.tosa_spec
)
elif output.dtype == ts.DType.INT16:
# Scale output back to 16 bit
output_scale = (
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
)
tqutils.insert_rescale_op_to_int16(
tosa_graph, mul_output, output_scale, node, self.tosa_spec
)


@register_node_visitor
Expand Down
112 changes: 111 additions & 1 deletion backends/arm/test/ops/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@

from typing import Tuple

import pytest
import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa.specification import TosaSpecification

from executorch.backends.xnnpack.test.tester import Quantize

input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
aten_op = "torch.ops.aten.mul.Tensor"
Expand Down Expand Up @@ -284,3 +292,105 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
)
pipeline.pop_stage("check.quant_nodes")
pipeline.run()


def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_data", test_data_suite)
@pytest.mark.xfail(
reason="missing int16 mul ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13947"
)
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = TosaPipelineINT[input_t1](
Mul(),
test_data(),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_mul_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
)
def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1):
"""Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU55PipelineINT[input_t1](
Mul(),
test_data(),
aten_op,
exir_ops=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_mul_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
)
def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1):
"""Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU85PipelineINT[input_t1](
Mul(),
test_data(),
aten_op,
exir_ops=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_mul_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()
1 change: 1 addition & 0 deletions backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def define_arm_tests():
"ops/test_add.py",
"ops/test_avg_pool2d.py",
"ops/test_linear.py",
"ops/test_mul.py",
"ops/test_slice.py",
"ops/test_sigmoid.py",
"ops/test_tanh.py",
Expand Down
114 changes: 104 additions & 10 deletions backends/arm/tosa/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-unsafe

# Utiliy functions for TOSA quantized lowerings
# Utility functions for TOSA quantized lowerings

import math

Expand All @@ -27,11 +27,11 @@ def insert_rescale_ops_to_int32_maxscale(
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
) -> tuple[list[Any], float]:
"""For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale))
compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision
compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision
for the computation without overflowing.

Returns a list of the rescaled nodes and the scale factor used,
needed by rescale_node_back_to_int8.
needed by insert_rescale_op_to_int8.
"""

if len(inputs) > 2:
Expand Down Expand Up @@ -86,7 +86,7 @@ def insert_rescale_ops_to_int32(
The scales are adjusted using the smallest scale of all 'nodes'.

Returns a list of the rescaled nodes and the scale factor used,
needed by rescale_node_back_to_int8.
needed by insert_rescale_op_to_int8.

This functions is used in serialization to TOSA for target ops that are
handled by the DQ/D folding pass, which stores the quantization parameters
Expand Down Expand Up @@ -134,7 +134,59 @@ def insert_rescale_op_to_int8(
Parameters:
node: The original node that is being handled by the rescales.
last_tensor:the tosa tensor to rescale back.
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
tosa_graph: the tosa_graph to manipulate.

This functions is used in serialization to TOSA for target ops that are
handled by the DQ/D folding pass, which stores the quantization parameters
in the node meta dict.
"""
_insert_rescale_op_to_dtype(
tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec
)


def insert_rescale_op_to_int16(
tosa_graph: Any,
last_tensor: TosaArg,
scale: float,
node: Node,
compute_rescale=True,
tosa_spec=None,
) -> None:
"""Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
Parameters:
node: The original node that is being handled by the rescales.
last_tensor:the tosa tensor to rescale back.
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
tosa_graph: the tosa_graph to manipulate.

This functions is used in serialization to TOSA for target ops that are
handled by the DQ/D folding pass, which stores the quantization parameters
in the node meta dict.
"""
_insert_rescale_op_to_dtype(
tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec
)


def _insert_rescale_op_to_dtype(
tosa_graph: Any,
last_tensor: TosaArg,
scale: float,
node: Node,
output_dtype: Any,
compute_rescale=True,
tosa_spec=None,
) -> None:
"""Common implementation for rescaling nodes back to a specific dtype.
Parameters:
node: The original node that is being handled by the rescales.
last_tensor:the tosa tensor to rescale back.
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
tosa_graph: the tosa_graph to manipulate.

Expand All @@ -156,20 +208,21 @@ def insert_rescale_op_to_int8(
else:
output_rescale_scale = scale

# Rescale Back to INT8
build_rescale_from_int32(
# Rescale Back to the specified dtype
build_rescale_from_int32_to_dtype(
tosa_graph,
last_tensor,
node.name,
qargs_out.get_zp_per_tensor(),
output_rescale_scale,
output_dtype,
tosa_spec=tosa_spec,
)


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multier and shift given a scale.
# This utility function is for calculating the multiplier and shift given a scale.
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
def compute_multiplier_and_shift(
scales: list[float], scaleWidth: int = 32
Expand Down Expand Up @@ -214,7 +267,7 @@ def compute_multiplier_and_shift(
return multipliers, shifts


# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be
# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
# const inputs. Create constant operators from the data already initialized.
def create_const_ops_for_rescale(
tosa_fb,
Expand Down Expand Up @@ -335,14 +388,55 @@ def build_rescale_from_int32(
per_channel: bool = False,
tosa_spec=None,
) -> None:
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
build_rescale_from_int32_to_dtype(
tosa_fb,
input_node,
output_name,
output_zp,
rescale_scale,
ts.DType.INT8,
is_scale32,
is_double_round,
per_channel,
tosa_spec,
)

return


def build_rescale_from_int32_to_dtype(
tosa_fb: Any,
input_node: TosaArg,
output_name: str,
output_zp: int,
rescale_scale: float,
output_dtype: Any,
is_scale32: bool = True,
is_double_round: bool = False,
per_channel: bool = False,
tosa_spec=None,
) -> None:
"""Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).

Parameters:
tosa_fb: The TOSA serializer
input_node: Input tensor (should be INT32)
output_name: Name for the output tensor
output_zp: Output zero point
rescale_scale: Rescaling factor
output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
Other parameters: Standard rescale parameters
"""
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
build_rescale(
tosa_fb,
[rescale_scale],
input_node,
output_name=output_name,
output_type=ts.DType.INT8,
output_type=output_dtype,
input_zp=[0],
output_zp=[output_zp],
rounding_mode=RoundingMode.SINGLE_ROUND,
Expand Down
Loading