Skip to content

Commit

Permalink
[quant][pt2] Update special qspecs after QAT rewrite
Browse files Browse the repository at this point in the history
Summary:
Special qspecs like `SharedQuantizationSpec` and
`DerivedQuantizationSpec` refer to other nodes in the graph.
However, after subgraph rewriting in QAT, the nodes referred
to in these special qspecs may be replaced by new nodes.
This could lead to the following error when inserting
observers according to these qspecs:

```
AssertionError: please make sure only refer to edge or node
that has observer/fake_quant inserted: 'getitem' not in
dict_keys([(arg0, convolution_default_1), (mul_tensor, convolution_default_1), getitem_3])
```

This commit fixes this by keeping track of the nodes that
are replaced during subgraph rewriting in QAT, and using
this mapping to update the dangling references used in these
special qspecs.

Test Plan: python test/test_quantization.py TestQuantizePT2E.test_qat_update_shared_qspec

Reviewed By: jerryzh168

Differential Revision: D46606614

fbshipit-source-id: 70e937c5aede94530073c729aaa6994b88141ed3
  • Loading branch information
andrewor14 authored and facebook-github-bot committed Jun 21, 2023
1 parent 678ce61 commit b647905
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 25 deletions.
27 changes: 27 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,33 @@ def test_qat_conv_bn_relu_numerics(self):
m, example_inputs, is_per_channel=True, verify_convert=True,
)

def test_qat_update_shared_qspec(self):
"""
Test the case where nodes used in SharedQuantizationSpec were replaced
during QAT subgraph rewriting.
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.hardtanh = torch.nn.Hardtanh()

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.hardtanh(x)
return x
m = M()
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_numerics(
M(), example_inputs, is_per_channel=False, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M(), example_inputs, is_per_channel=True, verify_convert=True,
)


@skipIfNoQNNPACK
class TestQuantizePT2EOps(QuantizationTestCase):
def test_gru(self):
Expand Down
119 changes: 95 additions & 24 deletions torch/ao/quantization/_pt2e/qat_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import dataclasses
import itertools
import operator
from typing import Any, Callable, Dict, List, Tuple
Expand All @@ -8,6 +9,12 @@
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from .quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
SharedQuantizationSpec,
QuantizationSpecBase,
)
from .utils import _fold_bn_weights_into_conv_node

# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
Expand Down Expand Up @@ -427,6 +434,75 @@ def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
new_node.args = new_node.args[:3] + original_node.args[3:]

def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacement_node: Node):
"""
Update the `input_qspec_map` in the annotation after subgraph rewriting.
The original annotation referred to the nodes in the original graph,
so the keys in the `input_qspec_map` will need to be updated to reflect
the corresponding nodes in the replacement graph.
"""
assert original_node.target == torch.ops.aten.convolution.default
assert replacement_node.target == torch.ops.aten.convolution.default
if "quantization_annotation" not in original_node.meta:
return
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
input_qspec_map = {}
# get the list of configs, it should be ordered as input, weight, bias
# note: this is really hacky, we need a better solution, hopefully
# in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
all_configs = list(original_input_qspec_map.items())
# input activation
input_qspec_map[replacement_node.args[0]] = all_configs[0][1]
# weight
input_qspec_map[replacement_node.args[1]] = all_configs[1][1]
# bias
if len(replacement_node.args) > 2 and len(all_configs) > 2:
input_qspec_map[replacement_node.args[2]] = all_configs[2][1]
replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map

def _update_special_qspecs_after_replacement(
node: Node,
original_to_replacement_node: Dict[Node, Node],
):
"""
Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s
used in `node`'s quantization annotation after subgraph rewriting.
The original annotation referred to the nodes in the original graph,
so the nodes used in these special quantization specs will need to
be updated to the corresponding nodes in the replacement graph.
"""
def _get_new_edge_or_node(edge_or_node: EdgeOrNode):
if isinstance(edge_or_node, Node):
_node = edge_or_node
return original_to_replacement_node.get(_node, _node)
elif isinstance(edge_or_node, Tuple[Node, Node]):
src, dest = edge_or_node
return (
original_to_replacement_node.get(src, src),
original_to_replacement_node.get(dest, dest),
)
else:
raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node))

def _get_new_qspec(qspec: QuantizationSpecBase):
if isinstance(qspec, SharedQuantizationSpec):
new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node)
return SharedQuantizationSpec(new_edge_or_node)
elif isinstance(qspec, DerivedQuantizationSpec):
new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from]
return dataclasses.replace(qspec, derived_from=new_derived_from)
else:
return qspec

if "quantization_annotation" not in node.meta:
return
annotation = node.meta["quantization_annotation"]
for input_node, qspec in annotation.input_qspec_map.items():
annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
annotation.output_qspec = _get_new_qspec(annotation.output_qspec)

def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
"""
Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
Expand Down Expand Up @@ -480,49 +556,44 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# Due to limited functionality in the subgraph rewriter, here we manually
# update the replacement graph as follows:
#
# (1) Copy over metadata from original subgraph. This ensures the stack traces
# (a) Copy over metadata from original subgraph. This ensures the stack traces
# and annotations are preserved in the new subgraph
#
# (2) Copy over literal args for conv from the original subgraph
# (b) Copy over literal args for conv from the original subgraph
# TODO: do this for literal args for batchnorm as well
#
# (c) Update all references of the old nodes in the original subgraph to refer
# to the corresponding nodes in the new subgraph in the annotations
#
# In the future, we should try to push as much of this functionality into the
# subgraph rewriter as possible, so we don't have to manually copy anything over.
# For more detail, see https://github.com/pytorch/pytorch/issues/100419.

