From 936997d734adc73b3dcca435b8ab245cde40d58b Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Wed, 30 Jul 2025 11:33:35 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Fix the regression of whisper model Summary: - Resolve the Whisper model accuracy issue caused by upgrading the Transformers. - Modify decompose_sdpa.py to support kwargs "scale" - fixed internal CI --- backends/qualcomm/tests/models.py | 5 +- backends/qualcomm/tests/test_qnn_delegate.py | 46 +++++++++++++------ backends/transforms/decompose_sdpa.py | 13 ++++++ .../qualcomm/oss_scripts/whisper/whisper.py | 17 +++---- .../oss_scripts/whisper/whisper_model.py | 15 +++--- 5 files changed, 63 insertions(+), 33 deletions(-) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 988665c6583..01ed37f80a3 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1530,12 +1530,13 @@ def forward(self, x): class ScaledDotProductAttention(torch.nn.Module): - def __init__(self): + def __init__(self, scale=None): super().__init__() + self.scale = scale def forward(self, query_layer, key_layer, value_layer, attn_mask): attn_output = torch.nn.functional.scaled_dot_product_attention( - query_layer, key_layer, value_layer, attn_mask + query_layer, key_layer, value_layer, attn_mask, scale=self.scale ) return attn_output diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 85b9c869739..157aff397e5 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1008,7 +1008,11 @@ def test_qnn_backend_rsqrt(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_sdpa(self): - module = ScaledDotProductAttention() # noqa: F405 + modules = [ + ScaledDotProductAttention(), # noqa: F405 + ScaledDotProductAttention(scale=0.5), # noqa: F405 + ScaledDotProductAttention(scale=1.0), # noqa: F405 + ] mask = torch.tril(torch.randn(1, 1, 100, 100)) mask[mask == 0] = float("-inf") sample_input = ( @@ -1017,7 +1021,9 @@ def test_qnn_backend_sdpa(self): torch.randn(1, 4, 100, 64), mask, ) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_sigmoid(self): module = Sigmoid() # noqa: F405 @@ -2414,7 +2420,11 @@ def test_qnn_backend_rsqrt(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_sdpa(self): - module = ScaledDotProductAttention() # noqa: F405 + modules = [ + ScaledDotProductAttention(), # noqa: F405 + ScaledDotProductAttention(scale=0.5), # noqa: F405 + ScaledDotProductAttention(scale=1.0), # noqa: F405 + ] mask = torch.tril(torch.randn(1, 1, 100, 100)) mask[mask == 0] = torch.finfo(torch.float32).min sample_input = ( @@ -2423,8 +2433,12 @@ def test_qnn_backend_sdpa(self): torch.randn(1, 4, 100, 64), mask, ) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a8w + ) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_select_copy(self): module = SelectCopy() # noqa: F405 @@ -4949,13 +4963,14 @@ def test_gMLP(self): self.assertGreaterEqual(msg["top_1"], 60) self.assertGreaterEqual(msg["top_5"], 85) - def test_mobilevit_v1(self): + @unittest.skip("Only outputs good accuracy in QNN 2.29") + def test_mobilevit_v2(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v1.py" + f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py", "--dataset", self.image_dataset, "--artifact", @@ -4973,6 +4988,8 @@ def test_mobilevit_v1(self): ] if self.host: cmds.extend(["--host", self.host]) + if self.shared_buffer: + cmds.extend(["--shared_buffer"]) p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: @@ -4982,17 +4999,16 @@ def test_mobilevit_v1(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["top_1"], 70) + self.assertGreaterEqual(msg["top_1"], 50) self.assertGreaterEqual(msg["top_5"], 85) - @unittest.skip("Only outputs good accuracy in QNN 2.29") - def test_mobilevit_v2(self): + def test_mobilevit1(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit1.py", "--dataset", self.image_dataset, "--artifact", @@ -5010,8 +5026,6 @@ def test_mobilevit_v2(self): ] if self.host: cmds.extend(["--host", self.host]) - if self.shared_buffer: - cmds.extend(["--shared_buffer"]) p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: @@ -5021,7 +5035,7 @@ def test_mobilevit_v2(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["top_1"], 50) + self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 85) def test_pvt(self): @@ -5031,7 +5045,11 @@ def test_pvt(self): cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/pvt.py", + "--dataset", self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", self.build_folder, "--device", self.device, diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index 73e9d986c3d..d49e0da0c9b 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -6,6 +6,8 @@ # pyre-strict +import math + import torch from executorch.exir.pass_base import ExportPass, PassResult from torch._decomp import get_decompositions @@ -30,6 +32,7 @@ def call( for node in graph.nodes: if node.target == torch.ops.aten.scaled_dot_product_attention.default: input_tensors = (arg.meta["val"] for arg in node.args) + scale = node.kwargs.get("scale", None) # refer to pytorch/test/test_decomp.py decomposed_module = make_fx( @@ -81,6 +84,16 @@ def call( ) continue + if scale is not None and decomposed_node.target in [ + torch.ops.aten.mul.Scalar + ]: + new_args = list(decomposed_node.args) + # Based on the implementation of _scaled_dot_product_attention_math, + # the scale is applied to q and k before matmul. + # refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873 + new_args[1] = math.sqrt(scale) + decomposed_node.args = tuple(new_args) + subgraph_node = graph.node_copy( decomposed_node, arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023 diff --git a/examples/qualcomm/oss_scripts/whisper/whisper.py b/examples/qualcomm/oss_scripts/whisper/whisper.py index 4b0d681f6ec..a9f666e5f54 100644 --- a/examples/qualcomm/oss_scripts/whisper/whisper.py +++ b/examples/qualcomm/oss_scripts/whisper/whisper.py @@ -36,8 +36,8 @@ from executorch.devtools.backend_debug import print_delegation_info from executorch.examples.qualcomm.oss_scripts.whisper.whisper_model import ( - Seq2SeqLMDecoderExportableModuleWithStaticCache, - Seq2SeqLMEncoderExportableModule, + QnnSeq2SeqLMDecoderExportableModuleWithStaticCache, + QnnSeq2SeqLMEncoderExportableModule, ) from executorch.examples.qualcomm.utils import ( @@ -169,14 +169,14 @@ def __init__( ) self.whisper_encoder = ( - Seq2SeqLMEncoderExportableModule(whisper_model.get_encoder()) + QnnSeq2SeqLMEncoderExportableModule(whisper_model.get_encoder()) .to("cpu") .eval() ) self.encoder_passes_job = get_capture_program_passes() self.whisper_decoder = ( - Seq2SeqLMDecoderExportableModuleWithStaticCache( + QnnSeq2SeqLMDecoderExportableModuleWithStaticCache( whisper_model=whisper_model, max_cache_length=self.max_seq_length, batch_size=batch_size, @@ -190,20 +190,21 @@ def __init__( self.exported_whisper_encoder = None self.exported_whisper_decoder = None self.has_quant_io = False + self.kv_shape = { + (self.max_seq_length, self.head_dim), + } def _tag_ios(self, node, fixed_point_type): if not self.has_quant_io: return quant_io_type = None - if node.op == "placeholder" and "static_cache_" in node.name: + if node.op == "placeholder" and node.meta["val"].size()[-2:] in self.kv_shape: quant_io_type = fixed_point_type if is_graph_output(node): # shape of k caches and v caches - if node.meta["val"].size()[-2:] in { - (self.max_seq_length, self.head_dim), - }: + if node.meta["val"].size()[-2:] in self.kv_shape: quant_io_type = fixed_point_type return quant_io_type diff --git a/examples/qualcomm/oss_scripts/whisper/whisper_model.py b/examples/qualcomm/oss_scripts/whisper/whisper_model.py index ec0e96cae12..22437c51044 100644 --- a/examples/qualcomm/oss_scripts/whisper/whisper_model.py +++ b/examples/qualcomm/oss_scripts/whisper/whisper_model.py @@ -6,10 +6,11 @@ import torch -from transformers import StaticCache, WhisperForConditionalGeneration +from transformers.cache_utils import DynamicCache, EncoderDecoderCache, StaticCache +from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration -class Seq2SeqLMEncoderExportableModule(torch.nn.Module): +class QnnSeq2SeqLMEncoderExportableModule(torch.nn.Module): """ A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`. This module ensures that the exported encoder model is compatible with ExecuTorch. @@ -29,7 +30,7 @@ def get_metadata(self): return {} -class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): +class QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): """ A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`, specifically for use with static caching. This module ensures the exported decoder @@ -57,11 +58,7 @@ def __init__(self, whisper_model, max_cache_length, batch_size): device="cpu", dtype=torch.float32, ) - - # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i]) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i]) + self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) def forward( self, decoder_input_ids, attention_mask, encoder_hidden_states, cache_position @@ -71,7 +68,7 @@ def forward( input_ids=decoder_input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, - past_key_values=self.static_cache, + past_key_values=self.cache, use_cache=True, cache_position=cache_position, )