diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 77def9e7cd3..786117e6457 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -14,7 +14,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs): return args[0] +register_passable_op(torch.ops.passthrough_to_tosa._transpose) + + class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py index 152d5c95f6d..adf2b4f491c 100644 --- a/backends/arm/_passes/insert_squeeze_after_sum_pass.py +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -8,9 +8,7 @@ import torch import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair - -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node +from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass): sum(dims, keep_dim = False) After pass: sum(dims, keep_dim = True) - (q) - (dq) squeeze(dim = dims) """ @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule): continue dim_list = cast(list[int], sum_node.args[1]) - quantized = is_quant_node(sum_node) - if quantized: - qparams = get_quant_node_args(sum_node.all_input_nodes[0]) - qparams = qparams + (torch.int8,) - else: - qparams = None # Add keep_dim = True arg to sum node. sum_node.args = sum_node.args[0:2] + (True,) @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) sum_node.replace_all_uses_with(squeeze_node) squeeze_node.args = (sum_node, dim_list) - if quantized: - sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) graph_module.graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py index 980ab09e597..c7bd27dcce0 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_conv2d_pass.py @@ -9,7 +9,7 @@ from typing import cast, Optional import torch.fx -from executorch.backends.arm.tosa_quant_utils import is_quant_node +from executorch.backends.arm.tosa_quant_utils import is_node_quantized from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._ops import OpOverload @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule): slice_node = graph.create_node( "call_function", self.slice_op, (last_node,) + args ) - if is_quant_node(last_node): + if is_node_quantized(last_node): q_params = last_node.args[1:] dq_node = insert_q_dq_pair( graph_module.graph, slice_node, q_params diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 161b5d22396..8c9bd7ac2a6 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import get_two_inputs from serializer.tosa_serializer import TosaOp @@ -42,8 +46,10 @@ def define_node( # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_q_params = get_quant_arg_upstream(input0) + input1_q_params = get_quant_arg_upstream(input1) + input0_zp = input0_q_params.zp + input1_zp = input1_q_params.zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -63,9 +69,7 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + output_q_params = get_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 64cde0724f5..ffbeee7306d 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import List import serializer.tosa_serializer as ts import torch @@ -15,9 +15,10 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( build_rescale_conv_output, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, ) -from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -82,7 +83,7 @@ def define_node( ) input_zp = ( - get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0 + get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 ) attr.ConvAttribute( @@ -158,9 +159,10 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if is_quant_node: # Get scale_factor from input, weight, and output. - _, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0])) - _, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1])) - _, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0]) + input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale + weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale + output_qargs = get_quant_arg_downstream(list(node.users)[0]) + build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. @@ -169,6 +171,6 @@ def define_node( actual_out_type, input_scale, weight_scale, - output_scale, - output_zp, + output_qargs.scale, + output_qargs.zp, ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 0e0a75dcc47..7a0b4e104f3 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -48,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = exp_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index cf67975e0d9..d2bc1377ce7 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,7 +14,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_downstream, + quantize_value, +) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -39,10 +42,8 @@ def define_node( value = inputs[1].number if is_quant_node: - qargs = get_quant_node_args(list(node.users)[0]) - qvalue = np.clip( - np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax - ) + qargs = get_quant_arg_downstream(list(node.users)[0]) + qvalue = quantize_value(value, qargs) dtype = ts.DType.INT8 data = np.full(shape, qvalue, dtype=np.int8) else: diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 62c0a27f05f..e7260282060 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -14,7 +14,10 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_upstream, + quantize_value, +) from serializer.tosa_serializer import TosaOp @@ -37,12 +40,10 @@ def define_node( if is_quant_node: # Get quant parameters - scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0]) + qargs = get_quant_arg_upstream(node.all_input_nodes[0]) # Convert to quantized representation - clamp_min_qs = round((inputs[1].number / scale) + zp) - clamp_min_qs = max(clamp_min_qs, qmin) - clamp_max_qs = round((inputs[2].number / scale) + zp) - clamp_max_qs = min(clamp_max_qs, qmax) + clamp_min_qs = quantize_value(inputs[1].number, qargs) + clamp_max_qs = quantize_value(inputs[2].number, qargs) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 5276173efa3..76adc2325e4 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = log_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index a0b868f684d..74e33ddb02c 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import get_quant_node_args +from executorch.backends.arm.tosa_utils import ( + get_quant_arg_downstream, + get_quant_arg_upstream, +) from serializer.tosa_serializer import TosaOp @@ -54,8 +57,8 @@ def define_node( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(node.all_input_nodes[0]).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp + input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index ebddb3a40e2..81334de16cb 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import ( build_reshape, expand_dims, @@ -54,8 +58,8 @@ def define_node( # For INT8, we need to get the zero point, otherwise it is 0 input0_zp, input1_zp = 0, 0 if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_zp = get_quant_arg_upstream(input0).zp + input1_zp = get_quant_arg_upstream(input1).zp mat_mul_result = tosa_graph.addIntermediate( output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype @@ -86,9 +90,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + input0_q_params = get_quant_arg_upstream(input0) + input1_q_params = get_quant_arg_upstream(input1) + output_q_params = get_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index c152e8759ef..ad578aa1f06 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -37,10 +37,10 @@ def define_node( if is_quant_node: input_A = inputs[0] input_B = inputs[1] - input_A_qargs = tqutils.get_quant_node_args( + input_A_qargs = tqutils.get_quant_arg_upstream( cast(torch.fx.Node, node.args[0]) ) - input_B_qargs = tqutils.get_quant_node_args( + input_B_qargs = tqutils.get_quant_arg_upstream( cast(torch.fx.Node, node.args[1]) ) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 950d4636d27..d466a13e385 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -10,13 +10,14 @@ import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_dtype, - get_quant_node_args, - is_quant_arg, + get_quant_arg_upstream, + get_quantized_node_output_dtype, + is_node_quantized, ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( is_bias_node_for_quantized_conv, + map_dtype, tosa_shape, ) from torch.export.exported_program import ExportedProgram @@ -41,7 +42,11 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype, + ( + map_dtype(get_quantized_node_output_dtype(node)) + if is_node_quantized(node) + else inputs[0].dtype + ), data=None, placeholderFilename=inputs[0].name + ".npy", ) @@ -63,8 +68,8 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_node_args(input_node).scale - weight_node_scale = get_quant_node_args(weight_node).scale + input_node_scale = get_quant_arg_upstream(input_node).scale + weight_node_scale = get_quant_arg_upstream(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 3d43fd8f7da..774c4d94b19 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -15,7 +15,8 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -41,8 +42,8 @@ def define_node( if is_quant_node: input = inputs[0] - input_qargs = get_quant_node_args(node.all_input_nodes[0]) - output_qargs = get_quant_node_args(list(node.users)[0]) + input_qargs = get_quant_arg_upstream(node.all_input_nodes[0]) + output_qargs = get_quant_arg_downstream(list(node.users)[0]) div_table = div_table_8bit(input_qargs, output_qargs) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 20bba3f6545..a3a7c82ab8e 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -38,7 +38,7 @@ def define_node( clamp_min_qs = 0 clamp_max_qs = 0 if is_quant_node: - out_qargs = tqutils.get_quant_node_args(list(node.users)[0]) + out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0]) clamp_min_qs = tqutils.quantize_value(0, out_qargs) clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 9225c7d938f..b503a323b18 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -16,7 +16,8 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -39,9 +40,9 @@ def define_node( # Assume quantized input is 8 bit. # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = rsqrt_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() table_attr.TableAttribute(table) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 0087b1f7a81..e299e99b43c 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = sigmoid_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 20f343a7f1b..2c84580edce 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = tanh_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e61fbc5bbee..511aeda1ac9 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -75,7 +75,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern [torch.nn.AdaptiveAvgPool2d], [F.adaptive_avg_pool2d], ], - "mul": [torch.mul], + "mul": [[torch.mul]], "sub": [[torch.sub]], } return copy.deepcopy(supported_operators) diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index 126051f158f..b093eec8083 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -29,6 +29,9 @@ torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default, torch.ops.aten.reshape.default, + torch.ops.aten.repeat.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, # Disabling these as there seems to be an issue with support for complex # datatypes in torch: # torch.ops.aten.view_as_complex.default, diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py index b48c6d59905..60d9adb1c3c 100644 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mm_annotator.py @@ -24,7 +24,9 @@ def _annotate_mm( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions(gm.graph, [torch.mm, torch.bmm], filter_fn) + mm_partitions = get_source_partitions( + gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn + ) mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) annotated_partitions = [] for mm_partition in mm_partitions: diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e5e9508e250..62466571209 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -32,6 +32,12 @@ class BMM(torch.nn.Module): def forward(self, x, y): return torch.bmm(x, y) + class MatMul(torch.nn.Module): + test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] + + def forward(self, x, y): + return torch.matmul(x, y) + class BMMSingleInput(torch.nn.Module): test_parameters = [ (torch.rand(20, 3, 3),), @@ -53,9 +59,9 @@ def _test_bmm_tosa_MI_pipeline( compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -74,9 +80,9 @@ def _test_bmm_tosa_BI_pipeline( ) .quantize() .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -116,6 +122,16 @@ def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) + + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) + @parameterized.expand(BMM.test_parameters) def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index c7a475035dc..30d4b2890a2 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -165,7 +165,7 @@ def _test_linear_tosa_BI_pipeline( .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=True) + .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) def _test_linear_tosa_ethosu_BI_pipeline( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index fe408e41b3a..19397fe6b21 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -8,21 +8,38 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import NamedTuple, Sequence +from typing import Callable, cast, NamedTuple, Sequence import numpy as np import serializer.tosa_serializer as ts import torch.fx import tosa.Op as TosaOp -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -dq_q_ops = [q_op, dq_op] +dq_q_ops = (q_op, dq_op) +passable_ops = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, +] + + +def register_passable_op(op): + """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" + passable_ops.append(op) class QuantArgs(NamedTuple): @@ -30,6 +47,19 @@ class QuantArgs(NamedTuple): zp: int qmin: int qmax: int + dtype: torch.dtype + + def quantize_value(self, x): + if not isinstance(x, torch.Tensor): + x = torch.Tensor([x]) + return torch.clip( + torch.round(x / self.scale) + self.zp, + self.qmin, + self.qmax, + ).to(self.dtype) + + def dequantize_value(self, qx: int) -> float: + return (qx - self.zp) * self.scale def quantize_value(x, qargs: QuantArgs, dtype=np.int8): @@ -44,81 +74,159 @@ def dequantize_value(qx, qargs: QuantArgs): return (qx - qargs.zp) * qargs.scale -def is_quant_node(node: torch.fx.Node): +def qargs_from_qnode(node: torch.fx.Node): + assert node.target in dq_q_ops, f"Op {node} is not a quant node." - consumer_node_condition = False - if len(list(node.users)) > 0: - consumer_node = list(node.users)[0] + return QuantArgs( + scale=cast(float, node.args[1]), + zp=cast(int, node.args[2]), + qmin=cast(int, node.args[3]), + qmax=cast(int, node.args[4]), + dtype=cast(torch.dtype, node.args[5]), + ) - # For Rank > 2 Linear layers, the quant node is after the view_copy - if ( - node.target == exir_ops.edge.aten.addmm.default - and consumer_node.target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - return True if consumer_consumer_node.target == q_op else False - consumer_node_condition = consumer_node.target == q_op - input_node_condition = False - if len(node.all_input_nodes) > 0: - input = node.all_input_nodes[0] - input_node_condition = input.target in dq_q_ops +def get_neighbour_quant_args( + node: torch.fx.Node, +) -> tuple[list[QuantArgs], list[QuantArgs]]: + user_q_args = [] - return node.target in dq_q_ops or consumer_node_condition or input_node_condition + for user in node.users: + q_args = search_quant_arg_downstream(user) + if q_args: + user_q_args.append(q_args) + input_q_nodes = [] + for input_node in node.all_input_nodes: + q_args = search_quant_arg_upstream(input_node) + if q_args: + input_q_nodes.append(q_args) + return user_q_args, input_q_nodes -def get_quant_node_dtype(node: torch.fx.Node): - # pyre-ignore[16]: Undefined attribute. - if "tosa" in node.target.__name__: - return node.meta["val"].dtype - if node.target in dq_q_ops: - return node.args[5] +def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: + first_q_arg = q_arg_list[0] + for q_arg in q_arg_list: + if q_arg != first_q_arg: + return False + return True - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - consumer_node = list(node.users)[0] - while True: - if consumer_node.target in dq_q_ops: - return consumer_node.args[5] - # Try to move on to the next node - if len(consumer_node.users) == 0: - raise RuntimeError(f"No quantized node found in graph for node {node}") - consumer_node = list(consumer_node.users)[0] +def is_node_quantized(node: torch.fx.Node) -> bool: + if node.target in dq_q_ops: + return True + user_q_args, input_q_args = get_neighbour_quant_args(node) -def is_quant_arg(arg): - consumer_node = list(arg.users)[0] - return consumer_node.target == q_op + # If we did not find any neighbouring quant nodes, we are not quantized. + if len(input_q_args) == 0 and len(user_q_args) == 0: + return False + if node.target in passable_ops: + assert all_q_args_equal( + user_q_args + input_q_args + ), f"Node {node} needs same quantization parameters on all inputs and outputs." -def get_quant_arg_dtype(node: torch.fx.Node): - consumer_node = list(node.users)[0] + return True - # Get type of quant node, args differ from per_tensor and per_channel. - if consumer_node.target == q_op: - if is_quant_arg(node): - return map_dtype(consumer_node.args[5]) - else: - raise RuntimeError("Quantization argument not found") + +def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple consumers is encountered, + find QuantArgs for all consumers and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without consumers is encountered, return None. + """ + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + consumer_nodes = list(node.users) + if len(consumer_nodes) == 0: + return None + elif len(consumer_nodes) == 1: + return search_quant_arg_downstream(consumer_nodes[0]) + else: + consumer_qargs: list[QuantArgs] = [] + for input in consumer_nodes: + quant_args = search_quant_arg_downstream(input) + if quant_args: + consumer_qargs.append(quant_args) + if len(consumer_qargs) == 0: + return None + assert all_q_args_equal( + consumer_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." + return consumer_qargs[0] + + +def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_downstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_downstream(node) + assert qargs, f"Did not find QuantArgs downstream for node {node}" + return qargs -def get_quant_node_args(node: torch.fx.Node): +def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple inputs is encountered, + find QuantArgs for all inputs and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without inputs is encountered, return None. """ - Get the quantization parameters from a quant node. - Args: - node: The quant node. - Returns: - QuantArgs: scale, zp, qmin, qmax + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + input_nodes = list(node.all_input_nodes) + if len(input_nodes) == 0: + return None + elif len(input_nodes) == 1: + return search_quant_arg_upstream(input_nodes[0]) + else: + input_qargs: list[QuantArgs] = [] + for input in input_nodes: + quant_args = search_quant_arg_upstream(input) + if quant_args: + input_qargs.append(quant_args) + if len(input_qargs) == 0: + return None + assert all_q_args_equal( + input_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." + return input_qargs[0] + + +def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_upstream and asserts that QuantArgs are found, + meaning return value can't be None. """ - quant_args = [TosaArg(arg) for arg in node.args] - return QuantArgs( - quant_args[1].number, - quant_args[2].number, - quant_args[3].number, - quant_args[4].number, - ) + qargs = search_quant_arg_upstream(node) + assert qargs, f"Did not find QuantArgs upstream for node {node}" + return qargs + + +def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: + if isinstance(node.target, Callable) and "tosa" in node.target.__name__: + return node.meta["val"].dtype + if node.target in dq_q_ops: + return cast(torch.dtype, node.args[5]) + + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + user_q_args, input_q_args = get_neighbour_quant_args(node) + if len(user_q_args) > 0: + return user_q_args[0].dtype + elif node.target in passable_ops and len(input_q_args) > 0: + return input_q_args[0].dtype + else: + raise RuntimeError("No quantized node found in graph") # Check if scale32 mode is used for given output element type @@ -267,14 +375,14 @@ def rescale_nodes_to_int32( needed by rescale_node_back_to_int8. """ - tensors = [TosaArg(node.args[0]) for node in nodes] + tensors = [TosaArg(node) for node in nodes] # Reshape tensor according to tosa dim order for tensor in tensors: dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = [get_quant_node_args(node) for node in nodes] + qargs = [get_quant_arg_upstream(node) for node in nodes] # Scale the int8 quantized input to a common scale in the integer # domain @@ -307,7 +415,7 @@ def rescale_node_back_to_int8( scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. """ - qargs_out = get_quant_node_args(list(node.users)[0]) + qargs_out = get_quant_arg_downstream(list(node.users)[0]) output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -334,7 +442,7 @@ def build_rescale_conv_output( output_zp, ): # TODO add check to verify if this is a Per-channel quantization. - post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number + post_conv2d_scale = (input_scale * weight_scale) / output_scale # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. build_rescale( @@ -345,6 +453,6 @@ def build_rescale_conv_output( output_type, op.shape, 0, - output_zp.number, + output_zp, ) return diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index c91d89b1b96..b61b27853ac 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -16,9 +16,10 @@ from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_node_args, - get_quant_node_dtype, - is_quant_node, + get_quant_arg_downstream, + get_quant_arg_upstream, + get_quantized_node_output_dtype, + is_node_quantized, q_op, ) from executorch.backends.arm.tosa_specification import TosaSpecification @@ -183,8 +184,8 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp + input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( @@ -244,10 +245,15 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) + is_quant_node = is_node_quantized(node) + if is_quant_node: + output_dtype = map_dtype(get_quantized_node_output_dtype(node)) + else: + output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( output.name, (tosa_shape(output.shape, output.dim_order)), - map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, + output_dtype, ) # Visiting each Node @@ -259,7 +265,7 @@ def process_call_function( tosa_graph, inputs, output, - is_quant_node(node), + is_quant_node, ) else: raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index 4ffb80c2f0b..b348f10722c 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -7,7 +7,7 @@ import os import tempfile import zipfile -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple import torch @@ -32,7 +32,7 @@ def __init__( else: self.tosa_output_path = None - def get_model_error(self) -> Union[float, float, float, float]: + def get_model_error(self) -> tuple[float, float, float, float]: """ Returns the following metrics between the outputs of the FP32 and INT8 model: - Maximum error @@ -51,7 +51,12 @@ def get_model_error(self) -> Union[float, float, float, float]: max_percentage_error = torch.max(percentage_error).item() mean_absolute_error = torch.mean(torch.abs(difference).float()).item() - return max_error, max_absolute_error, max_percentage_error, mean_absolute_error + return ( + float(max_error), + float(max_absolute_error), + float(max_percentage_error), + float(mean_absolute_error), + ) def get_compression_ratio(self) -> float: """Compute the compression ratio of the outputted TOSA flatbuffer.""" @@ -67,7 +72,7 @@ def get_compression_ratio(self) -> float: return compression_ratio - def evaluate(self) -> dict[any]: + def evaluate(self) -> dict[str, Any]: max_error, max_absolute_error, max_percent_error, mean_absolute_error = ( self.get_model_error() ) @@ -82,6 +87,8 @@ def evaluate(self) -> dict[any]: } if self.tosa_output_path: + # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. + # pyre-ignore[16] output_metrics["metrics"][ "compression_ratio" ] = self.get_compression_ratio()