original_to_replacement_node = {}
for r in replacements_with_conv_bias + replacements_no_conv_bias:
(replacement_conv_node, replacement_bn_node, replacement_getitem_node) =\
_get_conv_bn_getitem_nodes(r.replacements)

# Copy over metadata for all three nodes in [conv - bn - getitem]
# Also copy over literal args for conv
# Step (3a): Copy over metadata for all three nodes in [conv - bn - getitem]
for match_pattern_node, original_node in _filter_nodes_map(r.nodes_map).items():
if original_node.target == torch.ops.aten.convolution.default:
_copy_over_literal_conv_args(original_node, replacement_conv_node)
replacement_conv_node.meta = original_node.meta
# original annotation is referring to the node object in the graph
# after rewrite we'll need to update this mapping (input_qspec_map)
# update quantization_annotation
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
if "quantization_annotation" not in original_node.meta:
continue
input_qspec_map = {}
# get the list of configs, it should be ordered as input, weight, bias
# note: this is really hacky, we need a better solution, hopefully
# in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
all_configs = list(original_input_qspec_map.items())
# input activation
input_qspec_map[replacement_conv_node.args[0]] = all_configs[0][1]
# weight
input_qspec_map[replacement_conv_node.args[1]] = all_configs[1][1]
# bias
if len(replacement_conv_node.args) > 2 and len(all_configs) > 2:
input_qspec_map[replacement_conv_node.args[2]] = all_configs[2][1]
replacement_conv_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
original_to_replacement_node[original_node] = replacement_conv_node
# Step (3b): Copy over conv literal args
_copy_over_literal_conv_args(original_node, replacement_conv_node)
# Step (3c): Update old references in the conv node's input_qspec_map
_update_conv_input_qspec_map_after_replacement(original_node, replacement_conv_node)
if original_node.target == torch.ops.aten._native_batch_norm_legit.default:
replacement_bn_node.meta = original_node.meta
original_to_replacement_node[original_node] = replacement_bn_node
if original_node.target == operator.getitem:
replacement_getitem_node.meta = original_node.meta
original_to_replacement_node[original_node] = replacement_getitem_node

# Step (3c): Update old references in the special qspecs for all nodes in the graph
for n in m.graph.nodes:
_update_special_qspecs_after_replacement(n, original_to_replacement_node)

return m

def _duplicate_dequantize_node(m: GraphModule):
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _create_obs_or_fq_from_qspec(
edge_or_node = quantization_spec.edge_or_node
assert edge_or_node in obs_or_fq_map, \
"please make sure only refer to edge or node that has " \
"observer/fake_quant inserted {} not in {}".format(edge_or_node, obs_or_fq_map)
"observer/fake_quant inserted: '{}' not in\n{}".format(edge_or_node, obs_or_fq_map.keys())
return obs_or_fq_map[edge_or_node]
elif isinstance(quantization_spec, DerivedQuantizationSpec):
# can't use asdict, so not calling get_observer_kwargs here
Expand Down

0 comments on commit b647905

Please sign in to comment.