diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index b1720d20ffa56..2bab0aa69052d 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -365,12 +365,16 @@ def test_propagate_annotation(self): m = prepare_pt2e(m, quantizer) m(*example_inputs) - self.assertEqual( - id(m.activation_post_process_2), id(m.activation_post_process_3) - ) - self.assertEqual( - id(m.activation_post_process_3), id(m.activation_post_process_4) - ) + act_post_processes_pairs = [] + for n in m.graph.nodes: + if n.target in [ + torch.ops.aten.view.default, + torch.ops.aten.hardtanh.default, + ]: + input_act = getattr(m, n.args[0].target) + output_act = getattr(m, list(n.users)[0].target) + self.assertEqual(id(input_act), id(output_act)) + m = convert_pt2e(m, fold_quantize=True) node_occurrence = { # input and output are using quantize_per_tensor and weight is using quantize_per_channel diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 2b0a3fe5642a8..885f27aef1658 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -1,9 +1,6 @@ import torch from torch._subclasses import FakeTensor from torch.ao.quantization.fx.prepare import ( - _get_arg_as_input_act_obs_or_fq, - _get_output_act_obs_or_fq, - _get_dtype_and_is_dynamic, _insert_obs_or_fq, _save_state, _is_activation_post_process_node, @@ -21,7 +18,6 @@ from torch.ao.quantization.fx.custom_config import PrepareCustomConfig from typing import Dict, Tuple, Union, Any, Optional from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, EdgeOrNode, SharedQuantizationSpec, QuantizationSpecBase, @@ -260,70 +256,56 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # default (no observer) new_arg = arg - quantization_annotation = node.meta.get("quantization_annotation", QuantizationAnnotation()) - arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat) - arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) - - arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat) - arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) - - if arg_as_input_target_is_dynamic or arg_as_input_target_dtype not in [torch.float, None]: - if arg_as_input_target_dtype == arg_as_output_target_dtype and \ - arg_as_input_target_is_dynamic == arg_as_output_target_is_dynamic: - assert _is_activation_post_process_node(arg, named_modules) - assert arg_as_input_act_obs_or_fq is not None - observed_arg = arg.args[0] - assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}" - assert observed_arg in obs_or_fq_map, \ - f"can't find a sharing group for node: {observed_arg}" - # reuse the existing obs/fq - arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg] - # we don't need to insert new observer node - new_arg = arg - else: - # skip inserting new observers if there is an observer inserted for the arg before - # that has the same dtype that we want to insert here - # alternatively we could have a dedup pass after we insert all observers to deduplicate - # observers - # Example: - # arg -> existing_obs -> conv1 - # \ -> conv2 - # - # instead of inserting new observers we will have: - # arg -> existing_obs -> conv1 - # \ -> conv2 - existing_obs_node = None - for maybe_obs_node in arg.users.keys(): - if maybe_obs_node.op == 'call_module': - maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] - if ( - type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and - maybe_obs_mod.dtype == arg_as_input_target_dtype - ): - arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] - existing_obs_node = maybe_obs_node - break - - assert arg_as_input_act_obs_or_fq is not None - if existing_obs_node is None: - maybe_observed_arg = arg - # When quantizing two layers with different configs we can have - # conv2d (int8) -> avgpool(uint8) - # In this case observer insertion for avgpool will come here but the input - # to avgpool will be output observer of conv2d - # Now the obs map that we update must correspond to the original input of - # avgpool and not the output obs of conv2d - # This is because when referring to the edge, quantizer would refer to - # original input and not the observed one. - while _is_activation_post_process_node(arg, named_modules): - arg = arg.args[0] # type: ignore[assignment] - arg_as_input_act_obs_or_fq = obs_or_fq_map[(arg, node)] - new_obs_node = _insert_obs_or_fq( - maybe_observed_arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph) - # override this arg to be the observed arg - new_arg = new_obs_node - else: - new_arg = existing_obs_node + # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes + original_arg = arg + while _is_activation_post_process_node(original_arg, named_modules): + original_arg = original_arg.args[0] # type: ignore[assignment] + assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}" + + input_edge = (original_arg, node) + if input_edge not in obs_or_fq_map: + return new_arg + # input_edge needs to be observed + input_edge_obs_or_fq = obs_or_fq_map[input_edge] + if input_edge_obs_or_fq is None: + return new_arg + + arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) + # the arg is observed as the output and is using the same instance as the input_edge + # we'll reuse the inserted observer/fake_quant + if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq): + return new_arg + + # otherwise, we'll insert a new observer/fake_quant node + + existing_obs_node = None + # skip inserting new observers if there is an observer inserted for the arg before + # that has the same dtype that we want to insert here + # alternatively we could have a dedup pass after we insert all observers to deduplicate + # observers + # Example: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + # + # instead of inserting new observers we will have: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + for maybe_obs_node in arg.users.keys(): + if not _is_activation_post_process_node(maybe_obs_node, named_modules): + continue + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if ( + type(maybe_obs_mod) == type(input_edge_obs_or_fq) and + maybe_obs_mod.dtype == input_edge_obs_or_fq.dtype + ): + input_edge_obs_or_fq = maybe_obs_mod # type: ignore[assignment] + existing_obs_node = maybe_obs_node + break + + if existing_obs_node is None: + new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph) + else: + new_arg = existing_obs_node return new_arg diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index 607e1b47a3bd3..2a2de76a0c6ee 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -138,7 +138,9 @@ class QuantizationAnnotation: """ # a map from torch.fx.Node to a type of QuantizationSpecBase - input_qspec_map: Dict[Node, QuantizationSpecBase] = field(default_factory=dict) + input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field( + default_factory=dict + ) # How the output of this node is quantized, expressed as QuantizationSpec # TODO: change the value to QuantizationSpec in a separate PR diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 40eb238d09193..d8b5d9a712a3b 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -154,12 +154,7 @@ def get_symmetric_quantization_config( ), ) - bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( - PlaceholderObserver - ) - bias_quantization_spec = QuantizationSpec( - dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr - ) + bias_quantization_spec = None if is_dynamic: quantization_config = QuantizationConfig( act_quantization_spec,