From 4b02da3d58494e0b799379b4025f3b7f482620e6 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 11 Feb 2025 18:08:29 -0800 Subject: [PATCH] Use symmetric weights for convs and int8 in the default quantizer (#8344) Summary: As titled. int8 should give better performance with Cadence kernels, since they're not improving uint8 anymore. The upcoming (quantized) convolution kernel needs symmetric weights, so we make that change as well. Reviewed By: zonglinpeng Differential Revision: D69405797 --- backends/cadence/aot/export_example.py | 36 +--------- backends/cadence/aot/ops_registrations.py | 2 +- backends/cadence/aot/quantizer/quantizer.py | 69 +++++++++++-------- .../hifi/operators/op_quantized_relu_out.cpp | 35 +--------- 4 files changed, 48 insertions(+), 94 deletions(-) diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 0345aa6e2ef..420f1760e32 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -6,11 +6,11 @@ # Example script for exporting simple models to flatbuffer +#pyre-unsafe + import logging import tempfile -import torch - from executorch.backends.cadence.aot.ops_registrations import * # noqa from typing import Any, Tuple @@ -23,13 +23,8 @@ from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer from executorch.backends.cadence.runtime import runtime from executorch.backends.cadence.runtime.executor import BundledProgramManager -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( - QuantizationConfig, - QuantizationSpec, -) from executorch.exir import ExecutorchProgramManager from torch import nn -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver from .utils import save_bpte_program, save_pte_program @@ -37,24 +32,6 @@ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) -act_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), -) - -wgt_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=MinMaxObserver, -) - def export_model( model: nn.Module, @@ -66,15 +43,8 @@ def export_model( working_dir = tempfile.mkdtemp(dir="/tmp") logging.debug(f"Created work directory {working_dir}") - qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, - None, - ) - # Instantiate the quantizer - quantizer = CadenceDefaultQuantizer(qconfig) + quantizer = CadenceDefaultQuantizer() # Convert the model converted_model = convert_pt2(model, example_inputs, quantizer) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 1dcfcbcd8db..a8dd1315846 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -576,7 +576,7 @@ def quantized_relu_per_tensor_meta( out_multiplier: int, out_shift: int, ) -> torch.Tensor: - return input.new_empty(input.size(), dtype=torch.uint8) + return input.new_empty(input.size(), dtype=input.dtype) @register_fake("cadence::fully_connected") diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 936209d5e39..d6765d2ad30 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -40,30 +40,46 @@ from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -act_qspec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, +act_qspec_asym8u = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, qscheme=torch.per_tensor_affine, is_dynamic=False, observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), ) -wgt_qspec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, +wgt_qspec_asym8u = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, qscheme=torch.per_tensor_affine, is_dynamic=False, observer_or_fake_quant_ctr=MinMaxObserver, ) +wgt_qspec_asym8s = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=MinMaxObserver, +) + bias_qspec: Optional[QuantizationSpec] = None -_default_qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, +qconfig_A8uW8u = QuantizationConfig( + act_qspec_asym8u, + act_qspec_asym8u, + wgt_qspec_asym8u, + None, +) + +qconfig_A8uW8s = QuantizationConfig( + act_qspec_asym8u, + act_qspec_asym8u, + wgt_qspec_asym8s, None, ) @@ -147,19 +163,17 @@ def get_supported_operators(cls) -> List[OperatorConfig]: return [] -def get_cadence_default_quantizer_list_with_config( - quantization_config: QuantizationConfig, -) -> List[Quantizer]: +def get_cadence_default_quantizers() -> List[Quantizer]: return [ - CadenceAtenQuantizer(AddmmPattern(), quantization_config), - CadenceAtenQuantizer(BmmPattern(), quantization_config), - CadenceAtenQuantizer(Conv1dPattern(), quantization_config), - CadenceAtenQuantizer(Conv2dPattern(), quantization_config), - CadenceAtenQuantizer(LayerNormPattern(), quantization_config), - CadenceAtenQuantizer(LinearPattern(), quantization_config), - CadenceAtenQuantizer(MatmulPattern(), quantization_config), - CadenceAtenQuantizer(ReluPattern0(), quantization_config), - CadenceAtenQuantizer(ReluPattern1(), quantization_config), + CadenceAtenQuantizer(AddmmPattern(), qconfig_A8uW8u), + CadenceAtenQuantizer(BmmPattern(), qconfig_A8uW8u), + CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8uW8s), + CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8uW8s), + CadenceAtenQuantizer(LayerNormPattern(), qconfig_A8uW8u), + CadenceAtenQuantizer(LinearPattern(), qconfig_A8uW8u), + CadenceAtenQuantizer(MatmulPattern(), qconfig_A8uW8u), + CadenceAtenQuantizer(ReluPattern0(), qconfig_A8uW8u), + CadenceAtenQuantizer(ReluPattern1(), qconfig_A8uW8u), ] @@ -178,10 +192,9 @@ class CadenceDefaultQuantizer(CadenceQuantizer): Default quantizer for Cadence backend. """ - def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None: - if qconfig is None: - qconfig = _default_qconfig - quantizers = get_cadence_default_quantizer_list_with_config(qconfig) + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = get_cadence_default_quantizers() super().__init__(quantizers) diff --git a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp index 06eb8aa3a5b..28227b7cc92 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp @@ -18,33 +18,6 @@ namespace impl { namespace HiFi { namespace native { -template -void quantized_relu_( - const Tensor& input, - const Tensor& in_zero_point, - const int64_t out_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, - Tensor& output) { - T q_zero_point = in_zero_point.const_data_ptr()[0]; - const T* __restrict__ in = input.const_data_ptr(); - T* __restrict__ out = output.mutable_data_ptr(); - - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = - -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); - - for (size_t i = 0, e = input.numel(); i < e; ++i) { - float temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; - out[i] = kernels::quantize(temp, out_scale, (int32_t)out_zero_point); - } -} - void quantized_relu_per_tensor_out( KernelRuntimeContext& ctx, const Tensor& input, @@ -68,7 +41,7 @@ void quantized_relu_per_tensor_out( _out_multiplier, _out_shift, _out_zero_point, - _out_zero_point, + 0, 255, input.numel()); @@ -85,7 +58,7 @@ void quantized_relu_per_tensor_out( _out_multiplier, _out_shift, _out_zero_point, - _out_zero_point, + -128, 127, input.numel()); @@ -107,9 +80,7 @@ void quantized_relu_per_tensor_out( const Tensor& out_multiplier, const Tensor& out_shift, Tensor& output) { - const uint8_t* p_in = input.const_data_ptr(); - uint8_t* p_out = output.mutable_data_ptr(); - uint8_t _in_zero_point = in_zero_point.const_data_ptr()[0]; + int8_t _in_zero_point = in_zero_point.const_data_ptr()[0]; int32_t _out_multiplier = out_multiplier.const_data_ptr()[0]; int32_t _out_shift = out_shift.const_data_ptr()[0];