From 70ef2b2fbaffd69abb1b75987d7654259378eb23 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 21 Sep 2023 15:00:17 -0700 Subject: [PATCH] [quant] Enable quantization for wav2letter Summary: Also added annotation support for conv1d_relu and conv1d in XNNPACKQuantizer, the quantized results still matches fx quant path (didn't quantize conv1d) so tests are not disabled Test Plan: with-proxy buck2 run executorch/examples/quantization:example -- -m=w2l --verify Differential Revision: D49479546 --- test/quantization/pt2e/test_duplicate_dq.py | 8 +-- .../pt2e/test_metadata_porting.py | 8 +-- test/quantization/pt2e/test_quantize_pt2e.py | 55 ++++++++++++++++++- torch/ao/quantization/pt2e/graph_utils.py | 1 + .../quantizer/xnnpack_quantizer.py | 4 +- .../quantizer/xnnpack_quantizer_utils.py | 42 +++++++++----- .../testing/_internal/common_quantization.py | 21 ++++++- 7 files changed, 112 insertions(+), 27 deletions(-) diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index bf6be1c9ace86..bd75c2f32e940 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -134,7 +134,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) OP_TO_ANNOTATOR["linear"](gm, quantization_config) - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: @@ -170,7 +170,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) OP_TO_ANNOTATOR["linear"](gm, quantization_config) - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) OP_TO_ANNOTATOR["add"](gm, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: @@ -207,7 +207,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) OP_TO_ANNOTATOR["linear"](gm, quantization_config) - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: pass @@ -289,7 +289,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) avgpool_qconfig = _get_uint8_quantization_config() - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) OP_TO_ANNOTATOR["add"](gm, quantization_config) for n in gm.graph.nodes: if n.op == "call_function" and n.target == torch.ops.aten.mean.dim: diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 900e60733f433..c3fce74f92719 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -115,7 +115,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm, quantization_config ) _tag_partitions(backend_string, "linear", annotated_partitions) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + annotated_partitions = OP_TO_ANNOTATOR["conv"]( gm, quantization_config ) _tag_partitions(backend_string, "conv2d", annotated_partitions) @@ -171,7 +171,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm, quantization_config ) _tag_partitions(backend_string, "linear", annotated_partitions) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + annotated_partitions = OP_TO_ANNOTATOR["conv"]( gm, quantization_config ) _tag_partitions(backend_string, "conv2d", annotated_partitions) @@ -212,7 +212,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: quantization_config = get_symmetric_quantization_config( is_per_channel=True ) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + annotated_partitions = OP_TO_ANNOTATOR["conv"]( gm, quantization_config ) _tag_partitions(backend_string, "conv2d", annotated_partitions) @@ -282,7 +282,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: quantization_config_dynamic = get_symmetric_quantization_config( is_per_channel=True, is_dynamic=True ) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + annotated_partitions = OP_TO_ANNOTATOR["conv"]( gm, quantization_config_dynamic ) _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 59baeaf162cb8..5b6aafdcb5d51 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -982,7 +982,32 @@ def test_mul_and_inplace_mul(self): node_list, ) - def test_xnnpack_quantizer_conv(self): + def test_xnnpack_quantizer_conv1d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_xnnpack_quantizer_conv2d(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) quantizer.set_global(quantization_config) @@ -1007,6 +1032,34 @@ def test_xnnpack_quantizer_conv(self): node_list, ) + def test_xnnpack_quantizer_conv1d_with_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.Conv1dWithConv2d(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + def test_xnnpack_quantizer_linear(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py index e9b31e83f6ef5..2d4e73dee3577 100644 --- a/torch/ao/quantization/pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -18,6 +18,7 @@ ] _EQUIVALENT_TYPES: List[Set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, {torch.nn.Conv2d, torch.nn.functional.conv2d}, {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index f1b9937d0fcd6..eb6cf1b59688c 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -245,8 +245,8 @@ class XNNPACKQuantizer(Quantizer): # static quantization ops (both PTQ and QAT) STATIC_OPS = [ "linear", - "conv2d_relu", - "conv2d", + "conv_relu", + "conv", "adaptive_avg_pool2d", # TODO: move this to BoltNNQuantizer? "gru_io_only", diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index ab078191a29ef..bbfde41552b06 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -208,27 +208,34 @@ def _annotate_linear( return annotated_partitions -@register_annotator("conv2d") -def _annotate_conv2d( +@register_annotator("conv") +def _annotate_conv( gm: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: conv_partitions = get_source_partitions( - gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d], filter_fn + gm.graph, + [torch.nn.Conv2d, + torch.nn.functional.conv2d, + torch.nn.Conv1d, + torch.nn.functional.conv1d], + filter_fn ) conv_partitions = list(itertools.chain(*conv_partitions.values())) annotated_partitions = [] for conv_partition in conv_partitions: - annotated_partitions.append(conv_partition.nodes) if len(conv_partition.output_nodes) > 1: raise ValueError("conv partition has more than one output node") conv_node = conv_partition.output_nodes[0] if ( conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default + or conv_node.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ] ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") + raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator") # skip annotation if it is already annotated if _is_annotated([conv_node]): continue @@ -251,22 +258,26 @@ def _annotate_conv2d( output_qspec=get_output_act_qspec(quantization_config), _annotated=True, ) + _mark_nodes_as_annotated(conv_partition.nodes) + annotated_partitions.append(conv_partition.nodes) return annotated_partitions -@register_annotator("conv2d_relu") -def _annotate_conv2d_relu( +@register_annotator("conv_relu") +def _annotate_conv_relu( gm: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - fused_partitions = find_sequential_partitions( + fused_partitions1 = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.ReLU], filter_fn ) + fused_partitions2 = find_sequential_partitions( + gm, [torch.nn.Conv1d, torch.nn.ReLU], filter_fn + ) annotated_partitions = [] - for fused_partition in fused_partitions: + for fused_partition in (fused_partitions1 + fused_partitions2): conv_partition, relu_partition = fused_partition - annotated_partitions.append(conv_partition.nodes + relu_partition.nodes) if len(relu_partition.output_nodes) > 1: raise ValueError("Relu partition has more than one output node") relu_node = relu_partition.output_nodes[0] @@ -278,9 +289,12 @@ def _annotate_conv2d_relu( raise ValueError(f"{conv_node} is not a Node") if ( conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default + or conv_node.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ] ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") + raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator") if relu_node.op != "call_function" or relu_node.target not in [ torch.ops.aten.relu.default, torch.ops.aten.relu_.default, @@ -310,6 +324,8 @@ def _annotate_conv2d_relu( output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, ) + _mark_nodes_as_annotated(conv_partition.nodes + relu_partition.nodes) + annotated_partitions.append(conv_partition.nodes + relu_partition.nodes) return annotated_partitions diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index c28ffcc64abda..553154709fe7b 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -2527,11 +2527,14 @@ def forward(self, x): return x class ConvWithBNRelu(torch.nn.Module): - def __init__(self, relu, bn=True, bias=True): + def __init__(self, relu, dim=2, bn=True, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias) + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d} + bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} + self.conv = convs[dim](3, 3, 3, bias=bias) + if bn: - self.bn = torch.nn.BatchNorm2d(3) + self.bn = bns[dim](3) else: self.bn = torch.nn.Identity() if relu: @@ -2544,6 +2547,18 @@ def forward(self, x): x = self.bn(x) return self.relu(x) + class Conv1dWithConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = torch.nn.Conv1d(3, 3, 3) + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv2d(x) + x = x.squeeze(0) + x = self.conv1d(x) + return x + class Conv2dWithCat(torch.nn.Module): def __init__(self): super().__init__()