Skip to content

Commit

Permalink
[quant][pt2][be] Rewrite QAT annotations using subgraph matcher (#113709
Browse files Browse the repository at this point in the history
)

Summary: This is the recommended way to write quantizers according
to https://pytorch.org/tutorials/prototype/pt2e_quantizer.html#a-note-on-ir-for-pt2e-quantization-flow.
It is agnostic to changes in the aten IR and can be easily extended
to support conv1d-bn and conv3d-bn fusion patterns in the future.
This is the first step towards rewriting XNNPACKQuantizer using
this subgraph matcher.

Test Plan:
python test/test_quantization.py TestQuantizePT2EQAT_ConvBn2d

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel, supriyar

Differential Revision: [D51366525](https://our.internmc.facebook.com/intern/diff/D51366525)
Pull Request resolved: #113709
Approved by: https://github.com/jerryzh168
  • Loading branch information
andrewor14 authored and pytorchmergebot committed Nov 16, 2023
1 parent 8efa6ad commit 8241fe6
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 109 deletions.
14 changes: 2 additions & 12 deletions torch/ao/quantization/pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
QuantizationSpecBase,
)
from .utils import (
_conv2d_bn_example_inputs,
_is_supported_batch_norm_for_training,
fold_bn_weights_into_conv_node,
get_aten_graph_module,
Expand All @@ -27,17 +28,6 @@
__all__ = [] # type: ignore[var-annotated]


# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)

# Example inputs for both `_quantized_qat_conv2d_bn_pattern` and `_folded_quantized_qat_conv2d_bn_pattern`
_quantized_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
Expand Down Expand Up @@ -520,7 +510,7 @@ def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule:
"""
m.graph.eliminate_dead_code()
m.recompile()
example_inputs = _conv2d_bn_pattern_example_inputs
example_inputs = _conv2d_bn_example_inputs
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs, is_cuda)

# Step (1): Replace patterns with conv bias
Expand Down
10 changes: 10 additions & 0 deletions torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]

# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)

def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
"""
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def not_module_type_or_name_filter(n: Node) -> bool:
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
STATIC_QAT_ONLY_OPS = [
"conv2d_bn_relu",
"conv2d_bn",
"conv_bn_relu",
"conv_bn",
]

# static quantization ops (both PTQ and QAT)
Expand Down
209 changes: 114 additions & 95 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import torch.nn.functional as F
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.pt2e.utils import (
_conv2d_bn_example_inputs,
get_aten_graph_module,
)
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Expand All @@ -19,6 +23,9 @@
_annotate_output_qspec,
)
from torch.fx import Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


Expand Down Expand Up @@ -318,127 +325,139 @@ def _annotate_conv_relu(
return annotated_partitions


@register_annotator("conv2d_bn")
def _annotate_conv2d_bn(
@register_annotator("conv_bn")
def _annotate_conv_bn(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""
Find Conv2d + batchnorm parititions
Find conv + batchnorm parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d], filter_fn
)
annotated_partitions = []
for fused_partition in fused_partitions:
conv_partition, bn_partition = fused_partition
annotated_partitions.append(conv_partition.nodes + bn_partition.nodes)
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
conv_node_users = list(conv_node.users.keys())
if len(conv_node_users) > 1:
raise ValueError(
"Conv node must be consumed by BN only for it to be fusable."
)
if len(bn_partition.output_nodes) > 1:
raise ValueError("BatchNorm partition has more than one output node")
bn_output_node = bn_partition.output_nodes[0]

if _is_annotated([bn_output_node, conv_node]):
continue

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)

bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)

conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
)

bn_output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
_mark_nodes_as_annotated(nodes_to_mark_annotated)
return annotated_partitions


@register_annotator("conv2d_bn_relu")
def _annotate_conv2d_bn_relu(
@register_annotator("conv_bn_relu")
def _annotate_conv_bn_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""
Find Conv2d + batchnorm + relu parititions
Find conv + batchnorm + relu parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], filter_fn
)
annotated_partitions = []
for fused_partition in fused_partitions:
conv_partition, bn_partition, relu_partition = fused_partition
annotated_partitions.append(
conv_partition.nodes + bn_partition.nodes + relu_partition.nodes
)
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
conv_node_users = list(conv_node.users.keys())
if len(conv_node_users) > 1:
raise ValueError(
"Conv node must be consumed by BN only for it to be fusable."
)
if len(bn_partition.output_nodes) > 1:
raise ValueError("BatchNorm partition has more than one output node")
bn_output_node = bn_partition.output_nodes[0]
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)

if _is_annotated([relu_node, bn_output_node, conv_node]):
continue

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
def _do_annotate_conv_bn(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]],
has_relu: bool,
) -> List[List[Node]]:
"""
Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
return a list of annotated partitions.
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
The output of the pattern must include a dictionary from string name to node
for the following names: "input", "conv", "weight", "bias", and "output".
"""

bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
conv = conv_fn(x, conv_weight, conv_bias)
bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
if has_relu:
output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
else:
output = bn
return output, {
"input": x,
"conv": conv,
"weight": conv_weight,
"bias": conv_bias,
"output": output,
}

return _conv_bn

# Needed for matching, otherwise the matches gets filtered out due to unused
# nodes returned by batch norm
gm.graph.eliminate_dead_code()
gm.recompile()

matches = []
combinations = [
(F.conv2d, _conv2d_bn_example_inputs),
]

# Add `is_cuda` and `relu_is_inplace` dimensions
combinations = itertools.product(
combinations,
[True, False] if torch.cuda.is_available() else [False], # is_cuda
[True, False] if has_relu else [False], # relu_is_inplace
)

# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
pattern = get_pattern(conv_fn, relu_is_inplace)
pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
pattern.graph.eliminate_dead_code()
pattern.recompile()
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
matches.extend(matcher.match(gm.graph))

# Annotate nodes returned in the matches
annotated_partitions = []
for match in matches:
name_node_map = match.name_node_map
input_node = name_node_map["input"]
conv_node = name_node_map["conv"]
weight_node = name_node_map["weight"]
bias_node = name_node_map["bias"]
output_node = name_node_map["output"]

# TODO: annotate the uses of input, weight, and bias separately instead
# of assuming they come from a single conv node. This is not possible today
# because input may have multiple users, and we can't rely on the conv node
# always being the first user. This was the case in models with skip
# connections like resnet18

# Validate conv args
if conv_node.args[0] is not input_node:
raise ValueError("Conv arg did not contain input node ", input_node)
if conv_node.args[1] is not weight_node:
raise ValueError("Conv arg did not contain weight node ", weight_node)
if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
raise ValueError("Conv arg did not contain bias node ", bias_node)

# Skip if the partition is already annotated or is filtered out by the user
partition = [conv_node, weight_node]
if bias_node is not None:
partition.append(bias_node)
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue

# Annotate conv inputs and pattern output
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
if bias_node is not None:
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
input_qspec_map=input_qspec_map,
_annotated=True,
)

relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
nodes_to_mark_annotated.extend(list(relu_partition.nodes))
_mark_nodes_as_annotated(nodes_to_mark_annotated)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions


Expand Down

0 comments on commit 8241fe6

Please sign in to comment.