diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index c53070007dc..acab688ab64 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -665,6 +665,7 @@ def quantize_with_submodules( model: GraphModule, calibration_samples: list[tuple], is_qat: bool = False, + fold_quantize: bool = True, ): """Quantizes a GraphModule in a way such that conditional submodules are handled properly. @@ -680,6 +681,8 @@ def quantize_with_submodules( model with submodules, at least one sample per code path is needed. is_qat (bool): Whether to do quantization aware training or not. + fold_quantize (bool): Enables or disables constant folding when quantization + is completed. Returns: GraphModule: The quantized model. @@ -694,8 +697,11 @@ def quantize_with_submodules( prepared(*inp) for name, submodule, _ in self._get_submodules_not_handled_by_torchao(prepared): - prepared.set_submodule(name, convert_pt2e(submodule), strict=True) - converted = convert_pt2e(prepared) + prepared.set_submodule( + name, convert_pt2e(submodule, fold_quantize=fold_quantize), strict=True + ) + converted = convert_pt2e(prepared, fold_quantize=fold_quantize) + return converted diff --git a/backends/arm/test/tester/quantize.py b/backends/arm/test/tester/quantize.py index 18ecd401efe..88223d4483c 100644 --- a/backends/arm/test/tester/quantize.py +++ b/backends/arm/test/tester/quantize.py @@ -1,9 +1,9 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +from typing import Any, Optional, Sequence, Tuple import torch from executorch.backends.arm.quantizer import TOSAQuantizer @@ -14,9 +14,29 @@ ) from torch.export import export +from torchao.quantization.pt2e.quantizer import Quantizer class ArmQuantize(Quantize): + def __init__( + self, + quantizer: Optional[Quantizer] = None, + quantization_config: Optional[Any] = None, + calibrate: bool = True, + calibration_samples: Optional[Sequence[Any]] = None, + is_qat: Optional[bool] = False, + set_global: bool = True, + fold_quantize: bool = True, + ): + super().__init__( + quantizer, + quantization_config, + calibrate, + calibration_samples, + is_qat, + set_global, + ) + self.fold_quantize = fold_quantize def run( self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] @@ -31,11 +51,11 @@ def run( if self.calibration_samples is not None: converted = self.quantizer.quantize_with_submodules( - captured_graph, self.calibration_samples, bool(self.is_qat) # type: ignore + captured_graph, self.calibration_samples, bool(self.is_qat), self.fold_quantize # type: ignore ) else: converted = self.quantizer.quantize_with_submodules( - captured_graph, [inputs], bool(self.is_qat) + captured_graph, [inputs], bool(self.is_qat), self.fold_quantize ) DuplicateDynamicQuantChainPass()(converted) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 4e060919738..366687c6c1e 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -425,6 +425,7 @@ def __init__( tosa_version: Optional[str] = "1.0", tosa_extensions: Optional[List[str]] = None, epsilon: float = 2**-16, + fold_quantize: bool = True, ): if tosa_extensions is None: tosa_extensions = [] @@ -450,7 +451,9 @@ def __init__( ) if symmetric_io_quantization: quantizer.set_io(quantization_config) - quant_stage = Quantize(quantizer, quantization_config) + quant_stage = Quantize( + quantizer, quantization_config, fold_quantize=fold_quantize + ) super().__init__( module, @@ -622,6 +625,7 @@ def __init__( rtol: float = 1e-03, qtol: int = 1, epsilon: float = 2**-12, + fold_quantize: bool = True, ): super().__init__( module, @@ -644,7 +648,9 @@ def __init__( ) if symmetric_io_quantization: quantizer.set_io(quantization_config) - quant_stage = Quantize(quantizer, quantization_config) + quant_stage = Quantize( + quantizer, quantization_config, fold_quantize=fold_quantize + ) self.add_stage(self.tester.quantize, quant_stage, pos=0) @@ -720,6 +726,7 @@ def __init__( rtol: float = 1e-03, qtol: int = 1, epsilon: float = 2**-12, + fold_quantize: bool = True, ): compile_spec = common.get_u55_compile_spec( custom_path=custom_path, @@ -740,6 +747,7 @@ def __init__( rtol=rtol, qtol=qtol, epsilon=epsilon, + fold_quantize=fold_quantize, ) @@ -777,6 +785,7 @@ def __init__( rtol: float = 1e-03, qtol: int = 1, epsilon: float = 2**-12, + fold_quantize: bool = True, ): compile_spec = common.get_u85_compile_spec( custom_path=custom_path, @@ -797,6 +806,7 @@ def __init__( rtol=rtol, qtol=qtol, epsilon=epsilon, + fold_quantize=fold_quantize, ) @@ -982,6 +992,7 @@ def __init__( input_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, output_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, custom_path: Optional[str] = None, + fold_quantize: bool = True, ): tosa_spec = quantizer.tosa_spec compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path) @@ -994,7 +1005,7 @@ def __init__( use_to_edge_transform_and_lower=True, ) # TODO sort out typing - quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) # type: ignore[arg-type] + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config, fold_quantize=fold_quantize) # type: ignore[arg-type] self.add_stage(self.tester.quantize, quant_stage, pos=0) # Delete most of the pipeline @@ -1126,6 +1137,7 @@ def __init__( tosa_version: str | None = None, tosa_extensions: Optional[List[str]] = None, tosa_spec: TosaSpecification | str | None = None, + fold_quantize: bool = True, ): if tosa_spec is None: if tosa_version is None: @@ -1169,7 +1181,9 @@ def __init__( ) if symmetric_io_quantization: quantizer.set_io(quantization_config) - quant_stage = Quantize(quantizer, quantization_config) + quant_stage = Quantize( + quantizer, quantization_config, fold_quantize=fold_quantize + ) self.add_stage(self.tester.quantize, quant_stage, pos=0)