Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,12 +1530,13 @@ def forward(self, x):


class ScaledDotProductAttention(torch.nn.Module):
def __init__(self):
def __init__(self, scale=None):
super().__init__()
self.scale = scale

def forward(self, query_layer, key_layer, value_layer, attn_mask):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attn_mask
query_layer, key_layer, value_layer, attn_mask, scale=self.scale
)
return attn_output

Expand Down
46 changes: 32 additions & 14 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,11 @@ def test_qnn_backend_rsqrt(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_sdpa(self):
module = ScaledDotProductAttention() # noqa: F405
modules = [
ScaledDotProductAttention(), # noqa: F405
ScaledDotProductAttention(scale=0.5), # noqa: F405
ScaledDotProductAttention(scale=1.0), # noqa: F405
]
mask = torch.tril(torch.randn(1, 1, 100, 100))
mask[mask == 0] = float("-inf")
sample_input = (
Expand All @@ -1017,7 +1021,9 @@ def test_qnn_backend_sdpa(self):
torch.randn(1, 4, 100, 64),
mask,
)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_sigmoid(self):
module = Sigmoid() # noqa: F405
Expand Down Expand Up @@ -2414,7 +2420,11 @@ def test_qnn_backend_rsqrt(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_sdpa(self):
module = ScaledDotProductAttention() # noqa: F405
modules = [
ScaledDotProductAttention(), # noqa: F405
ScaledDotProductAttention(scale=0.5), # noqa: F405
ScaledDotProductAttention(scale=1.0), # noqa: F405
]
mask = torch.tril(torch.randn(1, 1, 100, 100))
mask[mask == 0] = torch.finfo(torch.float32).min
sample_input = (
Expand All @@ -2423,8 +2433,12 @@ def test_qnn_backend_sdpa(self):
torch.randn(1, 4, 100, 64),
mask,
)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_16a8w
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_select_copy(self):
module = SelectCopy() # noqa: F405
Expand Down Expand Up @@ -4949,13 +4963,14 @@ def test_gMLP(self):
self.assertGreaterEqual(msg["top_1"], 60)
self.assertGreaterEqual(msg["top_5"], 85)

def test_mobilevit_v1(self):
@unittest.skip("Only outputs good accuracy in QNN 2.29")
def test_mobilevit_v2(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")

cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v1.py"
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py",
"--dataset",
self.image_dataset,
"--artifact",
Expand All @@ -4973,6 +4988,8 @@ def test_mobilevit_v1(self):
]
if self.host:
cmds.extend(["--host", self.host])
if self.shared_buffer:
cmds.extend(["--shared_buffer"])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
Expand All @@ -4982,17 +4999,16 @@ def test_mobilevit_v1(self):
if "Error" in msg:
self.fail(msg["Error"])
else:
self.assertGreaterEqual(msg["top_1"], 70)
self.assertGreaterEqual(msg["top_1"], 50)
self.assertGreaterEqual(msg["top_5"], 85)

@unittest.skip("Only outputs good accuracy in QNN 2.29")
def test_mobilevit_v2(self):
def test_mobilevit1(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")

cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit_v2.py",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/mobilevit1.py",
"--dataset",
self.image_dataset,
"--artifact",
Expand All @@ -5010,8 +5026,6 @@ def test_mobilevit_v2(self):
]
if self.host:
cmds.extend(["--host", self.host])
if self.shared_buffer:
cmds.extend(["--shared_buffer"])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
Expand All @@ -5021,7 +5035,7 @@ def test_mobilevit_v2(self):
if "Error" in msg:
self.fail(msg["Error"])
else:
self.assertGreaterEqual(msg["top_1"], 50)
self.assertGreaterEqual(msg["top_1"], 70)
self.assertGreaterEqual(msg["top_5"], 85)

def test_pvt(self):
Expand All @@ -5031,7 +5045,11 @@ def test_pvt(self):
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/pvt.py",
"--dataset",
self.image_dataset,
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--device",
self.device,
Expand Down
13 changes: 13 additions & 0 deletions backends/transforms/decompose_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# pyre-strict

import math

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch._decomp import get_decompositions
Expand All @@ -30,6 +32,7 @@ def call(
for node in graph.nodes:
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
input_tensors = (arg.meta["val"] for arg in node.args)
scale = node.kwargs.get("scale", None)

# refer to pytorch/test/test_decomp.py
decomposed_module = make_fx(
Expand Down Expand Up @@ -81,6 +84,16 @@ def call(
)
continue

if scale is not None and decomposed_node.target in [
torch.ops.aten.mul.Scalar
]:
new_args = list(decomposed_node.args)
# Based on the implementation of _scaled_dot_product_attention_math,
# the scale is applied to q and k before matmul.
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
new_args[1] = math.sqrt(scale)
decomposed_node.args = tuple(new_args)

subgraph_node = graph.node_copy(
decomposed_node,
arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023
Expand Down
17 changes: 9 additions & 8 deletions examples/qualcomm/oss_scripts/whisper/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@

from executorch.devtools.backend_debug import print_delegation_info
from executorch.examples.qualcomm.oss_scripts.whisper.whisper_model import (
Seq2SeqLMDecoderExportableModuleWithStaticCache,
Seq2SeqLMEncoderExportableModule,
QnnSeq2SeqLMDecoderExportableModuleWithStaticCache,
QnnSeq2SeqLMEncoderExportableModule,
)

from executorch.examples.qualcomm.utils import (
Expand Down Expand Up @@ -169,14 +169,14 @@ def __init__(
)

self.whisper_encoder = (
Seq2SeqLMEncoderExportableModule(whisper_model.get_encoder())
QnnSeq2SeqLMEncoderExportableModule(whisper_model.get_encoder())
.to("cpu")
.eval()
)
self.encoder_passes_job = get_capture_program_passes()

self.whisper_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(
whisper_model=whisper_model,
max_cache_length=self.max_seq_length,
batch_size=batch_size,
Expand All @@ -190,20 +190,21 @@ def __init__(
self.exported_whisper_encoder = None
self.exported_whisper_decoder = None
self.has_quant_io = False
self.kv_shape = {
(self.max_seq_length, self.head_dim),
}

def _tag_ios(self, node, fixed_point_type):
if not self.has_quant_io:
return

quant_io_type = None
if node.op == "placeholder" and "static_cache_" in node.name:
if node.op == "placeholder" and node.meta["val"].size()[-2:] in self.kv_shape:
quant_io_type = fixed_point_type

if is_graph_output(node):
# shape of k caches and v caches
if node.meta["val"].size()[-2:] in {
(self.max_seq_length, self.head_dim),
}:
if node.meta["val"].size()[-2:] in self.kv_shape:
quant_io_type = fixed_point_type

return quant_io_type
Expand Down
15 changes: 6 additions & 9 deletions examples/qualcomm/oss_scripts/whisper/whisper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@


import torch
from transformers import StaticCache, WhisperForConditionalGeneration
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration


class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
class QnnSeq2SeqLMEncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
This module ensures that the exported encoder model is compatible with ExecuTorch.
Expand All @@ -29,7 +30,7 @@ def get_metadata(self):
return {}


class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
class QnnSeq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
specifically for use with static caching. This module ensures the exported decoder
Expand Down Expand Up @@ -57,11 +58,7 @@ def __init__(self, whisper_model, max_cache_length, batch_size):
device="cpu",
dtype=torch.float32,
)

# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i])
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i])
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())

def forward(
self, decoder_input_ids, attention_mask, encoder_hidden_states, cache_position
Expand All @@ -71,7 +68,7 @@ def forward(
input_ids=decoder_input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
Expand Down
Loading