Skip to content

Commit 08759d6

Browse files
authored
Merge branch 'main' into dev1/winskuo/mimi_stage2
2 parents 4501ff2 + e29a4b5 commit 08759d6

21 files changed

+583
-194
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .decompose_select import DecomposeSelectPass # noqa
2828
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
2929
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
30+
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
3031
from .decompose_var_pass import DecomposeVarPass # noqa
3132
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3233
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeSelectPass,
3333
DecomposeSoftmaxPass,
3434
DecomposeSoftmaxUnstablePass,
35+
DecomposeSqrtPass,
3536
DecomposeVarPass,
3637
FoldAndAnnotateQParamsPass,
3738
FuseBatchnorm2DPass,
@@ -115,6 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115116
return self._transform(exported_program.graph_module)
116117

117118
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
119+
self.add_pass(DecomposeSqrtPass())
118120
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
119121
self.add_pass(FuseQuantizedActivationPass())
120122
self.add_pass(RemoveGetItemPass())
@@ -181,6 +183,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
181183
self.add_pass(DecomposeMeanDimPass())
182184
self.add_pass(DecomposeDivPass())
183185
self.add_pass(DecomposeLeakyReLUPass())
186+
self.add_pass(DecomposeSqrtPass())
184187

185188
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
186189
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
12+
aten_sqrt_ops = (
13+
torch.ops.aten.sqrt.default,
14+
torch.ops.aten.sqrt_.default,
15+
)
16+
17+
18+
def get_sqrt_decomposition(op) -> tuple:
19+
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
20+
if op in edge_sqrt_ops:
21+
return exir_ops.edge.aten.pow.Tensor_Scalar
22+
if op in aten_sqrt_ops:
23+
return torch.ops.aten.pow.Tensor_Scalar
24+
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
25+
26+
27+
class DecomposeSqrtPass(ExportPass):
28+
29+
def call_operator(self, op, args, kwargs, meta):
30+
"""
31+
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
32+
"""
33+
34+
if op not in (edge_sqrt_ops + aten_sqrt_ops):
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
pow_op = get_sqrt_decomposition(op)
38+
39+
return super().call_operator(pow_op, (args[0], 0.5), {}, meta)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(self, exported_program):
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
51+
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.lt.Tensor,
5153
exir_ops.edge.aten.pow.Tensor_Tensor,
5254
exir_ops.edge.aten.where.self,
5355
]

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@
2626
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
29+
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
2931
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3032
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3133
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
3234
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
3335
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
3436
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3537
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
38+
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
39+
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
3640
}
3741

