diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index e34630538d0..7063ab73aa3 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -11,6 +11,7 @@ _is_float_tensor, Q_ANNOTATION_KEY, ) +from enum import Enum, unique from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, get_16a8w_qnn_qat_config, @@ -125,6 +126,100 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): _annotated=True, ) +def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None: + """ + This function is for static LLM models. + This function will annotate the last conv(linear), which is the lm_head, as 16a8w. + """ + + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + weight = node.args[1] + input_qspec_map[weight] = quantization_config.weight + + if len(node.args) > 2 and isinstance(node.args[2], Node): + input_qspec_map[node.args[2]] = quantization_config.bias(node) + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + if is_qat: + quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver + ) + else: + quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver + ) + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: + if "nn_module_stack" in node.meta: + module_values_list = list(node.meta["nn_module_stack"].values()) + full_qualified_name = module_values_list[-1][0] + if full_qualified_name == "output.conv": + annotate_conv2d( + node, quantization_config=quantization_config_16a8w_per_channel + ) + + +@unique +class StaticLLMQuantConfig(Enum): + """ + Layer namespace configuration for Qualcomm's static LLaMA quantization. + """ + + wq_sha = "wq_sha" # Query weight (single head) + wk_sha = "wk_sha" # Key weight (single head) + wv_sha = "wv_sha" # Value weight (single head) + +def annotate_qkv_proj_sha( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + qkv_tags: set[StaticLLMQuantConfig], +): + """ + Annotates QKV projection layers in a GraphModule for quantization, + specifically layers defined in StaticLLMQuantConfig. + + Args: + qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers + (e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in + StaticLLMQuantConfig are allowed. + + Raises: + ValueError: If any tag in `qkv_tags` is not among the allowed enum members. + """ + + # Get all valid tags from the StaticLLMQuantConfig enum + allowed_tags = set(StaticLLMQuantConfig) + invalid_tags = qkv_tags - allowed_tags + if invalid_tags: + raise ValueError( + f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}" + ) + + for node in gm.graph.nodes: + if node.target == torch.ops.aten.conv2d.default and any( + tag.value in node.meta["stack_trace"] for tag in qkv_tags + ): + input_qspec_map = {} + input_qspec_map[node.args[0]] = quantization_config.input_activation + input_qspec_map[node.args[1]] = quantization_config.weight + if len(node.args) > 2 and isinstance(node.args[2], Node): + input_qspec_map[node.args[2]] = quantization_config.bias(node) + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + def annotate_kv_8bit( # noqa: C901 gm: torch.fx.GraphModule,