From 8b709dd6e81556f8b9d8467dc7aba0896b196f71 Mon Sep 17 00:00:00 2001 From: ssjia Date: Sat, 28 Mar 2026 18:59:32 -0700 Subject: [PATCH] [ET-VK] Fix embedding_q4gsw out-of-bounds access with dynamic shapes The embedding_q4gsw shader used push constants for num_indices, out_height, and embed_dim that were captured at graph build time and never updated when input tensors were dynamically resized. This caused out-of-bounds GPU memory reads when the actual input was smaller than the initial allocation, resulting in VK_ERROR_DEVICE_LOST on Mali GPUs. The fix derives all shape-dependent values (embed_dim, out_height, num_indices) from the output tensor's sizes UBO, which is automatically updated on resize. Only truly constant values (group_size, is_linear_weight) remain as push constants. Root cause: With a 7-token input on a graph built for 256 tokens, the local workgroup rounding created an extra thread (y=7) that passed the stale bounds check (7 >= 256 == false) and read past the 7-element indices buffer. Differential Revision: [D98642319](https://our.internmc.facebook.com/intern/diff/D98642319/) ghstack-source-id: 359350851 Pull Request resolved: https://github.com/pytorch/executorch/pull/18558 --- .../graph/ops/glsl/embedding_q4gsw.glsl | 15 ++++++--- .../runtime/graph/ops/impl/EmbeddingQ4gsw.cpp | 33 +++++++------------ 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl index ce6779a0c9b..f1b369c2cdf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl @@ -45,11 +45,11 @@ $else: // Scales are ALWAYS buffer, loaded as scalar ${layout_declare_tensor(B, "r", "t_scales", SCALES_DTYPE, "buffer")} +// Output sizes in WHCN order +${layout_declare_ubo(B, "ivec4", "out_sizes")} + layout(push_constant) uniform PushConstants { int group_size; - int embed_dim; - int num_indices; - int out_height; int is_linear_weight; }; @@ -66,6 +66,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; VEC4_T load_embedding_weights( const int embedding_idx, const int dim, + const int embed_dim, const float scale) { const int n8 = embedding_idx >> 3; const int n_local = embedding_idx & 7; @@ -96,6 +97,7 @@ VEC4_T load_embedding_weights( VEC4_T load_embedding_weights( const int embedding_idx, const int dim, + const int embed_dim, const float scale) { const int blocks_per_row = embed_dim >> 5; const int block_in_row = dim >> 5; @@ -124,7 +126,12 @@ void main() { const int y_idx = int(gl_GlobalInvocationID.y); const int z_idx = int(gl_GlobalInvocationID.z); + // out_sizes is in WHCN order: x=W(embed_dim), y=H, z=C, w=N + const int embed_dim = out_sizes.x; const int blocks_per_row = embed_dim >> 5; + const int out_height = out_sizes.y; + const int num_indices = out_sizes.y * out_sizes.z * out_sizes.w; + const int indices_idx = z_idx * out_height + y_idx; if (block_in_row >= blocks_per_row || indices_idx >= num_indices) { return; @@ -147,7 +154,7 @@ void main() { float(t_scales[embedding_idx * groups_per_row + dim / group_size]); const VEC4_T vals = - load_embedding_weights(embedding_idx, dim, scale); + load_embedding_weights(embedding_idx, dim, embed_dim, scale); #ifdef OUTPUT_BUFFER const int out_base = indices_idx * embed_dim + dim; diff --git a/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp b/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp index c5051392074..46a65c9284c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp @@ -64,14 +64,14 @@ void add_embedding_q4gsw_node( const ValueRef weight, const ValueRef weight_scales, const int32_t group_size, - const int32_t embed_dim, - const int32_t num_indices, - const int32_t out_height, const int32_t is_linear_weight, - const ValueRef out) { + const ValueRef out, + const ValueRef embed_dim_ref) { VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); VK_CHECK_COND(graph.packed_dim_of(indices) == WHCN::kWidthDim); - VK_CHECK_COND(embed_dim % 32 == 0, "embed_dim must be a multiple of 32"); + VK_CHECK_COND( + graph.get_int(embed_dim_ref) % 32 == 0, + "embed_dim must be a multiple of 32"); std::string kernel_name = "embedding_q4gsw"; kernel_name.reserve(kShaderNameReserve); @@ -91,13 +91,10 @@ void add_embedding_q4gsw_node( std::vector push_constants = { PushConstantDataInfo(&group_size, sizeof(group_size)), - PushConstantDataInfo(&embed_dim, sizeof(embed_dim)), - PushConstantDataInfo(&num_indices, sizeof(num_indices)), - PushConstantDataInfo(&out_height, sizeof(out_height)), PushConstantDataInfo(&is_linear_weight, sizeof(is_linear_weight)), }; - ValueRef embed_dim_ref = graph.add_scalar(embed_dim); + vkapi::ParamsBindList param_ubos = {graph.sizes_ubo(out)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -105,7 +102,7 @@ void add_embedding_q4gsw_node( pick_embedding_q4gsw_global_wg_size, default_pick_local_wg_size, {{out, vkapi::kWrite}, {{indices, weight, weight_scales}, vkapi::kRead}}, - {}, + param_ubos, push_constants, {}, {embed_dim_ref}, @@ -125,14 +122,8 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector& args) { graph.extract_scalar(is_linear_weight_ref) ? 1 : 0; const std::vector weight_sizes = graph.sizes_of(weight_data); - int32_t embed_dim = static_cast(weight_sizes.back() * 2); - - const std::vector indices_sizes = graph.sizes_of(indices); - int32_t num_indices = 1; - for (auto s : indices_sizes) { - num_indices *= static_cast(s); - } - int32_t out_height = static_cast(indices_sizes.back()); + int64_t embed_dim = weight_sizes.back() * 2; + ValueRef embed_dim_ref = graph.add_scalar(embed_dim); ValueRef weight; if (is_linear_weight) { @@ -152,11 +143,9 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector& args) { weight, weight_scales, group_size, - embed_dim, - num_indices, - out_height, is_linear_weight, - out); + out, + embed_dim_ref); } REGISTER_OPERATORS {