Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,12 +1005,19 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
};

switch (unary_op) {
case GGML_UNARY_OP_XIELU:
params.push_back(static_cast<uint32_t>(ggml_get_op_params_f32(dst, 1))); // alpha_n
params.push_back(static_cast<uint32_t>(ggml_get_op_params_f32(dst, 2))); // alpha_p
params.push_back(static_cast<uint32_t>(ggml_get_op_params_f32(dst, 3))); // beta
params.push_back(static_cast<uint32_t>(ggml_get_op_params_f32(dst, 4))); // eps
case GGML_UNARY_OP_XIELU: {
// Get float parameters and reinterpret their bit patterns as uint32_t
// for passing through the params buffer
float alpha_n = ggml_get_op_params_f32(dst, 1);
float alpha_p = ggml_get_op_params_f32(dst, 2);
float beta = ggml_get_op_params_f32(dst, 3);
float eps = ggml_get_op_params_f32(dst, 4);
params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
break;
}
default:
break;
}
Expand Down Expand Up @@ -2519,8 +2526,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_GELU_ERF:
// TODO: Investigate XIELU numerical issues
//case GGML_UNARY_OP_XIELU:
case GGML_UNARY_OP_XIELU:
supports_op = supports_op =
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,22 @@

{
"SHADER_NAME": "xielu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" },
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" },
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" },
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" },
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
Expand Down
Loading