Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ layout(push_constant) uniform PushConstants {
int quant_max;
};

#extension GL_EXT_debug_printf : enable

// Shared memory for cooperative min/max finding
shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/im2col.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

#version 450 core

#extension GL_EXT_debug_printf : enable

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)}
#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ void main() {
IntPerInChannelParams int8_input_sums_tile;

const int num_groups = K4 / K4_per_group;
const int group_size = mul_4(K4_per_group);

for (int group_i = 0; group_i < num_groups; ++group_i) {
// Reset int accumulator
Expand All @@ -119,7 +120,6 @@ void main() {

load_int8_input_tile(int8_in_tile, k4, m4, K4);
load_int4_weight_tile(int4_weight_tile, k4, n8, K4);
// load_int4_weight_tile(int4_weight_tile, n8, k4, N8);

int_accumulate_with_int4_weight(
out_accum, int8_in_tile, int4_weight_tile);
Expand All @@ -129,13 +129,6 @@ void main() {
load_weight_sums_tile_for_group(weight_sums_tile, n4, group_i, N4);
load_int8_input_sums_tile_for_group(int8_input_sums_tile, m4, group_i, M4);

const int group_size = mul_4(K4_per_group);

// // Update output tile with accumulated values
// accumulate_out_tile_with_int_accum_from_int4_weights_test(
// out_tile,
// out_accum);

accumulate_out_tile_with_int_accum_from_int4_weights(
out_tile,
out_accum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ void accumulate_out_tile_with_int_accum_from_int4_weights(
const FPPerOutChannelParams weight_scales,
const int group_size) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
float input_scale_m = input_scales.data[0][m];
int input_zp_m = input_zps.data[0][m];
int input_sum_m = input_sums.data[0][m];
const int m4 = div_4(m);
const int m4i = mod_4(m);

float input_scale_m = input_scales.data[m4][m4i];
int input_zp_m = input_zps.data[m4][m4i];
int input_sum_m = input_sums.data[m4][m4i];

[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
ivec4 accum_adjusted = accum.data[m][n4] -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ ${layout_declare_spec_const(C, "int", "K4_per_group", "0")}
shared FPOutTile partial_sums[WGS];

void main() {
const int lid = int(gl_LocalInvocationID.x);
const int n8 = int(gl_GlobalInvocationID.y);
const int lid = int(gl_LocalInvocationID.z);
const int n8 = int(gl_GlobalInvocationID.x);

// The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes
// 8 output elements, so each thread will write to 8 elements starting at the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT];
* the entire work group co-operates to compute one reduction output.
*/

#extension GL_EXT_debug_printf : enable

void main() {
const int worker_id = int(gl_LocalInvocationID.y);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ ${layout_declare_spec_const(C, "float", "inv_scale", "1.0")}
*
*/

#extension GL_EXT_debug_printf : enable

void main() {
const int tile_idx_x = int(gl_GlobalInvocationID.x);
const int tile_idx_y = int(gl_GlobalInvocationID.y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT];
* the entire work group co-operates to compute one reduction output.
*/

#extension GL_EXT_debug_printf : enable

void main() {
const int worker_id = int(gl_LocalInvocationID.y);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* output has shape (batches, seq_len, num_q_heads, head_dim)
*/

#extension GL_EXT_debug_printf : enable

void main() {
const int tile_idx_x = int(gl_GlobalInvocationID.x);
const int tile_idx_y = int(gl_GlobalInvocationID.y);
Expand Down
30 changes: 14 additions & 16 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,27 @@ utils::uvec3 quantized_linear_global_wg_size(
const ValueRef out = args.at(0).refs.at(0);

std::vector<int64_t> out_sizes = graph->sizes_of(out);
// height
const uint32_t M = utils::val_at(-2, out_sizes);
// width
const uint32_t N = utils::val_at(-1, out_sizes);
// height
const uint32_t M = utils::val_at(-2, out_sizes);

const uint32_t M4 = utils::div_up(M, 4u);
const uint32_t N4 = utils::div_up(N, 4u);
uint32_t N_per_tile = 4;
uint32_t M_per_tile = 4;

// For 4-bit weights, each output tile contains 8 columns and 4 rows
// For 4-bit weights, each output tile contains 8 columns
if (shader.kernel_name.find("q4") != std::string::npos) {
const uint32_t N8 = utils::div_up(N, 8u);

const bool using_coop_algorithm =
shader.kernel_name.find("_coop") != std::string::npos;
// TODO: explain
if (using_coop_algorithm) {
return {64, N8, M};
}
return {N8, M4, 1};
N_per_tile = 8;
}
if (shader.kernel_name.find("coop") != std::string::npos) {
M_per_tile = 1;
}

const uint32_t num_N_tiles = utils::div_up(N, N_per_tile);
const uint32_t num_M_tiles = utils::div_up(M, M_per_tile);

// Otherwise, each output tile contains 4 columns and 4 rows
return {N4, M4, 1};
return {num_N_tiles, num_M_tiles, 1};
}

utils::uvec3 quantized_linear_local_wg_size(
Expand All @@ -96,7 +94,7 @@ utils::uvec3 quantized_linear_local_wg_size(
shader.kernel_name.find("_coop") != std::string::npos;

if (use_coop_algorithm) {
return {64, 1, 1};
return {1, 1, 64};
} else {
return pick_hw_square_wg_size(
graph, shader, global_workgroup_size, args, resize_args);
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/test/custom_ops/q4gsw_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ std::vector<TestCase> generate_quantized_linear_test_cases() {
{32, 256, 128, 64, false},
// Performance test cases
{1, 2048, 2048, 128},
{128, 2048, 2048, 256},
{256, 2048, 2048, 256},
{1024, 2048, 2048, 256},
{128, 2048, 2048, 128},
{256, 2048, 2048, 128},
{1024, 2048, 2048, 128},
};

// Test with different storage types and data types
Expand Down
Loading