From 55c0c8efe8208026561d1c867545e44693c334b4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 5 Nov 2025 13:16:48 -0800 Subject: [PATCH] [ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache SDPA used to be handled by a custom op `sdpa_with_kv_cache`, but it was eventually split (D62301837) into update_cache and custom_sdpa ops. However, having a single fused op is useful for Vulkan since it allows more control over how the cache tensors are stored and represented. Essentially, it makes it easier to manage the cache tensors and opens up opportunities for future optimizations. This diff introduces a fusion pass that does 2 things: 1. Combine update_cache and custom_sdpa back into sdpa_with_kv_cache 2. Ensure all references to the cache_pos symint use the same node - this prevents the select_at_dim_as_symint op from being called every time it is used. Differential Revision: [D86340339](https://our.internmc.facebook.com/intern/diff/D86340339/) [ghstack-poisoned] --- backends/vulkan/patterns/TARGETS | 1 + backends/vulkan/patterns/__init__.py | 2 + backends/vulkan/patterns/sdpa.py | 166 +++++++++++++++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 backends/vulkan/patterns/sdpa.py diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index ddc9cd77c04..3baf7c9e251 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "sdpa.py", "select_as_symint.py", ], visibility = [ diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 9239416dc2d..9b875def944 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -14,6 +14,8 @@ import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.sdpa # noqa + import executorch.backends.vulkan.patterns.select_as_symint # noqa import torch diff --git a/backends/vulkan/patterns/sdpa.py b/backends/vulkan/patterns/sdpa.py new file mode 100644 index 00000000000..56799e1f7cc --- /dev/null +++ b/backends/vulkan/patterns/sdpa.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +def is_update_cache_node(node: Any) -> bool: + if not hasattr(node, "target"): + return False + + if isinstance(node.target, str): + return node.target == "llama::update_cache" + elif hasattr(node.target, "name"): + return node.target.name() == "llama::update_cache" + else: + return False + + +def is_sdpa_with_kv_cache_node(node: Any) -> bool: + if not hasattr(node, "target"): + return False + + if isinstance(node.target, str): + return "sdpa_with_kv_cache" in node.target + elif hasattr(node.target, "name"): + return "sdpa_with_kv_cache" in node.target.name() + else: + return False + + +class CausalSDPAMatch(PatternMatch): + def __init__(self, custom_sdpa_node: torch.fx.Node) -> None: + self.anchor_node = custom_sdpa_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # llama.custom_sdpa has signature: + # custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output + if len(custom_sdpa_node.args) < 4: + return + + self.query_node = custom_sdpa_node.args[0] + self.key_cache_node = custom_sdpa_node.args[1] + self.value_cache_node = custom_sdpa_node.args[2] + self.start_pos_node = custom_sdpa_node.args[3] + self.attn_mask_node = custom_sdpa_node.args[4] + self.dropout_p_node = custom_sdpa_node.args[5] + self.is_causal_node = custom_sdpa_node.args[6] + if len(custom_sdpa_node.args) > 7: + self.scale_node = custom_sdpa_node.args[7] + else: + self.scale_node = None + + # try to find update key cache node + self.update_key_cache_node = None + for user in self.key_cache_node.users: + if is_update_cache_node(user): + self.update_key_cache_node = user + break + + self.key_projection_node = None + if self.update_key_cache_node is not None: + self.key_projection_node = self.update_key_cache_node.args[0] + + # find update value cache node + self.update_value_cache_node = None + for user in self.value_cache_node.users: + if is_update_cache_node(user): + self.update_value_cache_node = user + break + + self.value_projection_node = None + if self.update_value_cache_node is not None: + self.value_projection_node = self.update_value_cache_node.args[0] + + # We have additional optional arguments but we don't need to capture them + # since the new op doesn't use them + + self.match_found = True + + +@register_pattern_detector("causal_sdpa") +def find_causal_sdpa_patterns( + node: torch.fx.Node, +) -> Optional[CausalSDPAMatch]: + if node.target != exir_ops.edge.llama.custom_sdpa.default: + return None + + matched_pattern = CausalSDPAMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if is_update_cache_node(node): + return node.args[2] + + if is_sdpa_with_kv_cache_node(node): + return node.args[5] + + raise Exception( + "Could not find an instance of llama::update_cache or sdpa_with_kv_cache" + ) + + +@register_pattern_replacement("causal_sdpa") +def replace_custom_sdpa_with_causal_sdpa( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: CausalSDPAMatch, +): + assert match.update_key_cache_node is not None + assert match.key_projection_node is not None + assert match.update_value_cache_node is not None + assert match.value_projection_node is not None + + singleton_start_pos_node = find_singleton_start_pos_node(graph_module) + + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + torch.ops.llama.sdpa_with_kv_cache.default, + args=( + match.query_node, + match.key_projection_node, + match.value_projection_node, + match.key_cache_node, + match.value_cache_node, + singleton_start_pos_node, + 1, + match.attn_mask_node, + match.dropout_p_node, + match.is_causal_node, + match.scale_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # Manually erase update_cache nodes since DCE will not remove them since they + # modify inputs (specifically, the cache args are modified) + graph_module.graph.erase_node(match.update_key_cache_node) + graph_module.graph.erase_node(match.update_value_cache_node)