From d8bae5bfb17dbf14a7b9b6518f2c78b199eedd97 Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Sun, 30 Nov 2025 16:16:59 -0800 Subject: [PATCH] ggml webgpu: fix xielu parameter passing The XIELU operation was incorrectly using static_cast to convert float parameters to uint32_t, which converted numeric values instead of preserving IEEE 754 bit patterns. This caused incorrect values to be interpreted by the GPU shader. * Use reinterpret_cast to preserve float bit patterns when passing through uint32_t params buffer * Update WGSL shader parameter types from u32 to f32 * Re-enable XIELU support (was disabled due to numerical issues) Fixes NMSE test failures for XIELU operation on WebGPU backend. --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 ++++++++++++------- .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 8 ++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e684db9e210..56dc24681f8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1005,12 +1005,19 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s }; switch (unary_op) { - case GGML_UNARY_OP_XIELU: - params.push_back(static_cast(ggml_get_op_params_f32(dst, 1))); // alpha_n - params.push_back(static_cast(ggml_get_op_params_f32(dst, 2))); // alpha_p - params.push_back(static_cast(ggml_get_op_params_f32(dst, 3))); // beta - params.push_back(static_cast(ggml_get_op_params_f32(dst, 4))); // eps + case GGML_UNARY_OP_XIELU: { + // Get float parameters and reinterpret their bit patterns as uint32_t + // for passing through the params buffer + float alpha_n = ggml_get_op_params_f32(dst, 1); + float alpha_p = ggml_get_op_params_f32(dst, 2); + float beta = ggml_get_op_params_f32(dst, 3); + float eps = ggml_get_op_params_f32(dst, 4); + params.push_back(*reinterpret_cast(&alpha_n)); + params.push_back(*reinterpret_cast(&alpha_p)); + params.push_back(*reinterpret_cast(&beta)); + params.push_back(*reinterpret_cast(&eps)); break; + } default: break; } @@ -2519,8 +2526,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU_ERF: - // TODO: Investigate XIELU numerical issues - //case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_XIELU: supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index 54ddcbd37fd..d474ab107b4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -320,22 +320,22 @@ { "SHADER_NAME": "xielu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" }, + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "xielu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" }, + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "xielu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" }, + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, "DECLS": ["INPLACE"] }, { "SHADER_NAME": "xielu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" }, + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, "DECLS": ["INPLACE"] }, {