diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index d51a9b4e1e91d..0eb69bd2c84a9 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -133,7 +133,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) OP_TO_ANNOTATOR["linear"](gm, quantization_config) - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: @@ -169,7 +169,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) OP_TO_ANNOTATOR["linear"](gm, quantization_config) - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) OP_TO_ANNOTATOR["add"](gm, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: @@ -206,7 +206,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) OP_TO_ANNOTATOR["linear"](gm, quantization_config) - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: pass @@ -288,7 +288,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: is_per_channel=True ) avgpool_qconfig = _get_uint8_quantization_config() - OP_TO_ANNOTATOR["conv2d"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) OP_TO_ANNOTATOR["add"](gm, quantization_config) for n in gm.graph.nodes: if n.op == "call_function" and n.target == torch.ops.aten.mean.dim: diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index c3641047626e1..59823fbf7ea80 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -128,9 +128,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm, quantization_config ) _tag_partitions(backend_string, "linear", annotated_partitions) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( - gm, quantization_config - ) + annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) _tag_partitions(backend_string, "conv2d", annotated_partitions) annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( gm, quantization_config @@ -189,9 +187,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm, quantization_config ) _tag_partitions(backend_string, "linear", annotated_partitions) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( - gm, quantization_config - ) + annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) _tag_partitions(backend_string, "conv2d", annotated_partitions) def validate(self, model: torch.fx.GraphModule) -> None: @@ -232,9 +228,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: quantization_config = get_symmetric_quantization_config( is_per_channel=True ) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( - gm, quantization_config - ) + annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) _tag_partitions(backend_string, "conv2d", annotated_partitions) annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( gm, quantization_config @@ -305,7 +299,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: quantization_config_dynamic = get_symmetric_quantization_config( is_per_channel=True, is_dynamic=True ) - annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( + annotated_partitions = OP_TO_ANNOTATOR["conv"]( gm, quantization_config_dynamic ) _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index d526ea3a750d7..eb6c2d4b7ba66 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -849,7 +849,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: bias=None, output_activation=int16_qspec, ) - OP_TO_ANNOTATOR["conv2d"](model, quantization_config) + OP_TO_ANNOTATOR["conv"](model, quantization_config) def validate(self, model: torch.fx.GraphModule) -> None: pass @@ -1075,7 +1075,32 @@ def test_mul_and_inplace_mul(self): node_list, ) - def test_xnnpack_quantizer_conv(self): + def test_xnnpack_quantizer_conv1d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_xnnpack_quantizer_conv2d(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) quantizer.set_global(quantization_config) @@ -1101,6 +1126,34 @@ def test_xnnpack_quantizer_conv(self): node_list, ) + def test_xnnpack_quantizer_conv1d_with_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.Conv1dWithConv2d(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + def test_xnnpack_quantizer_linear(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py index e9b31e83f6ef5..2d4e73dee3577 100644 --- a/torch/ao/quantization/pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -18,6 +18,7 @@ ] _EQUIVALENT_TYPES: List[Set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, {torch.nn.Conv2d, torch.nn.functional.conv2d}, {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 96c2f4c077504..40eb238d09193 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -246,8 +246,8 @@ class XNNPACKQuantizer(Quantizer): # static quantization ops (both PTQ and QAT) STATIC_OPS = [ "linear", - "conv2d_relu", - "conv2d", + "conv_relu", + "conv", "adaptive_avg_pool2d", # TODO: move this to BoltNNQuantizer? "gru_io_only", diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index ab078191a29ef..add533dd80d15 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -208,27 +208,33 @@ def _annotate_linear( return annotated_partitions -@register_annotator("conv2d") -def _annotate_conv2d( +@register_annotator("conv") +def _annotate_conv( gm: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: conv_partitions = get_source_partitions( - gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d], filter_fn + gm.graph, + [ + torch.nn.Conv2d, + torch.nn.functional.conv2d, + torch.nn.Conv1d, + torch.nn.functional.conv1d, + ], + filter_fn, ) conv_partitions = list(itertools.chain(*conv_partitions.values())) annotated_partitions = [] for conv_partition in conv_partitions: - annotated_partitions.append(conv_partition.nodes) if len(conv_partition.output_nodes) > 1: raise ValueError("conv partition has more than one output node") conv_node = conv_partition.output_nodes[0] - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default - ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") + if conv_node.op != "call_function" or conv_node.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator") # skip annotation if it is already annotated if _is_annotated([conv_node]): continue @@ -251,22 +257,26 @@ def _annotate_conv2d( output_qspec=get_output_act_qspec(quantization_config), _annotated=True, ) + _mark_nodes_as_annotated(conv_partition.nodes) + annotated_partitions.append(conv_partition.nodes) return annotated_partitions -@register_annotator("conv2d_relu") -def _annotate_conv2d_relu( +@register_annotator("conv_relu") +def _annotate_conv_relu( gm: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - fused_partitions = find_sequential_partitions( + fused_partitions1 = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.ReLU], filter_fn ) + fused_partitions2 = find_sequential_partitions( + gm, [torch.nn.Conv1d, torch.nn.ReLU], filter_fn + ) annotated_partitions = [] - for fused_partition in fused_partitions: + for fused_partition in fused_partitions1 + fused_partitions2: conv_partition, relu_partition = fused_partition - annotated_partitions.append(conv_partition.nodes + relu_partition.nodes) if len(relu_partition.output_nodes) > 1: raise ValueError("Relu partition has more than one output node") relu_node = relu_partition.output_nodes[0] @@ -276,11 +286,11 @@ def _annotate_conv2d_relu( if not isinstance(conv_node, Node): raise ValueError(f"{conv_node} is not a Node") - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default - ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") + if conv_node.op != "call_function" or conv_node.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator") if relu_node.op != "call_function" or relu_node.target not in [ torch.ops.aten.relu.default, torch.ops.aten.relu_.default, @@ -310,6 +320,8 @@ def _annotate_conv2d_relu( output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] _annotated=True, ) + _mark_nodes_as_annotated(conv_partition.nodes + relu_partition.nodes) + annotated_partitions.append(conv_partition.nodes + relu_partition.nodes) return annotated_partitions diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index c28ffcc64abda..553154709fe7b 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -2527,11 +2527,14 @@ def forward(self, x): return x class ConvWithBNRelu(torch.nn.Module): - def __init__(self, relu, bn=True, bias=True): + def __init__(self, relu, dim=2, bn=True, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias) + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d} + bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} + self.conv = convs[dim](3, 3, 3, bias=bias) + if bn: - self.bn = torch.nn.BatchNorm2d(3) + self.bn = bns[dim](3) else: self.bn = torch.nn.Identity() if relu: @@ -2544,6 +2547,18 @@ def forward(self, x): x = self.bn(x) return self.relu(x) + class Conv1dWithConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = torch.nn.Conv1d(3, 3, 3) + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv2d(x) + x = x.squeeze(0) + x = self.conv1d(x) + return x + class Conv2dWithCat(torch.nn.Module): def __init__(self): super().__init__()