diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl new file mode 100644 index 00000000000..30375728921 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec3", "xqout_limits")} +${layout_declare_ubo(B, "ivec3", "xkout_limits")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = 0; + +#include "indexing_utils.h" + +/* + * This shader computes rotary positional embeddings which are used in the Llama + * model architecture. There are 4 input tensors with the following shapes. + * Note that head_dim = embedding_dim / num_heads + * + * 1. xq (batch_size, sequence_len, num_heads, head_dim) + * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) + * 3. freqs_cos (sequence_len, head_dim / 2) + * 4. freqs_cos (sequence_len, head_dim / 2) + * + * Two output tensors are produced, with the same shapes as xq and xk + * respectively. + * + * The computation of rotary positional embeddings can be summarized with the + * following equations: + * + * xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i] + * xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i] + * + * Essentially, taking each row along head_dim of the xq and xk tensors, each + * row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively). + * The even components of the output multiply the even components of the inputs + * with the freqs_cos tensor, and the odd components of the inputs with the + * freqs_sin tensor. The odd components of the output swap this. Throughout the + * implementation the even components have the _r suffix and the odd components + * have the _i suffix; this is a reference to complex numbers which can be used + * to represent rotations. + * + * Note that this implementation assumes that all input tensors have the width + * dim as the packed dim. + */ +void main() { + // Each thread will write to two output locations to maximize data re-use. + // One texel loaded from the freqs_cos/freqs_sin tensors can be used to + // calculate two output texels. + const ivec3 x_pos_1 = ivec3( + gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); + const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); + + if (any(greaterThanEqual(x_pos_2, xqout_limits))) { + return; + } + + const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); + + VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); + VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); + + // Compute xqout + + VEC4_T x_tex_1 = load_texel(xq, x_pos_1); + VEC4_T x_tex_2 = load_texel(xq, x_pos_2); + + // Separate into even and odd elements + VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); + VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); + + VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex; + VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex; + + VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); + VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); + + write_texel(xqout, x_pos_1, xout_tex_1); + write_texel(xqout, x_pos_2, xout_tex_2); + + // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout + // may have a larger height dim than xk and xkout. Only compute xkout if this + // invocation is still within bounds. + if (any(greaterThanEqual(x_pos_2, xkout_limits))) { + return; + } + + // Compute xkout + + x_tex_1 = load_texel(xk, x_pos_1); + x_tex_2 = load_texel(xk, x_pos_2); + + x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); + x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); + + xout_r = x_r * cos_tex - x_i * sin_tex; + xout_i = x_r * sin_tex + x_i * cos_tex; + + xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); + xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); + + write_texel(xkout, x_pos_1, xout_tex_1); + write_texel(xkout, x_pos_2, xout_tex_2); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml new file mode 100644 index 00000000000..a81fd564d10 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml @@ -0,0 +1,10 @@ +rotary_embedding: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: rotary_embedding diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp new file mode 100644 index 00000000000..859a3d98aac --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +void resize_rotary_embedding_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + std::vector in_sizes = in->sizes(); + // UNCOMMENT BELOW IF NEEDED + // out->virtual_resize(in_sizes); +} + +void add_rotary_embedding_node( + ComputeGraph& graph, + const ValueRef xq, + const ValueRef xk, + const ValueRef freqs_cos, + const ValueRef freqs_sin, + const ValueRef xq_out, + const ValueRef xk_out) { + VK_CHECK_COND(graph.size_at(-1, xq) == graph.size_at(-1, xk)); + VK_CHECK_COND(graph.size_at(-3, xq) == graph.size_at(-3, xk)); + VK_CHECK_COND( + graph.size_at(-1, xq) == graph.size_at(-1, freqs_cos) * 2); + VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin)); + + VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim); + VK_CHECK_COND(graph.has_standard_axis_map(xq)); + VK_CHECK_COND(graph.has_standard_axis_map(xk)); + VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos)); + VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); + + std::string kernel_name = "rotary_embedding"; + add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + + utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out); + global_wg_size[0] /= 2; + const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + // Shader + VK_KERNEL_FROM_STR(kernel_name), + // Workgroup sizes + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{{xq_out, xk_out}, vkapi::kWrite}, + {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, + // Parameter buffers + {graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, + // Specialization Constants + {}, + // Resizing Logic + resize_rotary_embedding_node)); +} + +void apply_rotary_emb(ComputeGraph& graph, const std::vector& args) { + const ValueListPtr out_tuple = graph.get_value_list(args[4]); + const ValueRef xq_out = out_tuple->at(0); + const ValueRef xk_out = out_tuple->at(1); + + add_rotary_embedding_node( + graph, args[0], args[1], args[2], args[3], xq_out, xk_out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp new file mode 100644 index 00000000000..534bb577e7a --- /dev/null +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include + +// +// Reference Implementations +// + +std::pair rotary_embedding_impl( + const at::Tensor& xq, + const at::Tensor& xk, + const at::Tensor& freqs_cos, + const at::Tensor& freqs_sin) { + std::vector xq_even_odd = at::unbind( + xq.reshape({xq.size(0), xq.size(1), xq.size(2), xq.size(3) / 2, 2}), -1); + at::Tensor& xq_r = xq_even_odd[0]; + at::Tensor& xq_i = xq_even_odd[1]; + + std::vector xk_even_odd = at::unbind( + xk.reshape({xk.size(0), xk.size(1), xk.size(2), xk.size(3) / 2, 2}), -1); + at::Tensor& xk_r = xk_even_odd[0]; + at::Tensor& xk_i = xk_even_odd[1]; + + at::Tensor freqs_cos_reshape = + freqs_cos.reshape({1, freqs_cos.size(0), 1, freqs_cos.size(1)}); + at::Tensor freqs_sin_reshape = + freqs_sin.reshape({1, freqs_sin.size(0), 1, freqs_sin.size(1)}); + + at::Tensor xq_out_r = xq_r * freqs_cos_reshape - xq_i * freqs_sin_reshape; + at::Tensor xq_out_i = xq_r * freqs_sin_reshape + xq_i * freqs_cos_reshape; + at::Tensor xk_out_r = xk_r * freqs_cos_reshape - xk_i * freqs_sin_reshape; + at::Tensor xk_out_i = xk_r * freqs_sin_reshape + xk_i * freqs_cos_reshape; + + at::Tensor xq_out = at::flatten(at::stack({xq_out_r, xq_out_i}, -1), 3); + at::Tensor xk_out = at::flatten(at::stack({xk_out_r, xk_out_i}, -1), 3); + + return std::make_pair(xq_out, xk_out); +} + +// +// Test functions +// + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kFloat: + return vkapi::kFloat; + case c10::kHalf: + return vkapi::kHalf; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + default: + VK_THROW("Unsupported at::ScalarType!"); + } +} + +void test_reference( + const int n_heads = 4, + const int n_kv_heads = 2, + const int dim = 32, + const int seq_len = 1) { + const int head_dim = dim / n_heads; + + at::Tensor xq = at::rand( + {1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor xk = at::rand( + {1, seq_len, n_kv_heads, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_cos = + at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_sin = + at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat)); + + std::pair outs = + rotary_embedding_impl(xq, xk, freqs_cos, freqs_sin); + at::Tensor& xq_out = outs.first; + at::Tensor& xk_out = outs.second; + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_INPUT_FOR(x) \ + IOValueRef r_##x = graph.add_input_tensor( \ + x.sizes().vec(), from_at_scalartype(x.scalar_type())); + + MAKE_INPUT_FOR(xq); + MAKE_INPUT_FOR(xk); + MAKE_INPUT_FOR(freqs_cos); + MAKE_INPUT_FOR(freqs_sin); + + const ValueRef r_xq_out = graph.add_tensor( + xq_out.sizes().vec(), from_at_scalartype(xq_out.scalar_type())); + const ValueRef r_xk_out = graph.add_tensor( + xk_out.sizes().vec(), from_at_scalartype(xk_out.scalar_type())); + + VK_GET_OP_FN("et_vk.apply_rotary_emb.default") + (graph, + {r_xq.value, + r_xk.value, + r_freqs_cos.value, + r_freqs_sin.value, + graph.add_value_list({r_xq_out, r_xk_out})}); + + ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out); + ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_xq.staging, xq.const_data_ptr(), xq.numel()); + graph.copy_into_staging(r_xk.staging, xk.const_data_ptr(), xk.numel()); + graph.copy_into_staging( + r_freqs_cos.staging, freqs_cos.const_data_ptr(), freqs_cos.numel()); + graph.copy_into_staging( + r_freqs_sin.staging, freqs_sin.const_data_ptr(), freqs_sin.numel()); + + graph.execute(); + + at::Tensor vk_xq_out = at::empty_like(xq_out); + graph.copy_from_staging( + staging_xq_out, vk_xq_out.mutable_data_ptr(), vk_xq_out.numel()); + + at::Tensor vk_xk_out = at::empty_like(xk_out); + graph.copy_from_staging( + staging_xk_out, vk_xk_out.mutable_data_ptr(), vk_xk_out.numel()); + + EXPECT_TRUE(at::allclose(xq_out, vk_xq_out, 1e-4, 1e-4)); + EXPECT_TRUE(at::allclose(xk_out, vk_xk_out, 1e-4, 1e-4)); +} + +TEST(VulkanRotaryEmbeddingTest, rotary_embedding_test) { + test_reference(); +} + +TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test) { + test_reference( + /*n_heads=*/32, + /*n_kv_heads=*/8, + /*dim=*/2048); +} + +TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test_seq_len_3) { + test_reference( + /*n_heads=*/32, + /*n_kv_heads=*/8, + /*dim=*/2048, + /*seq_len=*/3); +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 270e1b768a8..4770dc5708a 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -224,3 +224,40 @@ def define_common_targets(is_fbcode = False): runtime.external_dep_location("libtorch"), ], ) + + runtime.cxx_binary( + name = "rotary_embedding_test_bin", + srcs = [ + "rotary_embedding_test.cpp", + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ], + ) + + runtime.cxx_test( + name = "rotary_embedding_test", + srcs = [ + "rotary_embedding_test.cpp", + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/extension/tensor:tensor", + runtime.external_dep_location("libtorch"), + ], + )