Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions backends/vulkan/_passes/custom_ops_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,53 @@ def linear_weight_int4_impl(
)
lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
######################


# Note that this implementation is copied from executorch.examples.models.llama.rope
# but it is copied here to avoid introducing a dependency on the llama code.
def apply_rotary_emb_impl(
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
):
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
freqs_cis_ndim = freqs_cis.ndim
if freqs_cis_ndim == 3:
# freqs_cis: (seq_len, n_heads, head_dim // 2)
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
shape = [
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
for i, d in enumerate(x.shape)
]
else:
# freqs_cis: (seq_len, head_dim // 2)
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)

xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)


name = "apply_rotary_emb"
lib.define(
f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin) -> (Tensor, Tensor)"
)
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)
3 changes: 2 additions & 1 deletion backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def __contains__(self, op):
# Convolution
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
# Custom ops
# Llama ops
"llama::sdpa_with_kv_cache",
exir_ops.edge.et_vk.apply_rotary_emb.default,
]

NO_DYNAMIC_SHAPE = [
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ runtime.python_library(
"source_transformation/rope.py",
"source_transformation/sdpa.py",
"source_transformation/spin_quant.py",
"source_transformation/vulkan_rope.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
Expand Down
4 changes: 4 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_simple_sdpa,
)
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb

IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -943,4 +944,7 @@ def _get_source_transforms( # noqa
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_kv_cache_with_coreml_kv_cache)

if args.vulkan:
transforms.append(replace_with_vulkan_rotary_emb)

return transforms
39 changes: 39 additions & 0 deletions examples/models/llama/source_transformation/vulkan_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
apply_rotary_emb_op,
)

from executorch.examples.models.llama.rope import RotaryEmbedding


class VkRotaryEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
return xq_out, xk_out


def replace_with_vulkan_rotary_emb(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, RotaryEmbedding):
new_module = VkRotaryEmbedding()
setattr(module, name, new_module)
else:
replace_with_vulkan_rotary_emb(child)

return module
Loading