Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class BaseTOSASupportList(OperatorSupportBase):

def is_node_supported(self, submodules, node: fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import ( # noqa
node_visitor,
op_abs,
op_add,
op_avg_pool2d,
op_bmm,
Expand Down
133 changes: 133 additions & 0 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification

from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class AbsVisitor_080_BI(NodeVisitor):
target = "aten.abs.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
# Handle int8 (quantized) and int32
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
raise ValueError(
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
)

if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.abs
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
abs_output = output

# Do the INT32 Abs
tosa_graph.addOperator(
TosaOp.Op().ABS,
[
rescaled_inputs[0].name,
],
[abs_output.name],
None,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, abs_output, scale_back, node) # type: ignore[possibly-undefined]


@register_node_visitor
class AbsVisitor_080_MI(AbsVisitor_080_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Abs lowering

if not (inputs[0].dtype == ts.DType.FP32):
raise ValueError(
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
)

if not (output.dtype == ts.DType.FP32):
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ABS,
[inputs[0].name],
[output.name],
None,
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _match_pattern(


_one_to_one = [
torch.ops.aten.abs.default,
torch.ops.aten.exp.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
Expand Down
125 changes: 125 additions & 0 deletions backends/arm/test/ops/test_abs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple

import pytest

import torch
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized


class TestAbs(unittest.TestCase):
class Abs(torch.nn.Module):
test_parameters = [
(torch.zeros(5),),
(torch.full((5,), -1, dtype=torch.float32),),
(torch.ones(5) * -1,),
(torch.randn(8),),
(torch.randn(2, 3, 4),),
(torch.randn(1, 2, 3, 4),),
(torch.normal(mean=0, std=10, size=(2, 3, 4)),),
]

def forward(self, x):
return torch.abs(x)

def _test_abs_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check_count({"torch.ops.aten.abs.default": 1})
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["torch.ops.aten.abs.default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_abs_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.check_count({"torch.ops.aten.abs.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

def _test_abs_ethosu_BI_pipeline(
self,
compile_spec: list[CompileSpec],
module: torch.nn.Module,
test_data: Tuple[torch.Tensor],
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_count({"torch.ops.aten.abs.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)

@parameterized.expand(Abs.test_parameters)
def test_abs_tosa_MI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_abs_tosa_MI_pipeline(self.Abs(), test_data)

@parameterized.expand(Abs.test_parameters)
def test_abs_tosa_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_abs_tosa_BI_pipeline(self.Abs(), test_data)

@parameterized.expand(Abs.test_parameters)
@pytest.mark.corstone_fvp
def test_abs_u55_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_abs_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Abs(), test_data
)

@parameterized.expand(Abs.test_parameters)
@pytest.mark.corstone_fvp
def test_abs_u85_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_abs_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Abs(), test_data
)
Loading