From 6a636f65cb9815765c9047f536ff98cdd06e31a0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 12 Apr 2023 18:07:35 -0700 Subject: [PATCH 1/2] [quant][pt2e][improvement] Remove the need to annotate all nodes with default annotation Summary: This PR changes prepare to use some default observer/fq constructor when "target_dtype_info" is not set, this allows user to not initialize all nodes to default observer/fq constructor. Note we may still need to annotate intermediate node after this PR, there will be a follow up PR to allow users to only annotate things they want to quantize Test Plan: python test/test_quantization.py TestQuantizePT2E python test/test_quantization.py TestQuantizePT2EModels Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/quantization/fx/test_quantize_pt2e.py | 2 -- .../ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py | 8 -------- torch/ao/quantization/fx/prepare.py | 9 +++++++-- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/test/quantization/fx/test_quantize_pt2e.py b/test/quantization/fx/test_quantize_pt2e.py index 7456969a7fbb..bf350c388692 100644 --- a/test/quantization/fx/test_quantize_pt2e.py +++ b/test/quantization/fx/test_quantize_pt2e.py @@ -265,10 +265,8 @@ def forward(self, x): ) m = prepare_pt2e_quantizer(m, quantizer) - print("after prepare:", m) m(*example_inputs) m = convert_pt2e(m) - print("m:", m) node_occurrence = { # input and output are using quantize_per_tensor and weight is using quantize_per_channel ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 5, diff --git a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py index 7b51833910cd..2f5352ba9cdc 100644 --- a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py +++ b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py @@ -258,14 +258,6 @@ def set_spec_for_operator_type( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """ just handling global spec for now """ - # initialize default target_dtype_info - _DEFAULT_TARGET_DTYPE_INFO = { - "input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), - "output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(), - } - for node in model.graph.nodes: - node.meta["target_dtype_info"] = copy.deepcopy(_DEFAULT_TARGET_DTYPE_INFO) - global_spec = self.operator_spec_config.global_spec ops = self.get_supported_operator_for_operator_spec(global_spec) # annotate the nodes from last to first since the matching is in the reversed order diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 64c3972c3820..a4c2273e9897 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -115,6 +115,8 @@ # list of dtypes to not add observers to _DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None] +_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float) + # note: the following default target dtype info dicts are temporary, # should be moved to the new programmable API class soon _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = { @@ -497,8 +499,11 @@ def _get_arg_target_dtype_as_output( assert isinstance(observed_arg, Node), "Currently we only support observing Node" output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] else: - output_act_obs_or_fq_ctr = \ - arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] + if "target_dtype_info" in arg.meta: + output_act_obs_or_fq_ctr = \ + arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] + else: + output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr) # TODO: should support is_dynamic here as well return output_act_dtype From dee5289531fc9edc696da1ef0f5134cf3bfe81d5 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 12 Apr 2023 18:14:08 -0700 Subject: [PATCH 2/2] Update on "[quant][pt2e][improvement] Remove the need to annotate all nodes with default annotation" Summary: This PR changes prepare to use some default observer/fq constructor when "target_dtype_info" is not set, this allows user to not initialize all nodes to default observer/fq constructor. Note we may still need to annotate intermediate node after this PR, there will be a follow up PR to allow users to only annotate things they want to quantize Test Plan: python test/test_quantization.py TestQuantizePT2E python test/test_quantization.py TestQuantizePT2EModels Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/ao/quantization/fx/prepare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index a4c2273e9897..d57f315a0a84 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -501,7 +501,7 @@ def _get_arg_target_dtype_as_output( else: if "target_dtype_info" in arg.meta: output_act_obs_or_fq_ctr = \ - arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] + arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR) else: output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)