diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index c20065654ca..94e2ae74a7a 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -38,7 +38,6 @@ HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, ObserverOrFakeQuantizeConstructor, PerChannelMinMaxObserver, PlaceholderObserver, @@ -95,24 +94,26 @@ def get_symmetric_quantization_config( **extra_args, ), ) + + # Setup quantization config for weights weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) + # Determine the right observer/fake-quant constructor if is_qat: - # TODO: qat + per channel? - weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize - elif is_per_channel: - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + # Set plain fake-quant with true min/max + weight_observer_or_fake_quant_ctr = FakeQuantize + else: + # PTQ: set min/max observer + weight_observer_or_fake_quant_ctr = ( + PerChannelMinMaxObserver if is_per_channel else MinMaxObserver + ) + + extra_args = {"eps": 2**-12} - extra_args: Dict[str, Any] = {"eps": 2**-12} - if is_qat: - if weight_qscheme == torch.per_tensor_symmetric: - extra_args["observer"] = MovingAverageMinMaxObserver - else: - extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=weight_qmin,