diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index ed9afd70ce6..c54b2bc3ebb 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -1055,17 +1055,25 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return + input_act_qspec = quantization_config.input_activation _annotate_input_qspec_map( node, act_node, - quantization_config.input_activation, - ) - _annotate_input_qspec_map( - node, - weight_node, - quantization_config.input_activation, + input_act_qspec, ) + if input_act_qspec.dtype == torch.int32: + _annotate_input_qspec_map( + node, + weight_node, + get_default_16bit_qnn_ptq_config().weight, + ) + else: + _annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) nodes_to_mark_annotated = [node, weight_node] if bias_node: _annotate_input_qspec_map( diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 8abed68c630..f9fc2ea12e1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -634,6 +634,16 @@ def test_qnn_backend_16a4w_conv2d(self): ) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_16a4w_layer_norm(self): + module = LayerNorm() # noqa: F405 + sample_input = (torch.randn(196, 768),) + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_16a4w_linear(self): module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),)