Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2][Quant] Use module partitions for conv2d and conv2d + relu #102395

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
142 changes: 81 additions & 61 deletions torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools

import operator
import itertools
from typing import Callable, Dict, List, Optional, Set, Any

import torch
Expand Down Expand Up @@ -275,11 +276,11 @@ def annotate_symmetric_config(
if config.is_qat:
self._annotate_conv2d_bn_relu(model, config)
self._annotate_conv2d_bn(model, config)
self._annotate_conv2d_relu(model, config)
self._annotate_conv2d(model, config)
for node in reversed(model.graph.nodes):
# one improvement is to register node annotators for each
# supported op type.
self._annotate_conv2d_relu(node, config)
self._annotate_conv2d(node, config)
self._annotate_maxpool2d(node, config)
self._annotate_add_relu(node, config)
self._annotate_add(node, config)
Expand Down Expand Up @@ -399,77 +400,96 @@ def _annotate_conv2d_bn_relu(
_mark_nodes_as_annotated(nodes_to_mark_annotated)

def _annotate_conv2d_relu(
self, node: Node, quantization_config: QuantizationConfig
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
if node.op != "call_function" or node.target not in [
torch.ops.aten.relu_.default,
torch.ops.aten.relu.default,
]:
return
relu_node = node
conv_node = relu_node.args[0]
assert isinstance(conv_node, Node)
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
return
if _is_annotated([relu_node, conv_node]):
return
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.ReLU]
)
for fused_partition in fused_partitions:
conv_partition, relu_partition = fused_partition
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_act_qspec(quantization_config)
if not isinstance(conv_node, Node):
raise ValueError(f"{conv_node} is not a Node")
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
if (
relu_node.op != "call_function"
or relu_node.target not in [torch.ops.aten.relu.default, torch.ops.aten.relu_.default]
):
raise ValueError(f"{relu_node} is not an aten relu operator")

if _is_annotated([relu_node, conv_node]):
continue

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_act_qspec(quantization_config)

bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)

conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)

conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
)

def _annotate_conv2d(
self, node: Node, quantization_config: QuantizationConfig
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
conv_node = node
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
return
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
return
conv_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
)
conv_partitions = list(itertools.chain(*conv_partitions.values()))
for conv_partition in conv_partitions:
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
continue

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_act_qspec(quantization_config)
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_act_qspec(quantization_config)

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)

bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
bias = conv_node.args[2]
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)

conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=get_act_qspec(quantization_config),
_annotated=True
)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=get_act_qspec(quantization_config),
_annotated=True
)

def _annotate_linear(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
Expand Down