Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ $else:
// Scales are ALWAYS buffer, loaded as scalar
${layout_declare_tensor(B, "r", "t_scales", SCALES_DTYPE, "buffer")}

// Output sizes in WHCN order
${layout_declare_ubo(B, "ivec4", "out_sizes")}

layout(push_constant) uniform PushConstants {
int group_size;
int embed_dim;
int num_indices;
int out_height;
int is_linear_weight;
};

Expand All @@ -66,6 +66,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
VEC4_T load_embedding_weights(
const int embedding_idx,
const int dim,
const int embed_dim,
const float scale) {
const int n8 = embedding_idx >> 3;
const int n_local = embedding_idx & 7;
Expand Down Expand Up @@ -96,6 +97,7 @@ VEC4_T load_embedding_weights(
VEC4_T load_embedding_weights(
const int embedding_idx,
const int dim,
const int embed_dim,
const float scale) {
const int blocks_per_row = embed_dim >> 5;
const int block_in_row = dim >> 5;
Expand Down Expand Up @@ -124,7 +126,12 @@ void main() {
const int y_idx = int(gl_GlobalInvocationID.y);
const int z_idx = int(gl_GlobalInvocationID.z);

// out_sizes is in WHCN order: x=W(embed_dim), y=H, z=C, w=N
const int embed_dim = out_sizes.x;
const int blocks_per_row = embed_dim >> 5;
const int out_height = out_sizes.y;
const int num_indices = out_sizes.y * out_sizes.z * out_sizes.w;

const int indices_idx = z_idx * out_height + y_idx;
if (block_in_row >= blocks_per_row || indices_idx >= num_indices) {
return;
Expand All @@ -147,7 +154,7 @@ void main() {
float(t_scales[embedding_idx * groups_per_row + dim / group_size]);

const VEC4_T vals =
load_embedding_weights(embedding_idx, dim, scale);
load_embedding_weights(embedding_idx, dim, embed_dim, scale);

#ifdef OUTPUT_BUFFER
const int out_base = indices_idx * embed_dim + dim;
Expand Down
33 changes: 11 additions & 22 deletions backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ void add_embedding_q4gsw_node(
const ValueRef weight,
const ValueRef weight_scales,
const int32_t group_size,
const int32_t embed_dim,
const int32_t num_indices,
const int32_t out_height,
const int32_t is_linear_weight,
const ValueRef out) {
const ValueRef out,
const ValueRef embed_dim_ref) {
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(indices) == WHCN::kWidthDim);
VK_CHECK_COND(embed_dim % 32 == 0, "embed_dim must be a multiple of 32");
VK_CHECK_COND(
graph.get_int(embed_dim_ref) % 32 == 0,
"embed_dim must be a multiple of 32");

std::string kernel_name = "embedding_q4gsw";
kernel_name.reserve(kShaderNameReserve);
Expand All @@ -91,21 +91,18 @@ void add_embedding_q4gsw_node(

std::vector<PushConstantDataInfo> push_constants = {
PushConstantDataInfo(&group_size, sizeof(group_size)),
PushConstantDataInfo(&embed_dim, sizeof(embed_dim)),
PushConstantDataInfo(&num_indices, sizeof(num_indices)),
PushConstantDataInfo(&out_height, sizeof(out_height)),
PushConstantDataInfo(&is_linear_weight, sizeof(is_linear_weight)),
};

ValueRef embed_dim_ref = graph.add_scalar<int64_t>(embed_dim);
vkapi::ParamsBindList param_ubos = {graph.sizes_ubo(out)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
pick_embedding_q4gsw_global_wg_size,
default_pick_local_wg_size,
{{out, vkapi::kWrite}, {{indices, weight, weight_scales}, vkapi::kRead}},
{},
param_ubos,
push_constants,
{},
{embed_dim_ref},
Expand All @@ -125,14 +122,8 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
graph.extract_scalar<bool>(is_linear_weight_ref) ? 1 : 0;

const std::vector<int64_t> weight_sizes = graph.sizes_of(weight_data);
int32_t embed_dim = static_cast<int32_t>(weight_sizes.back() * 2);

const std::vector<int64_t> indices_sizes = graph.sizes_of(indices);
int32_t num_indices = 1;
for (auto s : indices_sizes) {
num_indices *= static_cast<int32_t>(s);
}
int32_t out_height = static_cast<int32_t>(indices_sizes.back());
int64_t embed_dim = weight_sizes.back() * 2;
ValueRef embed_dim_ref = graph.add_scalar<int64_t>(embed_dim);

ValueRef weight;
if (is_linear_weight) {
Expand All @@ -152,11 +143,9 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
weight,
weight_scales,
group_size,
embed_dim,
num_indices,
out_height,
is_linear_weight,
out);
out,
embed_dim_ref);
}

REGISTER_OPERATORS {
Expand Down
Loading