Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 90 additions & 37 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

#define PRECISION ${PRECISION}

#define FOUR 4

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define FLOAT_T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type(STORAGE)}
Expand All @@ -26,12 +29,17 @@ ${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
${layout_declare_tensor(2, "r", "t_mat2", "int8", STORAGE)}
${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)}

${layout_declare_ubo(4, "ivec4", "out_sizes")}
${layout_declare_ubo(5, "ivec4", "out_strides")}
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
${layout_declare_ubo(9, "ivec4", "scales_strides")}
$if STORAGE == "texture3d":
${layout_declare_ubo(4, "ivec4", "out_sizes")}
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
${layout_declare_ubo(6, "ivec4", "scales_strides")}
$else:
${layout_declare_ubo(4, "ivec4", "out_sizes")}
${layout_declare_ubo(5, "ivec4", "out_strides")}
${layout_declare_ubo(6, "ivec4", "mat1_sizes")}
${layout_declare_ubo(7, "ivec4", "mat1_strides")}
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
${layout_declare_ubo(9, "ivec4", "scales_strides")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -49,45 +57,90 @@ void main() {
return;
}

const uint K = mat2_sizes.x * 2;
const uint N = mat2_sizes.y;
const uint K = mat1_sizes.x;
const uint n = out_pos.x;
const uint m = out_pos.y;
const uint k_block = (K + group_size - 1) / group_size;
const uint mask = uint(0x0f);
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);

float rc = 0.0;
int k = 0;

for (int kb = 0; kb < k_block; kb++) {
scale_pos.x = kb;
const int scale_id = to_buffer_id(scale_pos, scales_strides);
const float scale = float(t_scales_and_zeros[scale_id]);

zero_pos.x = kb;
const int zero_id = to_buffer_id(zero_pos, scales_strides);
const float zero = float(t_scales_and_zeros[zero_id]) - scale * 8.0;

for(uint idx = 0; idx < group_size && k < K; idx++, k++) {
mat1_pos.x = k;
const int mat1_id = to_buffer_id(mat1_pos, mat1_strides);
const float mat1_val = float(t_mat1[mat1_id]);

mat2_pos.x = k / 2;
const int mat2_id = to_buffer_id(mat2_pos, mat2_strides);
// Bitwise op treats sign bit from int8 as a value bit instead,
// since there is no uint8_t datatype
uint mat2_val = (t_mat2[mat2_id] & 0xFF);
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
#ifdef USING_BUFFER
const uint k_block = (K + group_size - 1) / group_size;
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);

for (int kb = 0; kb < k_block; kb++) {
scale_pos.x = kb;
const int scale_id = to_buffer_id(scale_pos, scales_strides);
const float scale = float(t_scales_and_zeros[scale_id]);

zero_pos.x = kb;
const int zero_id = to_buffer_id(zero_pos, scales_strides);
const float zero = float(t_scales_and_zeros[zero_id]) - scale * 8.0;

for(uint idx = 0; idx < group_size && k < K; idx++, k++) {
mat1_pos.x = k;
const int mat1_id = to_buffer_id(mat1_pos, mat1_strides);
const float mat1_val = float(t_mat1[mat1_id]);

mat2_pos.x = k / 2;
const int mat2_id = to_buffer_id(mat2_pos, mat2_strides);
// Bitwise op treats sign bit from int8 as a value bit instead,
// since there is no uint8_t datatype
uint mat2_val = (t_mat2[mat2_id] & 0xFF);
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);

rc += mat1_val * (scale * float(mat2_val) + zero);
}
}

rc += mat1_val * (scale * float(mat2_val) + zero);
const int out_id = to_buffer_id(out_pos, out_strides);
t_out[out_id] = FLOAT_T(rc);

#else // Using texture
const uint texel_group_size = group_size / FOUR;
const uint k_block = (K + texel_group_size - 1) / texel_group_size;
ivec3 mat1_pos = ivec3(0, m, out_pos.z);
ivec3 mat2_pos = ivec3(0, n, out_pos.z);
ivec3 scale_pos = ivec3(0, n, 0);
ivec3 zero_pos = ivec3(0, n, 1);

