Skip to content

Commit

Permalink
[quant] Enable quantization for wav2letter (#109830)
Browse files Browse the repository at this point in the history
Summary:
Also added annotation support for conv1d_relu and conv1d in XNNPACKQuantizer, the quantized results still
matches fx quant path (didn't quantize conv1d) so tests are not disabled

Test Plan: with-proxy buck2 run executorch/examples/quantization:example -- -m=w2l --verify

Differential Revision: D49479546

Pull Request resolved: #109830
Approved by: https://github.com/kimishpatel
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Sep 29, 2023
1 parent ce8b4f5 commit c9b8e06
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 41 deletions.
1 change: 0 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,6 @@ exclude_patterns = [
'test/quantization/jit/test_quantize_jit.py',
'test/quantization/pt2e/test_graph_utils.py',
'test/quantization/pt2e/test_quantize_pt2e.py',
'test/quantization/pt2e/test_quantize_pt2e_fx.py',
'test/quantization/pt2e/test_x86inductor_quantizer.py',
'test/scripts/cuda_memcheck_common.py',
'test/scripts/run_cuda_memcheck.py',
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -169,7 +169,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -206,7 +206,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass
Expand Down Expand Up @@ -288,7 +288,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
avgpool_qconfig = _get_uint8_quantization_config()
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)
for n in gm.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
Expand Down
14 changes: 4 additions & 10 deletions test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
Expand Down Expand Up @@ -189,9 +187,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
_tag_partitions(backend_string, "conv2d", annotated_partitions)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -232,9 +228,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
gm, quantization_config
)
annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
gm, quantization_config
Expand Down Expand Up @@ -305,7 +299,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
annotated_partitions = OP_TO_ANNOTATOR["conv"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
Expand Down
57 changes: 55 additions & 2 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
bias=None,
output_activation=int16_qspec,
)
OP_TO_ANNOTATOR["conv2d"](model, quantization_config)
OP_TO_ANNOTATOR["conv"](model, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass
Expand Down Expand Up @@ -1075,7 +1075,32 @@ def test_mul_and_inplace_mul(self):
node_list,
)

def test_xnnpack_quantizer_conv(self):
def test_xnnpack_quantizer_conv1d(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv1d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False),
example_inputs,
quantizer,
node_occurrence,
node_list,
)

def test_xnnpack_quantizer_conv2d(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
Expand All @@ -1101,6 +1126,34 @@ def test_xnnpack_quantizer_conv(self):
node_list,
)

def test_xnnpack_quantizer_conv1d_with_conv2d(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv1d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.Conv1dWithConv2d(),
example_inputs,
quantizer,
node_occurrence,
node_list,
)

def test_xnnpack_quantizer_linear(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
Expand Down
1 change: 1 addition & 0 deletions torch/ao/quantization/pt2e/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
]

_EQUIVALENT_TYPES: List[Set] = [
{torch.nn.Conv1d, torch.nn.functional.conv1d},
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ class XNNPACKQuantizer(Quantizer):
# static quantization ops (both PTQ and QAT)
STATIC_OPS = [
"linear",
"conv2d_relu",
"conv2d",
"conv_relu",
"conv",
"adaptive_avg_pool2d",
# TODO: move this to BoltNNQuantizer?
"gru_io_only",
Expand Down
50 changes: 31 additions & 19 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,27 +208,33 @@ def _annotate_linear(
return annotated_partitions


@register_annotator("conv2d")
def _annotate_conv2d(
@register_annotator("conv")
def _annotate_conv(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
conv_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d], filter_fn
gm.graph,
[
torch.nn.Conv2d,
torch.nn.functional.conv2d,
torch.nn.Conv1d,
torch.nn.functional.conv1d,
],
filter_fn,
)
conv_partitions = list(itertools.chain(*conv_partitions.values()))
annotated_partitions = []
for conv_partition in conv_partitions:
annotated_partitions.append(conv_partition.nodes)
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
if conv_node.op != "call_function" or conv_node.target not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]:
raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator")
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
continue
Expand All @@ -251,22 +257,26 @@ def _annotate_conv2d(
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
)
_mark_nodes_as_annotated(conv_partition.nodes)
annotated_partitions.append(conv_partition.nodes)
return annotated_partitions


@register_annotator("conv2d_relu")
def _annotate_conv2d_relu(
@register_annotator("conv_relu")
def _annotate_conv_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
fused_partitions = find_sequential_partitions(
fused_partitions1 = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.ReLU], filter_fn
)
fused_partitions2 = find_sequential_partitions(
gm, [torch.nn.Conv1d, torch.nn.ReLU], filter_fn
)
annotated_partitions = []
for fused_partition in fused_partitions:
for fused_partition in fused_partitions1 + fused_partitions2:
conv_partition, relu_partition = fused_partition
annotated_partitions.append(conv_partition.nodes + relu_partition.nodes)
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
Expand All @@ -276,11 +286,11 @@ def _annotate_conv2d_relu(

if not isinstance(conv_node, Node):
raise ValueError(f"{conv_node} is not a Node")
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
if conv_node.op != "call_function" or conv_node.target not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]:
raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator")
if relu_node.op != "call_function" or relu_node.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
Expand Down Expand Up @@ -310,6 +320,8 @@ def _annotate_conv2d_relu(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
_mark_nodes_as_annotated(conv_partition.nodes + relu_partition.nodes)
annotated_partitions.append(conv_partition.nodes + relu_partition.nodes)
return annotated_partitions


Expand Down
21 changes: 18 additions & 3 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,11 +2527,14 @@ def forward(self, x):
return x

class ConvWithBNRelu(torch.nn.Module):
def __init__(self, relu, bn=True, bias=True):
def __init__(self, relu, dim=2, bn=True, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias)
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
self.conv = convs[dim](3, 3, 3, bias=bias)

if bn:
self.bn = torch.nn.BatchNorm2d(3)
self.bn = bns[dim](3)
else:
self.bn = torch.nn.Identity()
if relu:
Expand All @@ -2544,6 +2547,18 @@ def forward(self, x):
x = self.bn(x)
return self.relu(x)

class Conv1dWithConv2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1d = torch.nn.Conv1d(3, 3, 3)
self.conv2d = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
x = self.conv2d(x)
x = x.squeeze(0)
x = self.conv1d(x)
return x

class Conv2dWithCat(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit c9b8e06

Please sign in to comment.