From 04d02c7ace7c510587d4c2313d4a62a84ac9fce7 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 9 Apr 2024 17:02:05 -0700 Subject: [PATCH] Skip annotate boolean input Differential Revision: [D55946526](https://our.internmc.facebook.com/intern/diff/D55946526/) [ghstack-poisoned] --- backends/qualcomm/quantizer/utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index 809b7298eba..e2cf9b3b106 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -9,6 +9,7 @@ import torch from torch._ops import OpOverload +from torch._subclasses import FakeTensor from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -41,6 +42,14 @@ def decorator(annotator: Callable): return decorator +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + def _is_annotated(nodes: List[Node]): """ @@ -123,11 +132,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None input_qspec_map = {} input_act0 = node.args[0] - if isinstance(input_act0, Node): + if isinstance(input_act0, Node) and not _is_input_non_float_tensor(input_act0): input_qspec_map[input_act0] = input_act_qspec input_act1 = node.args[1] - if isinstance(input_act1, Node): + if isinstance(input_act1, Node) and not _is_input_non_float_tensor(input_act1): input_qspec_map[input_act1] = input_act_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(