Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
QuantizeOperatorArguments,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
Expand Down Expand Up @@ -92,7 +92,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMeanDimToAveragePoolPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
Expand Down Expand Up @@ -128,7 +128,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeSoftmaxesPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
Expand Down
46 changes: 34 additions & 12 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,25 @@ def call(self, graph_module: GraphModule) -> PassResult:
return PassResult(graph_module, True)


class QuantizeFullArgument(ExportPass):
class QuantizeOperatorArguments(ExportPass):
"""
Make sure the fill_value for full.default is quantized. This pass needs to be run before
the folding pass above to make sure that the retraced output of the full.default op is
the right dtype.
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
More specifically, this pass:
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
the folding pass above to make sure that the retraced output of the full.default op is
the right dtype.
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
"""

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
# Loop over the graph nodes and find full.default nodes.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.target != exir_ops.edge.aten.full.default:
if n.target not in {
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.full.default,
}:
continue

# Make sure we have a quantized operator
Expand All @@ -188,13 +194,29 @@ def call(self, graph_module: GraphModule) -> PassResult:
continue

qargs = QuantArgs.from_operator(user.target, user.args)
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
# replace the node arg with a quantized dito and also set dtype
# to get the right output according to the Edge IR specification:
# exir/dialects/edge/edge.yaml:3596
quantized_full_value = qargs.quantize_value(n.args[1]).item()
n.update_arg(1, quantized_full_value)
n.update_kwarg("dtype", qargs.dtype)

if n.target == exir_ops.edge.aten.full.default:
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
# replace the node arg with a quantized dito and also set dtype
# to get the right output according to the Edge IR specification:
# exir/dialects/edge/edge.yaml:3596
quantized_full_value = qargs.quantize_value(n.args[1]).item()
n.update_arg(1, quantized_full_value)
n.update_kwarg("dtype", qargs.dtype)
modified = True
elif n.target == exir_ops.edge.aten.clamp.default:
# Quantize the min and max arguments of clamp, if they are not None
min_val = n.args[1]
max_val = None if len(n.args) <= 2 else n.args[2]

if min_val is not None:
quantized_min_val = qargs.quantize_value(min_val).item()
n.update_arg(1, quantized_min_val)

if max_val is not None:
quantized_max_val = qargs.quantize_value(max_val).item()
n.update_arg(2, quantized_max_val)

modified = True

return PassResult(graph_module, modified)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.hardtanh.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
op_batch_norm,
op_bmm,
op_cat,
op_clamp,
op_conv2d,
op_eq,
op_exp,
Expand Down
144 changes: 144 additions & 0 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree

from typing import Any, List, Tuple

import serializer.tosa_serializer as ts # type: ignore

import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class ClampVisitor_080_BI(NodeVisitor):
target = "aten.clamp.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
]

def __init__(self, *args):
super().__init__(*args)

def _create_clamp_node(
self,
tosa_graph: ts.TosaSerializer,
input_name: str,
output_name: str,
min_int: int,
max_int: int,
min_fp32: float,
max_fp32: float,
) -> None:
attr = ts.TosaSerializerAttribute()
attr.ClampAttribute(
tosa_graph.builder,
min_int,
max_int,
min_fp32,
max_fp32,
)
tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)

def _get_min_max_arguments(
self, node: Node, dtype_min: int | float, dtype_max: int | float
) -> Tuple[int | float, int | float]:

def cast_type(value: Any) -> int | float:
if isinstance(value, int):
return value
else:
# Attempt to cast to float
return float(value)

assert 2 <= len(node.args) <= 3

min_arg = dtype_min
max_arg = dtype_max

if node.args[1] is not None:
min_arg = cast_type(node.args[1])

if len(node.args) > 2:
if node.args[2] is not None:
max_arg = cast_type(node.args[2])

return min_arg, max_arg

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert len(node.all_input_nodes) == 1

min_int8, max_int8 = self._get_min_max_arguments(
node,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max,
)

# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
self._create_clamp_node(
tosa_graph,
inputs[0].name,
output.name,
int(min_int8),
int(max_int8),
0,
0,
)


@register_node_visitor
class ClampVisitor_080_MI(ClampVisitor_080_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert len(node.all_input_nodes) == 1

if inputs[0].dtype == ts.DType.INT8:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
min_fp32, max_fp32 = self._get_min_max_arguments(
node,
torch.finfo(torch.float32).min,
torch.finfo(torch.float32).max,
)

self._create_clamp_node(
tosa_graph,
inputs[0].name,
output.name,
0,
0,
min_fp32,
max_fp32,
)
2 changes: 2 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def _match_pattern(
torch.ops.aten.full.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
torch.ops.aten.clamp.default,
torch.ops.aten.clamp.Tensor,
operator.getitem,
]

Expand Down
Loading
Loading