From da35397017dbb16d71c6eddb4c3e1e0e4939bec6 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Mon, 10 Nov 2025 16:09:39 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - VIT Optimization --- examples/qualcomm/scripts/torchvision_vit.py | 74 +++++++++++++++++--- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/examples/qualcomm/scripts/torchvision_vit.py b/examples/qualcomm/scripts/torchvision_vit.py index 2a428683ec3..ed8dbb792c4 100755 --- a/examples/qualcomm/scripts/torchvision_vit.py +++ b/examples/qualcomm/scripts/torchvision_vit.py @@ -7,12 +7,14 @@ import json import logging import os +from contextlib import contextmanager from multiprocessing.connection import Client import numpy as np import torch +import torch.nn.functional as F from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel from executorch.examples.qualcomm.utils import ( @@ -25,6 +27,56 @@ ) +# Copied from torch/nn/functional.py +# QNN does not have 5D permute optimization. Fuse to a single 4D optimization +# Changed unsqueeze(0).transpose(0, -2).squeeze(-2) to permute(2, 0, 1, 3) +def _in_projection_packed_custom(q, k, v, w, b=None) -> list[torch.Tensor]: + from torch.nn.functional import linear + + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = proj.unflatten(-1, (3, E)).permute(2, 0, 1, 3).contiguous() + # pyrefly: ignore # bad-return + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = kv_proj.unflatten(-1, (2, E)).permute(2, 0, 1, 3).contiguous() + # pyrefly: ignore # bad-return + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + # pyrefly: ignore # bad-return + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +# Context manager to patch temporarily, so it won't affect other users using F._in_projection_packed +@contextmanager +def PermuteInProjectionPacked(): + # Save the original function so it can be restored later + _original_in_projection_packed = F._in_projection_packed + F._in_projection_packed = _in_projection_packed_custom + try: + yield + finally: + F._in_projection_packed = _original_in_projection_packed + + def main(args): # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -44,16 +96,18 @@ def main(args): ) pte_filename = "vit_qnn_q8" - instance = TorchVisionViTModel() - build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - quant_dtype=QuantDtype.use_8a8w, - shared_buffer=args.shared_buffer, - ) + instance = TorchVisionViTModel().get_eager_model().eval() + + with PermuteInProjectionPacked(): + build_executorch_binary( + instance, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) if args.compile_only: return