diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 30af923781a..694fab3dc6b 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, @@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config( def get_8a8w_qnn_qat_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.uint8, qscheme=( torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine @@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config( def get_16a4w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, @@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config( # If zero_point is 128, htp can do optimizations. # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, qscheme=torch.per_tensor_symmetric, observer=act_observer, @@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config( observer_or_fake_quant_ctr=act_fake_quant_ctr, ) else: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max,