diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index cfe20892994..3263d273b72 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -118,6 +118,22 @@ runtime.python_library( ], ) +runtime.python_library( + name = "fuse_patterns", + srcs = ["fuse_patterns.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan/patterns:vulkan_patterns", + "//executorch/exir:lib", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + typing = True, +) + runtime.python_library( name = "vulkan_passes", srcs = [ @@ -128,6 +144,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":fuse_patterns", ":fuse_quantized_ops", ":insert_prepack_nodes", ":int4_weight_only_quantizer", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 7ff93a6ee38..ccf15fd2c7f 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, ) @@ -29,6 +30,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FusePatternsPass", "FuseQuantizedOpsTransform", "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py new file mode 100644 index 00000000000..6ced1f32a7c --- /dev/null +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.vulkan.patterns as vk_patterns + +import torch + +from executorch.exir import ExportedProgram +from executorch.exir.pass_base import ExportPass, PassResult + + +class FusePatternsPass(ExportPass): + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.program = exported_program + + def call(self, graph_module: torch.fx.GraphModule): + total_replaced = vk_patterns.replace_all_fusable_subgraphs( + self.program, graph_module + ) + + if total_replaced > 0: + graph_module.recompile() + # Re-trace the graph + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, total_replaced > 0) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index c9b884e5b86..bc61b44ce78 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import executorch.backends.vulkan.patterns as vk_patterns import torch.library namespace = "et_vk" @@ -325,42 +326,11 @@ def linear_qta8a_qga4w( ###################### -# Note that this implementation is copied from executorch.examples.models.llama.rope -# but it is copied here to avoid introducing a dependency on the llama code. def apply_rotary_emb_impl( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ): - def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - freqs_cis_ndim = freqs_cis.ndim - if freqs_cis_ndim == 3: - # freqs_cis: (seq_len, n_heads, head_dim // 2) - assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) - shape = [ - d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 - for i, d in enumerate(x.shape) - ] - else: - # freqs_cis: (seq_len, head_dim // 2) - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(shape) - - xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) - xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) - - freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) - freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) - - xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin - xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos - xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin - xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos - - xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) - xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) - - return xq_out.type_as(xq), xk_out.type_as(xk) + pattern = vk_patterns.RotaryEmbeddingPattern() + return pattern.forward(xq, xk, freqs_cos, freqs_sin) name = "apply_rotary_emb" diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 675143cd7fd..b7f8f3de955 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -125,6 +125,7 @@ def update_features_impl(op: OpKey): operator.gt, operator.ge, operator.le, + operator.eq, # Guard and assert ops torch.ops.aten._assert_scalar.default, torch.ops.aten.sym_constrain_range_for_size.default, diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index 1d1d29f6fb0..40e1f36349a 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -15,6 +15,7 @@ runtime.python_library( "//executorch/backends/vulkan:op_registry", "//executorch/backends/vulkan:utils_lib", "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/backends/vulkan/patterns:vulkan_patterns", "//executorch/exir:delegate", "//executorch/exir:lib", "//executorch/exir/backend:partitioner", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 302b9af83e2..1b5ff0a44e4 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -9,6 +9,7 @@ import logging from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple +import executorch.backends.vulkan.patterns as vk_patterns import executorch.backends.vulkan.utils as utils import torch @@ -37,9 +38,10 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.matcher_utils import InternalMatch # pyre-ignore ops_not_to_decompose = [ @@ -58,6 +60,7 @@ def __init__( require_dynamic_shape: bool = False, operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, + fusable_subgraphs: Optional[List[InternalMatch]] = None, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits @@ -67,6 +70,13 @@ def __init__( operator_blocklist if operator_blocklist is not None else set() ) self.operator_allowlist = operator_allowlist + self.fusable_subgraphs: List[InternalMatch] = ( + fusable_subgraphs if fusable_subgraphs is not None else [] + ) + # Create a set of all nodes that are part of fusable subgraphs for quick lookup + self.fusable_nodes: Set[torch.fx.Node] = set() + for match in self.fusable_subgraphs: + self.fusable_nodes.update(match.nodes_map.values()) def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None @@ -204,6 +214,10 @@ def is_node_supported( return r def _is_node_supported(self, node: torch.fx.Node) -> bool: + # Check if this node is part of a fusable subgraph + if node.op == "call_function" and node in self.fusable_nodes: + return True + target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] @@ -330,6 +344,11 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags partition_tags = {} + # Get all fusable subgraphs from fuse_patterns + fusable_subgraphs = vk_patterns.get_all_fusable_subgraphs( + exported_program.graph_module + ) + texture_limits: utils.ImageExtents = self.options.get( "texture_limits", utils.DEFAULT_TEXTURE_LIMITS ) @@ -342,6 +361,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: require_dynamic_shape=self.options.get("require_dynamic_shapes", False), operator_blocklist=self.operator_blocklist, operator_allowlist=self.operator_allowlist, + fusable_subgraphs=fusable_subgraphs, ), allows_single_node_partition=True, ) diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS new file mode 100644 index 00000000000..b9fe79685dd --- /dev/null +++ b/backends/vulkan/patterns/TARGETS @@ -0,0 +1,24 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "vulkan_patterns", + srcs = [ + "__init__.py", + "pattern_registry.py", + "rope.py", + ], + visibility = [ + "//executorch/backends/...", + "//executorch/examples/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/backends/transforms:utils", + "//executorch/backends/vulkan:utils_lib", + ], + typing = True, +) diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py new file mode 100644 index 00000000000..bb6a4d07dc5 --- /dev/null +++ b/backends/vulkan/patterns/__init__.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import executorch.backends.vulkan.patterns.rope # noqa + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + CreateReplacementFn, + fusable_patterns, + GetGraphFn, + register_pattern_graph, + register_pattern_replacement, +) + +from executorch.backends.vulkan.patterns.rope import RotaryEmbeddingPattern + +from executorch.exir import ExportedProgram + +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher + + +__all__ = [ + "GetGraphFn", + "CreateReplacementFn", + "RotaryEmbeddingPattern", + "fusable_patterns", + "register_pattern_graph", + "register_pattern_replacement", +] + + +def all_fusable_graph_patterns() -> List[torch.fx.GraphModule]: + all_patterns = [] + for entry in fusable_patterns.values(): + if entry.get_graphs_fn is not None: + all_patterns.extend(entry.get_graphs_fn()) + + return all_patterns + + +def get_all_fusable_subgraphs( + graph_module: torch.fx.GraphModule, +) -> List[InternalMatch]: + fusable_subgraphs = [] + + fuse_patterns = all_fusable_graph_patterns() + for pattern in fuse_patterns: + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) + matches = list(sm.match(graph_module.graph)) + fusable_subgraphs.extend(matches) + + return fusable_subgraphs + + +def create_replacement_for_pattern( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + patterns: List[torch.fx.GraphModule], + create_replacement_func: CreateReplacementFn, +) -> int: + total_replaced = 0 + + for pattern in patterns: + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) + matches = list(sm.match(graph_module.graph)) + + for partition_to_replace in matches: + create_replacement_func(ep, graph_module, partition_to_replace) + total_replaced += 1 + # Remove dead code so they won't be matched again + graph_module.graph.eliminate_dead_code() + + return total_replaced + + +def replace_all_fusable_subgraphs( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, +) -> int: + total_replaced = 0 + + for entry in fusable_patterns.values(): + if entry.get_graphs_fn is not None and entry.create_replacement_fn is not None: + total_replaced += create_replacement_for_pattern( + ep, + graph_module, + entry.get_graphs_fn(), + # pyre-ignore[6] + entry.create_replacement_fn, + ) + + return total_replaced diff --git a/backends/vulkan/patterns/pattern_registry.py b/backends/vulkan/patterns/pattern_registry.py new file mode 100644 index 00000000000..37fa0bcca8c --- /dev/null +++ b/backends/vulkan/patterns/pattern_registry.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, List, Optional + +import torch + +from executorch.exir import ExportedProgram + +from torch.fx.passes.utils.matcher_utils import InternalMatch + +GetGraphFn = Callable[[], List[torch.fx.GraphModule]] +CreateReplacementFn = Callable[ + [ExportedProgram, torch.fx.GraphModule, InternalMatch], None +] + + +class PatternEntry: + def __init__( + self, + get_graphs_fn: Optional[GetGraphFn] = None, + create_replacement_fn: Optional[CreateReplacementFn] = None, + ): + self.get_graphs_fn = get_graphs_fn + self.create_replacement_fn = create_replacement_fn + + def is_valid(self): + return self.get_graphs_fn is not None and self.create_replacement_fn is not None + + +fusable_patterns: Dict[str, PatternEntry] = {} + + +def register_pattern_graph(pattern_name: str): + def decorator(fn: GetGraphFn): + if pattern_name not in fusable_patterns: + fusable_patterns[pattern_name] = PatternEntry() + + fusable_patterns[pattern_name].get_graphs_fn = fn + return fn + + return decorator + + +def register_pattern_replacement(pattern_name: str): + def decorator(fn: CreateReplacementFn): + if pattern_name not in fusable_patterns: + fusable_patterns[pattern_name] = PatternEntry() + + fusable_patterns[pattern_name].create_replacement_fn = fn + return fn + + return decorator diff --git a/backends/vulkan/patterns/rope.py b/backends/vulkan/patterns/rope.py new file mode 100644 index 00000000000..e0c2e4c5501 --- /dev/null +++ b/backends/vulkan/patterns/rope.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +from functools import lru_cache +from typing import List, Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + register_pattern_graph, + register_pattern_replacement, +) + +from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.export import export +from torch.fx.passes.utils.matcher_utils import InternalMatch + + +class RotaryEmbeddingPattern(torch.nn.Module): + """ + Implementation of rotary embedding pattern that matches the one + in examples/model/llama/rope.py + """ + + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + # This implementation matches the apply_rotary_emb function in rope.py + # Split into real and imaginary parts + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + # Reshape frequencies for broadcasting + freqs_cos = self._reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = self._reshape_for_broadcast(freqs_sin, xq_r) + + # Apply rotary embedding + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + # Recombine real and imaginary parts + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + freqs_cis_ndim = freqs_cis.ndim + if freqs_cis_ndim == 3: + # freqs_cis: (seq_len, n_heads, head_dim // 2) + assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) + shape = [ + d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 + for i, d in enumerate(x.shape) + ] + else: + # freqs_cis: (seq_len, head_dim // 2) + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) + + +@lru_cache(maxsize=2) +@register_pattern_graph("export_llama_rope") +def get_rope_graphs() -> List[torch.fx.GraphModule]: + batch_size = 1 + seq_len = 1 + n_heads = 4 + n_kv_heads = 2 + head_dim = 32 + + graphs = [] + dtype = torch.float32 + + xq = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype) + xk = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=dtype) + freqs_cos = torch.randn(seq_len, head_dim // 2, dtype=dtype) + freqs_sin = torch.randn(seq_len, head_dim // 2, dtype=dtype) + + edge = to_edge( + export( + RotaryEmbeddingPattern(), + (xq, xk, freqs_cos, freqs_sin), + strict=True, + ), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge.exported_program().graph_module + graphs.append(gm) + + return graphs + + +def identify_rotary_emb_io_nodes( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: InternalMatch, +) -> Optional[List[torch.fx.Node]]: + # Get the input placeholders (xq, xk, freqs_cos, freqs_sin) + placeholder_nodes = match.placeholder_nodes + if len(placeholder_nodes) != 4: + return None + + xq, xk, freqs_cos, freqs_sin = placeholder_nodes + + output_nodes = match.returning_nodes + if len(output_nodes) != 2: + return None + + xq_out, xk_out = output_nodes + + return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out] + + +@register_pattern_replacement("export_llama_rope") +def create_rotary_emb_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: InternalMatch, +): + io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match) + if io_nodes is None: + return + + assert len(io_nodes) == 6 + xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes + + # Create the custom op node + with graph_module.graph.inserting_before(xq_out): + rotary_emb_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.apply_rotary_emb.default, + args=(xq, xk, freqs_cos, freqs_sin), + ) + + # The custom op returns a tuple (xq_out, xk_out) + # We need to extract the individual outputs + with graph_module.graph.inserting_after(rotary_emb_node): + getitem_0 = graph_module.graph.create_node( + "call_function", + operator.getitem, + args=(rotary_emb_node, 0), + ) + getitem_1 = graph_module.graph.create_node( + "call_function", + operator.getitem, + args=(rotary_emb_node, 1), + ) + + if hasattr(xq_out, "meta") and "val" in xq_out.meta: + getitem_0.meta["val"] = xq_out.meta["val"] + if hasattr(xk_out, "meta") and "val" in xk_out.meta: + getitem_1.meta["val"] = xk_out.meta["val"] + + xq_out.replace_all_uses_with(getitem_0) + xk_out.replace_all_uses_with(getitem_1) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl index 81d2a5f0aed..150efeef1ad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl @@ -16,7 +16,6 @@ #define WGS ${WGS} ${define_required_extensions(DTYPE)} -${define_required_extensions("uint8")} layout(std430) buffer; diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 590e76e1486..ac26d202fe1 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -344,6 +344,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ "//caffe2:torch", + "//executorch/backends/vulkan/patterns:vulkan_patterns", ] ) diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 6b05890c3c7..b277dff2a76 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -5,6 +5,7 @@ from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform +from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( get_symmetric_quantization_config, @@ -210,3 +211,107 @@ def test_fuse_linear_qta8a_qga4w(self): self.assertEqual(op_node_count(gm, "quantize_per_token.default"), 0) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) self.assertEqual(op_node_count(gm, "linear.default"), 0) + + def test_fuse_rotary_emb(self): + """Test conversion of rotary embedding pattern to et_vk.apply_rotary_emb custom op.""" + + class RotaryEmbeddingModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + # This implementation matches the apply_rotary_emb function in rope.py + # Split into real and imaginary parts + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + # Reshape frequencies for broadcasting + freqs_cos = self._reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = self._reshape_for_broadcast(freqs_sin, xq_r) + + # Apply rotary embedding + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + # Recombine real and imaginary parts + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): + """Helper function to reshape frequencies for broadcasting""" + ndim = x.ndim + freqs_cis_ndim = freqs_cis.ndim + if freqs_cis_ndim == 3: + # freqs_cis: (seq_len, n_heads, head_dim // 2) + shape = [ + d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 + for i, d in enumerate(x.shape) + ] + else: + # freqs_cis: (seq_len, head_dim // 2) + shape = [ + d if i == 1 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + return freqs_cis.view(shape) + + # Create sample inputs based on the test file + batch_size = 1 + seq_len = 5 + n_heads = 32 + n_kv_heads = 8 + head_dim = 2048 + + xq = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=torch.float) + xk = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float) + freqs_cos = torch.randn(seq_len, head_dim // 2, dtype=torch.float) + freqs_sin = torch.randn(seq_len, head_dim // 2, dtype=torch.float) + + sample_inputs = (xq, xk, freqs_cos, freqs_sin) + + model = RotaryEmbeddingModel() + + # Export the model + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, + _check_ir_validity=False, + ) + + program = torch.export.export(model, sample_inputs, strict=True) + + edge_manager = to_edge( + program, + compile_config=edge_compile_config, + ) + + # Apply the rotary embedding pass + ep = edge_manager._edge_programs["forward"] + rotary_pass = FusePatternsPass(ep) + result = rotary_pass.call(ep.graph_module) + + # Verify that the pass was successful + self.assertTrue(result.modified) + + # Check that the custom op was created + gm = ep.graph_module + custom_op_count = 0 + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "__name__") + and "apply_rotary_emb" in str(node.target) + ): + custom_op_count += 1 + + # We expect at least one custom op to be created + self.assertGreater(custom_op_count, 0) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a6d5737dbb8..8c1165a89df 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -29,6 +29,7 @@ SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) +from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder @@ -154,6 +155,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + FusePatternsPass(program), RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseQuantizedOpsTransform(program), diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 62c33c6a245..b081fe68a2d 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -116,7 +116,6 @@ runtime.python_library( "source_transformation/rope.py", "source_transformation/sdpa.py", "source_transformation/spin_quant.py", - "source_transformation/vulkan_rope.py", "source_transformation/attention_sink.py", ], ) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a1801f063c..18700acade2 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -85,7 +85,6 @@ replace_sdpa_with_quantized_sdpa, replace_sdpa_with_simple_sdpa, ) -from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -1469,9 +1468,6 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) - if vulkan: - transforms.append(replace_with_vulkan_rotary_emb) - if local_global_attention: transforms.append( partial( diff --git a/examples/models/llama/source_transformation/vulkan_rope.py b/examples/models/llama/source_transformation/vulkan_rope.py deleted file mode 100644 index cdaf6f0baa7..00000000000 --- a/examples/models/llama/source_transformation/vulkan_rope.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import executorch.backends.vulkan.custom_ops_lib # noqa -import torch - -from executorch.examples.models.llama.rope import RotaryEmbedding - - -class VkRotaryEmbedding(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - ): - xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) - return xq_out, xk_out - - -def replace_with_vulkan_rotary_emb(module: torch.nn.Module): - for name, child in module.named_children(): - if isinstance(child, RotaryEmbedding): - new_module = VkRotaryEmbedding() - setattr(module, name, new_module) - else: - replace_with_vulkan_rotary_emb(child) - - return module