Skip to content

Commit

Permalink
[quant] Enable quantization for wav2letter
Browse files Browse the repository at this point in the history
Summary:
Also added annotation support for conv1d_relu and conv1d in XNNPACKQuantizer, the quantized results still
matches fx quant path (didn't quantize conv1d) so tests are not disabled

Test Plan: with-proxy buck2 run executorch/examples/quantization:example -- -m=w2l --verify

Differential Revision: D49479546
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Sep 21, 2023
1 parent 255d1a7 commit 70ef2b2
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 27 deletions.
8 changes: 4 additions & 4 deletions test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -170,7 +170,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -207,7 +207,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass
Expand Down Expand Up @@ -289,7 +289,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
is_per_channel=True
)
avgpool_qconfig = _get_uint8_quantization_config()
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)
for n in gm.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
annotated_partitions = OP_TO_ANNOTATOR["conv"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
Expand Down Expand Up @@ -171,7 +171,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
gm, quantization_config
)
_tag_partitions(backend_string, "linear", annotated_partitions)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
annotated_partitions = OP_TO_ANNOTATOR["conv"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
Expand Down Expand Up @@ -212,7 +212,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
annotated_partitions = OP_TO_ANNOTATOR["conv"](
gm, quantization_config
)
_tag_partitions(backend_string, "conv2d", annotated_partitions)
Expand Down Expand Up @@ -282,7 +282,7 @@ def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
annotated_partitions = OP_TO_ANNOTATOR["conv2d"](
annotated_partitions = OP_TO_ANNOTATOR["conv"](
gm, quantization_config_dynamic
)
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
Expand Down
55 changes: 54 additions & 1 deletion test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,32 @@ def test_mul_and_inplace_mul(self):
node_list,
)

def test_xnnpack_quantizer_conv(self):
def test_xnnpack_quantizer_conv1d(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv1d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False),
example_inputs,
quantizer,
node_occurrence,
node_list,
)

def test_xnnpack_quantizer_conv2d(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
Expand All @@ -1007,6 +1032,34 @@ def test_xnnpack_quantizer_conv(self):
node_list,
)

def test_xnnpack_quantizer_conv1d_with_conv2d(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv1d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.Conv1dWithConv2d(),
example_inputs,
quantizer,
node_occurrence,
node_list,
)

def test_xnnpack_quantizer_linear(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
Expand Down
1 change: 1 addition & 0 deletions torch/ao/quantization/pt2e/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
]

_EQUIVALENT_TYPES: List[Set] = [
{torch.nn.Conv1d, torch.nn.functional.conv1d},
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ class XNNPACKQuantizer(Quantizer):
# static quantization ops (both PTQ and QAT)
STATIC_OPS = [
"linear",
"conv2d_relu",
"conv2d",
"conv_relu",
"conv",
"adaptive_avg_pool2d",
# TODO: move this to BoltNNQuantizer?
"gru_io_only",
Expand Down
42 changes: 29 additions & 13 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,27 +208,34 @@ def _annotate_linear(
return annotated_partitions


@register_annotator("conv2d")
def _annotate_conv2d(
@register_annotator("conv")
def _annotate_conv(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
conv_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d], filter_fn
gm.graph,
[torch.nn.Conv2d,
torch.nn.functional.conv2d,
torch.nn.Conv1d,
torch.nn.functional.conv1d],
filter_fn
)
conv_partitions = list(itertools.chain(*conv_partitions.values()))
annotated_partitions = []
for conv_partition in conv_partitions:
annotated_partitions.append(conv_partition.nodes)
if len(conv_partition.output_nodes) > 1:
raise ValueError("conv partition has more than one output node")
conv_node = conv_partition.output_nodes[0]
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
or conv_node.target not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator")
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
continue
Expand All @@ -251,22 +258,26 @@ def _annotate_conv2d(
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
)
_mark_nodes_as_annotated(conv_partition.nodes)
annotated_partitions.append(conv_partition.nodes)
return annotated_partitions


@register_annotator("conv2d_relu")
def _annotate_conv2d_relu(
@register_annotator("conv_relu")
def _annotate_conv_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
fused_partitions = find_sequential_partitions(
fused_partitions1 = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.ReLU], filter_fn
)
fused_partitions2 = find_sequential_partitions(
gm, [torch.nn.Conv1d, torch.nn.ReLU], filter_fn
)
annotated_partitions = []
for fused_partition in fused_partitions:
for fused_partition in (fused_partitions1 + fused_partitions2):
conv_partition, relu_partition = fused_partition
annotated_partitions.append(conv_partition.nodes + relu_partition.nodes)
if len(relu_partition.output_nodes) > 1:
raise ValueError("Relu partition has more than one output node")
relu_node = relu_partition.output_nodes[0]
Expand All @@ -278,9 +289,12 @@ def _annotate_conv2d_relu(
raise ValueError(f"{conv_node} is not a Node")
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
or conv_node.target not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
):
raise ValueError(f"{conv_node} is not an aten conv2d operator")
raise ValueError(f"{conv_node} is not an aten conv1d or conv2d operator")
if relu_node.op != "call_function" or relu_node.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
Expand Down Expand Up @@ -310,6 +324,8 @@ def _annotate_conv2d_relu(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
_mark_nodes_as_annotated(conv_partition.nodes + relu_partition.nodes)
annotated_partitions.append(conv_partition.nodes + relu_partition.nodes)
return annotated_partitions


Expand Down
21 changes: 18 additions & 3 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,11 +2527,14 @@ def forward(self, x):
return x

class ConvWithBNRelu(torch.nn.Module):
def __init__(self, relu, bn=True, bias=True):
def __init__(self, relu, dim=2, bn=True, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias)
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
self.conv = convs[dim](3, 3, 3, bias=bias)

if bn:
self.bn = torch.nn.BatchNorm2d(3)
self.bn = bns[dim](3)
else:
self.bn = torch.nn.Identity()
if relu:
Expand All @@ -2544,6 +2547,18 @@ def forward(self, x):
x = self.bn(x)
return self.relu(x)

class Conv1dWithConv2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1d = torch.nn.Conv1d(3, 3, 3)
self.conv2d = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
x = self.conv2d(x)
x = x.squeeze(0)
x = self.conv1d(x)
return x

class Conv2dWithCat(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 70ef2b2

Please sign in to comment.