-
Notifications
You must be signed in to change notification settings - Fork 743
Quantization folding pass #7240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
6967ade
Add functions for usage with DQ/Q folding pass
per 86777b1
Introduce a quantization folding pass with annotations
per 0386b23
Add lowering of TOSA.MIN and TOSA.MAX
per a8daea5
Add ADD to qdq pass handling
per 2cbf05a
Add test for fold qdq pass annotation
per ed236c3
Add helper functions for Q/DQ folding pass
per 2a03d6f
Update Q/DQ Folding pass test to sequence of ops
per File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| # Copyright 2024 Arm Limited and/or its affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
|
|
||
| from typing import Callable, cast, Iterable | ||
|
|
||
| from executorch.backends.arm.tosa_quant_utils import QuantArgs | ||
|
|
||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
|
|
||
| from executorch.exir.pass_base import ExportPass, PassResult | ||
| from torch.fx import GraphModule, Node | ||
|
|
||
|
|
||
| def get_input_qparams(node: Node) -> dict[int, QuantArgs]: | ||
| """ | ||
| Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. | ||
| Raises a ValueError if the node doesn't have any parameters set. | ||
| """ | ||
| if "input_qparams" not in node.meta.keys(): | ||
| raise ValueError(f"No input quantization parameter found in node {node}") | ||
| input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"]) | ||
| if len(input_qparams) == 0: | ||
| raise ValueError(f"No input quantization parameter found in node {node}") | ||
| return input_qparams | ||
|
|
||
|
|
||
| def get_output_qparams(node: Node) -> dict[int, QuantArgs]: | ||
| """ | ||
| Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. | ||
| Raises a ValueError if the node doesn't have any parameters set. | ||
| """ | ||
| if "output_qparams" not in node.meta.keys(): | ||
| raise ValueError(f"No output quantization parameter found in node {node}") | ||
| input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"]) | ||
| if len(input_qparams) == 0: | ||
| raise ValueError(f"No output quantization parameter found in node {node}") | ||
| return input_qparams | ||
|
|
||
|
|
||
| class FoldAndAnnotateQParamsPass(ExportPass): | ||
| """ | ||
| A pass that walks the graph and removes any DQ and Q nodes before and after the target | ||
| node in the supplied list of operators. | ||
| The quantization parameters from the DQ/Q nodes are stored as meta values to be | ||
| accessible for later lowering and serialization passes. | ||
| The assumption is that the quantization annotatation adds DQ nodes for all tensor | ||
| inputs to the target one Q node to the output. | ||
|
|
||
| Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability): | ||
|
|
||
| x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) | ||
|
|
||
| x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8) | ||
| aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq) | ||
| aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8) | ||
|
|
||
| output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) | ||
|
|
||
| Becomes: | ||
| x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) | ||
|
|
||
| aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q) | ||
|
|
||
| output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) | ||
|
|
||
| The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node. | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, targeted_ops: Iterable[Callable]): | ||
| super().__init__() | ||
| self.targeted_ops = targeted_ops | ||
|
|
||
| def call(self, graph_module: GraphModule) -> PassResult: | ||
| q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default | ||
| dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default | ||
|
|
||
| # Loop over the graph nodes and find any node in the 'targeted_ops' list. | ||
| for n in graph_module.graph.nodes: | ||
| n = cast(Node, n) | ||
| if n.op != "call_function" or n.target not in self.targeted_ops: | ||
| continue | ||
|
|
||
| # Make sure we haven't already set qparams meta information on the node | ||
| assert "input_qparams" not in n.meta.keys() | ||
| assert "output_qparams" not in n.meta.keys() | ||
|
|
||
| # for the inputs and outputs search the graph for quantization info and | ||
| # store the information in a dict with order of the _tensor_ inputs as key, | ||
| # ignoring any other arguments to the target node. | ||
| n.meta["input_qparams"] = {} | ||
| n.meta["output_qparams"] = {} | ||
| for i, arg in enumerate(n.args): | ||
| if not isinstance(arg, Node): | ||
| continue | ||
| if arg.target != dq_op: | ||
| continue | ||
|
|
||
| # arg.target for argument i is a dequant node, extract the information | ||
| n.meta["input_qparams"][i] = QuantArgs.from_operator( | ||
| arg.target, arg.args | ||
| ) | ||
|
|
||
| # arg.args[0] is the tensor input, replace the input usage | ||
| n.replace_input_with(arg, arg.args[0]) | ||
| graph_module.graph.erase_node(arg) | ||
|
|
||
| # Copy the users, since we are modifying it. | ||
| users_copy = copy.copy(n.users) | ||
| for i, user in enumerate(users_copy): | ||
| if user.target != q_op: | ||
| continue | ||
|
|
||
| # quantization node found here, store the quantization parameters in meta value | ||
| n.meta["output_qparams"][i] = QuantArgs.from_operator( | ||
| user.target, user.args | ||
| ) | ||
|
|
||
| user.replace_all_uses_with(n) | ||
| graph_module.graph.erase_node(user) | ||
|
|
||
| # retrace the graph to update the fake tensor types | ||
| graph_module = super().call(graph_module).graph_module | ||
|
|
||
| graph_module.recompile() | ||
| return PassResult(graph_module, True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,9 @@ | |
| op_get_item, | ||
| op_hardtanh, | ||
| op_log, | ||
| op_max, | ||
| op_max_pool2d, | ||
| op_min, | ||
| op_mm, | ||
| op_mul, | ||
| op_permute, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # Copyright 2024 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # pyre-unsafe | ||
|
|
||
| from typing import List | ||
|
|
||
| import executorch.backends.arm.tosa_quant_utils as tqutils | ||
| import serializer.tosa_serializer as ts | ||
| from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( | ||
| get_input_qparams, | ||
| ) | ||
| from executorch.backends.arm.operators.node_visitor import ( | ||
| NodeVisitor, | ||
| register_node_visitor, | ||
| ) | ||
| from executorch.backends.arm.tosa_mapping import TosaArg | ||
| from executorch.backends.arm.tosa_utils import tosa_shape | ||
|
|
||
| from serializer.tosa_serializer import TosaOp | ||
| from torch.fx import Node | ||
|
|
||
|
|
||
| @register_node_visitor | ||
| class MaxVisitor(NodeVisitor): | ||
| target = "aten.maximum.default" | ||
|
|
||
| def __init__(self, *args): | ||
| super().__init__(*args) | ||
|
|
||
| def define_node( | ||
| self, | ||
| node: Node, | ||
| tosa_graph: ts.TosaSerializer, | ||
| inputs: List[TosaArg], | ||
| output: TosaArg, | ||
| is_quant_node: bool, | ||
| ) -> None: | ||
| assert inputs[0].dtype == inputs[1].dtype | ||
|
|
||
| max_output = output | ||
| if inputs[0].dtype == ts.DType.INT8: | ||
| input_qparams = get_input_qparams(node) | ||
| assert ( | ||
| len(input_qparams) == 2 | ||
| ), f"Both inputs needs to have quantization information for {node}" | ||
| # insert RESCALEs to int32 | ||
| assert ( | ||
| input_qparams[0] == input_qparams[1] | ||
| ), "Both inputs must have same quantization for MAX" | ||
|
|
||
| operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( | ||
| tosa_graph, inputs, node | ||
| ) | ||
|
|
||
| output.shape = tosa_shape(output.shape, output.dim_order) | ||
| max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) | ||
| else: | ||
| operand_inputs = inputs | ||
|
|
||
| tosa_graph.addOperator( | ||
| TosaOp.Op().MAXIMUM, | ||
| [ | ||
| operand_inputs[0].name, | ||
| operand_inputs[1].name, | ||
| ], | ||
| [max_output.name], | ||
| ) | ||
|
|
||
| if output.dtype == ts.DType.INT8: | ||
| # insert RESCALE from int32 back to int8 | ||
| tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.