From 2f5c5173446ab7aae94479afe7a54be2499b5686 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 10 Oct 2025 15:09:44 -0700 Subject: [PATCH 1/2] Support for batched matmul (#14956) Summary: Matmul was relying on linear infra which didn't support batched second argument. This adds support. Differential Revision: D84279595 --- backends/cadence/aot/ref_implementations.py | 32 ++++++++++--------- .../aot/tests/test_ref_implementations.py | 25 ++++++++++++++- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 6a13a4424da..ed9bb438a9e 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -62,7 +62,7 @@ def quantize_per_tensor( ] if dtype not in supported_quant_types: raise ValueError( - f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" + f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_quant_types}" ) return torch.ops.quantized_decomposed.quantize_per_tensor( @@ -264,7 +264,7 @@ def quantized_linear_common( supported_dtypes = [torch.int8, torch.uint8, torch.int32] if dtype not in supported_dtypes: raise ValueError( - f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}" + f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}" ) out = torch.nn.functional.linear( @@ -427,25 +427,27 @@ def quantized_matmul( - out_multiplier (int): The multiplier used to scale the output - out_shift (int): The shift used to scale the output - out_zero_point (int): The quantized mapping of zero for the output - - transposed (bool): Whether to transpose the weight tensor + - transposed (bool): Whether Y is transposed. """ if bias is not None and not torch.all(bias == 0): raise ValueError("bias must be None or all zeros since unused in out variant") - # Looks weird, but quantized linear assumes weights are pre-transposed, - # hence we transpose only if `transposed` is False. - if not transposed: - Y = Y.T + if transposed: + Y = Y.transpose(-1, -2) - return quantized_linear_common( - X, - Y, - bias or torch.zeros(1, dtype=torch.int32), - X_zero_point, - Y_zero_point, - out_multiplier, - out_shift, + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) + + out = torch.matmul( + (X - X_zero_point).float(), + (Y - Y_zero_point).float(), + ) + return quantize_per_tensor( + out, + out_scale, out_zero_point, + torch.iinfo(X.dtype).min, + torch.iinfo(X.dtype).max, + X.dtype, ) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f679bae9485..259752f3893 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -350,6 +350,29 @@ def test_quantized_add( for (matmul, transposed_matmul) in ((True, False), (True, True)) for (per_tensor, dtype) in ((True, torch.int8),) ], + *[ + ( + torch.Size([2, 1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int32 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[[1, 2]], [[0, -1]]], dtype=dtype), # expected_output + per_tensor, + matmul, + transposed_matmul, + ) + for (matmul, transposed_matmul) in ((True, False), (True, True)) + for (per_tensor, dtype) in ((True, torch.int8),) + ], ] ) def test_quantized_linear( @@ -380,7 +403,7 @@ def test_quantized_linear( .to(expected_output.dtype) ) if matmul and not transposed_matmul: - weight = weight.T + weight = weight.transpose(-1, -2) if per_tensor: weight_zero_point = weight_zero_point[0] From c805660ed45cc57458a4ad5410402d05cb988cd3 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 10 Oct 2025 15:09:44 -0700 Subject: [PATCH 2/2] Refactor quantizer: Only replace with per-tensor variants (#14974) Summary: In our previous flow, we would replace ops with default variants, have a special fusion pass which constructs singleton tensors for a variety of fused quantized ops, and then we would call a replace ops to turn them into per-tensor-variants. I confirmed this was for legacy reasons, so a cleanup was much due. This diff directly replaces ops with the per-tensor variants and removes the pass which replaces singleton tensors with scalars. Reviewed By: zonglinpeng Differential Revision: D83873738 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/quantizer/fusion_pass.py | 177 ++--------- backends/cadence/aot/quantizer/patterns.py | 23 +- backends/cadence/aot/replace_ops.py | 131 +-------- .../aot/tests/test_replace_ops_passes.py | 275 +++++------------- 5 files changed, 138 insertions(+), 469 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 94ab6de0e29..4497b425557 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -425,6 +425,7 @@ python_unittest( "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/passes:lib", + ":ref_implementations", ], ) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index c8bfa5cbac7..b77b14b3041 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -65,33 +65,18 @@ def get_args_and_kwargs_add( dequants_inputs: List[fx.Node], quant_node: fx.Node, ) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: - X_scale_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[1]), - {"dtype": torch.float}, - ) - X_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[2]), - {"dtype": torch.int32}, - ) - Y_scale_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[1].args[1]), - {"dtype": torch.float}, - ) - Y_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[1].args[2]), - {"dtype": torch.int32}, - ) + X_scale = dequants_inputs[0].args[1] + + X_zero_point = dequants_inputs[0].args[2] + Y_scale = dequants_inputs[1].args[1] + Y_zero_point = dequants_inputs[1].args[2] args = ( inputs_inputs[0], - X_scale_, - X_zero_point_, + X_scale, + X_zero_point, inputs_inputs[1], - Y_scale_, - Y_zero_point_, + Y_scale, + Y_zero_point, quant_node.args[1], quant_node.args[2], ) @@ -129,31 +114,12 @@ def get_args_and_kwargs_linear( else: bias = bias_inputs[0] - # Create single element tensors for weight_zero_point, out_multiplier, out_shift. - # Note that the function expects int32_t, when it would default to int64_t, so - # we explicitly require that type. - weight_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_weights[0].args[2]), - {"dtype": torch.int32}, - ) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - args = tuple(inputs_inputs + weights_inputs + [bias]) kwargs = { "src_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_, - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "weight_zero_point": dequants_weights[0].args[2], + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), "out_zero_point": quant_node.args[2], "offset": None, } @@ -178,22 +144,8 @@ def get_args_and_kwargs_layer_norm( ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars" # Make the scale and zero_point tensors - scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[1], - ), - {"dtype": torch.float32}, - ) - zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[2], - ), - {"dtype": torch.int32}, - ) + scale = dequants_inputs[0].args[1] + zero_point = dequants_inputs[0].args[2] weight = other_inputs[1] if len(other_inputs) > 1 else None @@ -220,7 +172,7 @@ def get_args_and_kwargs_layer_norm( ) # Make the args and kwargs for the replacement op - args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor]) + args = tuple(inputs_inputs + [scale, zero_point]) kwargs = { "normalized_shape": other_inputs[0], "weight": weight, @@ -308,31 +260,6 @@ def get_args_and_kwargs_conv( (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the weight zero point - weight_zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], weight_zero_point), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the bias scale - bias_scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], bias_scale), - {"dtype": torch.float32}, - ) - # Make the args and kwargs for the replacement op args = tuple(inputs_inputs + weights_inputs + [bias]) kwargs = { @@ -341,12 +268,12 @@ def get_args_and_kwargs_conv( "dilation": dilation, "groups": groups, "input_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_tensor, - "bias_scale": bias_scale_tensor, + "weight_zero_point": weight_zero_point, + "bias_scale": bias_scale, "out_scale": quant_node.args[1], "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), } return args, kwargs @@ -367,27 +294,11 @@ def get_args_and_kwargs_relu( # Make the args and kwargs for the replacement op args = tuple(inputs_inputs) - X_zero_point = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[2]), - {"dtype": torch.int32}, - ) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - kwargs = { - "X_zero_point": X_zero_point, + "X_zero_point": dequants_inputs[0].args[2], "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), } return args, kwargs @@ -435,48 +346,20 @@ def get_args_and_kwargs_softmax( {"dtype": torch.int32}, ) # Make the scale and zero_point tensors - in_scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[1], - ), - {"dtype": torch.float32}, - ) - in_zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[2], - ), - {"dtype": torch.int32}, - ) - out_scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - quant_node.args[1], - ), - {"dtype": torch.float32}, - ) - out_zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - quant_node.args[2], - ), - {"dtype": torch.int32}, - ) + in_scale = dequants_inputs[0].args[1] + in_zero_point = dequants_inputs[0].args[2] + out_scale = quant_node.args[1] + out_zero_point = quant_node.args[2] # Make the args and kwargs for the replacement op args = ( inputs_inputs[0], mask_tensor, op_node.args[1], - in_scale_tensor, - in_zero_point_tensor, - out_scale_tensor, - out_zero_point_tensor, + in_scale, + in_zero_point, + out_scale, + out_zero_point, ) kwargs = {} diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 65389aaad37..8ba7f789595 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -112,7 +112,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_linear.default + return torch.ops.cadence.quantized_linear.per_tensor class AddPattern(QuantizationPattern): @@ -150,7 +150,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_add.default + return torch.ops.cadence.quantized_add.per_tensor class BmmPattern(QuantizationPattern): @@ -265,7 +265,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor class Conv2dPattern(QuantizationPattern): @@ -307,7 +307,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor class LayerNormPattern(QuantizationPattern): @@ -345,7 +345,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_layer_norm.default + return torch.ops.cadence.quantized_layer_norm.per_tensor class LinearPattern(QuantizationPattern): @@ -387,7 +387,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_linear.default + return torch.ops.cadence.quantized_linear.per_tensor class MatmulPattern(QuantizationPattern): @@ -411,6 +411,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: + # TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default @@ -437,7 +438,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_relu.default + return torch.ops.cadence.quantized_relu.per_tensor # Regular relu op @@ -496,7 +497,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor # Conv1d + regular relu op fusion @@ -544,7 +545,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_softmax.default + return torch.ops.cadence.quantized_softmax.per_tensor class MixedW8A32LinearPattern(QuantizationPattern): @@ -598,7 +599,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_w8a32_linear.default + return torch.ops.cadence.quantized_w8a32_linear.per_tensor class MixedW8A32ConvPattern(QuantizationPattern): @@ -660,4 +661,4 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_w8a32_conv.default + return torch.ops.cadence.quantized_w8a32_conv.per_tensor diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 3cfc059e75b..e75c967d682 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -41,7 +41,7 @@ ReplaceScalarWithTensorArgPass, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue from torch.fx.node import Argument @@ -762,8 +762,8 @@ class ReplaceTrivialConvWithLinear(ExportPass): trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, - exir_ops.edge.cadence.quantized_conv2d_nchw.default: exir_ops.edge.cadence.quantized_linear.default, - exir_ops.edge.cadence.quantized_conv2d_nhwc.default: exir_ops.edge.cadence.quantized_linear.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } def call_operator(self, op, args, kwargs, meta): @@ -775,8 +775,8 @@ def call_operator(self, op, args, kwargs, meta): # extra args holding at least the zero point and scale of input, weight, bias, # and output tensor. quantized_op = ( - op == exir_ops.edge.cadence.quantized_conv2d_nchw.default - or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default + op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) assert (len(args) == 8 and not quantized_op) or ( len(args) >= 12 and quantized_op @@ -934,18 +934,18 @@ def call_operator( ) -> ProxyValue: if op not in { exir_ops.edge.cadence.convolution.default, - exir_ops.edge.cadence.quantized_conv2d_nchw.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, }: return super().call_operator(op, args, kwargs, meta) - quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.default + quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor if not quantized_op and len(args) == 8 and args[-1] is True: # Already in NHWC layout. return super().call_operator(op, args, kwargs, meta) new_op = ( - exir_ops.edge.cadence.quantized_conv2d_nhwc.default + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor if quantized_op else exir_ops.edge.cadence.convolution.default ) @@ -1022,8 +1022,8 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass): # decompose to. conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, - exir_ops.edge.cadence.quantized_conv2d_nchw.default: exir_ops.edge.cadence.quantized_linear.default, - exir_ops.edge.cadence.quantized_conv2d_nhwc.default: exir_ops.edge.cadence.quantized_linear.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } def call_operator(self, op, args, kwargs, meta): @@ -1032,8 +1032,8 @@ def call_operator(self, op, args, kwargs, meta): # Get the relevant args from convolution node. quantized_op = ( - op == exir_ops.edge.cadence.quantized_conv2d_nchw.default - or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default + op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) assert (len(args) == 8 and not quantized_op) or ( len(args) >= 12 and quantized_op @@ -1063,7 +1063,7 @@ def call_operator(self, op, args, kwargs, meta): # channel_last layout is specified by the channel_last arg of conv # op, which is either the last argument (15th) or implicitely False # if the op is quantized, or the last argument if not. - channel_last = op == exir_ops.edge.cadence.quantized_conv2d_nhwc.default + channel_last = op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor # The weight tensor is [out_channels, in_channels, X] for NCHW layout, # and [out_channels, X, in_channels] for NHWC layout. Here, X is the # kernel_width for conv1d, and X = kernel_height * kernel_width for @@ -1072,21 +1072,8 @@ def call_operator(self, op, args, kwargs, meta): # If the convolution op was quantized, we need the input tensor's # zero_point for im2row. Otherwise in_zero_point defaults to a zero # tensor. - in_zero_point = ( - ( - super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[7], - ), - {"dtype": torch.int32}, - meta, - ) - ) - if quantized_op - else torch.tensor(0, dtype=torch.int32) - ) + in_zero_point = args[7] if quantized_op else 0 + # im2row expects every kernel parameter to be 2d. So we extend the # parameters for conv1d by prepending their default values. stride = ([1] + stride) if len(stride) == 1 else stride @@ -1109,7 +1096,7 @@ def call_operator(self, op, args, kwargs, meta): channel_last, ) im2row = super().call_operator( - exir_ops.edge.cadence.im2row.default, + exir_ops.edge.cadence.im2row.per_tensor, im2row_args, kwargs, meta, @@ -1529,91 +1516,6 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, tuple(new_args), kwargs, meta) -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass): - """ - Replace ops with single element arguments (size = [1]) with overloads that accept scalar ints/floats. - """ - - # Keep track of which operators and arguments are being replaced. - replaced_scalar_args: dict[ - EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]] - ] = { - exir_ops.edge.cadence.quantized_add.default: ( - exir_ops.edge.cadence.quantized_add.per_tensor, - [1, 2, 4, 5], - ), - exir_ops.edge.cadence.quantized_conv2d_nchw.default: ( - exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, - [8, 9, 12, 13], - ), - exir_ops.edge.cadence.quantized_conv2d_nhwc.default: ( - exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, - [8, 9, 12, 13], - ), - exir_ops.edge.cadence.quantized_fully_connected.default: ( - exir_ops.edge.cadence.quantized_fully_connected.per_tensor, - [4, 5, 6], - ), - exir_ops.edge.cadence.quantized_layer_norm.default: ( - exir_ops.edge.cadence.quantized_layer_norm.per_tensor, - [1, 2], - ), - exir_ops.edge.cadence.quantized_linear.default: ( - exir_ops.edge.cadence.quantized_linear.per_tensor, - [4, 5, 6], - ), - exir_ops.edge.cadence.quantized_relu.default: ( - exir_ops.edge.cadence.quantized_relu.per_tensor, - [1, 3, 4], - ), - exir_ops.edge.cadence.im2row.default: ( - exir_ops.edge.cadence.im2row.per_tensor, - [5], - ), - exir_ops.edge.cadence.requantize.default: ( - exir_ops.edge.cadence.requantize.per_tensor, - [1, 2, 3, 4], - ), - } - - def call_operator(self, op, args, kwargs, meta): - if op not in self.replaced_scalar_args: - return super().call_operator(op, args, kwargs, meta) - - # Get all the args that need to be replaced. - new_op, args_to_be_replaced = self.replaced_scalar_args[op] - - if op == new_op: - return super().call_operator(op, args, kwargs, meta) - - updated_args = list(args) - for op_arg_index in args_to_be_replaced: - arg = args[op_arg_index] - if not isinstance(arg, ProxyValue) or not arg.is_tensor(): - return super().call_operator(op, args, kwargs, meta) - - if not isinstance(arg.node.target, EdgeOpOverload): - return super().call_operator(op, args, kwargs, meta) - - if get_edge_overload_packet(arg.node.target) != exir_ops.edge.aten.full: - # Only replace if arg generated by a full op. - return super().call_operator(op, args, kwargs, meta) - - if tuple(arg.node.args[0]) != (1,): - # Only replace if the size of the full op is [1]. - return super().call_operator(op, args, kwargs, meta) - - updated_args[op_arg_index] = arg.node.args[1] - - return super().call_operator( - new_op, - tuple(updated_args), - kwargs, - meta, - ) - - @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(ExportPass): """ @@ -2260,7 +2162,6 @@ class CadenceReplaceOpsInGraph: ReplaceScalarTensorWithFullPass, ReplaceInfArgInFullWithValuePass, ReplaceLogicalNotBooleanWhereWithWherePass, - ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index e2fbd516757..73964c6c4c4 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -6,10 +6,13 @@ # pyre-strict +import copy import operator import unittest from typing import cast, List, Optional, Sequence, Tuple, Union +import executorch.backends.cadence.aot.ref_implementations # noqa + import torch from executorch.backends.cadence.aot.graph_builder import ( GraphBuilder, @@ -42,7 +45,6 @@ ReplaceScalarTensorWithFullPass, ReplaceScalarWithTensorArgPass, ReplaceSelectWithViewOpPass, - ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceSplitWithSlicePass, ReplaceSqueezeAndUnsqueezeWithViewPass, ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding, @@ -54,11 +56,30 @@ from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue -from executorch.exir.passes import dead_code_elimination_pass from torch.fx.passes.infra.pass_base import PassResult from torch.utils import _pytree as pytree +def validate( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, +) -> None: + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + if not all(pytree.tree_map(torch.equal, flat_orig_out, flat_mod_out)): + raise AssertionError( + f"Pass validation failed with exact match for pass {pass_name}. Original graph {original} and modified graph {modified}" + ) + + class TestReplaceOpsPasses(unittest.TestCase): def assertTargetCountEqual( self, @@ -105,8 +126,10 @@ def test_replace_matmul_with_transposed_matmul( y_shape: Tuple[int], ) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*x_shape, dtype=torch.float32)) - y = builder.placeholder("y", torch.randn(*y_shape, dtype=torch.float32)) + x_ = torch.randint(0, 100, x_shape, dtype=torch.int8) + x = builder.placeholder("x", x_) + y_ = torch.randint(0, 100, y_shape, dtype=torch.int8) + y = builder.placeholder("y", y_) matmul = builder.call_operator( op=exir_ops.edge.cadence.quantized_matmul.default, args=( @@ -135,6 +158,12 @@ def test_replace_matmul_with_transposed_matmul( ), 1, ) + validate( + original_gm, + graph_after_passes, + (x_, y_), + "ReplaceMatmulWithTransposedMatmulPass", + ) @expand( [ @@ -1000,152 +1029,6 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: 0, ) - @torch.no_grad() - def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( - self, - in_features: int = 16, - out_features: int = 16, - ) -> None: - src_zero_point = 0 - out_zero_point = 0 - builder = GraphBuilder() - x = builder.placeholder("x", torch.randn([1, in_features])) - weights = builder.placeholder( - "weights", torch.randn([in_features, out_features], dtype=torch.float32) - ) - bias = builder.placeholder( - "bias", torch.randn([out_features], dtype=torch.float32) - ) - quantized_input = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - args=(x, 0.01431146077811718, 57, -128, 127, torch.int8), - ) - weight_zero_point = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - out_multiplier = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - out_shift = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - output = builder.call_operator( - op=exir_ops.edge.cadence.quantized_linear.default, - args=( - quantized_input, - weights, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - None, - ), - ) - dequantized_output = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - args=(output, 0.010696045123040676, -31, -128, 127, torch.int8), - ) - builder.output([dequantized_output]) - original_gm = builder.get_graph_module() - p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module - self.assertIsNotNone(graph_after_passes) - gm = dead_code_elimination_pass(graph_after_passes).graph_module - # By default, the quantized linear op should have constant scalar attributes. - self.assertTargetCountsEqual( - gm, - [ - # No default quantized linear op. - (exir_ops.edge.cadence.quantized_linear.default, 0), - # The default quantized linear op will be replaced with quantized_linear.per_tensor. - (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), - # No aten.full ops. - (exir_ops.edge.aten.full.default, 0), - ], - ) - - @torch.no_grad() - def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_args( - self, - in_features: int = 16, - out_features: int = 16, - ) -> None: - src_zero_point = 0 - out_zero_point = 0 - builder = GraphBuilder() - x = builder.placeholder("x", torch.randn([1, in_features])) - weights = builder.placeholder( - "weights", torch.randn([in_features, out_features], dtype=torch.float32) - ) - bias = builder.placeholder( - "bias", torch.randn([out_features], dtype=torch.float32) - ) - quantized_input = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - args=(x, 0.01431146077811718, 57, -128, 127, torch.int8), - ) - weight_zero_point = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - out_multiplier = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - out_shift = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - output = builder.call_operator( - op=exir_ops.edge.cadence.quantized_linear.default, - args=( - quantized_input, - weights, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - None, - ), - ) - dequantized_output = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - args=(output, 0.010696045123040676, -31, -128, 127, torch.int8), - ) - builder.output([dequantized_output]) - original_gm = builder.get_graph_module() - - for node in original_gm.graph.nodes: - # Replace the `shape` argument for aten.full op with a tuple. - if node.target == exir_ops.edge.aten.full.default: - node.args = (tuple(node.args[0]), node.args[1]) - - # Apply replacement pass. - p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module - self.assertIsNotNone(graph_after_passes) - gm = dead_code_elimination_pass(graph_after_passes).graph_module - - # By default, the quantized linear op should have constant scalar attributes. - self.assertTargetCountsEqual( - gm, - [ - # No default quantized linear op. - (exir_ops.edge.cadence.quantized_linear.default, 0), - # The default quantized linear op will be replaced with quantized_linear.per_tensor. - (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), - # No aten.full ops. - (exir_ops.edge.aten.full.default, 0), - ], - ) - @torch.no_grad() def test_replace_conv1d_with_linear(self) -> None: x = torch.randn(1, 96, 7) @@ -1231,7 +1114,7 @@ def test_replace_conv2d_with_im2row_and_linear(self) -> None: count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 1 + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.per_tensor), 1 ) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 @@ -1799,33 +1682,33 @@ def test_no_transpose_if_already_channel_last(self) -> None: def create_quantized_convolution_graph_module( self, channels_last: Optional[bool] = None - ) -> torch.fx.GraphModule: + ) -> tuple[tuple[torch.Tensor, ...], torch.fx.GraphModule]: """Helper to create a quantized conv node. - quantized_conv( + quantized_conv_per_tensor( Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, - int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, - Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, - Tensor out_shift, bool channel_last=False) -> (Tensor Z)" + int[] dilation, int groups, int input_zero_point, int weight_zero_point, + Tensor bias_scale, float out_scale, int out_zero_point, int out_multiplier, + int out_shift, bool channel_last=False) -> (Tensor Z)" """ if channels_last: - x = torch.randn(1, 224, 56, 3) - w = torch.randn(16, 16, 16, 3) + x = torch.randint(0, 100, (1, 224, 56, 3), dtype=torch.int32) + w = torch.randint(0, 100, (16, 16, 16, 3), dtype=torch.int32) else: - x = torch.randn(1, 3, 224, 56) - w = torch.randn(16, 3, 16, 16) + x = torch.randint(0, 100, (1, 3, 224, 56), dtype=torch.int32) + w = torch.randint(0, 100, (16, 3, 16, 16), dtype=torch.int32) b = torch.randn(16) stride = (2, 2) padding = (0, 0) dilation = (1, 1) groups = 1 input_zero_point = 0 - w_zero_point = torch.randn(1) - b_scale = torch.randn(1) + w_zero_point = 100 + b_scale = 10 out_scale = 1 out_zero_point = 0 - out_multiplier = torch.randn(1) - out_shift = torch.randn(1) + out_multiplier = 5 + out_shift = 5 args = ( x, w, @@ -1843,50 +1726,35 @@ def create_quantized_convolution_graph_module( out_shift, ) if channels_last is not None: - return single_op_builder( - placeholders=( - x, - w, - b, - w_zero_point, - b_scale, - out_multiplier, - out_shift, - ), - op=exir_ops.edge.cadence.quantized_conv2d_nhwc.default, - args=args, - ) + op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor else: - return single_op_builder( - placeholders=( - x, - w, - b, - w_zero_point, - b_scale, - out_multiplier, - out_shift, - ), - op=exir_ops.edge.cadence.quantized_conv2d_nchw.default, - args=args, - ) + op = exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + + placeholders = (x, w, b) + + return placeholders, single_op_builder( + placeholders=placeholders, + op=op, + args=args, + ) def test_quantized_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. - gm = self.create_quantized_convolution_graph_module() + placeholders, gm = self.create_quantized_convolution_graph_module() self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.default), 1 + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor), 1 ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) # Apply replacement pass. p = ReplaceConvWithChannelLastConvPass() + original = copy.deepcopy(gm) gm_after_replacement = p.call(gm).graph_module # Check that no replacement was made. self.assertEqual( count_node( gm_after_replacement, - exir_ops.edge.cadence.quantized_conv2d_nhwc.default, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, ), 1, ) @@ -1895,14 +1763,23 @@ def test_quantized_convolution_default_channel_last(self) -> None: count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 3, ) + validate( + gm_after_replacement, + original, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. - gm = self.create_quantized_convolution_graph_module(channels_last=True) + placeholders, gm = self.create_quantized_convolution_graph_module( + channels_last=True + ) # Check if graph module is valid by running exportpass on it. + original = copy.deepcopy(gm) gm = ExportPass().call(gm).graph_module self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nhwc.default), 1 + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor), 1 ) # Apply replacement pass. @@ -1912,11 +1789,17 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: self.assertEqual( count_node( gm_after_replacement, - exir_ops.edge.cadence.quantized_conv2d_nhwc.default, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, ), 1, ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + validate( + gm_after_replacement, + original, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase):