From 958e7bbd953335558de3f62f59c68cca6b0dbb6c Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 14 Aug 2025 22:36:51 -0700 Subject: [PATCH] [ET-VK] Move rotary embedding custom op to be handled via graph pass instead of source transform ## Motivation Be able to test Vulkan lowering via optimum-executorch. ## Context Currently, ET-VK implements rotary embeddings via a custom op. This op is currently inserted into Transformer models by replacing Rotary Embedding modules with a custom module that executes the custom op via a source transform. The source transform approach makes it cumbersome to lower LLMs to Vulkan, since it requires the export logic to apply the source transform before calling `torch.export()`. This in turn makes it difficult to integrate Vulkan lowering into optimum-executorch, which tries to use a common export + lowering logic for all lowering paths. As an alternative, leverage `SubgraphMatcher` to detect fusable patterns and fuse the rotary embedding graph pattern into the custom op as part of the Vulkan delegate's graph passes. This removes the requirement to apply a custom source transform just for Vulkan. ## Changes * Introduce the `backends/vulkan/patterns` folder to store fusable graph patterns * Introduce a fusable graph pattern for rotary positional embeddings * Update partitioner logic to automatically include nodes that are part of a fusable graph pattern * Introduce a pass to fuse known patterns into custom ops / custom op sequence Differential Revision: [D80293301](https://our.internmc.facebook.com/intern/diff/D80293301/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 17 +++ backends/vulkan/_passes/__init__.py | 2 + backends/vulkan/_passes/fuse_patterns.py | 131 ++++++++++++++++++ backends/vulkan/custom_ops_lib.py | 36 +---- backends/vulkan/op_registry.py | 1 + backends/vulkan/partitioner/TARGETS | 1 + .../vulkan/partitioner/vulkan_partitioner.py | 32 +++++ backends/vulkan/patterns/TARGETS | 21 +++ backends/vulkan/patterns/__init__.py | 16 +++ backends/vulkan/patterns/rope.py | 96 +++++++++++++ .../graph/ops/glsl/linear_qga4w_coop.glsl | 1 - backends/vulkan/targets.bzl | 1 + backends/vulkan/test/test_vulkan_passes.py | 105 ++++++++++++++ backends/vulkan/vulkan_preprocess.py | 2 + examples/models/llama/TARGETS | 1 - examples/models/llama/export_llama_lib.py | 4 - .../source_transformation/vulkan_rope.py | 36 ----- 17 files changed, 428 insertions(+), 75 deletions(-) create mode 100644 backends/vulkan/_passes/fuse_patterns.py create mode 100644 backends/vulkan/patterns/TARGETS create mode 100644 backends/vulkan/patterns/__init__.py create mode 100644 backends/vulkan/patterns/rope.py delete mode 100644 examples/models/llama/source_transformation/vulkan_rope.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index cfe20892994..3263d273b72 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -118,6 +118,22 @@ runtime.python_library( ], ) +runtime.python_library( + name = "fuse_patterns", + srcs = ["fuse_patterns.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan/patterns:vulkan_patterns", + "//executorch/exir:lib", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + typing = True, +) + runtime.python_library( name = "vulkan_passes", srcs = [ @@ -128,6 +144,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":fuse_patterns", ":fuse_quantized_ops", ":insert_prepack_nodes", ":int4_weight_only_quantizer", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 7ff93a6ee38..ccf15fd2c7f 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, ) @@ -29,6 +30,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FusePatternsPass", "FuseQuantizedOpsTransform", "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py new file mode 100644 index 00000000000..b320dc973a0 --- /dev/null +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import Callable, List, Optional + +import executorch.backends.vulkan.patterns as vk_patterns + +import torch + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher + + +def fuse_pattern( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + patterns: List[torch.fx.GraphModule], + create_replacement_func: Callable, +) -> int: + total_replaced = 0 + + for pattern in patterns: + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) + matches = list(sm.match(graph_module.graph)) + + for partition_to_replace in matches: + create_replacement_func(ep, graph_module, partition_to_replace) + total_replaced += 1 + # Remove dead code so they won't be matched again + graph_module.graph.eliminate_dead_code() + + return total_replaced + + +## +## Rotary Embedding +## + + +def identify_rotary_emb_io_nodes( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: InternalMatch, +) -> Optional[List[torch.fx.Node]]: + # Get the input placeholders (xq, xk, freqs_cos, freqs_sin) + placeholder_nodes = match.placeholder_nodes + if len(placeholder_nodes) != 4: + return None + + xq, xk, freqs_cos, freqs_sin = placeholder_nodes + + output_nodes = match.returning_nodes + if len(output_nodes) != 2: + return None + + xq_out, xk_out = output_nodes + + return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out] + + +def create_rotary_emb_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: InternalMatch, +): + io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match) + if io_nodes is None: + return + + assert len(io_nodes) == 6 + xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes + + # Create the custom op node + with graph_module.graph.inserting_before(xq_out): + rotary_emb_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.apply_rotary_emb.default, + args=(xq, xk, freqs_cos, freqs_sin), + ) + + # The custom op returns a tuple (xq_out, xk_out) + # We need to extract the individual outputs + with graph_module.graph.inserting_after(rotary_emb_node): + getitem_0 = graph_module.graph.create_node( + "call_function", + operator.getitem, + args=(rotary_emb_node, 0), + ) + getitem_1 = graph_module.graph.create_node( + "call_function", + operator.getitem, + args=(rotary_emb_node, 1), + ) + + if hasattr(xq_out, "meta") and "val" in xq_out.meta: + getitem_0.meta["val"] = xq_out.meta["val"] + if hasattr(xk_out, "meta") and "val" in xk_out.meta: + getitem_1.meta["val"] = xk_out.meta["val"] + + xq_out.replace_all_uses_with(getitem_0) + xk_out.replace_all_uses_with(getitem_1) + + +class FusePatternsPass(ExportPass): + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.program = exported_program + + def call(self, graph_module: torch.fx.GraphModule): + total_replaced = 0 + + total_replaced += fuse_pattern( + self.program, + graph_module, + vk_patterns.get_rope_graphs(), + create_rotary_emb_custom_op, + ) + + if total_replaced > 0: + graph_module.recompile() + # Re-trace the graph + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, total_replaced > 0) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index c9b884e5b86..bc61b44ce78 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import executorch.backends.vulkan.patterns as vk_patterns import torch.library namespace = "et_vk" @@ -325,42 +326,11 @@ def linear_qta8a_qga4w( ###################### -# Note that this implementation is copied from executorch.examples.models.llama.rope -# but it is copied here to avoid introducing a dependency on the llama code. def apply_rotary_emb_impl( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ): - def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - freqs_cis_ndim = freqs_cis.ndim - if freqs_cis_ndim == 3: - # freqs_cis: (seq_len, n_heads, head_dim // 2) - assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) - shape = [ - d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 - for i, d in enumerate(x.shape) - ] - else: - # freqs_cis: (seq_len, head_dim // 2) - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(shape) - - xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) - xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) - - freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) - freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) - - xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin - xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos - xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin - xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos - - xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) - xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) - - return xq_out.type_as(xq), xk_out.type_as(xk) + pattern = vk_patterns.RotaryEmbeddingPattern() + return pattern.forward(xq, xk, freqs_cos, freqs_sin) name = "apply_rotary_emb" diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 675143cd7fd..b7f8f3de955 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -125,6 +125,7 @@ def update_features_impl(op: OpKey): operator.gt, operator.ge, operator.le, + operator.eq, # Guard and assert ops torch.ops.aten._assert_scalar.default, torch.ops.aten.sym_constrain_range_for_size.default, diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index 1d1d29f6fb0..40e1f36349a 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -15,6 +15,7 @@ runtime.python_library( "//executorch/backends/vulkan:op_registry", "//executorch/backends/vulkan:utils_lib", "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/backends/vulkan/patterns:vulkan_patterns", "//executorch/exir:delegate", "//executorch/exir:lib", "//executorch/exir/backend:partitioner", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 302b9af83e2..fa0cd107a3b 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -9,6 +9,7 @@ import logging from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple +import executorch.backends.vulkan.patterns as vk_patterns import executorch.backends.vulkan.utils as utils import torch @@ -40,6 +41,7 @@ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher # pyre-ignore ops_not_to_decompose = [ @@ -58,6 +60,7 @@ def __init__( require_dynamic_shape: bool = False, operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, + fusable_subgraphs: Optional[List[InternalMatch]] = None, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits @@ -67,6 +70,13 @@ def __init__( operator_blocklist if operator_blocklist is not None else set() ) self.operator_allowlist = operator_allowlist + self.fusable_subgraphs: List[InternalMatch] = ( + fusable_subgraphs if fusable_subgraphs is not None else [] + ) + # Create a set of all nodes that are part of fusable subgraphs for quick lookup + self.fusable_nodes: Set[torch.fx.Node] = set() + for match in self.fusable_subgraphs: + self.fusable_nodes.update(match.nodes_map.values()) def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None @@ -204,6 +214,10 @@ def is_node_supported( return r def _is_node_supported(self, node: torch.fx.Node) -> bool: + # Check if this node is part of a fusable subgraph + if node.op == "call_function" and node in self.fusable_nodes: + return True + target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] @@ -290,6 +304,20 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: return compile_specs +def get_fusable_subgraphs(graph_module: torch.fx.GraphModule) -> List[InternalMatch]: + fusable_subgraphs = [] + + fuse_patterns = [] + fuse_patterns.extend(vk_patterns.get_rope_graphs()) + + for pattern in fuse_patterns: + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) + matches = list(sm.match(graph_module.graph)) + fusable_subgraphs.extend(matches) + + return fusable_subgraphs + + @final class VulkanPartitioner(Partitioner): def __init__( @@ -330,6 +358,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags partition_tags = {} + # Get all fusable subgraphs from fuse_patterns + fusable_subgraphs = get_fusable_subgraphs(exported_program.graph_module) + texture_limits: utils.ImageExtents = self.options.get( "texture_limits", utils.DEFAULT_TEXTURE_LIMITS ) @@ -342,6 +373,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: require_dynamic_shape=self.options.get("require_dynamic_shapes", False), operator_blocklist=self.operator_blocklist, operator_allowlist=self.operator_allowlist, + fusable_subgraphs=fusable_subgraphs, ), allows_single_node_partition=True, ) diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS new file mode 100644 index 00000000000..7068799d02e --- /dev/null +++ b/backends/vulkan/patterns/TARGETS @@ -0,0 +1,21 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "vulkan_patterns", + srcs = [ + "__init__.py", + "rope.py", + ], + visibility = [ + "//executorch/backends/...", + "//executorch/examples/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + ], + typing = True, +) diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py new file mode 100644 index 00000000000..189f01d67a6 --- /dev/null +++ b/backends/vulkan/patterns/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.vulkan.patterns.rope import ( + get_rope_graphs, + RotaryEmbeddingPattern, +) + + +__all__ = [ + "get_rope_graphs", + "RotaryEmbeddingPattern", +] diff --git a/backends/vulkan/patterns/rope.py b/backends/vulkan/patterns/rope.py new file mode 100644 index 00000000000..56e3d9cc60b --- /dev/null +++ b/backends/vulkan/patterns/rope.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +from typing import List + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from torch.export import export + + +class RotaryEmbeddingPattern(torch.nn.Module): + """ + Implementation of rotary embedding pattern that matches the one + in examples/model/llama/rope.py + """ + + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + # This implementation matches the apply_rotary_emb function in rope.py + # Split into real and imaginary parts + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + # Reshape frequencies for broadcasting + freqs_cos = self._reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = self._reshape_for_broadcast(freqs_sin, xq_r) + + # Apply rotary embedding + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + # Recombine real and imaginary parts + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + freqs_cis_ndim = freqs_cis.ndim + if freqs_cis_ndim == 3: + # freqs_cis: (seq_len, n_heads, head_dim // 2) + assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) + shape = [ + d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 + for i, d in enumerate(x.shape) + ] + else: + # freqs_cis: (seq_len, head_dim // 2) + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) + + +@lru_cache(maxsize=None) +def get_rope_graphs() -> List[torch.fx.GraphModule]: + batch_size = 1 + seq_len = 1 + n_heads = 4 + n_kv_heads = 2 + head_dim = 32 + + graphs = [] + dtype = torch.float32 + + xq = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype) + xk = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=dtype) + freqs_cos = torch.randn(seq_len, head_dim // 2, dtype=dtype) + freqs_sin = torch.randn(seq_len, head_dim // 2, dtype=dtype) + + edge = to_edge( + export( + RotaryEmbeddingPattern(), + (xq, xk, freqs_cos, freqs_sin), + strict=True, + ), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge.exported_program().graph_module + graphs.append(gm) + + return graphs diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl index 81d2a5f0aed..150efeef1ad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl @@ -16,7 +16,6 @@ #define WGS ${WGS} ${define_required_extensions(DTYPE)} -${define_required_extensions("uint8")} layout(std430) buffer; diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 590e76e1486..ac26d202fe1 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -344,6 +344,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ "//caffe2:torch", + "//executorch/backends/vulkan/patterns:vulkan_patterns", ] ) diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 6b05890c3c7..b277dff2a76 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -5,6 +5,7 @@ from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform +from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( get_symmetric_quantization_config, @@ -210,3 +211,107 @@ def test_fuse_linear_qta8a_qga4w(self): self.assertEqual(op_node_count(gm, "quantize_per_token.default"), 0) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) self.assertEqual(op_node_count(gm, "linear.default"), 0) + + def test_fuse_rotary_emb(self): + """Test conversion of rotary embedding pattern to et_vk.apply_rotary_emb custom op.""" + + class RotaryEmbeddingModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + # This implementation matches the apply_rotary_emb function in rope.py + # Split into real and imaginary parts + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + # Reshape frequencies for broadcasting + freqs_cos = self._reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = self._reshape_for_broadcast(freqs_sin, xq_r) + + # Apply rotary embedding + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + # Recombine real and imaginary parts + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): + """Helper function to reshape frequencies for broadcasting""" + ndim = x.ndim + freqs_cis_ndim = freqs_cis.ndim + if freqs_cis_ndim == 3: + # freqs_cis: (seq_len, n_heads, head_dim // 2) + shape = [ + d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 + for i, d in enumerate(x.shape) + ] + else: + # freqs_cis: (seq_len, head_dim // 2) + shape = [ + d if i == 1 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + return freqs_cis.view(shape) + + # Create sample inputs based on the test file + batch_size = 1 + seq_len = 5 + n_heads = 32 + n_kv_heads = 8 + head_dim = 2048 + + xq = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=torch.float) + xk = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float) + freqs_cos = torch.randn(seq_len, head_dim // 2, dtype=torch.float) + freqs_sin = torch.randn(seq_len, head_dim // 2, dtype=torch.float) + + sample_inputs = (xq, xk, freqs_cos, freqs_sin) + + model = RotaryEmbeddingModel() + + # Export the model + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, + _check_ir_validity=False, + ) + + program = torch.export.export(model, sample_inputs, strict=True) + + edge_manager = to_edge( + program, + compile_config=edge_compile_config, + ) + + # Apply the rotary embedding pass + ep = edge_manager._edge_programs["forward"] + rotary_pass = FusePatternsPass(ep) + result = rotary_pass.call(ep.graph_module) + + # Verify that the pass was successful + self.assertTrue(result.modified) + + # Check that the custom op was created + gm = ep.graph_module + custom_op_count = 0 + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "__name__") + and "apply_rotary_emb" in str(node.target) + ): + custom_op_count += 1 + + # We expect at least one custom op to be created + self.assertGreater(custom_op_count, 0) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a6d5737dbb8..8c1165a89df 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -29,6 +29,7 @@ SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) +from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder @@ -154,6 +155,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + FusePatternsPass(program), RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseQuantizedOpsTransform(program), diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 62c33c6a245..b081fe68a2d 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -116,7 +116,6 @@ runtime.python_library( "source_transformation/rope.py", "source_transformation/sdpa.py", "source_transformation/spin_quant.py", - "source_transformation/vulkan_rope.py", "source_transformation/attention_sink.py", ], ) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a1801f063c..18700acade2 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -85,7 +85,6 @@ replace_sdpa_with_quantized_sdpa, replace_sdpa_with_simple_sdpa, ) -from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -1469,9 +1468,6 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) - if vulkan: - transforms.append(replace_with_vulkan_rotary_emb) - if local_global_attention: transforms.append( partial( diff --git a/examples/models/llama/source_transformation/vulkan_rope.py b/examples/models/llama/source_transformation/vulkan_rope.py deleted file mode 100644 index cdaf6f0baa7..00000000000 --- a/examples/models/llama/source_transformation/vulkan_rope.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import executorch.backends.vulkan.custom_ops_lib # noqa -import torch - -from executorch.examples.models.llama.rope import RotaryEmbedding - - -class VkRotaryEmbedding(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - ): - xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) - return xq_out, xk_out - - -def replace_with_vulkan_rotary_emb(module: torch.nn.Module): - for name, child in module.named_children(): - if isinstance(child, RotaryEmbedding): - new_module = VkRotaryEmbedding() - setattr(module, name, new_module) - else: - replace_with_vulkan_rotary_emb(child) - - return module