diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 9fa15568cc4..92ef5be5781 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -145,6 +145,86 @@ def get_symmetric_quantization_config( return quantization_config +@functools.lru_cache +def get_symmetric_a16w8_quantization_config( + is_per_channel: bool = True, + is_qat: bool = False, + is_dynamic: bool = False, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + """ + 16A8W quantization config: 16-bit activations, 8-bit weights. + + This configuration provides better accuracy than 8A8W while maintaining + reasonable memory usage through 8-bit weights. + + Args: + is_per_channel: Whether to use per-channel quantization for weights + is_qat: Whether this is for Quantization Aware Training + is_dynamic: Whether to use dynamic quantization + weight_qmin: Minimum quantization value for weights + weight_qmax: Maximum quantization value for weights + + Returns: + QuantizationConfig with 16-bit activations and 8-bit weights + """ + extra_args: Dict[str, Any] = {"eps": 2**-12} + + # Setup observer/fake-quant for 16-bit activations + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + # HistogramObserver works well for 16-bit range + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + # 16-bit activation quantization spec + act_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min, # -32768 + quant_max=torch.iinfo(torch.int16).max, # 32767 + qscheme=torch.per_tensor_symmetric, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + + # Instead of reconstructing quantization_config, just clone and update as needed + # Clone the quantization_config from get_symmetric_quantization_config and update activation spec + base_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel, + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + # Replace activation quantization spec with 16-bit version + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, # 16-bit input activations + None, + base_config.weight, # 8-bit weights from base config + None, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, # 16-bit input activations + act_quantization_spec, # 16-bit output activations + base_config.weight, # 8-bit weights from base config + None, + ) + return quantization_config + + NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. A Node filter is a function that takes a Node and returns whether the node should be annotated or not.