From 82798dff5e79962e745b51194b9c7aa68b805ab3 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Fri, 4 Oct 2024 14:07:42 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Fixed layer norm quantization annotation for 16bit - Fixed quantization annotation for layer norm in 16bit. - Add a unit test for 16a4w layer norm. --- backends/qualcomm/quantizer/utils.py | 20 ++++++++++++++------ backends/qualcomm/tests/test_qnn_delegate.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index ed9afd70ce6..c54b2bc3ebb 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -1055,17 +1055,25 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return + input_act_qspec = quantization_config.input_activation _annotate_input_qspec_map( node, act_node, - quantization_config.input_activation, - ) - _annotate_input_qspec_map( - node, - weight_node, - quantization_config.input_activation, + input_act_qspec, ) + if input_act_qspec.dtype == torch.int32: + _annotate_input_qspec_map( + node, + weight_node, + get_default_16bit_qnn_ptq_config().weight, + ) + else: + _annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) nodes_to_mark_annotated = [node, weight_node] if bias_node: _annotate_input_qspec_map( diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 8abed68c630..f9fc2ea12e1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -634,6 +634,16 @@ def test_qnn_backend_16a4w_conv2d(self): ) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_16a4w_layer_norm(self): + module = LayerNorm() # noqa: F405 + sample_input = (torch.randn(196, 768),) + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_16a4w_linear(self): module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),)