Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {

ivec4 unpack_int8x4(const int packed) {
return ivec4(
extract_8bit_from_packed_int_le(packed, 0),
extract_8bit_from_packed_int_le(packed, 1),
extract_8bit_from_packed_int_le(packed, 2),
extract_8bit_from_packed_int_le(packed, 3));
bitfieldExtract(packed, 0, 8),
bitfieldExtract(packed, 8, 8),
bitfieldExtract(packed, 16, 8),
bitfieldExtract(packed, 24, 8));
}

int pack_4xqint_into_int32(
Expand Down Expand Up @@ -89,6 +89,24 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
return pack_into_int32(quantized);
}

// Software fallback for dotPacked4x8AccSatEXT when GL_EXT_integer_dot_product
// is unavailable. Saturation is omitted: for typical neural network inputs,
// int32 overflow does not occur in practice.
int dotPacked4x8Acc_fallback(const int a, const int b, const int acc) {
const vec4 fa = vec4(unpack_int8x4(a));
const vec4 fb = vec4(unpack_int8x4(b));
return acc + int(dot(fa, fb));
}

// Dispatch macro resolved at GLSL compile time by USE_INT8_DOT_PRODUCT_EXT.
// When USE_INT8_DOT_PRODUCT_EXT == 0, uses the software fallback.
// All other cases (flag=1 or undefined) use the hardware intrinsic.
#if defined(USE_INT8_DOT_PRODUCT_EXT) && USE_INT8_DOT_PRODUCT_EXT == 0
#define dotPacked4x8AccSat(a, b, acc) dotPacked4x8Acc_fallback(a, b, acc)
#else
#define dotPacked4x8AccSat(a, b, acc) dotPacked4x8AccSatEXT(a, b, acc)
#endif

#ifdef DEBUG_MODE

#define printf debugPrintfEXT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#define LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH

#extension GL_EXT_control_flow_attributes : require
#if !defined(USE_INT8_DOT_PRODUCT_EXT) || USE_INT8_DOT_PRODUCT_EXT != 0
#extension GL_EXT_integer_dot_product : require
#endif

