From d95e62f06b0b8b51ea94dfbdd3da9fdcf6ef355e Mon Sep 17 00:00:00 2001 From: Naveen Suda Date: Thu, 2 Oct 2025 18:38:33 -0700 Subject: [PATCH] Use FusedMovingAvgObsFakeQuantize instead of FakeQuantize for faster QAT (#14740) Summary: FusedMovingAvgObsFakeQuantize speeds up by fusing FakeQuantize and MovingAverageMinMaxObserver into one CUDA op. Using it should give good speedups. This change updates the QAT qconfigs to accordingly. Tested on llama model on HTP and got ~4x QAT speedup. Reviewed By: billmguo Differential Revision: D83583655 --- backends/qualcomm/quantizer/qconfig.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 30af923781a..694fab3dc6b 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, @@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config( def get_8a8w_qnn_qat_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.uint8, qscheme=( torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine @@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config( def get_16a4w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, @@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config( # If zero_point is 128, htp can do optimizations. # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, qscheme=torch.per_tensor_symmetric, observer=act_observer, @@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config( observer_or_fake_quant_ctr=act_fake_quant_ctr, ) else: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max,