From 54d7deee4827c8273c2e05acd4cfe59b3827bb16 Mon Sep 17 00:00:00 2001 From: Ryan O'Shea Date: Mon, 17 Mar 2025 16:25:49 +0100 Subject: [PATCH] Arm Backend: Add New DecomposeSilu pass to arm_pass_manager * Adds DecomposeSilu pass * Adds Tests for DecomposeSilu Signed-off-by: Ryan O'Shea Change-Id: Ib9f15d04c4c06d92d38cc9e6297145980052e673 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/decompose_silu_pass.py | 34 ++++++ backends/arm/quantizer/arm_quantizer.py | 4 +- backends/arm/test/ops/test_silu.py | 113 ++++++++++++++++++++ 5 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 backends/arm/_passes/decompose_silu_pass.py create mode 100644 backends/arm/test/ops/test_silu.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1a1719bf8ae..ddca8ea4a06 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -26,6 +26,7 @@ from .decompose_linear_pass import DecomposeLinearPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_select import DecomposeSelectPass # noqa +from .decompose_silu_pass import DecomposeSiluPass # noqa from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index dd9f8b7bed0..dd4ca7ad7bd 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -31,6 +31,7 @@ DecomposeLinearPass, DecomposeMeanDimPass, DecomposeSelectPass, + DecomposeSiluPass, DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, @@ -196,6 +197,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeSqrtPass()) + self.add_pass(DecomposeSiluPass()) if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: # Numerically stable softmax uses amax which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py new file mode 100644 index 00000000000..68ebb3f4515 --- /dev/null +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -0,0 +1,34 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +from executorch.exir.pass_base import ExportPass + +aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default) + + +class DecomposeSiluPass(ExportPass): + """ + This pass decomposes silu into a mul and a sigmoid node. + + Example: + y = silu(a) + Becomes: + x = sigmoid(a) + y = mul(a,x) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in (aten_silu_ops): + return super().call_operator(op, args, kwargs, meta) + sigmoid_op = torch.ops.aten.sigmoid.default + mul_op = torch.ops.aten.mul.Tensor + + original = args[0] + sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta) + + return super().call_operator(mul_op, (original, sigmoid), {}, meta) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e76ed5fb415..ee08f8e9eec 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -286,10 +286,10 @@ def _annotate_all_static_patterns( quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> GraphModule: - """Loops over all STATIC_OPS and runs the corresponding registred annotator. + """Loops over all STATIC_OPS and runs the corresponding registered annotator. Args: model: The model to annotate statically. - quantization_config: Specifices the QuantizationSpecs for the model's + quantization_config: Specifies the QuantizationSpecs for the model's input activations, output activations, weights and biases. filter_fn: An optional filter function that takes a node and returns whether the node should be annotated. Returns: diff --git a/backends/arm/test/ops/test_silu.py b/backends/arm/test/ops/test_silu.py new file mode 100644 index 00000000000..51748b02450 --- /dev/null +++ b/backends/arm/test/ops/test_silu.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + + +input_t = Tuple[torch.Tensor] + + +class Silu(torch.nn.Module): + def forward( + self, + _input: torch.Tensor, + _inplace: Optional[bool] = False, + ): + return torch.nn.SiLU(inplace=_inplace)(_input) + + test_data: list[input_t] = { + "op_silu_rank1_ones": (torch.ones(5),), + "op_silu_rank1_negative_ones": (torch.ones(5) * (-1),), + "op_silu_rank1_rand": (torch.rand(5) * 5,), + "op_silu_rank4_ones": (torch.ones(1, 10, 25, 20),), + "op_silu_rank4_negative_ones": ((-1) * torch.ones(1, 10, 25, 20),), + "op_silu_rank4_large_rand": (200 * torch.rand(1, 10, 25, 20),), + "op_silu_rank4_negative_large_rand": ((-200) * torch.rand(1, 10, 25, 20),), + "op_silu_rank4_large_randn": (200 * torch.randn(1, 10, 25, 20) + 1,), + } + + aten_op_MI = "torch.ops.aten.silu.default" + aten_op_inplace_MI = "torch.ops.aten.silu_.default" + aten_op_BI = ["torch.ops.aten.sigmoid.default", "torch.ops.aten.mul.Tensor"] + + +@common.parametrize("test_data", Silu.test_data) +def test_silu_tosa_MI(test_data: input_t): + silu_data = (test_data[0], False) + pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_MI) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +def test_silu_tosa_MI_inplace(test_data: input_t): + silu_data = (test_data[0], True) + pipeline = TosaPipelineMI[input_t](Silu(), silu_data, Silu.aten_op_inplace_MI) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +def test_silu_tosa_BI(test_data: input_t): + silu_data = (test_data[0], False) + pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +def test_silu_tosa_BI_inplace(test_data: input_t): + silu_data = (test_data[0], True) + pipeline = TosaPipelineBI[input_t](Silu(), silu_data, Silu.aten_op_BI) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +@common.XfailIfNoCorstone300 +def test_silu_u55_BI(test_data: input_t): + silu_data = (test_data[0], False) + pipeline = EthosU55PipelineBI[input_t]( + Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +@common.XfailIfNoCorstone300 +def test_silu_u55_BI_inplace(test_data: input_t): + silu_data = (test_data[0], True) + pipeline = EthosU55PipelineBI[input_t]( + Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +@common.XfailIfNoCorstone320 +def test_silu_u85_BI(test_data: input_t): + silu_data = (test_data[0], False) + pipeline = EthosU85PipelineBI[input_t]( + Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", Silu.test_data) +@common.XfailIfNoCorstone320 +def test_silu_u85_BI_inplace(test_data: input_t): + silu_data = (test_data[0], True) + pipeline = EthosU85PipelineBI[input_t]( + Silu(), silu_data, Silu.aten_op_BI, run_on_fvp=True + ) + pipeline.run()