diff --git a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py index 4ef3e8cfe94..d5c68606f60 100644 --- a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py +++ b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py @@ -210,12 +210,13 @@ def make_custom_quantizer( per_channel_linear=True, act_observer=MinMaxObserver, ) - if range_setting in ("mse_weight_only", "mse_with_act_loss", "na"): - if range_setting == "na": - observer = PerChannelMinMaxObserver - elif range_setting == "mse_weight_only": + if range_setting in ("mse_weight_only", "mse_with_act_loss"): + assert ( + quant_dtype != QuantDtype.use_16a4w_block + ), "Range setting only supported for per-channel quantization" + if range_setting == "mse_weight_only": observer = PerChannelMSEObserver.with_args( - **{"steps": 200, "use_mse": True} + **{"steps": 1600, "use_mse": True} ) else: observer = PerChannelFixedQParamsObserver.with_args(**{"eps": 2**-12})