#include "linear_common.glslh"
#include "linear_fp_output_tile.glslh"
Expand Down Expand Up @@ -50,7 +52,7 @@ void int_accumulate_with_int8_weight(
const int n4 = div_4(n);
const int n4i = mod_4(n);
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT(
accum.data[m][n4][n4i] = dotPacked4x8AccSat(
in_tile.data[m4][k4][m4i],
w_tile.data[k4][n4][n4i],
accum.data[m][n4][n4i]);
Expand Down
7 changes: 5 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

${define_required_extensions("buffer", DTYPE)}

#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_integer_dot_product : require
$if USE_INT8_DOT_PRODUCT_EXT == 1:
#extension GL_EXT_integer_dot_product : require

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
Expand Down Expand Up @@ -177,7 +180,7 @@ void main() {
// Accumulate using packed int8 dot product for each output channel
// dotPacked4x8AccSatEXT computes: acc + dot(unpack(a), unpack(b))
[[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) {
acc[subtile_w][oc_offset] = dotPacked4x8AccSatEXT(
acc[subtile_w][oc_offset] = dotPacked4x8AccSat(
packed_input,
weight_block[oc_offset],
acc[subtile_w][oc_offset]);
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
q8ta_conv2d:
parameter_names_with_default_values:
DTYPE: float
USE_INT8_DOT_PRODUCT_EXT: 1
generate_variant_forall:
DTYPE:
- VALUE: float
shader_variants:
- NAME: q8ta_conv2d
- NAME: q8ta_conv2d_fallback
USE_INT8_DOT_PRODUCT_EXT: 0
7 changes: 5 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

${define_required_extensions("buffer", DTYPE)}

#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_integer_dot_product : require
$if USE_INT8_DOT_PRODUCT_EXT == 1:
#extension GL_EXT_integer_dot_product : require

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
Expand Down Expand Up @@ -146,7 +149,7 @@ void main() {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
[[unroll]] for (int n4i = 0; n4i < 4; ++n4i) {
out_accum[m][n4][n4i] = dotPacked4x8AccSatEXT(
out_accum[m][n4][n4i] = dotPacked4x8AccSat(
int8_input_tile[m],
int8_weight_tile[n4][n4i],
out_accum[m][n4][n4i]);
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
q8ta_conv2d_pw:
parameter_names_with_default_values:
DTYPE: float
USE_INT8_DOT_PRODUCT_EXT: 1
generate_variant_forall:
DTYPE:
- VALUE: float
shader_variants:
- NAME: q8ta_conv2d_pw
- NAME: q8ta_conv2d_pw_fallback
USE_INT8_DOT_PRODUCT_EXT: 0
5 changes: 4 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

${define_required_extensions("buffer", DTYPE)}

#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_integer_dot_product : require
$if USE_INT8_DOT_PRODUCT_EXT == 1:
#extension GL_EXT_integer_dot_product : require

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ q8ta_linear:
TILE_M4: 1
TILE_N4: 2
TILE_K4: 1
USE_INT8_DOT_PRODUCT_EXT: 1
generate_variant_forall:
DTYPE:
- VALUE: float
shader_variants:
- NAME: q8ta_linear
- NAME: q8ta_linear_fallback
USE_INT8_DOT_PRODUCT_EXT: 0
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

${define_required_extensions("buffer", DTYPE)}

#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}

#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_integer_dot_product : require
$if USE_INT8_DOT_PRODUCT_EXT == 1:
#extension GL_EXT_integer_dot_product : require

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
Expand Down Expand Up @@ -94,7 +97,7 @@ void main() {
[[unroll]] for (int n = 0; n < TILE_N; ++n) {
const int tile_n4 = div_4(n);
const int n4i = mod_4(n);
out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSatEXT(
out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSat(
packed_input,
int8_weight_tile.data[0][tile_n4][n4i],
out_accum.data[0][tile_n4][n4i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ q8ta_linear_gemv:
TILE_K4: 1
TILE_N4: 2
WGS: 64
USE_INT8_DOT_PRODUCT_EXT: 1
generate_variant_forall:
DTYPE:
- VALUE: float
shader_variants:
- NAME: q8ta_linear_gemv
- NAME: q8ta_linear_gemv_fallback
USE_INT8_DOT_PRODUCT_EXT: 0
5 changes: 3 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ void add_q8ta_conv2d_node(
PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)),
};

// Select shader based on layout
std::string kernel_name = "q8ta_conv2d";
const bool use_hw_dot =
graph.context()->adapter_ptr()->supports_int8_dot_product();
std::string kernel_name = use_hw_dot ? "q8ta_conv2d" : "q8ta_conv2d_fallback";
add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales));

// Pass metadata for both output and input tensors
Expand Down
5 changes: 4 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ void add_q8ta_conv2d_pw_node(
PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)),
};

std::string kernel_name = "q8ta_conv2d_pw";
const bool use_hw_dot =
graph.context()->adapter_ptr()->supports_int8_dot_product();
std::string kernel_name =
use_hw_dot ? "q8ta_conv2d_pw" : "q8ta_conv2d_pw_fallback";
add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales));

// Pass metadata for both output and input tensors
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ void add_q8ta_linear_node(
apply_bias = 0;
}

std::string kernel_name = "q8ta_linear";
const bool use_hw_dot =
graph.context()->adapter_ptr()->supports_int8_dot_product();
std::string kernel_name = use_hw_dot ? "q8ta_linear" : "q8ta_linear_fallback";
add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales));

vkapi::ParamsBindList param_buffers = {
Expand Down
5 changes: 4 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Q8taLinearGemv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ void add_q8ta_linear_gemv_node(
apply_bias = 0;
}

std::string kernel_name = "q8ta_linear_gemv";
const bool use_hw_dot =
graph.context()->adapter_ptr()->supports_int8_dot_product();
std::string kernel_name =
use_hw_dot ? "q8ta_linear_gemv" : "q8ta_linear_gemv_fallback";
add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales));

vkapi::ParamsBindList param_buffers = {
Expand Down
3 changes: 0 additions & 3 deletions backends/vulkan/test/custom_ops/test_q8ta_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,6 @@ static TestCase create_test_case_from_config(
// Generate test cases for q8ta_linear operation
static std::vector<TestCase> generate_q8ta_linear_test_cases() {
std::vector<TestCase> test_cases;
if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) {
return test_cases;
}

std::vector<LinearConfig> configs = {
// Batch size 1 cases (test both tiled and gemv)
Expand Down
Loading