Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
)
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand All @@ -50,6 +53,7 @@
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import PassManager


Expand Down Expand Up @@ -80,6 +84,15 @@ def transform_to_backend_pipeline(
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(
FoldAndAnnotateQParamsPass(
[
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
]
)
)
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
131 changes: 131 additions & 0 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

from typing import Callable, cast, Iterable

from executorch.backends.arm.tosa_quant_utils import QuantArgs

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node


def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
Raises a ValueError if the node doesn't have any parameters set.
"""
if "input_qparams" not in node.meta.keys():
raise ValueError(f"No input quantization parameter found in node {node}")
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
if len(input_qparams) == 0:
raise ValueError(f"No input quantization parameter found in node {node}")
return input_qparams


def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
Raises a ValueError if the node doesn't have any parameters set.
"""
if "output_qparams" not in node.meta.keys():
raise ValueError(f"No output quantization parameter found in node {node}")
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
if len(input_qparams) == 0:
raise ValueError(f"No output quantization parameter found in node {node}")
return input_qparams


class FoldAndAnnotateQParamsPass(ExportPass):
"""
A pass that walks the graph and removes any DQ and Q nodes before and after the target
node in the supplied list of operators.
The quantization parameters from the DQ/Q nodes are stored as meta values to be
accessible for later lowering and serialization passes.
The assumption is that the quantization annotatation adds DQ nodes for all tensor
inputs to the target one Q node to the output.

Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):

x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)

x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)

output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)

Becomes:
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)

aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)

output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)

The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.

"""

def __init__(self, targeted_ops: Iterable[Callable]):
super().__init__()
self.targeted_ops = targeted_ops

def call(self, graph_module: GraphModule) -> PassResult:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
continue

# Make sure we haven't already set qparams meta information on the node
assert "input_qparams" not in n.meta.keys()
assert "output_qparams" not in n.meta.keys()

# for the inputs and outputs search the graph for quantization info and
# store the information in a dict with order of the _tensor_ inputs as key,
# ignoring any other arguments to the target node.
n.meta["input_qparams"] = {}
n.meta["output_qparams"] = {}
for i, arg in enumerate(n.args):
if not isinstance(arg, Node):
continue
if arg.target != dq_op:
continue

# arg.target for argument i is a dequant node, extract the information
n.meta["input_qparams"][i] = QuantArgs.from_operator(
arg.target, arg.args
)

# arg.args[0] is the tensor input, replace the input usage
n.replace_input_with(arg, arg.args[0])
graph_module.graph.erase_node(arg)

# Copy the users, since we are modifying it.
users_copy = copy.copy(n.users)
for i, user in enumerate(users_copy):
if user.target != q_op:
continue

# quantization node found here, store the quantization parameters in meta value
n.meta["output_qparams"][i] = QuantArgs.from_operator(
user.target, user.args
)

user.replace_all_uses_with(n)
graph_module.graph.erase_node(user)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
op_get_item,
op_hardtanh,
op_log,
op_max,
op_max_pool2d,
op_min,
op_mm,
op_mul,
op_permute,
Expand Down
51 changes: 26 additions & 25 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down Expand Up @@ -41,33 +40,27 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
input_nodes = tutils.get_two_inputs(node)

if not is_quant_node and not all(
tensor.meta["val"].dtype in (torch.int8, torch.int32)
for tensor in input_nodes
):
raise RuntimeError(
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
)

needs_rescale = not (
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
and node.meta["val"].dtype == torch.int32
)

if needs_rescale:
# Rescale inputs to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
input_nodes, tosa_graph
# Specification (0.80.0) states that input and output types
# should all be the same
assert inputs[0].dtype == inputs[1].dtype == output.dtype
# Handle int8 (quantized) and int32
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]

if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.ADD
rescaled_inputs = inputs

# Prepare add output tensor
if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
add_output = output
rescaled_inputs = inputs

# Do the INT32 Add
tosa_graph.addOperator(
Expand All @@ -80,10 +73,10 @@ def define_node(
None,
)

if needs_rescale:
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)


@register_node_visitor
Expand All @@ -105,11 +98,19 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
# Specification (0.80.0) states that input and output types
# should all be the same
assert inputs[0].dtype == inputs[1].dtype == output.dtype

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
else:
# FP32 Add lowering
assert inputs[0].dtype == ts.DType.FP32
assert output.dtype == ts.DType.FP32

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ADD,
[inputs[0].name, inputs[1].name],
Expand Down
74 changes: 74 additions & 0 deletions backends/arm/operators/op_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List

import executorch.backends.arm.tosa_quant_utils as tqutils
import serializer.tosa_serializer as ts
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import tosa_shape

from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class MaxVisitor(NodeVisitor):
target = "aten.maximum.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert inputs[0].dtype == inputs[1].dtype

max_output = output
if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
assert (
len(input_qparams) == 2
), f"Both inputs needs to have quantization information for {node}"
# insert RESCALEs to int32
assert (
input_qparams[0] == input_qparams[1]
), "Both inputs must have same quantization for MAX"

operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
)

output.shape = tosa_shape(output.shape, output.dim_order)
max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
else:
operand_inputs = inputs

tosa_graph.addOperator(
TosaOp.Op().MAXIMUM,
[
operand_inputs[0].name,
operand_inputs[1].name,
],
[max_output.name],
)

if output.dtype == ts.DType.INT8:
# insert RESCALE from int32 back to int8
tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node)
Loading
Loading