Skip to content

Commit d95e62f

Browse files
navsudfacebook-github-bot
authored andcommitted
Use FusedMovingAvgObsFakeQuantize instead of FakeQuantize for faster QAT (#14740)
Summary: FusedMovingAvgObsFakeQuantize speeds up by fusing FakeQuantize and MovingAverageMinMaxObserver into one CUDA op. Using it should give good speedups. This change updates the QAT qconfigs to accordingly. Tested on llama model on HTP and got ~4x QAT speedup. Reviewed By: billmguo Differential Revision: D83583655
1 parent c997fe4 commit d95e62f

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

backends/qualcomm/quantizer/qconfig.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config(
200200
act_observer=MovingAverageMinMaxObserver,
201201
) -> QuantizationConfig:
202202
extra_args: Dict[str, Any] = {"eps": 2**-20}
203-
act_fake_quant_ctr = FakeQuantize.with_args(
203+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
204204
dtype=torch.int32,
205205
quant_min=torch.iinfo(torch.uint16).min,
206206
quant_max=torch.iinfo(torch.uint16).max,
@@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config(
398398
def get_8a8w_qnn_qat_config(
399399
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
400400
) -> QuantizationConfig:
401-
act_fake_quant_ctr = FakeQuantize.with_args(
401+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
402402
dtype=torch.uint8,
403403
qscheme=(
404404
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
@@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config(
458458
def get_16a4w_qnn_qat_config(
459459
act_observer=MovingAverageMinMaxObserver,
460460
) -> QuantizationConfig:
461-
act_fake_quant_ctr = FakeQuantize.with_args(
461+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
462462
dtype=torch.int32,
463463
quant_min=torch.iinfo(torch.uint16).min,
464464
quant_max=torch.iinfo(torch.uint16).max,
@@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config(
541541
# If zero_point is 128, htp can do optimizations.
542542
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
543543
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
544-
act_fake_quant_ctr = FakeQuantize.with_args(
544+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
545545
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
546546
qscheme=torch.per_tensor_symmetric,
547547
observer=act_observer,
@@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config(
553553
observer_or_fake_quant_ctr=act_fake_quant_ctr,
554554
)
555555
else:
556-
act_fake_quant_ctr = FakeQuantize.with_args(
556+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
557557
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
558558
quant_min=torch.iinfo(act_dtype).min,
559559
quant_max=torch.iinfo(act_dtype).max,

0 commit comments

Comments
 (0)