From 291121cb04c597e1ba80244c9e9af94cd5809b1f Mon Sep 17 00:00:00 2001 From: Naveen Suda Date: Wed, 24 Sep 2025 15:04:27 -0700 Subject: [PATCH 1/2] Enable quantization for bf16 model (#14558) Summary: To save GPU memory `bfloat16` dtype is commonly used for training of LLMs. Currently, the quantizer ignores quantizing the nodes if they are not float32. This change enables quantization of bf16 nodes as well. Differential Revision: D82866443 --- backends/qualcomm/quantizer/annotators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index d584cd128ec..2649ed5b154 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -68,7 +68,7 @@ def _is_float_tensor(node: Node): or not isinstance(node.meta["val"], FakeTensor) ): return False - return node.meta["val"].dtype == torch.float32 + return node.meta["val"].dtype in (torch.bfloat16, torch.float32) def _mark_nodes_as_annotated(nodes: List[Node]): From 07e6a7317d837e7e51bcc0e32dae53019d1686db Mon Sep 17 00:00:00 2001 From: Naveen Suda Date: Wed, 24 Sep 2025 15:04:27 -0700 Subject: [PATCH 2/2] Remove reduce_range as it is not relevant for HTP Summary: `reduce_range=True` reduces the available bit width by 1, in cases where quant_min, quant_max are not provided. It was originally intended for intel `fbgemm` kernels but I don't think this quantization setting is relevant for HTP. Also, PTQ quantization config doesn't use it, so removing it in all the QAT configs. This helped improve the QAT model quality. Differential Revision: D82867843 --- backends/qualcomm/quantizer/qconfig.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 2f26cd27d31..30af923781a 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -205,7 +205,6 @@ def get_16a8w_qnn_qat_config( quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer.with_args(**extra_args), ) act_quantization_spec = QuantizationSpec( @@ -220,7 +219,6 @@ def get_16a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -421,7 +419,6 @@ def get_8a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -438,7 +435,6 @@ def get_8a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( @@ -467,7 +463,6 @@ def get_16a4w_qnn_qat_config( quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -484,7 +479,6 @@ def get_16a4w_qnn_qat_config( quant_max=7, qscheme=torch.per_tensor_symmetric, ch_axis=0, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -501,7 +495,6 @@ def get_16a4w_qnn_qat_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( @@ -551,7 +544,6 @@ def get_qat_per_channel_quant_config( act_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -566,7 +558,6 @@ def get_qat_per_channel_quant_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec(