Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion backends/arm/test/ops/test_sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@

from typing import Tuple

import pytest
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize

aten_op = "torch.ops.aten.sigmoid.default" # Used for checking that we do not have softmax in the graph after decompose
exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default"
Expand Down Expand Up @@ -253,3 +260,105 @@ def test_sigmoid_vgf_INT_add_3():
tosa_version="TOSA-1.0+INT",
)
pipeline.run()


def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_data", test_data_suite)
@pytest.mark.xfail(
reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974"
)
def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor):
"""Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = TosaPipelineINT[input_t1](
Sigmoid(),
(test_data(),),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_sigmoid_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations"
)
def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor):
"""Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU55PipelineINT[input_t1](
Sigmoid(),
(test_data(),),
aten_op,
exir_op,
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_sigmoid_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations"
)
def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor):
"""Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU85PipelineINT[input_t1](
Sigmoid(),
(test_data(),),
aten_op,
exir_op,
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_sigmoid_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()
Loading