diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/annotators.py similarity index 68% rename from backends/qualcomm/quantizer/utils.py rename to backends/qualcomm/quantizer/annotators.py index dc3d2a6841b..275da567e8f 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -5,29 +5,16 @@ # LICENSE file in the root directory of this source tree. import numbers import operator -from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Sequence, Tuple import torch - -from torch import Tensor from torch._ops import OpOverload -from torch._subclasses import FakeTensor - -from torch.ao.quantization.fake_quantize import ( - default_fake_quant, - FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( - FixedQParamsObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, - UniformQuantizationObserverBase, -) +from torch._subclasses import FakeTensor +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.observer import FixedQParamsObserver from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, @@ -40,397 +27,12 @@ ) from torch.fx import Node - -class ParamObserver(UniformQuantizationObserverBase): - def __init__( - self, - ch_axis=0, - use_mse=True, - steps=100, - dtype=torch.int8, - qscheme=torch.per_channel_symmetric, - reduce_range=False, - quant_min=None, - quant_max=None, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 - is_dynamic=False, - **kwargs, - ) -> None: - super().__init__( - dtype=dtype, - qscheme=qscheme, - reduce_range=reduce_range, - quant_min=quant_min, - quant_max=quant_max, - factory_kwargs=factory_kwargs, - eps=eps, - is_dynamic=is_dynamic, - **kwargs, - ) - - factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) - self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) - self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) - self.ch_axis = ch_axis - self.use_mse = use_mse - self.steps = steps - self.calibrated = False - - def to_ch_axis(self, x): - axis_order = list(range(len(x.size()))) - axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis - return torch.flatten(x.permute(axis_order), start_dim=1) - - def mse(self, pred, expect): - loss = (pred - expect).abs().pow(2) - return self.to_ch_axis(loss).mean(1) - - def cosine(self, pred, expect): - target = torch.ones(pred.shape[self.ch_axis]) - pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) - expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) - return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) - - def loss_fn(self, x, new_min, new_max): - scale, offset = self._calculate_qparams(new_min, new_max) - x_q = torch.fake_quantize_per_channel_affine( - x, - scale.data, - offset.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) - - def line_search(self, x): - x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) - x_range = torch.max(x_min.abs(), x_max) - optimal_loss = torch.zeros_like(x_min) + 1e9 - - # check which clip range could produce smallest loss - for i in range(1, self.steps + 1): - thres = x_range / self.steps * i - current_loss = self.loss_fn(x, -thres, thres) - x_min = torch.where(current_loss < optimal_loss, -thres, x_min) - x_max = torch.where(current_loss < optimal_loss, thres, x_max) - optimal_loss = torch.min(current_loss, optimal_loss) - - return x_min, x_max - - def forward(self, x_orig): - # since params are static, one calibration is enough - if not self.calibrated: - x = x_orig.detach().to(self.min_val.dtype) - self.min_val, self.max_val = self.line_search(x) - self.calibrated = True - - # return fake-quant result for saturating outliers - scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) - return torch.fake_quantize_per_channel_affine( - x_orig, - scale.data, - zero_point.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - - @torch.jit.export - def calculate_qparams(self): - return self._calculate_qparams(self.min_val, self.max_val) - - -@dataclass(eq=True, frozen=True) -class QuantizationConfig: - input_activation: Optional[QuantizationSpec] - output_activation: Optional[QuantizationSpec] - weight: Optional[QuantizationSpec] - bias: Optional[QuantizationSpec | Callable] - - -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_default_8bit_qat_proto(act_symmetric: bool = False) -> QuantizationConfig: - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=default_fake_quant, - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver - ), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=default_fake_quant, - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_8bit_qnn_ptq_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a8w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_16bit_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, weight_dtype=torch.int8 -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config +from .qconfig import ( + get_16a16w_qnn_ptq_config, + get_16a4w_qnn_qat_config, + get_8a8w_qnn_qat_config, + QuantizationConfig, +) QUANT_ANNOTATION_KEY = "quantization_annotation" @@ -901,19 +503,34 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non scale = 1 / (q_max - q_min + 1) - # make sigmoid map to the range between 0~1 - out_act_quantization_spec = QuantizationSpec( + bias_obs_ctr = observer = FixedQParamsObserver.with_args( + scale=scale, + zero_point=0, dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - observer_or_fake_quant_ctr=FixedQParamsObserver.with_args( + ) + if quantization_config in ( + get_8a8w_qnn_qat_config(), + get_16a4w_qnn_qat_config(), + ): + bias_obs_ctr = FixedQParamsFakeQuantize.with_args( + observer=observer, scale=scale, zero_point=0, dtype=quantization_config.output_activation.dtype, qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - ), + ) + + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=bias_obs_ctr, qscheme=torch.torch.per_tensor_affine, ) @@ -1086,7 +703,7 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -1115,7 +732,7 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -1258,7 +875,7 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> _annotate_input_qspec_map( node, weight_node, - get_default_16bit_qnn_ptq_config().weight, + get_16a16w_qnn_ptq_config().weight, ) else: _annotate_input_qspec_map( diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index db82172a9e2..9d6dea8a97b 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -6,12 +6,12 @@ from typing import Sequence import torch +from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, + get_8a8w_qnn_ptq_config, QuantizationConfig, ) -from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -110,7 +110,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): # Annotate 16a8w for matmul op to get better performance quantization_config_16a8w = get_16a8w_qnn_ptq_config() # Annotate 8a8w for second input of matmul until past_kv_cache - quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: if "nn_module_stack" in node.meta: diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py new file mode 100644 index 00000000000..d556dfa4ba3 --- /dev/null +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -0,0 +1,104 @@ +import torch +from torch.ao.quantization.observer import UniformQuantizationObserverBase + + +# TODO move to torch/ao/quantization/observer.py. +class PerChannelParamObserver(UniformQuantizationObserverBase): + def __init__( + self, + ch_axis=0, + use_mse=True, + steps=100, + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.ch_axis = ch_axis + self.use_mse = use_mse + self.steps = steps + self.calibrated = False + + def to_ch_axis(self, x): + axis_order = list(range(len(x.size()))) + axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis + return torch.flatten(x.permute(axis_order), start_dim=1) + + def mse(self, pred, expect): + loss = (pred - expect).abs().pow(2) + return self.to_ch_axis(loss).mean(1) + + def cosine(self, pred, expect): + target = torch.ones(pred.shape[self.ch_axis]) + pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) + expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) + return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) + + def loss_fn(self, x, new_min, new_max): + scale, offset = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_channel_affine( + x, + scale.data, + offset.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) + + def line_search(self, x): + x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) + x_range = torch.max(x_min.abs(), x_max) + optimal_loss = torch.zeros_like(x_min) + 1e9 + + # check which clip range could produce smallest loss + for i in range(1, self.steps + 1): + thres = x_range / self.steps * i + current_loss = self.loss_fn(x, -thres, thres) + x_min = torch.where(current_loss < optimal_loss, -thres, x_min) + x_max = torch.where(current_loss < optimal_loss, thres, x_max) + optimal_loss = torch.min(current_loss, optimal_loss) + + return x_min, x_max + + def forward(self, x_orig): + # since params are static, one calibration is enough + if not self.calibrated: + x = x_orig.detach().to(self.min_val.dtype) + self.min_val, self.max_val = self.line_search(x) + self.calibrated = True + + # return fake-quant result for saturating outliers + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + return torch.fake_quantize_per_channel_affine( + x_orig, + scale.data, + zero_point.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + + @torch.jit.export + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py new file mode 100644 index 00000000000..e07ca24d90f --- /dev/null +++ b/backends/qualcomm/quantizer/qconfig.py @@ -0,0 +1,464 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, +) +from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec +from torch.fx import Node + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec | Callable] + + +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_8a8w_qnn_ptq_config( + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a16w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int8, + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# TODO merge qat and ptq to a fucntion, and use a bool flag to control it +def get_8a8w_qnn_qat_config( + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a4w_qnn_qat_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_qat_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int8, + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer=MovingAveragePerChannelMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 9e5aaf782a7..50ed07788fd 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Callable, Dict, Optional, Sequence, Set +from typing import Callable, Optional, Sequence, Set import torch from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum @@ -22,14 +22,17 @@ from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule -from .utils import ( +from .annotators import OP_ANNOTATOR + +from .qconfig import ( + get_16a16w_qnn_ptq_config, get_16a4w_qnn_ptq_config, + get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, - get_default_8bit_qat_proto, - get_default_8bit_qnn_ptq_config, + get_8a8w_qnn_ptq_config, + get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, - OP_ANNOTATOR, + get_qat_per_channel_quant_config, QuantizationConfig, ) @@ -38,9 +41,10 @@ "QuantDtype", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", - "get_default_16bit_qnn_ptq_config", - "get_default_8bit_qnn_ptq_config", - "get_default_8bit_qat_proto", + "get_16a16w_qnn_ptq_config", + "get_8a8w_qnn_ptq_config", + "get_8a8w_qnn_qat_config", + "get_16a4w_qnn_qat_config", ] @@ -51,8 +55,39 @@ class QuantDtype(IntEnum): """ use_16a16w = 0 - use_16a4w = 1 - use_8a8w = 2 + use_16a8w = 1 + use_16a4w = 2 + use_8a8w = 3 + + +quant_config_dict = { + # PTQ + (QuantDtype.use_16a16w, False): ( + get_16a16w_qnn_ptq_config, + get_ptq_per_channel_quant_config(torch.uint16, torch.int16), + ), + (QuantDtype.use_16a8w, False): ( + get_16a8w_qnn_ptq_config, + get_ptq_per_channel_quant_config(torch.uint16, torch.int8), + ), + (QuantDtype.use_16a4w, False): ( + get_16a4w_qnn_ptq_config, + get_ptq_per_channel_quant_config(torch.uint16, "int4"), + ), + (QuantDtype.use_8a8w, False): ( + get_8a8w_qnn_ptq_config, + get_ptq_per_channel_quant_config(), + ), + # QAT, + (QuantDtype.use_16a4w, True): ( + get_16a4w_qnn_qat_config, + get_qat_per_channel_quant_config(torch.uint16, "int4"), + ), + (QuantDtype.use_8a8w, True): ( + get_8a8w_qnn_qat_config, + get_qat_per_channel_quant_config(), + ), +} class QnnQuantizer(Quantizer): @@ -60,23 +95,17 @@ class QnnQuantizer(Quantizer): def __init__(self): super().__init__() - self.bit8_quant_config: QuantizationConfig = get_default_8bit_qnn_ptq_config() - self.bit16_quant_config: QuantizationConfig = get_default_16bit_qnn_ptq_config() + self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() - self.bit8_quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() - self.bit16_quant_ops: Set[OpOverload] = set() + self.is_qat = False + self.quant_dtype = QuantDtype.use_8a8w + self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() + self.per_channel_quant_config = get_ptq_per_channel_quant_config() + self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() - self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() - # the weight quantized for activation 8 bits and 16 bits - self.per_channel_weight_dtype: Dict = { - "8bit_act": torch.int8, - "16bit_act": torch.int16, - } - self.per_channel_quant_config = None - def _annotate(self, gm: GraphModule) -> None: for node in gm.graph.nodes: if node.name in self.discard_nodes: @@ -94,29 +123,16 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig """ Priority: 1. is one of use_per_channel_weight_quant_ops - 2. int8 / int16 config + 2. quant config """ if isinstance(op, str): return if op in self.use_per_channel_weight_quant_ops: - if self.per_channel_quant_config is None: - if op in self.bit16_quant_ops: - return get_ptq_per_channel_quant_config( - act_dtype=torch.uint16, - weight_dtype=self.per_channel_weight_dtype["16bit_act"], - ) - return get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=self.per_channel_weight_dtype["8bit_act"], - ) return self.per_channel_quant_config - if op in self.bit8_quant_ops: - return self.bit8_quant_config - - if op in self.bit16_quant_ops: - return self.bit16_quant_config + if op in self.quant_ops: + return self.quant_config print(f"No quant config is implemented for op, {op}") @@ -126,15 +142,6 @@ def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: boo else: self.use_per_channel_weight_quant_ops.difference_update(ops) - def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: - for op in ops: - assert ( - op in self.SUPPORTED_OPS - ), f"The annotation of op {op} is not implemented" - - self.bit8_quant_ops.remove(op) - self.bit16_quant_ops.add(op) - def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] ) -> None: @@ -145,10 +152,7 @@ def add_discard_nodes(self, nodes: Sequence[str]) -> None: def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: for op in ops: - if op in self.bit8_quant_ops: - self.bit8_quant_ops.remove(op) - if op in self.bit16_quant_ops: - self.bit16_quant_ops.remove(op) + self.quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) @@ -159,24 +163,22 @@ def annotate(self, model: GraphModule) -> GraphModule: def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS - def set_bit16_op_quant_config( - self, quantization_config: QuantizationConfig - ) -> None: - self.bit16_quant_config = quantization_config - - def set_bit8_op_quant_config(self, quantization_config: QuantizationConfig) -> None: - self.bit8_quant_config = quantization_config - - def set_per_channel_weight_dtype( - self, - weight_dtype_for_8bit_act: Optional[str | torch.dtype] = None, - weight_dtype_for_16bit_act: Optional[str | torch.dtype] = None, + def set_quant_config( + self, quant_dtype: QuantDtype, is_qat=False, act_observer=None ) -> None: - # TODO accept temporally str type. Remove it when torch support torch.int4 dtype - if weight_dtype_for_8bit_act: - self.per_channel_weight_dtype["8bit_act"] = weight_dtype_for_8bit_act - if weight_dtype_for_16bit_act: - self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act + self.quant_dtype = quant_dtype + self.is_qat = is_qat + if (quant_dtype, is_qat) not in quant_config_dict: + raise RuntimeError( + f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" + ) + + quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ + (quant_dtype, is_qat) + ] + self.quant_config = ( + quant_config_fuc(act_observer) if act_observer else quant_config_fuc() + ) def set_per_channel_conv_quant(self, enable: bool) -> None: conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4bfdedcd4b4..64b0490d461 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -698,6 +698,17 @@ def test_qnn_backend_16a4w_conv2d(self): ) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_16a4w_conv2d_qat(self): + modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + prepared = self.get_prepared_qat_module(module, sample_input) + converted = self.get_converted_sgd_trained_module( + module, prepared, sample_input + ) + self.lower_module_and_test_output(converted, sample_input) + def test_qnn_backend_16a4w_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),) @@ -1063,18 +1074,8 @@ def test_qnn_backend_linear_qat(self): """ module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) - - module = self.get_prepared_qat_module(module, sample_input) - - optimizer = torch.optim.SGD(module.parameters(), lr=0.1) - criterion = torch.nn.CrossEntropyLoss() - output = module(*sample_input) - loss = criterion(output, module(*sample_input)) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - module = torch.ao.quantization.quantize_pt2e.convert_pt2e(module) + prepared = self.get_prepared_qat_module(module, sample_input) + module = self.get_converted_sgd_trained_module(module, prepared, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log_softmax(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 114493c7d2f..d2a3e7c2417 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -17,13 +17,7 @@ from executorch import exir from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import ( - get_16a4w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, - get_default_8bit_qat_proto, - QnnQuantizer, - QuantDtype, -) +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -405,18 +399,7 @@ def get_qdq_module( quantizer.add_custom_quant_annotations(custom_quant_annotations) quantizer.set_per_channel_conv_quant(is_conv_per_channel) quantizer.set_per_channel_linear_quant(is_linear_per_channel) - - if quant_dtype == QuantDtype.use_8a8w: - pass # default setting - elif quant_dtype == QuantDtype.use_16a16w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) - elif quant_dtype == QuantDtype.use_16a4w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) - quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - else: - raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + quantizer.set_quant_config(quant_dtype) prepared = prepare_pt2e(m, quantizer) prepared(*inputs) @@ -448,13 +431,28 @@ def get_prepared_qat_module( quantizer.set_per_channel_linear_quant(is_linear_per_channel) if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_bit8_op_quant_config(get_default_8bit_qat_proto()) + quantizer.set_quant_config(quant_dtype, is_qat=True) else: raise RuntimeError("Shuld not be here") prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) + def get_converted_sgd_trained_module( + self, + ori_module: torch.nn.Module, + prepared: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> torch.fx.GraphModule: + optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) + criterion = torch.nn.CrossEntropyLoss() + output = prepared(*inputs) + loss = criterion(output, ori_module(*inputs)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) + def split_graph(self, graph_module: torch.fx.GraphModule, division: int): class SplitGraph(ExportPass): """ diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 30e04750b58..0eff77fdce8 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -331,7 +331,7 @@ def _transform( def capture_program( module: torch.nn.Module, inputs: Tuple[torch.Tensor], - custom_pass_config: Set[str] = None, + custom_pass_config: Set[str] = frozenset(), ) -> exir.ExirExportedProgram: ep = torch.export.export(module, inputs) decomposed_ep = ep.run_decompositions(get_decomp_table()) diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 30fe74f35b5..0e2c695ab34 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -10,15 +10,19 @@ import numpy as np import torch - -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.quantizer.utils import ( - _derived_bias_quant_spec, - MovingAverageMinMaxObserver, - ParamObserver, +from executorch.backends.qualcomm.quantizer.annotators import ( QuantizationConfig, QuantizationSpec, ) +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) +from executorch.backends.qualcomm.quantizer.qconfig import ( + _derived_bias_quant_spec, + MovingAverageMinMaxObserver, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_EXPAND_BROADCAST_SHAPE, ) @@ -87,7 +91,7 @@ def main(args): quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=ParamObserver.with_args( + observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( **{"steps": 200, "use_mse": True} ), ) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 04569df5c92..9f7198a3447 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -56,12 +56,12 @@ def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: This function is specific for matmul op 16a8w. """ + from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, + get_8a8w_qnn_ptq_config, QuantizationConfig, ) - from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -119,7 +119,7 @@ def annotate_single_in_single_out( ) def annotate_matmul_input1(node: Node): - quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) while isinstance(node, Node) and node.op == "call_function": if node.target in [ torch.ops.aten.permute.default, @@ -142,11 +142,11 @@ def annotate_matmul_input1(node: Node): def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: + from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_ptq_per_channel_quant_config, QuantizationConfig, ) - from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import Node diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 2e49a2344b8..56169e39a2e 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -4,10 +4,7 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import ( - get_default_8bit_qnn_ptq_config, - QnnQuantizer, -) +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -64,8 +61,6 @@ def main() -> None: # Get quantizer quantizer = QnnQuantizer() - quant_config = get_default_8bit_qnn_ptq_config() - quantizer.set_bit8_op_quant_config(quant_config) # Typical pytorch 2.0 quantization flow m = torch.export.export(model.eval(), example_inputs).module() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 06225be2d1c..100008e91ca 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -16,13 +16,7 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import ( - get_16a4w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, - QnnQuantizer, - QuantDtype, -) +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -37,7 +31,11 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) class SimpleADB: @@ -187,36 +185,58 @@ def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): callback() +def ptq_calibrate(captured_model, quantizer, dataset): + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing(PTQ) the model...") + # calibration + if callable(dataset): + dataset(annotated_model) + else: + for data in dataset: + annotated_model(*data) + return annotated_model + + +def qat_train(ori_model, captured_model, quantizer, dataset): + data, targets = dataset + annotated_model = torch.ao.quantization.move_exported_model_to_train( + prepare_qat_pt2e(captured_model, quantizer) + ) + optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) + criterion = torch.nn.CrossEntropyLoss() + for i, d in enumerate(data): + print(f"Epoch {i}") + if i > 3: + # Freeze quantizer parameters + annotated_model.apply(torch.ao.quantization.disable_observer) + if i > 2: + # Freeze batch norm mean and variance estimates + annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + + output = annotated_model(*d) + loss = criterion(output, targets[i]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return torch.ao.quantization.quantize_pt2e.convert_pt2e( + torch.ao.quantization.move_exported_model_to_eval(annotated_model) + ) + + def make_quantizer( - quant_dtype: Optional[QuantDtype], + quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, custom_annotations=(), per_channel_conv=True, per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, + is_qat=False, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) quantizer.set_per_channel_conv_quant(per_channel_conv) quantizer.set_per_channel_linear_quant(per_channel_linear) - - if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_bit8_op_quant_config( - get_default_8bit_qnn_ptq_config(act_observer=act_observer) - ) - elif quant_dtype == QuantDtype.use_16a16w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_default_16bit_qnn_ptq_config(act_observer=act_observer) - ) - elif quant_dtype == QuantDtype.use_16a4w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_16a4w_qnn_ptq_config(act_observer=act_observer) - ) - quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - else: - raise AssertionError(f"No support for QuantDtype {quant_dtype}.") - + quantizer.set_quant_config(quant_dtype, is_qat, act_observer) return quantizer @@ -235,18 +255,22 @@ def build_executorch_binary( metadata=None, dump_intermediate_outputs=False, custom_pass_config=frozenset(), + qat_training_data=None, ): if quant_dtype is not None: - quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) captured_model = torch.export.export(model, inputs).module() - annotated_model = prepare_pt2e(captured_model, quantizer) - print("Quantizing the model...") - # calibration - if callable(dataset): - dataset(annotated_model) + if qat_training_data: + quantizer = custom_quantizer or make_quantizer( + quant_dtype=quant_dtype, is_qat=True + ) + # qat training + annotated_model = qat_train( + model, captured_model, quantizer, qat_training_data + ) else: - for data in dataset: - annotated_model(*data) + quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) + # ptq calibration + annotated_model = ptq_calibrate(captured_model, quantizer, dataset) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs, custom_pass_config) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index fd368d73f1f..ba281864a9f 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -144,6 +144,7 @@ def check_embedding_byte_registered(): def get_qnn_quantizer( pt2e_quantize: str, quantization_mode: Optional[str] = None, + is_qat: bool = False, ): try: from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] @@ -152,8 +153,6 @@ def get_qnn_quantizer( # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` from executorch.backends.qualcomm.quantizer.quantizer import ( - get_16a4w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, QnnQuantizer, QuantDtype, ) @@ -175,6 +174,7 @@ def get_qnn_quantizer( custom_annotations = () if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] + qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) elif quant_config == "16a16w": quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w @@ -184,20 +184,17 @@ def get_qnn_quantizer( ) qnn_quantizer.set_per_channel_conv_quant(enable=False) qnn_quantizer.set_per_channel_linear_quant(enable=False) - qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config( - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + qnn_quantizer.set_quant_config( + quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver ) elif quant_config == "16a4w": # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. quant_dtype = QuantDtype.use_16a4w - qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config( - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + qnn_quantizer.set_quant_config( + quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver ) - qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. custom_annotations = (custom_annotate_llama_matmul_16a8w,) else: