diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index c592ad64da6..e34630538d0 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum, unique + from typing import Sequence import torch @@ -17,7 +17,6 @@ get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, - get_qat_per_channel_quant_config, QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -32,36 +31,6 @@ ) -def annotate_down_proj( - gm: torch.fx.GraphModule, quantization_config: QuantizationConfig -): - for node in gm.graph.nodes: - if ( - node.target == torch.ops.aten.conv2d.default - and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"]) - and node.args[0].target == torch.ops.aten.mul.Tensor - ): - input_qspec_map = {} - input_qspec_map[node.args[0]] = quantization_config.input_activation - input_qspec_map[node.args[1]] = quantization_config.weight - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - -@unique -class StaticLLMQuantConfig(Enum): - """ - Layer namespace configuration for Qualcomm's static LLaMA quantization. - """ - - wq_sha = "wq_sha" # Query weight (single head) - wk_sha = "wk_sha" # Key weight (single head) - wv_sha = "wv_sha" # Value weight (single head) - - def annotate_eurobert(gm: torch.fx.GraphModule): """ QNN does not support int32 -> signed 16bit quant @@ -123,49 +92,6 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule): break -def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None: - """ - This function is for static LLM models. - This function will annotate the last conv(linear), which is the lm_head, as 16a8w. - """ - - def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: - input_qspec_map = {} - input_act = node.args[0] - input_spec = quantization_config.input_activation - input_qspec_map[input_act] = input_spec - - weight = node.args[1] - input_qspec_map[weight] = quantization_config.weight - - if len(node.args) > 2 and isinstance(node.args[2], Node): - input_qspec_map[node.args[2]] = quantization_config.bias(node) - - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - if is_qat: - quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - else: - quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: - if "nn_module_stack" in node.meta: - module_values_list = list(node.meta["nn_module_stack"].values()) - full_qualified_name = module_values_list[-1][0] - if full_qualified_name == "output.conv": - annotate_conv2d( - node, quantization_config=quantization_config_16a8w_per_channel - ) - - def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): for node in gm.graph.nodes: if node.op == "output": @@ -200,48 +126,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): ) -def annotate_qkv_proj_sha( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - qkv_tags: set[StaticLLMQuantConfig], -): - """ - Annotates QKV projection layers in a GraphModule for quantization, - specifically layers defined in StaticLLMQuantConfig. - - Args: - qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers - (e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in - StaticLLMQuantConfig are allowed. - - Raises: - ValueError: If any tag in `qkv_tags` is not among the allowed enum members. - """ - - # Get all valid tags from the StaticLLMQuantConfig enum - allowed_tags = set(StaticLLMQuantConfig) - invalid_tags = qkv_tags - allowed_tags - if invalid_tags: - raise ValueError( - f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}" - ) - - for node in gm.graph.nodes: - if node.target == torch.ops.aten.conv2d.default and any( - tag.value in node.meta["stack_trace"] for tag in qkv_tags - ): - input_qspec_map = {} - input_qspec_map[node.args[0]] = quantization_config.input_activation - input_qspec_map[node.args[1]] = quantization_config.weight - if len(node.args) > 2 and isinstance(node.args[2], Node): - input_qspec_map[node.args[2]] = quantization_config.bias(node) - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - def annotate_kv_8bit( # noqa: C901 gm: torch.fx.GraphModule, is_qat=False, @@ -262,7 +146,6 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): input_act = node.args[0] input_spec = quantization_config.input_activation input_qspec_map[input_act] = input_spec - input_act1 = node.args[1] input_spec1 = quantization_config.weight input_qspec_map[input_act1] = input_spec1 diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index e22d5b30fa7..593eb77961a 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -136,6 +136,61 @@ def get_8a8w_qnn_ptq_config( return quantization_config +def get_8a4w_qnn_ptq_config( + act_symmetric: bool = True, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + # 4 bits quantization only supports specific ops. def get_16a4w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, diff --git a/backends/qualcomm/quantizer/quant_recipe.py b/backends/qualcomm/quantizer/quant_recipe.py new file mode 100644 index 00000000000..92b9757e1fb --- /dev/null +++ b/backends/qualcomm/quantizer/quant_recipe.py @@ -0,0 +1,402 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from abc import ABC, abstractmethod +from enum import IntEnum, unique +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, + QuantizationConfig, +) +from tabulate import tabulate +from torch._ops import OpOverload +from torchao.quantization.pt2e import UniformQuantizationObserverBase + +from .annotators import OP_ANNOTATOR + + +def extract_node_metadata_mapping(node: torch.fx.Node): + deepest_module = None + + if node.op == "call_function" and "nn_module_stack" in node.meta: + deepest_module = list(node.meta["nn_module_stack"].values())[-1][0] + + return deepest_module + + +@unique +class QuantGranularity(IntEnum): + """ + Defines the quantization granularity levels: + - PER_TENSOR: single scale offset for entire tensor. + - PER_CHANNEL: independent scale/offset per channel within tensor. + - PER_BLOCK: independent scale/offset per block within tensor. + """ + + PER_TENSOR = 0 + PER_CHANNEL = 1 + PER_BLOCK = 2 + + +class QuantizationStrategy(ABC): + """ + Abstract base class for strategies that assign quantization config to FX graph nodes. + + Each strategy defines how to match nodes (e.g., by operator target, module stack pattern) + and provides a corresponding quantization config when a match occurs. + + Attributes: + quant_dtype (QuantDtype): Data type for quantization (e.g., 16a8w, 16a4w). + is_qat (bool): Whether the strategy applies QAT (True) or PTQ (False). + granularity (QuantGranularity): Quantization granularity (PER_TENSOR, PER_CHANNEL, PER_BLOCK). + act_observer (UniformQuantizationObserverBase): Observer class for activation quantization. + extra_kwargs (Dict): Additional configuration parameters (e.g., block size). + note (str): Developer notes or comments. + priority (int): Priority for resolving conflicts among multiple strategies. + + Abstract Methods: + _matches(node): Return True if the node matches this strategy's criteria. + """ + + def __init__( + self, + quant_dtype: QuantDtype, + is_qat: bool, + granularity: QuantGranularity, + act_observer: UniformQuantizationObserverBase, + extra_kwargs: Dict, + note: str, + priority: int, + ): + self.quant_dtype = quant_dtype + self.is_qat = is_qat + self.granularity = granularity + self.act_observer = act_observer + self.extra_kwargs = extra_kwargs + self.note = note + self.priority = priority + + self.quant_config = ModuleQConfig( + quant_dtype=self.quant_dtype, + is_qat=self.is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + act_observer=self.act_observer, + ) + + @abstractmethod + def _matches(self, node: torch.fx.Node) -> bool: + pass + + def get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: + op: OpOverload = node.target + + if not self._matches(node): + return None + + if self.granularity == QuantGranularity.PER_TENSOR: + return self.quant_config.quant_config + elif self.granularity == QuantGranularity.PER_CHANNEL: + ch_axis = self.quant_config.use_per_channel_weight_quant_ops.get(op) + assert ( + ch_axis is not None + and len(self.quant_config.per_channel_quant_config_list) > ch_axis + ), f"Unsupported per channel quantization axis: {ch_axis}, please increase the range of per_channel_quant_config_list" + return self.quant_config.per_channel_quant_config_list[ch_axis] + elif self.granularity == QuantGranularity.PER_BLOCK: + ch_axis = self.quant_config.op_axis_dict.get(op) + assert ( + ch_axis is not None + and len(self.quant_config.per_block_quant_config_list) > ch_axis + ), f"Unsupported per block quantization axis: {ch_axis}, please increase the range of per_block_quant_config_list" + config = self.quant_config.per_block_quant_config_list[ch_axis] + config.block_size = self.extra_kwargs["block_size"] + return config + else: + raise ValueError( + f"Unsupported quantization granularity: {self.granularity}. " + f"Supported values: {[granularity.name for granularity in QuantGranularity]}" + ) + + +class ByNodeTarget(QuantizationStrategy): + """ + Strategy that assigns quantization config to nodes based on their op target. + Useful for applying quantization to specific operations such as `aten.conv2d` or `aten.linear`. + + Attributes: + targets (Set[OpOverload]): Set of op overloads to match against node targets. + """ + + def __init__( + self, + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + targets: Set[OpOverload], + ): + super().__init__( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + ) + self.targets = targets + + def _matches(self, node: torch.fx.Node) -> bool: + # Matching: A node matches if its `node.target` is in the `targets` set. + return node.target in self.targets + + +class ByNameRegex(QuantizationStrategy): + """ + Strategy that assigns quantization config to nodes whose module stack matches given regex patterns. + Useful for targeting layers by name patterns (e.g., "layers.[0-3].feed_forward" or "layers.*.attention") in the module hierarchy. + + Attributes: + patterns (Set[str]): Set of regex patterns to match against module stack paths. + """ + + def __init__( + self, + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + patterns: Set[str], + ): + super().__init__( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + ) + self.patterns = patterns + + def _matches(self, node: torch.fx.Node) -> bool: + # Matching: A node matches if its `nn_module_stack` metadata contains a module path that matches any regex pattern. + if node.op == "call_function" and "nn_module_stack" in node.meta: + for module_stack, _ in list(node.meta["nn_module_stack"].values())[::-1]: + if module_stack and any( + re.search(p, module_stack) for p in self.patterns + ): + return True + return False + + +class QuantRecipe: + """ + A QuantRecipe builder for defining quantization strategies to an FX GraphModule. + + QuantRecipe manages a collection of strategies (e.g., by operator target or regex pattern) + and applies them to nodes in an FX graph to produce fine-grained quantization annotations. + + Attributes: + verbose (bool): If True, prints a summary after annotation. + custom_quant_annotations (Sequence[Callable]): Custom annotation functions applied after strategies. + + _strategies (List[QuantizationStrategy]): Registered quantization strategies. + _pending_annotate_nodes (Dict[torch.fx.Node, Tuple[QuantizationConfig, QuantizationStrategy]]): + Internal mapping of nodes to their resolved quantization config and strategy. + """ + + def __init__( + self, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + extra_kwargs: Optional[dict] = None, + verbose: bool = False, + ): + """ + Initialize a QuantRecipe with a default quantization strategy. + + Args: + quant_dtype (QuantDtype): Data type for quantization (e.g., int8, int4). + is_qat (bool): Whether to apply QAT (True) or PTQ (False). + act_observer (UniformQuantizationObserverBase): Observer class for activation quantization. + granularity (QuantGranularity): Quantization granularity (PER_TENSOR, PER_CHANNEL, PER_BLOCK). + note (str): Optional description for the default strategy. + extra_kwargs (dict, optional): Additional parameters (e.g., block size, group size). + verbose (bool): If True, prints a summary table after annotation. + """ + + self.verbose = verbose + self.custom_quant_annotations: Sequence[Callable] = [] + + self._strategies: List[QuantizationStrategy] = [] + self._pending_annotate_nodes: Dict[ + torch.fx.Node, Tuple[QuantizationConfig, QuantizationStrategy] + ] = {} + self._default_strategy = ByNodeTarget( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority=1, + targets=QnnQuantizer.SUPPORTED_OPS, + ) + + def _annotate_custom_annotation(self, gm: torch.fx.GraphModule) -> None: + for annotation_func in self.custom_quant_annotations: + annotation_func(gm) + + def annotate(self, graph_module: torch.fx.GraphModule): + # Sort node level strategies by (priority, insertion index). + # Higher priority value comes first; if priorities are equal, original insertion order is preserved. + strategies: List[QuantizationStrategy] = [ + strategy + for _, strategy in sorted( + enumerate(self._strategies), + key=lambda x: (x[1].priority, x[0]), + reverse=True, + ) + ] + # Ensure the default strategy is appended last + strategies.append(self._default_strategy) + + for node in graph_module.graph.nodes: + for strategy in strategies: + if isinstance(node.target, str) or node in self._pending_annotate_nodes: + continue + + if quant_config := strategy.get_quant_config(node): + self._pending_annotate_nodes[node] = (quant_config, strategy) + + if self.verbose: + print(self.summary()) + + for node in graph_module.graph.nodes: + if isinstance(node.target, str): + continue + if node not in self._pending_annotate_nodes: + print(f"No quant config is implemented for op, {node.target}") + continue + + OP_ANNOTATOR[node.target](node, self._pending_annotate_nodes[node][0]) + + # custom annotation + self._annotate_custom_annotation(graph_module) + + def add_node_target( + self, + targets, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + priority: int = 1, + extra_kwargs: Optional[dict] = None, + ): + self._strategies.append( + ByNodeTarget( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority, + targets, + ), + ) + return self + + def add_regex( + self, + regex, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + priority: int = 1, + extra_kwargs: Optional[dict] = None, + ): + """ + Add a quantization strategy targeting nodes whose module stack matches given regex patterns. + + Args: + regex (Iterable[str]): Regex patterns to match module stack paths. + quant_dtype (QuantDtype): Data type for quantization. + is_qat (bool): Whether to apply QAT or PTQ. + act_observer (UniformQuantizationObserverBase): Observer for activation quantization. + granularity (QuantGranularity): Tensor/channel/block granularity. + note (str): Optional description for the strategy. + priority (int): Strategy priority (higher value = higher precedence). + extra_kwargs (dict, optional): Additional parameters for the strategy. + """ + self._strategies.append( + ByNameRegex( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority, + regex, + ), + ) + return self + + def summary(self, max_rows: int = -1): + if not self._pending_annotate_nodes: + return None + + headers = [ + "module_stack", + "op_target", + "quantize", + "act_observer", + "granularity", + "note", + "extra_kwargs", + ] + rows = [] + for i, (node, (_, strategy)) in enumerate(self._pending_annotate_nodes.items()): + if max_rows > 0 and i >= max_rows: + break + + row = [ + extract_node_metadata_mapping(node), + node.target, + f"{strategy.quant_dtype.name}/{'QAT' if strategy.is_qat else 'PTQ'}", + strategy.act_observer.__name__, + strategy.granularity.name, + strategy.note, + strategy.extra_kwargs, + ] + rows.append(row) + + if max_rows > 0 and len(self._pending_annotate_nodes) > max_rows: + rows.append(["..."] * len(headers)) + + return tabulate(rows, headers=headers, tablefmt="grid") diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 4d0f1098a62..9ca9a7dad6c 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -24,6 +24,7 @@ get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, get_16a8w_qnn_qat_config, + get_8a4w_qnn_ptq_config, get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_block_quant_config, @@ -44,6 +45,7 @@ "get_16a16w_qnn_ptq_config", "get_8a8w_qnn_ptq_config", "get_8a8w_qnn_qat_config", + "get_8a4w_qnn_ptq_config", "get_16a4w_qnn_qat_config", "get_ptq_per_block_quant_config", ] @@ -60,6 +62,7 @@ class QuantDtype(IntEnum): use_16a4w = 2 use_16a4w_block = 3 use_8a8w = 4 + use_8a4w = 5 QUANT_CONFIG_DICT = { @@ -109,6 +112,15 @@ class QuantDtype(IntEnum): partial(get_ptq_per_channel_quant_config), None, ), + (QuantDtype.use_8a4w, False): ( + get_8a4w_qnn_ptq_config, + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint8, + weight_dtype=torch.int4, + ), + None, + ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, @@ -242,10 +254,12 @@ def __init__(self): self.submodule_qconfig_list: List[ Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig] ] = [] + self.block_size_map = {} self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() + self.recipe = None def _annotate(self, gm: GraphModule) -> None: """ @@ -348,14 +362,20 @@ def annotate(self, model: GraphModule) -> GraphModule: """ Annotates GraphModule during prepare_pt2e. + If a recipe is provided, it will be used to annotate the model. + Otherwise, fallback to the default annotation flow. + Args: model (GraphModule): The FX GraphModule to annotate. Returns: GraphModule: The annotated model. """ - self._annotate(model) - self._annotate_custom_annotation(model) + if self.recipe: + self.recipe.annotate(model) + else: + self._annotate(model) + self._annotate_custom_annotation(model) return model @@ -389,10 +409,10 @@ def set_default_quant_config( """ self.default_quant_config = ModuleQConfig( quant_dtype, - is_qat, - is_conv_per_channel, - is_linear_per_channel, - act_observer, + is_qat=is_qat, + is_conv_per_channel=is_conv_per_channel, + is_linear_per_channel=is_linear_per_channel, + act_observer=act_observer, ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 11a2f57a64f..b28465f1827 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -58,6 +58,7 @@ EdgeProgramManager, to_edge_transform_and_lower, ) +from tabulate import tabulate from torch._decomp import core_aten_decompositions, remove_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes @@ -1124,6 +1125,35 @@ def get_soc_to_chipset_map(): } +def show_nn_module_stack_for_quant_recipe(gm: torch.fx.GraphModule, supported_ops): + """ + Print a quick preview of op targets and module stack. + + Use this to inspect the FX graph and identify module stack, which helps you craft regex or op-target for quantization recipe. + + """ + + module_metadata = {} + for node in gm.graph.nodes: + target = node.target + deepest_module = None + if node.op == "call_function" and "nn_module_stack" in node.meta: + deepest_module = list(node.meta["nn_module_stack"].values())[-1][0] + if node.target in supported_ops: + module_metadata.setdefault((target, deepest_module), []).append(node) + + table_rows = [] + for (target, module_stack), nodes in module_metadata.items(): + node_names = ", ".join([node.name for node in nodes]) + table_rows.append([str(target), module_stack, node_names]) + + print( + tabulate( + table_rows, headers=["Op Target", "Module Stack", "Nodes"], tablefmt="grid" + ) + ) + + def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): """ Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess diff --git a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md index ae1b4f15c99..1168c4c04a3 100644 --- a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md +++ b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md @@ -26,7 +26,7 @@ Deploying large language models like Llama 3 on-device presents the following ch To address these, we apply the following optimizations: -1. Quantization: Use `QuantDtype.use_16a4w_block` for post-training quantization to reduce model size and memory usage. +1. Quantization: Apply the `quant_recipe` when setting the quantization config to reduce model size and memory usage. 2. Mixed Precision Quantization: compresses KV cache tensors to 8-bit and applies `QuantDtype.use_16a8w` to the LM head. @@ -48,9 +48,6 @@ class Llama3_2_3B_Instruct(LLMModelConfig): instruct_model = False num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 # Group size used in block quantization for weight quantization. Will only be used when ptq = 16a4w_block masked_softmax = False # SeqMSE Quantization: optimizes the parameter encodings of each layer of a model individually to minimize the difference between the layer’s original and quantized outputs. (Implementation details: ./backends/qualcomm/_passes/seq_mse.py) In this configuration, we set `seq_mse_candidates` = 0, which means SeqMSE quantization is not applied. @@ -58,10 +55,8 @@ class Llama3_2_3B_Instruct(LLMModelConfig): r1 = False r2 = False r3 = False - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - ) + # quant recipe + quant_recipe = Llama3_3BQuantRecipe ``` diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index 2ca3b2bf931..fb6a5a3a3b0 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -9,24 +9,11 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Callable, Dict, Tuple, Type - -import torch -from executorch.backends.qualcomm.quantizer.custom_annotation import ( - annotate_down_proj, - annotate_kv_8bit, - annotate_output_16a8w, - annotate_qkv_proj_sha, - StaticLLMQuantConfig, -) -from executorch.backends.qualcomm.quantizer.qconfig import ( - get_ptq_per_channel_quant_config, -) -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from typing import Callable, Dict, Type + from executorch.examples.models.codegen import ( convert_weights as convert_codegen_weights, ) - from executorch.examples.models.gemma import convert_weights as convert_gemma_weights from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights from executorch.examples.models.granite import ( @@ -52,8 +39,26 @@ from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( MultiScopeAwareLlamaModel, ) + +from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import ( + CodegenQuantRecipe, + Gemma3QuantRecipe, + Gemma_2BQuantRecipe, + Granite_3_3_2B_InstructQuantRecipe, + Llama3_1BQuantRecipe, + Llama3_3BQuantRecipe, + LlamaStories110MQuantRecipe, + LlamaStories260KQuantRecipe, + Phi4MiniQuantRecipe, + Qwen2_5_0_5BQuantRecipe, + Qwen2_5_1_5BQuantRecipe, + Qwen3_0_6BQuantRecipe, + Qwen3_1_7BQuantRecipe, + Smollm2QuantRecipe, + Smollm3QuantRecipe, + StaticLLMQuantRecipe, +) from tabulate import tabulate -from torchao.quantization.pt2e import MinMaxObserver BASE_DIR = os.path.dirname(__file__) @@ -62,15 +67,6 @@ LLM_VARIANT_ARCHS = { "gemma3-1b": MultiScopeAwareLlamaModel, } -annotate_wqkv_sha = partial( - annotate_qkv_proj_sha, - qkv_tags={ - StaticLLMQuantConfig.wq_sha, - StaticLLMQuantConfig.wk_sha, - StaticLLMQuantConfig.wv_sha, - }, -) -annotate_wv_sha = partial(annotate_qkv_proj_sha, qkv_tags={StaticLLMQuantConfig.wv_sha}) @dataclass(init=False, frozen=True) @@ -86,8 +82,6 @@ class LLMModelConfig(ABC): transform_weight: Set to true to change Hugging Face weight to improve the performance of RoPE in HTP backend. instruct_model: True if the model uses chat templates. Check Hugging Face model card to ensure the model uses chat templates. num_sharding: Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers. - ptq: Set to true to perform PTQ quantization. Support 16a16w, 16a8w, 16a4w, 16a4w_block, 8a8w. - group_size: Group size used in block quantization for weight quantization. Will only be used when ptq = 16a4w_block masked_softmax: The MaskedSoftmax feature is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents. Note that it is only supported starting from QNN 2.35. @@ -96,7 +90,7 @@ class LLMModelConfig(ABC): r1: Enable SpinQuant R1 quantization optimization. r2: Enable SpinQuant R2 quantization optimization. r3: Enable SpinQuant R3 quantization optimization. - custom_annotation: Custom annotation to use when setting quant configs for the model. + quant_recipe: Quantization recipe to use when setting quant configs for the model. """ repo_id: str @@ -107,14 +101,12 @@ class LLMModelConfig(ABC): transform_weight: bool instruct_model: bool num_sharding: int - ptq: QuantDtype - group_size: int masked_softmax: bool seq_mse_candidates: int r1: bool r2: bool r3: bool - custom_annotation: Tuple + quant_recipe: StaticLLMQuantRecipe def __str__(self): # noqa: C901 """ @@ -160,22 +152,6 @@ def format_value(v): table = [(k, v) for k, v in attrs.items()] return tabulate(table, headers=["Config", "Value"], tablefmt="grid") - def get_kv_io_bit_width(self) -> int: - if self.ptq is None: - return 32 - elif ( - self.ptq == QuantDtype.use_8a8w - or annotate_kv_8bit in self.custom_annotation - ): - return 8 - else: - # If quantized but not 8a8w or mix_quantization, it has to be 16bit kv io. - return 16 - - def get_logits_output_bit_width(self) -> int: - # We use 16bit logits for all quant config - return 32 if self.ptq is None else 16 - SUPPORTED_LLM_MODELS: Dict[str, LLMModelConfig] = {} @@ -197,27 +173,13 @@ class LlamaStories260K(LLMModelConfig): convert_weights = None transform_weight = True instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w - group_size = None masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) + quant_recipe = LlamaStories260KQuantRecipe @register_llm_model("stories110m") @@ -228,27 +190,13 @@ class LlamaStories110M(LLMModelConfig): convert_weights = None transform_weight = True instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w - group_size = None masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) + quant_recipe = LlamaStories110MQuantRecipe @register_llm_model("llama3_2-1b_instruct") @@ -260,26 +208,13 @@ class Llama3_2_1B_Instruct(LLMModelConfig): transform_weight = True # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = False seq_mse_candidates = 1000 r1 = False r2 = False r3 = False - quantization_config_down_proj_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial( - annotate_down_proj, quantization_config=quantization_config_down_proj_16a8w - ), - ) + quant_recipe = Llama3_1BQuantRecipe @register_llm_model("llama3_2-3b_instruct") @@ -291,72 +226,52 @@ class Llama3_2_3B_Instruct(LLMModelConfig): transform_weight = True # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. instruct_model = False - num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - ) + quant_recipe = Llama3_3BQuantRecipe -@register_llm_model("gemma-2b") +@register_llm_model("codegen2_1b") @dataclass(init=False, frozen=True) -class Gemma_2B(LLMModelConfig): - repo_id: str = "google/gemma-2b-it" +class Codegen(LLMModelConfig): + repo_id: str = "Salesforce/codegen2-1B_P" params_path: str = os.path.join( - BASE_DIR, "../../../models/gemma/config/2b_config.json" + BASE_DIR, "../../../models/codegen/config/config.json" ) - convert_weights = convert_gemma_weights - transform_weight = False - instruct_model = True - - num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 64 + convert_weights = convert_codegen_weights + transform_weight = True + instruct_model = False + num_sharding = 1 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), - ) + quant_recipe = CodegenQuantRecipe -@register_llm_model("codegen2_1b") +@register_llm_model("gemma-2b") @dataclass(init=False, frozen=True) -class Codegen(LLMModelConfig): - repo_id: str = "Salesforce/codegen2-1B_P" +class Gemma_2B(LLMModelConfig): + repo_id: str = "google/gemma-2b-it" params_path: str = os.path.join( - BASE_DIR, "../../../models/codegen/config/config.json" + BASE_DIR, "../../../models/gemma/config/2b_config.json" ) - convert_weights = convert_codegen_weights - transform_weight = True - instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a8w - group_size = None + convert_weights = convert_gemma_weights + transform_weight = False + instruct_model = True + + num_sharding = 4 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - custom_annotation = () + quant_recipe = Gemma_2BQuantRecipe @register_llm_model("gemma3-1b") @@ -369,23 +284,13 @@ class Gemma3(LLMModelConfig): convert_weights = convert_gemma3_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 64 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), - ) + quant_recipe = Gemma3QuantRecipe @register_llm_model("granite_3_3-2b_instruct") @@ -398,23 +303,13 @@ class Granite_3_3_2b_Instruct(LLMModelConfig): convert_weights = convert_granite_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 64 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), - ) + quant_recipe = Granite_3_3_2B_InstructQuantRecipe @register_llm_model("phi_4_mini") @@ -427,27 +322,13 @@ class Phi4Mini(LLMModelConfig): convert_weights = convert_phi_4_mini_weights transform_weight = False instruct_model = True - num_sharding = 8 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) + quant_recipe = Phi4MiniQuantRecipe @register_llm_model("qwen2_5-0_5b") @@ -460,17 +341,13 @@ class Qwen2_5_0_5B(LLMModelConfig): convert_weights = convert_qwen2_5_weights transform_weight = False instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = True - custom_annotation = () + quant_recipe = Qwen2_5_0_5BQuantRecipe @register_llm_model("qwen2_5-1_5b") @@ -483,17 +360,13 @@ class Qwen2_5_1_5B(LLMModelConfig): convert_weights = convert_qwen2_5_weights transform_weight = False instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = True - custom_annotation = (annotate_output_16a8w,) + quant_recipe = Qwen2_5_1_5BQuantRecipe @register_llm_model("qwen3-0_6b") @@ -506,24 +379,13 @@ class Qwen3_0_6B(LLMModelConfig): convert_weights = convert_qwen3_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = True seq_mse_candidates = 1000 r1 = False r2 = False r3 = False - quantization_config_down_proj_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - partial( - annotate_down_proj, quantization_config=quantization_config_down_proj_16a8w - ), - ) + quant_recipe = Qwen3_0_6BQuantRecipe @register_llm_model("qwen3-1_7b") @@ -536,20 +398,13 @@ class Qwen3_1_7B(LLMModelConfig): convert_weights = convert_qwen3_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = True - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - ) + quant_recipe = Qwen3_1_7BQuantRecipe @register_llm_model("smollm2_135m") @@ -562,17 +417,13 @@ class Smollm2_135M(LLMModelConfig): convert_weights = convert_smollm2_weights transform_weight = True instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a8w - group_size = None masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - custom_annotation = () + quant_recipe = Smollm2QuantRecipe @register_llm_model("smollm3-3b") @@ -583,23 +434,10 @@ class Smollm3_3B(LLMModelConfig): convert_weights = convert_smollm3_weights transform_weight = False instruct_model = True - num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wqkv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial( - annotate_wqkv_sha, quantization_config=quantization_config_wqkv_sha_16a8w - ), - ) + quant_recipe = Smollm3QuantRecipe diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index feaf99fd81d..29212c7855b 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -85,6 +85,9 @@ set_scales, WrappedLlamaModel, ) +from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import ( + StaticLLMQuantRecipe, +) from executorch.examples.qualcomm.utils import ( make_output_dir, @@ -220,36 +223,20 @@ def quantize( quant_dtype, args, tokenizer, - custom_annotations=(), + quant_recipe, scales_state_dict=None, chat_template=None, lookahead_config=None, ): self.quant_dtype = quant_dtype - quantizer = make_custom_quantizer( - quant_dtype, args.range_setting, custom_annotations - ) + quantizer = make_custom_quantizer(quant_dtype, args.range_setting, ()) self.has_quant_io = True fx_graph_module = None with torch.no_grad(): fx_graph_module = torch.export.export( self.llama_graph_module, self.inputs, strict=True ).module() - - if quant_dtype == QuantDtype.use_16a4w_block: - if self.decoder_model_config.group_size is None: - raise ValueError( - "Group size is required when use quant_dtype 16a4w_block" - ) - conv_nodes = [ - n for n in fx_graph_module.graph.nodes if "conv" in n.name - ] - block_size_map = { - n.name: (1, self.decoder_model_config.group_size, 1, 1) - for n in conv_nodes - } - quantizer.set_block_size_map(block_size_map) - + quantizer.recipe = quant_recipe fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") @@ -439,12 +426,16 @@ def compile( with open(params_path) as f: kv_config = ModelArgs(**json.load(f)) + # get quant recipe + quant_recipe: StaticLLMQuantRecipe = decoder_model_config.quant_recipe(True) + # TODO: support batch inputs if necessary kv_config.max_batch_size = 1 kv_config.max_seq_len = args.max_seq_len kv_config.use_kv_cache = True kv_config.enable_r3 = decoder_model_config.r3 - kv_config.kv_io_bit_width = decoder_model_config.get_kv_io_bit_width() + kv_config.kv_io_bit_width = quant_recipe.get_kv_io_bit_width() + if decoder_model_config.masked_softmax: if is_qnn_sdk_version_less_than("2.35"): logging.warning( @@ -643,7 +634,7 @@ def permute(w, heads, partial_rotary_dim): QuantDtype.use_8a8w: (8, 8), QuantDtype.use_16a4w: (16, 4), QuantDtype.use_16a4w_block: (16, 4), - }[decoder_model_config.ptq] + }[quant_recipe.default_quant_dtype] scales_state_dict = compute_scales( wrapped_model, tokens, weight_bits, act_bits, 1600 ) @@ -661,24 +652,24 @@ def permute(w, heads, partial_rotary_dim): use_fp16 = True # "io_type" here refers to logits output and "kv_type" refers to kv_cache input/output. fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} - if decoder_model_config.ptq: - if decoder_model_config.get_kv_io_bit_width() == 8: + if quant_recipe.default_quant_dtype: + if quant_recipe.get_kv_io_bit_width() == 8: fixed_point_type["kv_type"] = torch.uint8 - elif decoder_model_config.get_kv_io_bit_width() == 16: + elif quant_recipe.get_kv_io_bit_width() == 16: fixed_point_type["kv_type"] = torch.uint16 else: raise RuntimeError( - f"Unknown kv io bit width {decoder_model_config.get_kv_io_bit_width()}" + f"Unknown kv io bit width {quant_recipe.get_kv_io_bit_width()}" ) - if decoder_model_config.get_logits_output_bit_width() == 16: + if quant_recipe.get_logits_output_bit_width() == 16: fixed_point_type["io_type"] = torch.uint16 else: raise RuntimeError( - f"Unknown logits io bit width {decoder_model_config.get_logits_output_bit_width()}" + f"Unknown logits io bit width {quant_recipe.get_logits_output_bit_width()}" ) - quant_dtype = decoder_model_config.ptq + quant_dtype = quant_recipe.default_quant_dtype if args.dtype_override is not None: dtype_override = DType[args.dtype_override] @@ -701,9 +692,8 @@ def permute(w, heads, partial_rotary_dim): QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ]["skip_node"] = {"tokens"} - if decoder_model_config.ptq: + if quant_recipe.default_quant_dtype: start_quantize_ts = time.time() - custom_annotations = decoder_model_config.custom_annotation kv_quant_attrs = {} for i, llama_instance in enumerate(llama_instance_list): lookahead_config = ( @@ -711,11 +701,12 @@ def permute(w, heads, partial_rotary_dim): if i == 0 and args.model_mode == "lookahead" else None ) + llama_instance.quantize( quant_dtype=quant_dtype, args=args, tokenizer=tokenizer, - custom_annotations=custom_annotations, + quant_recipe=quant_recipe, scales_state_dict=scales_state_dict, chat_template=chat_template, lookahead_config=lookahead_config, @@ -729,11 +720,11 @@ def permute(w, heads, partial_rotary_dim): kv_quant_attrs[output_indices] = output.args[1:] output_indices += 1 break - custom_annotations = custom_annotations + ( + quant_recipe.recipe.custom_quant_annotations.append( partial( annotate_prefill_kv_output, kv_quant_attrs=kv_quant_attrs, - ), + ) ) # temporarily remove annotate_prefill_kv_output llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ diff --git a/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py new file mode 100644 index 00000000000..fc2827cd895 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py @@ -0,0 +1,589 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.backends.qualcomm.quantizer.custom_annotation import annotate_kv_8bit +from executorch.backends.qualcomm.quantizer.quant_recipe import ( + QuantGranularity, + QuantRecipe, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from torchao.quantization.pt2e import MinMaxObserver + + +class StaticLLMQuantRecipe: + """ + Qualcomm's static LLaMA quantization recipe. + """ + + def __init__(self): + self.recipe: Optional[QuantRecipe] = None + + # For IO bitwidth + self.default_quant_dtype = getattr(self, "default_quant_dtype", None) + if self.default_quant_dtype is None: + raise ValueError("default_quant_dtype must be defined in the recipe.") + + def annotate(self, graph_module: torch.fx.GraphModule): + self.recipe.annotate(graph_module) + + def get_kv_io_bit_width(self) -> int: + if self.default_quant_dtype is None: + return 32 + elif ( + self.default_quant_dtype == QuantDtype.use_8a8w + or annotate_kv_8bit in self.recipe.custom_quant_annotations + ): + return 8 + else: + # If quantized but not 8a8w or mix_quantization, it has to be 16bit kv io. + return 16 + + def get_logits_output_bit_width(self) -> int: + # We use 16bit logits for all quant config + return 32 if self.default_quant_dtype is None else 16 + + +class LlamaStories260KQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"layers\..*\.attention\.wv.*"}, + QuantDtype.use_8a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class LlamaStories110MQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"layers\..*\.attention\.wv.*"}, + QuantDtype.use_8a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Llama3_1BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w_block + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + note="default with 16bit activation", + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + note="Annotate with 16a4w block quantization since these layers are not sensitive.", + ) + .add_regex( + { + r"output\.conv", + r"layers\.[0-3]\.feed_forward\.w2_conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + note="Down proj layer is sensitive and should be annotated with 16a8w.", + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Llama3_3BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w_block + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"output\.conv", + r"layers\.2[1-7]\.feed_forward\.w2_conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class CodegenQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a8w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + + +class Gemma_2BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + r"output\.conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Gemma3QuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Granite_3_3_2B_InstructQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Phi4MiniQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + .add_regex( + {r"layers\..*\.attention\.wv.*"}, + QuantDtype.use_8a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Qwen2_5_0_5BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + + +class Qwen2_5_1_5BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + + +class Qwen3_0_6BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.feed_forward\.w2_conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + + +class Qwen3_1_7BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + .add_regex( + { + r"output\.conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Smollm2QuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a8w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + + +class Smollm3QuantRecipe(StaticLLMQuantRecipe): + + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wq.*", + r"layers\..*\.attention\.wk.*", + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + { + r"output\.conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit)