From 665b117a2dc08e40beaf7f5bcfa932ef62597183 Mon Sep 17 00:00:00 2001 From: Teo Bergkvist Date: Tue, 5 Aug 2025 16:07:15 +0200 Subject: [PATCH] Arm backend: Add support for fill_.Scalar Adds support for fill_.Scalar. fill_.scalar is decomposed to full_like which is already supported by Arm backend. Co-authored-by: Teo Bergkvist Signed-off-by: Oscar Andersson Change-Id: I3067136dec1d2762135158760a5368cc9103f74b --- .../arm/quantizer/quantization_annotator.py | 2 + backends/arm/test/ops/test_fill_scalar.py | 108 ++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 backends/arm/test/ops/test_fill_scalar.py diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ff1ad50e517..bea8fe2eddc 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -360,6 +360,7 @@ def _match_pattern( torch.ops.aten.max_pool2d.default, torch.ops.aten.full.default, torch.ops.aten.full, + torch.ops.aten.fill_.Scalar, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, torch.ops.aten.dropout_.default, @@ -625,6 +626,7 @@ def annotate_graph( # type: ignore[return] torch.ops.aten.full_like.default, torch.ops.aten.full.default, torch.ops.aten.full, + torch.ops.aten.fill_.Scalar, torch.ops.aten.scalar_tensor.default, ]: node.kwargs = {} diff --git a/backends/arm/test/ops/test_fill_scalar.py b/backends/arm/test/ops/test_fill_scalar.py new file mode 100644 index 00000000000..fb84d993575 --- /dev/null +++ b/backends/arm/test/ops/test_fill_scalar.py @@ -0,0 +1,108 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +aten_op = "torch.ops.aten.fill_.Scalar" +exir_op = "executorch_exir_dialects_edge__ops_aten_full_like_default" + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "ones_float": [torch.ones(2, 3), 5.0], + "ones_int": [torch.ones(2, 3), -3], +} + + +class FillScalar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y: torch.Tensor, fill_value: int | float): + mask = torch.full_like(y, 0) + mask.fill_(fill_value) + return mask * y + + +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + FillScalar(), + (*test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + FillScalar(), + (*test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + FillScalar(), + (*test_data,), + aten_ops=[aten_op], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + FillScalar(), + (*test_data,), + aten_ops=[aten_op], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_fill_scalar_vgf_FP(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FillScalar(), + (*test_data,), + aten_op, + exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_fill_scalar_vgf_INT(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FillScalar(), + (*test_data,), + aten_op, + exir_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()