Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from .insert_int32_casts_after_int64_placeholders import ( # noqa
InsertInt32CastsAfterInt64PlaceholdersPass,
)
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
InsertRescaleInt32Pass,
InsertRescalePass,
InsertTableOpsPass,
MatchArgDtypePass,
Expand Down Expand Up @@ -214,6 +215,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())
self.add_pass(InsertRescaleInt32Pass())

self.validate_constraints_mandatory()
return self._transform(exported_program.graph_module)
Expand Down
240 changes: 238 additions & 2 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
# LICENSE file in the root directory of this source tree.

from copy import copy
from typing import cast, Set, Type
from typing import cast, Dict, Optional, Set, Tuple, Type

from executorch.backends.arm._passes.arm_pass_utils import create_node
import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_output_qparams,
)
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -65,3 +70,234 @@ def call(self, graph_module: GraphModule) -> PassResult:
graph_module = super().call(graph_module).graph_module
graph_module.recompile()
return PassResult(graph_module, modified)


class InsertRescaleInt32Pass(ArmPass):
"""
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
quantized implementations. This pass treats such operator nodes by
inserting rescale ops before and after them if needed. Note that extra logic
that handles the scales and zero points must be in place because the affected
TOSA have naive implementations that do not account for the quantization
parameters.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

included_targets = [
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
]

def _int32_qargs(self, s):
"""Helper creator function for INT32-based QuantArgs"""

return QuantArgs(
scale=s,
zp=0,
qmin=torch.iinfo(torch.int32).min,
qmax=torch.iinfo(torch.int32).max,
dtype=torch.int32,
)

def _get_inputs_rescaled_qparams(
self, target, input_qparams: Dict[int, QuantArgs]
) -> Dict[int, QuantArgs]:
"""Get the qparams for the INT32 operands to the op ``target``

Inputs to the INT32-based operator must be rescaled from INT8 to INT32.
This function computes the ``QuantArgs`` for each of the operands and returns
it as a dict, mapping tensor index to ``QuantArgs``.
"""

if target in [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be this boat has sailed already but for logical grouping perspective, I would rather keep the util call in respective ops instead of a pass.. but that's just me. Not a strong objection.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai Not sure what you mean fully. Can you elaborate a bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of inserting rescale ops in the tosa graph through this pass, I would prefer (mild preference) of keeping the way things were, i.e. an op lowering (op_.py) would inject an op when needed.

This is mainly to keep things together logically and keep the tosa graph 'sound' after each op lowering.

That said I understand the code duplication aspect, but that could be tackled with keeping these insert rescale as util functions.

Again, not a strong objection. Stamping to unblock you, and leaving it up to you to decide.

Thanks!

exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
]:
# For these ops, use the smallest scale among the INT8 operands.
min_scale = min(
[qp.get_scale_per_tensor() for qp in input_qparams.values()]
)
qparams = {
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
}
else:
raise ValueError(f"Not a valid target: {target}")

return qparams

def _get_output_qparams(
self, target, inputs_qparams: Dict[int, QuantArgs]
) -> Optional[QuantArgs]:
"""Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute
the scale of the output based on how the operator itself affects it."""

if target in [
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
]:
# The op has not altered the scale; the output scale is equal to
# the operands' scales.
return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor())
elif target in [
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
]:
# Output is bool for these ops and thus no qparams are present
return None
else:
raise ValueError(f"Not a valid target: {target}")

def _get_rescale_qparams(
self, target, input_qparams: Dict[int, QuantArgs]
) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]:
"""
Get the quantization parameters of the INT32 inputs/outputs that will
surround the node after the new RESCALE ops have been inserted.
"""

inputs_rescaled_qparams = self._get_inputs_rescaled_qparams(
target, input_qparams
)
output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams)

return (inputs_rescaled_qparams, output_qparams)

def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool:
qargs = node.meta["input_qparams"]

args_copy = list(node.args)
seen_args = set()
modified = False
for i in qargs:
qp = qargs[i]
if qp.dtype != torch.int8:
continue

arg_node = args_copy[i]
if arg_node in seen_args:
continue
seen_args.add(arg_node)

with graph.inserting_after(arg_node):
rescale_node = create_node(
graph,
exir_ops.backend.tosa.RESCALE.default,
(
arg_node,
torch.int32,
qp.get_scale_per_tensor()
/ rescale_qargs[
i
].get_scale_per_tensor(), # Old scale / new scale
qp.get_zp_per_tensor(), # Old zero point
rescale_qargs[i].get_zp_per_tensor(), # New zero point
),
from_node=node,
)

node.replace_input_with(arg_node, rescale_node)
modified = True

return modified

def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool:
if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0:
return False

qargs = get_output_qparams(node)
assert len(qargs) == 1
assert rescale_qargs is not None

qarg = qargs[0]
if qarg.dtype != torch.int8:
return False

users_copy = list(node.users)

with graph.inserting_after(node):
rescale_node = create_node(
graph,
exir_ops.backend.tosa.RESCALE.default,
(
node,
torch.int8,
rescale_qargs.get_scale_per_tensor()
/ qarg.get_scale_per_tensor(), # Old scale / new scale
rescale_qargs.get_zp_per_tensor(), # Old zero point
qarg.get_zp_per_tensor(), # New zero point
),
from_node=node,
)

for user in users_copy:
user.replace_input_with(node, rescale_node)

return True

def call(self, graph_module: GraphModule) -> PassResult:
graph = graph_module.graph

modified = False
for node in list(graph.nodes):
node = cast(Node, node)

if node.op != "call_function" or node.target not in self.included_targets:
continue

if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0:
continue
input_qparams = node.meta["input_qparams"]

inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams(
node.target, input_qparams
)

inputs_was_rescaled = self._rescale_inputs(
graph, node, inputs_rescale_qargs
)
outputs_was_rescaled = False
if inputs_was_rescaled:
outputs_was_rescaled = self._rescale_outputs(
graph, node, output_rescale_qargs
)
modified = True

# Update node metadata

if inputs_was_rescaled:
assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"])
node.meta["input_qparams"] = inputs_rescale_qargs

if outputs_was_rescaled:
assert len(node.meta["output_qparams"]) == 1
node.meta["output_qparams"] = {0: output_rescale_qargs}

# If the output type is specified in the node, change it such
# that it matches the subsequent rescale node(s) that this node
# now has output edges to.
if "dtype" in node.kwargs:
set_node_arg(node, "dtype", torch.int32)

if modified:
# Retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module
graph_module.recompile()

return PassResult(graph_module, modified)
Loading
Loading