diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index f53821a99815f..29e6d331dee46 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -18,6 +18,7 @@ QuantizationSpecBase, ) from .utils import ( + _conv2d_bn_example_inputs, _is_supported_batch_norm_for_training, fold_bn_weights_into_conv_node, get_aten_graph_module, @@ -27,17 +28,6 @@ __all__ = [] # type: ignore[var-annotated] -# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias` -_conv2d_bn_pattern_example_inputs = ( - torch.randn(1, 1, 3, 3), # x - torch.randn(1, 1, 1, 1), # conv_weight - torch.randn(1), # conv_bias - torch.randn(1), # bn_weight - torch.randn(1), # bn_bias - torch.randn(1), # bn_running_mean - torch.randn(1), # bn_running_var -) - # Example inputs for both `_quantized_qat_conv2d_bn_pattern` and `_folded_quantized_qat_conv2d_bn_pattern` _quantized_conv2d_bn_pattern_example_inputs = ( torch.randn(1, 1, 3, 3), # x @@ -520,7 +510,7 @@ def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule: """ m.graph.eliminate_dead_code() m.recompile() - example_inputs = _conv2d_bn_pattern_example_inputs + example_inputs = _conv2d_bn_example_inputs match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs, is_cuda) # Step (1): Replace patterns with conv bias diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 9fbee1f2697bd..d73218ef4a9c6 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -36,6 +36,16 @@ torch.ops.quantized_decomposed.dequantize_per_channel.default, ] +# Example inputs for conv-bn2d patterns +_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var +) def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: """ diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 92d8771d8e49b..1968285a5d2ff 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -247,8 +247,8 @@ def not_module_type_or_name_filter(n: Node) -> bool: class XNNPACKQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() STATIC_QAT_ONLY_OPS = [ - "conv2d_bn_relu", - "conv2d_bn", + "conv_bn_relu", + "conv_bn", ] # static quantization ops (both PTQ and QAT) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 7387c01fcfb9a..4501db1ad3e8d 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -7,6 +7,10 @@ import torch.nn.functional as F from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.pt2e.utils import ( + _conv2d_bn_example_inputs, + get_aten_graph_module, +) from torch.ao.quantization.quantizer import ( QuantizationAnnotation, QuantizationSpec, @@ -19,6 +23,9 @@ _annotate_output_qspec, ) from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -318,127 +325,139 @@ def _annotate_conv_relu( return annotated_partitions -@register_annotator("conv2d_bn") -def _annotate_conv2d_bn( +@register_annotator("conv_bn") +def _annotate_conv_bn( gm: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: """ - Find Conv2d + batchnorm parititions + Find conv + batchnorm parititions Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ - fused_partitions = find_sequential_partitions( - gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d], filter_fn - ) - annotated_partitions = [] - for fused_partition in fused_partitions: - conv_partition, bn_partition = fused_partition - annotated_partitions.append(conv_partition.nodes + bn_partition.nodes) - if len(conv_partition.output_nodes) > 1: - raise ValueError("conv partition has more than one output node") - conv_node = conv_partition.output_nodes[0] - conv_node_users = list(conv_node.users.keys()) - if len(conv_node_users) > 1: - raise ValueError( - "Conv node must be consumed by BN only for it to be fusable." - ) - if len(bn_partition.output_nodes) > 1: - raise ValueError("BatchNorm partition has more than one output node") - bn_output_node = bn_partition.output_nodes[0] - - if _is_annotated([bn_output_node, conv_node]): - continue - - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = get_input_act_qspec(quantization_config) - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) - - bias = conv_node.args[2] if len(conv_node.args) > 2 else None - if isinstance(bias, Node): - input_qspec_map[bias] = get_bias_qspec(quantization_config) + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, _annotated=True - ) - bn_output_node.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - ) - nodes_to_mark_annotated = list(conv_partition.nodes) - nodes_to_mark_annotated.extend(list(bn_partition.nodes)) - _mark_nodes_as_annotated(nodes_to_mark_annotated) - return annotated_partitions - - -@register_annotator("conv2d_bn_relu") -def _annotate_conv2d_bn_relu( +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( gm: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: """ - Find Conv2d + batchnorm + relu parititions + Find conv + batchnorm + relu parititions Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ - fused_partitions = find_sequential_partitions( - gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], filter_fn - ) - annotated_partitions = [] - for fused_partition in fused_partitions: - conv_partition, bn_partition, relu_partition = fused_partition - annotated_partitions.append( - conv_partition.nodes + bn_partition.nodes + relu_partition.nodes - ) - if len(relu_partition.output_nodes) > 1: - raise ValueError("Relu partition has more than one output node") - relu_node = relu_partition.output_nodes[0] - if len(conv_partition.output_nodes) > 1: - raise ValueError("conv partition has more than one output node") - conv_node = conv_partition.output_nodes[0] - conv_node_users = list(conv_node.users.keys()) - if len(conv_node_users) > 1: - raise ValueError( - "Conv node must be consumed by BN only for it to be fusable." - ) - if len(bn_partition.output_nodes) > 1: - raise ValueError("BatchNorm partition has more than one output node") - bn_output_node = bn_partition.output_nodes[0] + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) - if _is_annotated([relu_node, bn_output_node, conv_node]): - continue - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = get_input_act_qspec(quantization_config) +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]], + has_relu: bool, +) -> List[List[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ - bias = conv_node.args[2] if len(conv_node.args) > 2 else None - if isinstance(bias, Node): - input_qspec_map[bias] = get_bias_qspec(quantization_config) + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _conv_bn + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + combinations = [ + (F.conv2d, _conv2d_bn_example_inputs), + ] + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: + pattern = get_pattern(conv_fn, relu_is_inplace) + pattern = get_aten_graph_module(pattern, example_inputs, is_cuda) + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, _annotated=True + input_qspec_map=input_qspec_map, + _annotated=True, ) - - relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_node.meta["quantization_annotation"] = QuantizationAnnotation( output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, ) - nodes_to_mark_annotated = list(conv_partition.nodes) - nodes_to_mark_annotated.extend(list(bn_partition.nodes)) - nodes_to_mark_annotated.extend(list(relu_partition.nodes)) - _mark_nodes_as_annotated(nodes_to_mark_annotated) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) return annotated_partitions