From f572b923cb572f98f72bdac672c93f7a49d40445 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Tue, 2 Sep 2025 16:40:55 -0700 Subject: [PATCH] Enforce range setting only used with per-channel Differential Revision: D81537565 --- .../qualcomm/oss_scripts/llama/range_setting_pt2e.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py index 4ef3e8cfe94..d5c68606f60 100644 --- a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py +++ b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py @@ -210,12 +210,13 @@ def make_custom_quantizer( per_channel_linear=True, act_observer=MinMaxObserver, ) - if range_setting in ("mse_weight_only", "mse_with_act_loss", "na"): - if range_setting == "na": - observer = PerChannelMinMaxObserver - elif range_setting == "mse_weight_only": + if range_setting in ("mse_weight_only", "mse_with_act_loss"): + assert ( + quant_dtype != QuantDtype.use_16a4w_block + ), "Range setting only supported for per-channel quantization" + if range_setting == "mse_weight_only": observer = PerChannelMSEObserver.with_args( - **{"steps": 200, "use_mse": True} + **{"steps": 1600, "use_mse": True} ) else: observer = PerChannelFixedQParamsObserver.with_args(**{"eps": 2**-12})