3842

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ class EthosU55NotSupported(OperatorSupportBase):
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137137
exir_ops.edge.aten.gt.Tensor,
138+
exir_ops.edge.aten.gt.Scalar,
138139
exir_ops.edge.aten.le.Tensor,
139140
exir_ops.edge.aten.lt.Tensor,
141+
exir_ops.edge.aten.lt.Scalar,
140142
exir_ops.edge.aten.flip.default, # REVERSE
141143
exir_ops.edge.aten.grid_sampler_2d, # GATHER
142144
exir_ops.edge.aten.scatter.src,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def is_node_supported(
176176
exir_ops.edge.aten.full_like.default,
177177
exir_ops.edge.aten.ge.Tensor,
178178
exir_ops.edge.aten.gt.Tensor,
179+
exir_ops.edge.aten.gt.Scalar,
179180
exir_ops.edge.aten.le.Tensor,
180181
exir_ops.edge.aten.lt.Tensor,
182+
exir_ops.edge.aten.lt.Scalar,
181183
exir_ops.edge.aten.mul.Tensor,
182184
exir_ops.edge.aten.add.Scalar,
183185
exir_ops.edge.aten.sub.Scalar,
@@ -194,6 +196,7 @@ def is_node_supported(
194196
exir_ops.edge.aten.reciprocal.default,
195197
exir_ops.edge.aten.relu.default,
196198
exir_ops.edge.aten.leaky_relu.default,
199+
exir_ops.edge.aten.sqrt.default,
197200
exir_ops.edge.aten.rsqrt.default,
198201
exir_ops.edge.aten._softmax.default,
199202
exir_ops.edge.aten.select_copy.int,
@@ -256,6 +259,7 @@ def is_node_supported(
256259
exir_ops.edge.aten.var.correction,
257260
exir_ops.edge.aten.var.dim,
258261
exir_ops.edge.aten.add.Scalar,
262+
exir_ops.edge.aten.sqrt.default,
259263
exir_ops.edge.aten.sub.Scalar,
260264
exir_ops.edge.aten.mul.Scalar,
261265
exir_ops.edge.aten.div.Scalar,

backends/arm/operators/op_clamp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def cast_type(value: Any) -> int | float:
6363
# Attempt to cast to float
6464
return float(value)
6565

66-
assert 2 <= len(node.args) <= 3
66+
if len(node.args) != 2 and len(node.args) != 3:
67+
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")
6768

6869
min_arg = dtype_min
6970
max_arg = dtype_max
@@ -84,7 +85,10 @@ def define_node(
8485
inputs: List[TosaArg],
8586
output: TosaArg,
8687
) -> None:
87-
assert len(node.all_input_nodes) == 1
88+
if len(node.all_input_nodes) != 1:
89+
raise ValueError(
90+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
91+
)
8892

8993
min_int8, max_int8 = self._get_min_max_arguments(
9094
node,
@@ -122,7 +126,10 @@ def define_node(
122126
inputs: List[TosaArg],
123127
output: TosaArg,
124128
) -> None:
125-
assert len(node.all_input_nodes) == 1
129+
if len(node.all_input_nodes) != 1:
130+
raise ValueError(
131+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
132+
)
126133

127134
if inputs[0].dtype == ts.DType.INT8:
128135
# Call the inherited define_node for handling integers

backends/arm/operators/op_maximum.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,27 @@ def define_node(
3636
inputs: List[TosaArg],
3737
output: TosaArg,
3838
) -> None:
39-
assert inputs[0].dtype == inputs[1].dtype
39+
if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype:
40+
raise TypeError(
41+
f"Data type of inputs and output must be the same. Got input 0 dtype: "
42+
f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output "
43+
f"dtype: {output.dtype}"
44+
)
4045

4146
scale_back = 1.0
4247
max_output = output
4348
if inputs[0].dtype == ts.DType.INT8:
4449
input_qparams = get_input_qparams(node)
45-
assert (
46-
len(input_qparams) == 2
47-
), f"Both inputs needs to have quantization information for {node}"
48-
# insert RESCALEs to int32
49-
assert (
50-
input_qparams[0] == input_qparams[1]
51-
), "Both inputs must have same quantization for MAX"
50+
if len(input_qparams) != 2:
51+
raise ValueError(
52+
f"Both inputs need to have quantization information for {node}"
53+
)
54+
if input_qparams[0] != input_qparams[1]:
55+
raise ValueError(
56+
"Both inputs must have the same quantization parameters for MAX"
57+
)
5258

59+
# insert RESCALEs to int32
5360
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5461
tosa_graph, inputs, node
5562
)

backends/arm/operators/op_minimum.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,27 @@ def define_node(
3737
inputs: List[TosaArg],
3838
output: TosaArg,
3939
) -> None:
40-
assert inputs[0].dtype == inputs[1].dtype
40+
if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype:
41+
raise TypeError(
42+
f"Data type of inputs and output must be the same. Got input 0 dtype: "
43+
f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output "
44+
f"dtype: {output.dtype}"
45+
)
4146

4247
scale_back = 1.0
4348
min_output = output
4449
if inputs[0].dtype == ts.DType.INT8:
4550
input_qparams = get_input_qparams(node)
46-
assert (
47-
len(input_qparams) == 2
48-
), f"Both inputs needs to have quantization information for {node}"
49-
# insert RESCALEs to int32
50-
assert (
51-
input_qparams[0] == input_qparams[1]
52-
), "Both inputs must have same quantization for MIN"
51+
if len(input_qparams) != 2:
52+
raise ValueError(
53+
f"Both inputs need to have quantization information for {node}"
54+
)
55+
if input_qparams[0] != input_qparams[1]:
56+
raise ValueError(
57+
"Both inputs must have the same quantization parameters for MIN"
58+
)
5359

60+
# insert RESCALEs to int32
5461
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5562
tosa_graph, inputs, node
5663
)

0 commit comments

Comments
 (0)