diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index a752f23d3ed..f556ba0a705 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -51,10 +51,10 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) { ivec4 unpack_int8x4(const int packed) { return ivec4( - extract_8bit_from_packed_int_le(packed, 0), - extract_8bit_from_packed_int_le(packed, 1), - extract_8bit_from_packed_int_le(packed, 2), - extract_8bit_from_packed_int_le(packed, 3)); + bitfieldExtract(packed, 0, 8), + bitfieldExtract(packed, 8, 8), + bitfieldExtract(packed, 16, 8), + bitfieldExtract(packed, 24, 8)); } int pack_4xqint_into_int32( @@ -89,6 +89,24 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) { return pack_into_int32(quantized); } +// Software fallback for dotPacked4x8AccSatEXT when GL_EXT_integer_dot_product +// is unavailable. Saturation is omitted: for typical neural network inputs, +// int32 overflow does not occur in practice. +int dotPacked4x8Acc_fallback(const int a, const int b, const int acc) { + const vec4 fa = vec4(unpack_int8x4(a)); + const vec4 fb = vec4(unpack_int8x4(b)); + return acc + int(dot(fa, fb)); +} + +// Dispatch macro resolved at GLSL compile time by USE_INT8_DOT_PRODUCT_EXT. +// When USE_INT8_DOT_PRODUCT_EXT == 0, uses the software fallback. +// All other cases (flag=1 or undefined) use the hardware intrinsic. +#if defined(USE_INT8_DOT_PRODUCT_EXT) && USE_INT8_DOT_PRODUCT_EXT == 0 + #define dotPacked4x8AccSat(a, b, acc) dotPacked4x8Acc_fallback(a, b, acc) +#else + #define dotPacked4x8AccSat(a, b, acc) dotPacked4x8AccSatEXT(a, b, acc) +#endif + #ifdef DEBUG_MODE #define printf debugPrintfEXT diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh index 850dc7943c0..838db2c09db 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -18,7 +18,9 @@ #define LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH #extension GL_EXT_control_flow_attributes : require +#if !defined(USE_INT8_DOT_PRODUCT_EXT) || USE_INT8_DOT_PRODUCT_EXT != 0 #extension GL_EXT_integer_dot_product : require +#endif #include "linear_common.glslh" #include "linear_fp_output_tile.glslh" @@ -50,7 +52,7 @@ void int_accumulate_with_int8_weight( const int n4 = div_4(n); const int n4i = mod_4(n); [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { - accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT( + accum.data[m][n4][n4i] = dotPacked4x8AccSat( in_tile.data[m4][k4][m4i], w_tile.data[k4][n4][n4i], accum.data[m][n4][n4i]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl index d693acbab3f..821f7f79b0e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl @@ -10,8 +10,11 @@ ${define_required_extensions("buffer", DTYPE)} +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + #extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_integer_dot_product : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require #define PRECISION ${PRECISION} #define VEC4_T ${texel_load_type(DTYPE, "buffer")} @@ -177,7 +180,7 @@ void main() { // Accumulate using packed int8 dot product for each output channel // dotPacked4x8AccSatEXT computes: acc + dot(unpack(a), unpack(b)) [[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) { - acc[subtile_w][oc_offset] = dotPacked4x8AccSatEXT( + acc[subtile_w][oc_offset] = dotPacked4x8AccSat( packed_input, weight_block[oc_offset], acc[subtile_w][oc_offset]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml index dc21e6da0c5..6ced1c16ebb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml @@ -7,8 +7,11 @@ q8ta_conv2d: parameter_names_with_default_values: DTYPE: float + USE_INT8_DOT_PRODUCT_EXT: 1 generate_variant_forall: DTYPE: - VALUE: float shader_variants: - NAME: q8ta_conv2d + - NAME: q8ta_conv2d_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index ec41d933114..d408b7ca9b8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -10,8 +10,11 @@ ${define_required_extensions("buffer", DTYPE)} +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + #extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_integer_dot_product : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require #define PRECISION ${PRECISION} #define VEC4_T ${texel_load_type(DTYPE, "buffer")} @@ -146,7 +149,7 @@ void main() { [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { [[unroll]] for (int n4i = 0; n4i < 4; ++n4i) { - out_accum[m][n4][n4i] = dotPacked4x8AccSatEXT( + out_accum[m][n4][n4i] = dotPacked4x8AccSat( int8_input_tile[m], int8_weight_tile[n4][n4i], out_accum[m][n4][n4i]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml index b7b8c42bf14..46670b8d2aa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml @@ -7,8 +7,11 @@ q8ta_conv2d_pw: parameter_names_with_default_values: DTYPE: float + USE_INT8_DOT_PRODUCT_EXT: 1 generate_variant_forall: DTYPE: - VALUE: float shader_variants: - NAME: q8ta_conv2d_pw + - NAME: q8ta_conv2d_pw_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl index 87a3d539297..98b0895d759 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl @@ -10,8 +10,11 @@ ${define_required_extensions("buffer", DTYPE)} +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + #extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_integer_dot_product : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require #define PRECISION ${PRECISION} #define VEC4_T ${texel_load_type(DTYPE, "buffer")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml index c7836c60477..7a324b8a338 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml @@ -11,8 +11,11 @@ q8ta_linear: TILE_M4: 1 TILE_N4: 2 TILE_K4: 1 + USE_INT8_DOT_PRODUCT_EXT: 1 generate_variant_forall: DTYPE: - VALUE: float shader_variants: - NAME: q8ta_linear + - NAME: q8ta_linear_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl index aa0837c4a6e..241fc1845bf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl @@ -10,8 +10,11 @@ ${define_required_extensions("buffer", DTYPE)} +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + #extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_integer_dot_product : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require #define PRECISION ${PRECISION} #define VEC4_T ${texel_load_type(DTYPE, "buffer")} @@ -94,7 +97,7 @@ void main() { [[unroll]] for (int n = 0; n < TILE_N; ++n) { const int tile_n4 = div_4(n); const int n4i = mod_4(n); - out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSatEXT( + out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSat( packed_input, int8_weight_tile.data[0][tile_n4][n4i], out_accum.data[0][tile_n4][n4i]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml index beae1eddf3e..b2ca7037162 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml @@ -11,8 +11,11 @@ q8ta_linear_gemv: TILE_K4: 1 TILE_N4: 2 WGS: 64 + USE_INT8_DOT_PRODUCT_EXT: 1 generate_variant_forall: DTYPE: - VALUE: float shader_variants: - NAME: q8ta_linear_gemv + - NAME: q8ta_linear_gemv_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 8273df6a07e..d1a4840fbba 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -288,8 +288,9 @@ void add_q8ta_conv2d_node( PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), }; - // Select shader based on layout - std::string kernel_name = "q8ta_conv2d"; + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = use_hw_dot ? "q8ta_conv2d" : "q8ta_conv2d_fallback"; add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); // Pass metadata for both output and input tensors diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index b72f5b78f53..1872e8796de 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -235,7 +235,10 @@ void add_q8ta_conv2d_pw_node( PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), }; - std::string kernel_name = "q8ta_conv2d_pw"; + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = + use_hw_dot ? "q8ta_conv2d_pw" : "q8ta_conv2d_pw_fallback"; add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); // Pass metadata for both output and input tensors diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp index 45366fbf044..210bd0cd78b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp @@ -98,7 +98,9 @@ void add_q8ta_linear_node( apply_bias = 0; } - std::string kernel_name = "q8ta_linear"; + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = use_hw_dot ? "q8ta_linear" : "q8ta_linear_fallback"; add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); vkapi::ParamsBindList param_buffers = { diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp index 120df6b0256..2885ad86f35 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp @@ -101,7 +101,10 @@ void add_q8ta_linear_gemv_node( apply_bias = 0; } - std::string kernel_name = "q8ta_linear_gemv"; + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = + use_hw_dot ? "q8ta_linear_gemv" : "q8ta_linear_gemv_fallback"; add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); vkapi::ParamsBindList param_buffers = { diff --git a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp index 707a8695171..1eb254d0dff 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp @@ -161,9 +161,6 @@ static TestCase create_test_case_from_config( // Generate test cases for q8ta_linear operation static std::vector generate_q8ta_linear_test_cases() { std::vector test_cases; - if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { - return test_cases; - } std::vector configs = { // Batch size 1 cases (test both tiled and gemv)