for (int kb = 0; kb < k_block; kb++) {
const int texel_kb = kb / FOUR;
const int kb_offset = kb % FOUR;

scale_pos.x = texel_kb;
const VEC4_T scale_texel = load_texel(t_scales_and_zeros, scale_pos);
const float scale = float(scale_texel[kb_offset]);

zero_pos.x = texel_kb;
const VEC4_T zero_texel = load_texel(t_scales_and_zeros, zero_pos);
const float zero = float(zero_texel[kb_offset]) - scale * 8.0;

for(uint idx = 0; idx < texel_group_size && k < K; idx++, k++) {
mat1_pos.x = k;
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);

mat2_pos.x = k / 2;
const i8vec4 mat2_tex = i8vec4(load_texel(t_mat2, mat2_pos));

// Every two texels of mat1 correspond to one texel of mat2
// Even mat1 indeces correspond to first half of mat2 texel and
// odd indeces correspond to second half
const int mat2_offset = (k & 1) == 0 ? 0 : 2;
for (int texel_idx = 0; texel_idx < FOUR; texel_idx++){
// Bitwise op treats sign bit from int8 as a value bit instead,
// since there is no uint8_t datatype
uint mat2_val = (mat2_tex[mat2_offset + texel_idx / 2] & 0xFF);
mat2_val = (texel_idx & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
rc += mat1_tex[texel_idx] * (scale * float(mat2_val) + zero);
}
}
}
}
write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0));

const int out_id = to_buffer_id(out_pos, out_strides);
t_out[out_id] = FLOAT_T(rc);
#endif
}
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@ q_4w_linear:
DTYPE:
- VALUE: float
- VALUE: half
STORAGE:
- VALUE: buffer
- VALUE: texture3d
shader_variants:
- NAME: q_4w_linear
42 changes: 33 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ void check_q_matmul_args(
VK_CHECK_COND(mat1_sizes.size() == 2);
VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());

VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out));
VK_CHECK_COND(graph.memory_layout_of(mat1) == utils::kWidthPacked);
VK_CHECK_COND(graph.memory_layout_of(mat2_data) == utils::kWidthPacked);
VK_CHECK_COND(
graph.memory_layout_of(scales_and_zeros) == utils::kWidthPacked);

if (graph.storage_type_of(out) == utils::kBuffer) {
VK_CHECK_COND(graph.memory_layout_of(out) == utils::kWidthPacked);
} else {
VK_CHECK_COND(graph.memory_layout_of(out) == utils::kChannelsPacked);
}

const int mat1_K = utils::val_at(-1, mat1_sizes);
const int mat2_K = utils::val_at(-1, mat2_sizes) * 2;
Expand Down Expand Up @@ -95,24 +104,39 @@ void add_q_matmul_node(
const ValueRef group_size,
const ValueRef scales_and_zeros_data,
const ValueRef out) {
ValueRef mat2 =
prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
auto storage_type = graph.storage_type_of(out);

ValueRef mat2;

if (storage_type == utils::kBuffer) {
mat2 = prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
} else {
mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
}

ValueRef scales_and_zeros =
prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked);

std::string kernel_name = "q_4w_linear";

add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, storage_type);

const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);

vkapi::ParamsBindList ubos({});
ubos.append(graph.sizes_ubo(out));
ubos.append(graph.strides_ubo(out));
ubos.append(graph.strides_ubo(mat1));
ubos.append(graph.sizes_ubo(mat2));
ubos.append(graph.strides_ubo(mat2));
ubos.append(graph.strides_ubo(scales_and_zeros));
if (storage_type == utils::kBuffer) {
ubos.append(graph.sizes_ubo(out));
ubos.append(graph.strides_ubo(out));
ubos.append(graph.sizes_ubo(mat1));
ubos.append(graph.strides_ubo(mat1));
ubos.append(graph.strides_ubo(mat2));
ubos.append(graph.strides_ubo(scales_and_zeros));
} else {
ubos.append(graph.sizes_ubo(out));
ubos.append(graph.sizes_ubo(mat1));
ubos.append(graph.strides_ubo(scales_and_zeros));
}

auto out_sizes = graph.sizes_of(out);
uint32_t N = utils::val_at(-1, out_sizes);
Expand Down
46 changes: 28 additions & 18 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2742,7 +2742,10 @@ TEST(VulkanComputeGraphOpsTest, grid_priors_test) {
/*data_out_expected = */ {4, 4, 12, 4, 20, 4, 4, 12, 12, 12, 20, 12});
}

