From b918df8825719896b5c88fa61419bd2ae77b06e4 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 12 Sep 2025 14:15:54 -0700 Subject: [PATCH] Add support for fusing Conv+ReLU (#14229) Summary: This diff adds support for *implicitly* fusing Conv2d+ReLU We add a new pattern which will capture this sequence of events and ensure the subgraph is treated as one node during calculation of qparam. Then, during fuse, we replace this subgraph with just the conv and drop the ReLU. Reviewed By: ethansfng, hsharma35, ivayloen Differential Revision: D79381533 --- backends/cadence/aot/quantizer/fusion_pass.py | 46 +++++++++++-- backends/cadence/aot/quantizer/patterns.py | 68 +++++++++++++++++++ backends/cadence/aot/quantizer/quantizer.py | 23 +++++++ backends/cadence/aot/quantizer/utils.py | 16 +++++ 4 files changed, 146 insertions(+), 7 deletions(-) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 729056ea2c8..8f106a815ac 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -15,7 +15,11 @@ BmmPattern, CatPattern, Conv1dPattern, + Conv1dReluPattern0, + Conv1dReluPattern1, Conv2dPattern, + Conv2dReluPattern0, + Conv2dReluPattern1, LayerNormPattern, LinearPattern, MatmulPattern, @@ -23,6 +27,7 @@ ReluPattern1, ) from executorch.backends.cadence.aot.quantizer.utils import ( + check_out_zero_point_is_min_range, create_zero_bias_int32, find_sequential_partitions_aten, get_conv_args, @@ -41,6 +46,13 @@ # Use this part for patterns with multiple aten ops ReluPatterns = (ReluPattern0, ReluPattern1) +ConvPatterns = (Conv1dPattern, Conv2dPattern) +ConvReluPatterns = ( + Conv1dReluPattern0, + Conv1dReluPattern1, + Conv2dReluPattern0, + Conv2dReluPattern1, +) def get_args_and_kwargs_add( @@ -432,12 +444,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 other_inputs = [node.args[idx] for node, idx in anchors.others] # The node is the first index of the list and first of the tuple - op_node = anchors.output[0][0] + anchor_output_node = anchors.output[0][0] - assert len(op_node.users) == 1 - quant_node = list(op_node.users.keys())[0] + assert len(anchor_output_node.users) == 1 + quant_node = list(anchor_output_node.users.keys())[0] - with graph_module.graph.inserting_after(op_node): + with graph_module.graph.inserting_after(anchor_output_node): args = tuple( inputs_inputs + weights_inputs + other_inputs + bias_inputs ) @@ -451,9 +463,29 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 ) elif isinstance(pattern, CatPattern): args, kwargs = get_args_and_kwargs_cat( - inputs_inputs, other_inputs, op_node + inputs_inputs, other_inputs, anchor_output_node + ) + elif isinstance(pattern, ConvReluPatterns): + # For ConvReLU, we are fusing Conv+ReLU + # This means that the op we want to get + # the replacement args and kwargs for is the + # *conv* op, which is the anchor input, NOT + # the anchor output (which is the ReLU) + check_out_zero_point_is_min_range( + quant_node.args[2], quant_node.args[5] + ) + anchor_input_node = anchors.inputs[0][0] + args, kwargs = get_args_and_kwargs_conv( + graph_module, + inputs_inputs, + dequants_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, + anchor_input_node, ) - elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)): + elif isinstance(pattern, ConvPatterns): args, kwargs = get_args_and_kwargs_conv( graph_module, inputs_inputs, @@ -462,7 +494,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_weights, bias_inputs, quant_node, - op_node, + anchor_output_node, ) elif isinstance(pattern, LinearPattern): args, kwargs = get_args_and_kwargs_linear( diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 74987f8b38d..b653be27e8f 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -417,3 +417,71 @@ def partition_types(self) -> List[OpOverload]: class ReluPattern1(ReluBasePattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.relu_.default] + + +# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops +class ConvReluBasePattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> List[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # The first node should be conv, the second should be relu + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + conv_node = fused_partition[0].nodes[-1] # Second to last node + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + relu_node = fused_partition[1].nodes[-1] # Last node + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(conv_node, 0)], + weights=[(conv_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(relu_node,)], # Output is from the relu node + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_conv_nchw.default + + +# Conv1d + regular relu op fusion +class Conv1dReluPattern0(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default] + + +# Conv1d + alternate relu op fusion +class Conv1dReluPattern1(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default] + + +# Conv2d + regular relu op fusion +class Conv2dReluPattern0(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default] + + +# Conv2d + alternate relu op fusion +class Conv2dReluPattern1(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 8c78ac87e58..cce7c207a6b 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -16,7 +16,11 @@ BmmPattern, CatPattern, Conv1dPattern, + Conv1dReluPattern0, + Conv1dReluPattern1, Conv2dPattern, + Conv2dReluPattern0, + Conv2dReluPattern1, LayerNormPattern, LinearPattern, MatmulPattern, @@ -260,3 +264,22 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) super().__init__(quantizers) + + +class CadenceFusedConvReluQuantizer(CadenceQuantizer): + """ + Quantizer using fused conv+relu patterns, and including add and cat + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Order matters here, perform the "fused" patterns first + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym)) + quantizers = quantizers + get_cadence_default_quantizers() + quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) + quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) + super().__init__(quantizers) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index beacd1b9e86..68fc6740cb4 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -234,3 +234,19 @@ def find_sequential_partitions_aten( if _partitions_sequential(candidate): fused_partitions.append(candidate) return fused_partitions + + +def check_out_zero_point_is_min_range( + out_zero_point: int, + out_dtype: torch.dtype, +) -> bool: + """ + Checks if the out_zero_point is the minimum range of the quant type. + """ + if out_dtype == torch.int8: + return out_zero_point == -128 + elif out_dtype == torch.int16: + return out_zero_point == -32768 + elif out_dtype == torch.uint8 or torch.uint16: + return out_zero_point == 0 + return False