diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 729056ea2c8..8f106a815ac 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -15,7 +15,11 @@ BmmPattern, CatPattern, Conv1dPattern, + Conv1dReluPattern0, + Conv1dReluPattern1, Conv2dPattern, + Conv2dReluPattern0, + Conv2dReluPattern1, LayerNormPattern, LinearPattern, MatmulPattern, @@ -23,6 +27,7 @@ ReluPattern1, ) from executorch.backends.cadence.aot.quantizer.utils import ( + check_out_zero_point_is_min_range, create_zero_bias_int32, find_sequential_partitions_aten, get_conv_args, @@ -41,6 +46,13 @@ # Use this part for patterns with multiple aten ops ReluPatterns = (ReluPattern0, ReluPattern1) +ConvPatterns = (Conv1dPattern, Conv2dPattern) +ConvReluPatterns = ( + Conv1dReluPattern0, + Conv1dReluPattern1, + Conv2dReluPattern0, + Conv2dReluPattern1, +) def get_args_and_kwargs_add( @@ -432,12 +444,12 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 other_inputs = [node.args[idx] for node, idx in anchors.others] # The node is the first index of the list and first of the tuple - op_node = anchors.output[0][0] + anchor_output_node = anchors.output[0][0] - assert len(op_node.users) == 1 - quant_node = list(op_node.users.keys())[0] + assert len(anchor_output_node.users) == 1 + quant_node = list(anchor_output_node.users.keys())[0] - with graph_module.graph.inserting_after(op_node): + with graph_module.graph.inserting_after(anchor_output_node): args = tuple( inputs_inputs + weights_inputs + other_inputs + bias_inputs ) @@ -451,9 +463,29 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 ) elif isinstance(pattern, CatPattern): args, kwargs = get_args_and_kwargs_cat( - inputs_inputs, other_inputs, op_node + inputs_inputs, other_inputs, anchor_output_node + ) + elif isinstance(pattern, ConvReluPatterns): + # For ConvReLU, we are fusing Conv+ReLU + # This means that the op we want to get + # the replacement args and kwargs for is the + # *conv* op, which is the anchor input, NOT + # the anchor output (which is the ReLU) + check_out_zero_point_is_min_range( + quant_node.args[2], quant_node.args[5] + ) + anchor_input_node = anchors.inputs[0][0] + args, kwargs = get_args_and_kwargs_conv( + graph_module, + inputs_inputs, + dequants_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, + anchor_input_node, ) - elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)): + elif isinstance(pattern, ConvPatterns): args, kwargs = get_args_and_kwargs_conv( graph_module, inputs_inputs, @@ -462,7 +494,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_weights, bias_inputs, quant_node, - op_node, + anchor_output_node, ) elif isinstance(pattern, LinearPattern): args, kwargs = get_args_and_kwargs_linear( diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 74987f8b38d..b653be27e8f 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -417,3 +417,71 @@ def partition_types(self) -> List[OpOverload]: class ReluPattern1(ReluBasePattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.relu_.default] + + +# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops +class ConvReluBasePattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> List[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # The first node should be conv, the second should be relu + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + conv_node = fused_partition[0].nodes[-1] # Second to last node + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + relu_node = fused_partition[1].nodes[-1] # Last node + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(conv_node, 0)], + weights=[(conv_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(relu_node,)], # Output is from the relu node + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_conv_nchw.default + + +# Conv1d + regular relu op fusion +class Conv1dReluPattern0(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default] + + +# Conv1d + alternate relu op fusion +class Conv1dReluPattern1(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default] + + +# Conv2d + regular relu op fusion +class Conv2dReluPattern0(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default] + + +# Conv2d + alternate relu op fusion +class Conv2dReluPattern1(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 8c78ac87e58..cce7c207a6b 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -16,7 +16,11 @@ BmmPattern, CatPattern, Conv1dPattern, + Conv1dReluPattern0, + Conv1dReluPattern1, Conv2dPattern, + Conv2dReluPattern0, + Conv2dReluPattern1, LayerNormPattern, LinearPattern, MatmulPattern, @@ -260,3 +264,22 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) super().__init__(quantizers) + + +class CadenceFusedConvReluQuantizer(CadenceQuantizer): + """ + Quantizer using fused conv+relu patterns, and including add and cat + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Order matters here, perform the "fused" patterns first + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym)) + quantizers = quantizers + get_cadence_default_quantizers() + quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) + quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) + super().__init__(quantizers) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index beacd1b9e86..68fc6740cb4 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -234,3 +234,19 @@ def find_sequential_partitions_aten( if _partitions_sequential(candidate): fused_partitions.append(candidate) return fused_partitions + + +def check_out_zero_point_is_min_range( + out_zero_point: int, + out_dtype: torch.dtype, +) -> bool: + """ + Checks if the out_zero_point is the minimum range of the quant type. + """ + if out_dtype == torch.int8: + return out_zero_point == -128 + elif out_dtype == torch.int16: + return out_zero_point == -32768 + elif out_dtype == torch.uint8 or torch.uint16: + return out_zero_point == 0 + return False