Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions backends/arm/operators/op_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def define_node(
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
output.tosa_spec,
)

Expand All @@ -59,12 +59,18 @@ def define_node(
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
elif inputs[0].dtype == ts.DType.INT16:
rescaled_inputs, scale_back = (
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.SUB
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
Expand Down Expand Up @@ -95,6 +101,15 @@ def define_node(
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]
elif output.dtype == ts.DType.INT16:
tqutils.insert_rescale_op_to_int16(
tosa_graph,
sub_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]


@register_node_visitor
Expand Down
101 changes: 100 additions & 1 deletion backends/arm/test/ops/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@
from typing import Tuple

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize

aten_op = "torch.ops.aten.sub.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_sub_Tensor"
Expand Down Expand Up @@ -242,3 +248,96 @@ def test_sub_tensor_vgf_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
tosa_version="TOSA-1.0+INT",
)
pipeline.run()


def get_symmetric_a16w8_sub_quantizer(per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_data", sub_test_data)
def test_sub_tensor_16a8w_tosa_INT(test_data: input_t1):
"""Test sub operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = TosaPipelineINT[input_t1](
Sub(),
test_data(),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_sub_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", sub_test_data)
@common.XfailIfNoCorstone300
def test_sub_tensor_16a8w_u55_INT16(test_data: input_t1):
"""Test sub operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU55PipelineINT[input_t1](
Sub(),
test_data(),
aten_op,
exir_op,
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_sub_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", sub_test_data)
@common.XfailIfNoCorstone320
def test_sub_tensor_16a8w_u85_INT16(test_data: input_t1):
"""Test sub operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU85PipelineINT[input_t1](
Sub(),
test_data(),
aten_op,
exir_op,
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_sub_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()
1 change: 1 addition & 0 deletions backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def define_arm_tests():
"ops/test_mul.py",
"ops/test_slice.py",
"ops/test_sigmoid.py",
"ops/test_sub.py",
"ops/test_tanh.py",
"ops/test_view.py",
"ops/test_cos.py",
Expand Down
Loading