void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
void test_int4pack_mm(
std::vector<uint32_t> MKN,
uint32_t group_size,
utils::StorageType storage_type) {
GraphConfig config;
ComputeGraph graph(config);

Expand All @@ -2756,8 +2759,7 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
const std::vector<int64_t> out_size = {M, N};

std::vector<float> A_data = create_random_float_buffer(M * K);
IOValueRef A =
graph.add_input_tensor(mat1_size, vkapi::kFloat, utils::kBuffer);
IOValueRef A = graph.add_input_tensor(mat1_size, vkapi::kFloat, storage_type);
graph.copy_into_staging(A.staging, A_data.data(), A_data.size());

// Quantized but un-packed weights
Expand All @@ -2768,7 +2770,7 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
int4mm_pack_weights(mat2_size, B_quant_data.data());

IOValueRef B_int4 =
graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, utils::kBuffer);
graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, storage_type);
graph.copy_into_staging(
B_int4.staging, B_int4_data.data(), B_int4_data.size());

Expand All @@ -2777,7 +2779,7 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
// Random scales and zeroes. Keep scales small to avoid overflow and zeroes in
// int4 range
IOValueRef scales_and_zeros =
graph.add_input_tensor({2, N, k_groups}, vkapi::kFloat, utils::kBuffer);
graph.add_input_tensor({2, N, k_groups}, vkapi::kFloat, storage_type);
std::vector<float> s_data(graph.numel_of(scales_and_zeros.value));
const int zeros_stride = s_data.size() / 2;
for (size_t i = 0; i < zeros_stride; i++) {
Expand All @@ -2789,7 +2791,13 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
scales_and_zeros.staging, s_data.data(), s_data.size());

IOValueRef out_int4;
out_int4.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer);

if (storage_type == utils::kBuffer) {
out_int4.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer);
} else {
out_int4.value =
graph.add_tensor(out_size, vkapi::kFloat, utils::kChannelsPacked);
}

VK_GET_OP_FN("aten._weight_int4pack_mm.default")
(graph,
Expand All @@ -2803,13 +2811,13 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {

// Dequantized matmul for comparison
IOValueRef B_deq =
graph.add_input_tensor(mat2_size, vkapi::kFloat, utils::kBuffer);
graph.add_input_tensor(mat2_size, vkapi::kFloat, storage_type);
std::vector<float> B_deq_data = int4mm_dequantize_weights(
mat2_size, B_quant_data.data(), group_size, s_data.data());
graph.copy_into_staging(B_deq.staging, B_deq_data.data(), B_deq_data.size());

IOValueRef out_deq;
out_deq.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer);
out_deq.value = graph.add_tensor(out_size, vkapi::kFloat, storage_type);

VK_GET_OP_FN("aten.mm.default")
(graph, {A.value, B_deq.value, out_deq.value});
Expand Down Expand Up @@ -2842,18 +2850,20 @@ TEST(VulkanComputeGraphOpsTest, int4pack_mm_test) {
GTEST_SKIP();
}

// Vector multiplication, single group per row
test_int4pack_mm({1, 32, 1}, 32);
for (auto storage_type : {utils::kBuffer, utils::kTexture3D}) {
// Vector multiplication, single group per row
test_int4pack_mm({1, 32, 1}, 32, storage_type);

// Vector multiplication, multiple groups per row
test_int4pack_mm({1, 256, 1}, 64);
// Vector multiplication, multiple groups per row
test_int4pack_mm({1, 256, 1}, 64, storage_type);

// Square matrices, single group per row
test_int4pack_mm({32, 32, 32}, 32);
// Square matrices, single group per row
test_int4pack_mm({32, 32, 32}, 32, storage_type);

// Irregular matrices, single group per row
test_int4pack_mm({37, 32, 19}, 32);
// Irregular matrices, single group per row
test_int4pack_mm({37, 32, 19}, 32, storage_type);

// Irregular matrices, multiple groups per row
test_int4pack_mm({37, 256, 19}, 64);
// Irregular matrices, multiple groups per row
test_int4pack_mm({37, 256, 19}, 64, storage_type);
}
}
Loading