Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
op_mean_dim,
op_permute,
op_quant,
op_sigmoid,
op_softmax,
op_view,
)
82 changes: 82 additions & 0 deletions backends/arm/operators/op_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class SigmoidVisitor(NodeVisitor):
target = "aten.sigmoid.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

assert len(node.all_input_nodes) == 1
assert len(node.users) == 1

if is_quant_node:
# Assume quantized input is 8 bit.

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)

table = sigmoid_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(table)

tosa_graph.addOperator(
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
)
else:
tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])


def sigmoid_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
"""
Returns a table mapping 256 entries to sigmoid([qmin,qmax])
Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_sigmoid
"""

def sigmoid(x):
# Convert quantized input to floating point sigmoid input space.
v = dequantize_value(x, in_quantargs)
# Compute sigmoid.
v = 1.0 / (1.0 + np.exp(-v))
# Convert sigmoid output back to quantized space.
return quantize_value(v, out_quantargs)

return [
sigmoid(x)
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
]
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class ArmQuantizer(Quantizer):
"max_pool2d",
"add",
"mul",
"sigmoid",
]

def __init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def decorator(annotator: AnnotatorType):
linear_annotator,
max_pool2d_annotator,
mul_annotator,
sigmoid_annotator,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, List, Optional

import torch
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node


@register_annotator("sigmoid")
def _annotate_sigmoid(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []

# input/ output range of sigmoid is always same -> quantize with fixed qspec.
# this configuration maps input: (-128, 127) -> (-6.0, 5.95). Outside these bounds, sigmoid ~= const.
# output: (-1,0.99) -> (-128, 127). Sigmoid has output value range (-1,1)
# Note that this exact choice is somewhat arbitrary.

input_act_qspec = quantization_config.get_fixed_qspec(scale=6 / 128, zp=0)
output_act_qspec = quantization_config.get_fixed_qspec(scale=1 / 128, zp=0)

for node in gm.graph.nodes:
if node.op != "call_function" or node.target != torch.ops.aten.sigmoid.default:
continue
if filter_fn and not filter_fn(node):
continue
input_node = node.args[0]

if not arm_quantizer_utils.is_annotated(node):
_annotate_input_qspec_map(
node,
input_node,
input_act_qspec,
)
_annotate_output_qspec(node, output_act_qspec)

arm_quantizer_utils.mark_nodes_as_annotated([node])
annotated_partitions.append([node])

return annotated_partitions
23 changes: 22 additions & 1 deletion backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import torch

from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer import (
FixedQParamsQuantizationSpec,
QuantizationSpec,
)


@dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -56,3 +59,21 @@ def get_bias_qspec(self) -> QuantizationSpec | None:
self.bias.dtype == torch.float
), "Only float dtype for bias is supported for bias right now"
return self.bias

def get_fixed_qspec(
self,
scale: float,
zp: int,
dtype: torch.dtype = torch.int8,
quant_min: int = -128,
quant_max: int = 127,
) -> FixedQParamsQuantizationSpec:
"""Returns a new FixedQParamsQuantizationSpec with the given parameters."""
return FixedQParamsQuantizationSpec(
dtype=dtype,
qscheme=torch.per_tensor_affine,
scale=scale,
zero_point=zp,
quant_min=quant_min,
quant_max=quant_max,
)
152 changes: 152 additions & 0 deletions backends/arm/test/ops/test_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


test_data_suite = [
# (test_name, test_data)
("zeros", torch.zeros(10, 10, 10, 10)),
("ones", torch.ones(10, 10, 10)),
("rand", torch.rand(10, 10) - 0.5),
("randn_pos", torch.randn(10) + 10),
("randn_neg", torch.randn(10) - 10),
("ramp", torch.arange(-16, 16, 0.2)),
]


class TestSigmoid(unittest.TestCase):
class Sigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x):
return self.sigmoid(x)

class AddSigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x):
return self.sigmoid(x + x)

class SigmoidAdd(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x):
return x + self.sigmoid(x)

class SigmoidAddSigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()

def forward(self, x, y):
return self.sigmoid((self.sigmoid(y) + self.sigmoid(x)))

def _test_sigmoid_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check(["torch.ops.aten.sigmoid.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.check(["torch.ops.aten.sigmoid.default"])
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_sigmoid_tosa_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
.quantize()
.export()
.check_count({"torch.ops.aten.sigmoid.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

@parameterized.expand(test_data_suite)
def test_sigmoid_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
):
self._test_sigmoid_tosa_MI_pipeline(self.Sigmoid(), (test_data,))

@parameterized.expand(test_data_suite)
def test_sigmoid_tosa_BI(self, test_name: str, test_data: torch.Tensor):
self._test_sigmoid_tosa_BI_pipeline(self.Sigmoid(), (test_data,))

def test_add_sigmoid_tosa_BI(self):
self._test_sigmoid_tosa_BI_pipeline(self.AddSigmoid(), (test_data_suite[0][1],))

def test_sigmoid_add_tosa_BI(self):
self._test_sigmoid_tosa_BI_pipeline(self.SigmoidAdd(), (test_data_suite[0][1],))

def test_sigmoid_add_sigmoid_tosa_BI(self):
self._test_sigmoid_tosa_BI_pipeline(
self.SigmoidAddSigmoid(), (test_data_suite[4][1], test_data_suite[3][1])
)

# Fails due to Vela diff from Tosa spec, expected to work with Regor.
@parameterized.expand(test_data_suite)
@unittest.expectedFailure
def test_sigmoid_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
self._test_sigmoid_tosa_u55_BI_pipeline(self.Sigmoid(), (test_data,))
14 changes: 14 additions & 0 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import math
from typing import NamedTuple

import numpy as np

import serializer.tosa_serializer as ts
import torch.fx
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
Expand All @@ -26,6 +28,18 @@ class QuantArgs(NamedTuple):
qmax: int


def quantize_value(x, qargs: QuantArgs, dtype=np.int8):
return np.clip(
np.round(x / qargs.scale) + qargs.zp,
qargs.qmin,
qargs.qmax,
).astype(dtype)


def dequantize_value(qx, qargs: QuantArgs):
return (qx - qargs.zp) * qargs.scale


def is_quant_node(node: torch.fx.Node):
consumer_node = list(node.users)[0]
input = node.all_input_nodes[0]
Expand Down