Skip to content

Commit

Permalink
[quant][pt2e][be] Cleanup observer insertion logic
Browse files Browse the repository at this point in the history
Summary:
att, after SharedQuantizationSpec bug fix we are doing some checks before hand, this can simplify the logic when we insert observers

Test Plan:
python test/test_quantization.py TestQuantizePT2E

CIs

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 10558a74d77ed81875be444cb6b55d143c70c0ea
Pull Request resolved: #111828
  • Loading branch information
jerryzh168 committed Oct 23, 2023
1 parent 45d26af commit 24e93ee
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 81 deletions.
16 changes: 10 additions & 6 deletions test/quantization/pt2e/test_xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,16 @@ def test_propagate_annotation(self):

m = prepare_pt2e(m, quantizer)
m(*example_inputs)
self.assertEqual(
id(m.activation_post_process_2), id(m.activation_post_process_3)
)
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_4)
)
act_post_processes_pairs = []
for n in m.graph.nodes:
if n.target in [
torch.ops.aten.view.default,
torch.ops.aten.hardtanh.default,
]:
input_act = getattr(m, n.args[0].target)
output_act = getattr(m, list(n.users)[0].target)
self.assertEqual(id(input_act), id(output_act))

m = convert_pt2e(m, fold_quantize=True)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
Expand Down
118 changes: 50 additions & 68 deletions torch/ao/quantization/pt2e/prepare.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import torch
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.prepare import (
_get_arg_as_input_act_obs_or_fq,
_get_output_act_obs_or_fq,
_get_dtype_and_is_dynamic,
_insert_obs_or_fq,
_save_state,
_is_activation_post_process_node,
Expand All @@ -21,7 +18,6 @@
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple, Union, Any, Optional
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
EdgeOrNode,
SharedQuantizationSpec,
QuantizationSpecBase,
Expand Down Expand Up @@ -260,70 +256,56 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
# default (no observer)
new_arg = arg

quantization_annotation = node.meta.get("quantization_annotation", QuantizationAnnotation())
arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat)
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)

arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)

if arg_as_input_target_is_dynamic or arg_as_input_target_dtype not in [torch.float, None]:
if arg_as_input_target_dtype == arg_as_output_target_dtype and \
arg_as_input_target_is_dynamic == arg_as_output_target_is_dynamic:
assert _is_activation_post_process_node(arg, named_modules)
assert arg_as_input_act_obs_or_fq is not None
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}"
assert observed_arg in obs_or_fq_map, \
f"can't find a sharing group for node: {observed_arg}"
# reuse the existing obs/fq
arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg]
# we don't need to insert new observer node
new_arg = arg
else:
# skip inserting new observers if there is an observer inserted for the arg before
# that has the same dtype that we want to insert here
# alternatively we could have a dedup pass after we insert all observers to deduplicate
# observers
# Example:
# arg -> existing_obs -> conv1
# \ -> conv2
#
# instead of inserting new observers we will have:
# arg -> existing_obs -> conv1
# \ -> conv2
existing_obs_node = None
for maybe_obs_node in arg.users.keys():
if maybe_obs_node.op == 'call_module':
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
maybe_obs_mod.dtype == arg_as_input_target_dtype
):
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
existing_obs_node = maybe_obs_node
break

assert arg_as_input_act_obs_or_fq is not None
if existing_obs_node is None:
maybe_observed_arg = arg
# When quantizing two layers with different configs we can have
# conv2d (int8) -> avgpool(uint8)
# In this case observer insertion for avgpool will come here but the input
# to avgpool will be output observer of conv2d
# Now the obs map that we update must correspond to the original input of
# avgpool and not the output obs of conv2d
# This is because when referring to the edge, quantizer would refer to
# original input and not the observed one.
while _is_activation_post_process_node(arg, named_modules):
arg = arg.args[0] # type: ignore[assignment]
arg_as_input_act_obs_or_fq = obs_or_fq_map[(arg, node)]
new_obs_node = _insert_obs_or_fq(
maybe_observed_arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
# override this arg to be the observed arg
new_arg = new_obs_node
else:
new_arg = existing_obs_node
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
original_arg = arg
while _is_activation_post_process_node(original_arg, named_modules):
original_arg = original_arg.args[0] # type: ignore[assignment]
assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}"

input_edge = (original_arg, node)
if input_edge not in obs_or_fq_map:
return new_arg
# input_edge needs to be observed
input_edge_obs_or_fq = obs_or_fq_map[input_edge]
if input_edge_obs_or_fq is None:
return new_arg

arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
# the arg is observed as the output and is using the same instance as the input_edge
# we'll reuse the inserted observer/fake_quant
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq):
return new_arg

# otherwise, we'll insert a new observer/fake_quant node

existing_obs_node = None
# skip inserting new observers if there is an observer inserted for the arg before
# that has the same dtype that we want to insert here
# alternatively we could have a dedup pass after we insert all observers to deduplicate
# observers
# Example:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
#
# instead of inserting new observers we will have:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
for maybe_obs_node in arg.users.keys():
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
continue
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(input_edge_obs_or_fq) and
maybe_obs_mod.dtype == input_edge_obs_or_fq.dtype
):
input_edge_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
existing_obs_node = maybe_obs_node
break

if existing_obs_node is None:
new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
else:
new_arg = existing_obs_node

return new_arg

Expand Down
4 changes: 3 additions & 1 deletion torch/ao/quantization/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class QuantizationAnnotation:
"""

# a map from torch.fx.Node to a type of QuantizationSpecBase
input_qspec_map: Dict[Node, QuantizationSpecBase] = field(default_factory=dict)
input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field(
default_factory=dict
)

# How the output of this node is quantized, expressed as QuantizationSpec
# TODO: change the value to QuantizationSpec in a separate PR
Expand Down
7 changes: 1 addition & 6 deletions torch/ao/quantization/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,7 @@ def get_symmetric_quantization_config(
),
)

bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
bias_quantization_spec = None
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec,
Expand Down

0 comments on commit 24e93ee

Please sign in to comment.