diff --git a/backends/vulkan/_passes/custom_ops_defs.py b/backends/vulkan/_passes/custom_ops_defs.py index 4da2a31fc44..0275239a86a 100644 --- a/backends/vulkan/_passes/custom_ops_defs.py +++ b/backends/vulkan/_passes/custom_ops_defs.py @@ -183,3 +183,53 @@ def linear_weight_int4_impl( ) lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) + +###################### +## apply_rotary_emb ## +###################### + + +# Note that this implementation is copied from executorch.examples.models.llama.rope +# but it is copied here to avoid introducing a dependency on the llama code. +def apply_rotary_emb_impl( + xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +): + def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + freqs_cis_ndim = freqs_cis.ndim + if freqs_cis_ndim == 3: + # freqs_cis: (seq_len, n_heads, head_dim // 2) + assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) + shape = [ + d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 + for i, d in enumerate(x.shape) + ] + else: + # freqs_cis: (seq_len, head_dim // 2) + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) + + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) + + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + +name = "apply_rotary_emb" +lib.define( + f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin) -> (Tensor, Tensor)" +) +lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") +apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 83dfb3b7686..92699be0f80 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -94,8 +94,9 @@ def __contains__(self, op): # Convolution exir_ops.edge.aten.convolution.default, exir_ops.edge.et_vk.conv_with_clamp.default, - # Custom ops + # Llama ops "llama::sdpa_with_kv_cache", + exir_ops.edge.et_vk.apply_rotary_emb.default, ] NO_DYNAMIC_SHAPE = [ diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 751c61da977..fafb69d878b 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -91,6 +91,7 @@ runtime.python_library( "source_transformation/rope.py", "source_transformation/sdpa.py", "source_transformation/spin_quant.py", + "source_transformation/vulkan_rope.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 940bcaecbc7..04bd5bddaaf 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -69,6 +69,7 @@ replace_sdpa_with_flex_sdpa, replace_sdpa_with_simple_sdpa, ) +from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -943,4 +944,7 @@ def _get_source_transforms( # noqa transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_kv_cache_with_coreml_kv_cache) + if args.vulkan: + transforms.append(replace_with_vulkan_rotary_emb) + return transforms diff --git a/examples/models/llama/source_transformation/vulkan_rope.py b/examples/models/llama/source_transformation/vulkan_rope.py new file mode 100644 index 00000000000..0dce6aeb448 --- /dev/null +++ b/examples/models/llama/source_transformation/vulkan_rope.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa + apply_rotary_emb_op, +) + +from executorch.examples.models.llama.rope import RotaryEmbedding + + +class VkRotaryEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) + return xq_out, xk_out + + +def replace_with_vulkan_rotary_emb(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, RotaryEmbedding): + new_module = VkRotaryEmbedding() + setattr(module, name, new_module) + else: + replace_with_vulkan_rotary_emb(child) + + return module