Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def _match_pattern(
torch.ops.aten.max_pool2d.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
torch.ops.aten.dropout_.default,
Expand Down Expand Up @@ -625,6 +626,7 @@ def annotate_graph( # type: ignore[return]
torch.ops.aten.full_like.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.scalar_tensor.default,
]:
node.kwargs = {}
108 changes: 108 additions & 0 deletions backends/arm/test/ops/test_fill_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

aten_op = "torch.ops.aten.fill_.Scalar"
exir_op = "executorch_exir_dialects_edge__ops_aten_full_like_default"

input_t1 = Tuple[torch.Tensor]

test_data_suite = {
"ones_float": [torch.ones(2, 3), 5.0],
"ones_int": [torch.ones(2, 3), -3],
}


class FillScalar(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, y: torch.Tensor, fill_value: int | float):
mask = torch.full_like(y, 0)
mask.fill_(fill_value)
return mask * y


@common.parametrize("test_data", test_data_suite)
def test_fill_scalar_tosa_FP(test_data: Tuple):
pipeline = TosaPipelineFP[input_t1](
FillScalar(),
(*test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_fill_scalar_tosa_INT(test_data: Tuple):
pipeline = TosaPipelineINT[input_t1](
FillScalar(),
(*test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_data", test_data_suite)
def test_fill_scalar_u55_INT(test_data: Tuple):
pipeline = EthosU55PipelineINT[input_t1](
FillScalar(),
(*test_data,),
aten_ops=[aten_op],
exir_ops=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_data", test_data_suite)
def test_fill_scalar_u85_INT(test_data: Tuple):
pipeline = EthosU85PipelineINT[input_t1](
FillScalar(),
(*test_data,),
aten_ops=[aten_op],
exir_ops=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_fill_scalar_vgf_FP(test_data: input_t1):
pipeline = VgfPipeline[input_t1](
FillScalar(),
(*test_data,),
aten_op,
exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_fill_scalar_vgf_INT(test_data: input_t1):
pipeline = VgfPipeline[input_t1](
FillScalar(),
(*test_data,),
aten_op,
exir_op,
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading