diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index d584cd128ec..2649ed5b154 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -68,7 +68,7 @@ def _is_float_tensor(node: Node): or not isinstance(node.meta["val"], FakeTensor) ): return False - return node.meta["val"].dtype == torch.float32 + return node.meta["val"].dtype in (torch.bfloat16, torch.float32) def _mark_nodes_as_annotated(nodes: List[Node]): diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 2f26cd27d31..30af923781a 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -205,7 +205,6 @@ def get_16a8w_qnn_qat_config( quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer.with_args(**extra_args), ) act_quantization_spec = QuantizationSpec( @@ -220,7 +219,6 @@ def get_16a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -421,7 +419,6 @@ def get_8a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -438,7 +435,6 @@ def get_8a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( @@ -467,7 +463,6 @@ def get_16a4w_qnn_qat_config( quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -484,7 +479,6 @@ def get_16a4w_qnn_qat_config( quant_max=7, qscheme=torch.per_tensor_symmetric, ch_axis=0, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -501,7 +495,6 @@ def get_16a4w_qnn_qat_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( @@ -551,7 +544,6 @@ def get_qat_per_channel_quant_config( act_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -566,7 +558,6 @@ def get_qat_per_channel_quant_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec(