From ea4eeef920810d5ca2304490382c7704b801c656 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 9 Oct 2025 17:36:44 -0700 Subject: [PATCH 1/3] Make HQQ default PTQ quantization in ExecuTorch Differential Revision: D84020605 Pull Request resolved: https://github.com/pytorch/executorch/pull/14834 (cherry picked from commit d39992f6d971e3548ee3ffe943d9224f63979126) --- examples/models/llama/export_llama_lib.py | 22 +++++++++---- .../llama/source_transformation/quantize.py | 33 ++++++++++++++----- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7192204a141..2307057f5ba 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1221,12 +1221,15 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": else: raise ValueError(f"{modelname} is not a valid Llama model.") - model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( - EagerModelFactory.create_model( - module_name, - model_class_name, - llm_config=llm_config, - ) + ( + model, + example_inputs, + example_kwarg_inputs, + dynamic_shapes, + ) = EagerModelFactory.create_model( + module_name, + model_class_name, + llm_config=llm_config, ) # Convert dtype override string to actual type. dtype_override = DType[llm_config.model.dtype_override.value] @@ -1303,6 +1306,7 @@ def _get_source_transforms( # noqa preq_group_size: Optional[int] = None, preq_embedding_quantize: Optional[str] = None, local_global_attention: Optional[List[int]] = None, + quantize_with_hqq: bool = True, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. @@ -1372,7 +1376,10 @@ def _get_source_transforms( # noqa """ transforms.append( get_quant_embedding_transform( - embedding_quantize, use_shared_embedding, checkpoint_dtype + embedding_quantize, + use_shared_embedding, + checkpoint_dtype, + quantize_with_hqq, ) ) @@ -1403,6 +1410,7 @@ def _get_source_transforms( # noqa calibration_tasks=calibration_tasks, calibration_limit=calibration_limit, calibration_seq_length=calibration_seq_length, + quantize_with_hqq=quantize_with_hqq, ) ) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 7cb65833f98..9e49f9e4e15 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -49,6 +49,7 @@ def quantize( # noqa C901 blocksize: int = 128, tokenizer_path: Optional[Path] = None, verbose: bool = False, + quantize_with_hqq: bool = True, ) -> torch.nn.Module: """ Quantizes a model by converting all weights to int8. @@ -119,7 +120,6 @@ def quantize( # noqa C901 from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, - MappingType, quantize_, ) from torchao.utils import unwrap_tensor_subclass @@ -134,9 +134,12 @@ def quantize( # noqa C901 weight_granularity=( PerAxis(0) if group_size == 0 else PerGroup(group_size) ), - weight_mapping_type=MappingType.SYMMETRIC, # pyre-ignore[6] intx_packing_format="opaque_torchao_auto", + # pyre-ignore[6] + intx_choose_qparams_algorithm=( + "hqq_scale_only" if quantize_with_hqq else "affine" + ), ), ) model = unwrap_tensor_subclass(model) @@ -170,6 +173,10 @@ def filter_fn(m, fqn): # pyre-ignore[16] weight_dtype=torch.int4, weight_granularity=PerGroup(group_size), + # pyre-ignore[6] + intx_choose_qparams_algorithm=( + "hqq_scale_only" if quantize_with_hqq else "affine" + ), ), filter_fn=filter_fn, ) @@ -191,6 +198,10 @@ def filter_fn(m, fqn): # pyre-ignore[16] weight_dtype=torch.int4, granularity=PerGroup(q_group_size), + # pyre-ignore[6] + intx_choose_qparams_algorithm=( + "hqq_scale_only" if quantize_with_hqq else "affine" + ), ) quantize_(model, q_config) model = unwrap_tensor_subclass(model) @@ -580,6 +591,7 @@ def __init__( group_size: Optional[int] = None, packed=False, precision: Optional[torch.dtype] = None, + quantize_with_hqq: bool = True, ): if isinstance(packed, str): packed = packed == "True" @@ -592,15 +604,12 @@ def __init__( self.precision = precision if (bitwidth not in [2, 4]) and packed: raise RuntimeError("pack only works with bitsize 2, 4") + self.quantize_with_hqq = quantize_with_hqq @torch.no_grad() def create_quantized_state_dict(self, packed=False) -> Dict: from torchao.quantization.granularity import PerAxis, PerGroup - from torchao.quantization.quant_api import ( - IntxWeightOnlyConfig, - MappingType, - quantize_, - ) + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ cur_state_dict = self.mod.state_dict() @@ -627,7 +636,10 @@ def create_quantized_state_dict(self, packed=False) -> Dict: if (self.group_size is None or self.group_size == 0) else PerGroup(self.group_size) ), - mapping_type=MappingType.SYMMETRIC, + # pyre-ignore[6] + intx_choose_qparams_algorithm=( + "hqq_scale_only" if self.quantize_with_hqq else "affine" + ), ) quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding)) weight = tmp_model.weight.qdata # pyre-ignore[16] @@ -765,6 +777,7 @@ def get_quant_embedding_transform( embedding_quantize: str, use_shared_embedding: bool = False, dtype_override: Optional[DType] = None, + quantize_with_hqq: bool = True, ): if embedding_quantize.startswith("torchao:"): from torchao.prototype.quantization.embedding.api import ( @@ -825,6 +838,7 @@ def _torchao_embedding_quantizer(model): group_size=group_size, packed=(bitwidth in [2, 4]), precision=torch_dtype, + quantize_with_hqq=quantize_with_hqq, ).quantized_model() @@ -838,6 +852,7 @@ def get_quant_weight_transform( calibration_tasks: Optional[list] = None, calibration_limit: Optional[int] = None, calibration_seq_length: Optional[int] = None, + quantize_with_hqq: bool = True, ): return partial( quantize, @@ -850,6 +865,7 @@ def get_quant_weight_transform( calibration_limit=calibration_limit, calibration_seq_length=calibration_seq_length, tokenizer_path=(Path(path) if (path := tokenizer_path) is not None else None), + quantize_with_hqq=quantize_with_hqq, ) @@ -877,7 +893,6 @@ def _load_torchao_aten_lib(libname): def set_8da4w_computation_dtype( module: nn.Module, computation_dtype: torch.dtype ) -> nn.Module: - from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None: From 5cd1731ee19cdf97105084caf34041b0efb59965 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 10 Oct 2025 10:08:30 -0700 Subject: [PATCH 2/3] bump torchao commit to release/0.14 --- third-party/ao | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third-party/ao b/third-party/ao index b99904b34c0..c40417e1996 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit b99904b34c0fd98f8a63ec57cbc1dc4993f74793 +Subproject commit c40417e1996a560a17001d663c36ed622007b52e From 720cbe217e496cbce994dd3f6105cc808b0ae458 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 10 Oct 2025 12:38:45 -0700 Subject: [PATCH 3/3] up --- .ci/scripts/test_llava.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/scripts/test_llava.sh b/.ci/scripts/test_llava.sh index afed3c54123..f3e87b0b2b6 100644 --- a/.ci/scripts/test_llava.sh +++ b/.ci/scripts/test_llava.sh @@ -149,7 +149,7 @@ run_and_verify() { # verify result.txt RESULT=$(cat result.txt) - EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with" + EXPECTED_PREFIX="ASSISTANT: image is a black and white photo of a basketball game in progress" if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}"