Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {
return byte;
}

ivec4 unpack_int8x4(const int packed) {
return ivec4(
extract_8bit_from_packed_int_le(packed, 0),
extract_8bit_from_packed_int_le(packed, 1),
extract_8bit_from_packed_int_le(packed, 2),
extract_8bit_from_packed_int_le(packed, 3));
}

int pack_4xqint_into_int32(
const int val0,
const int val1,
Expand All @@ -57,6 +65,13 @@ int pack_4xqint_into_int32(
return packed;
}

int pack_into_int32(const ivec4 quant_vals) {
int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) |
((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24);

return packed;
}

#ifdef DEBUG_MODE

#extension GL_EXT_debug_printf : require
Expand Down
12 changes: 12 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) {
return block_sizes;
}

Conv2dBlockIndex linear_idx_to_block_idx(
const int idx, const Conv2dBlockExtents block_extents) {
Conv2dBlockIndex block_idx;
block_idx.data.z = idx % block_extents.data.z;

const int row = idx / block_extents.data.z;
block_idx.data.x = row % block_extents.data.x;
block_idx.data.y = row / block_extents.data.x;

return block_idx;
}

bool block_idx_out_of_bounds(
const Conv2dBlockIndex block_idx,
const Conv2dBlockExtents block_extents) {
Expand Down
214 changes: 214 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef CONV2D_DW_Q8_UTILS_GLSLH
#define CONV2D_DW_Q8_UTILS_GLSLH

#extension GL_EXT_control_flow_attributes : require

struct InputWindow1D {
vec4[MAX_WINDOW_WIDTH] data;
int len;
};

InputWindow1D initial_input_window() {
InputWindow1D input_window;
for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
input_window.data[i] = vec4(0);
}
input_window.len = 0;
return input_window;
}

vec4 dequantize(const int packed_texel, const float scale, const int zp) {
return vec4(unpack_int8x4(packed_texel) - zp) * scale;
}

vec4 dequantize(const int packed_texel, const vec4 scales) {
return vec4(unpack_int8x4(packed_texel)) * scales;
}

bool in_bounds(
const int block_w,
const int block_h,
const int block_c4,
const Conv2dBlockExtents block_extents) {
ivec3 idx = ivec3(block_w, block_h, block_c4);
if (any(lessThan(idx, ivec3(0)))) {
return false;
}
if (any(greaterThanEqual(idx, block_extents.data))) {
return false;
}

return true;
}

InputWindow1D load_input_window(
const int w_start,
const int w_end,
const int h,
const int c4,
const Conv2dBlockExtents block_extents,
const float input_scale,
const int input_zp,
const ivec4 input_zps) {
InputWindow1D input_window = initial_input_window();

const int block_w_start = div_4(w_start);
const int block_w_end = div_4(w_end);

int window_i = 0;
for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
ivec4 input_block = input_zps;

if (in_bounds(block_w, h, c4, block_extents)) {
#ifdef PACKED_INT8_INPUT_BUFFER
const int buffer_idx =
h * block_extents.data_xz + block_w * block_extents.data.z + c4;
input_block = t_packed_int8_input[buffer_idx];
#else
input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
#endif
}

const int loaded_w_start = mul_4(block_w);
for (int row = 0; row < 4; ++row) {
if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
input_window.data[window_i++] =
dequantize(input_block[row], input_scale, input_zp);
}
}
}
input_window.len = window_i;
return input_window;
}

struct WeightRow {
vec4[MAX_KERNEL_WIDTH] data;
int len;
};

WeightRow initial_weight_row() {
WeightRow weight_row;
for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
weight_row.data[i] = vec4(0);
}
weight_row.len = 0;
return weight_row;
}

WeightRow load_weight_row(
const int oc4,
const int ky,
const int OC4,
const int Kw,
const int Kw4,
const vec4 weight_scales) {
WeightRow weight_row = initial_weight_row();

int k4 = ky * Kw4;
int row_idx = 0;
for (int w = 0; w < Kw; w += 4) {
#ifdef WEIGHT_BUFFER
const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
#else
const ivec4 weight_block = texelFetch(
t_packed_int8_weight, ivec2(oc4, k4), 0);
#endif

for (int row = 0; row < 4; ++row) {
if (w + row < Kw) {
weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
}
}
k4++;
}
weight_row.len = row_idx;
return weight_row;
}

struct FPOutBlock {
vec4[4] data;
};

void perform_conv1d(
inout FPOutBlock out_block,
const InputWindow1D input_window,
const WeightRow weight_row) {
for (int out_w = 0; out_w < 4; ++out_w) {
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
const int in_w = out_w * conv2d_params.stride.x;
out_block.data[out_w] = fma(
input_window.data[in_w + kx],
weight_row.data[kx],
out_block.data[out_w]);
}
}
}

ivec4 quantize(
const vec4 texel, const float inv_scale, const int zp) {
vec4 quantized = round(texel * inv_scale) + zp;
return clamp(ivec4(quantized), -128, 127);
}

ivec4 quantize_and_pack(
FPOutBlock out_block, const float inv_scale, const int zp) {
ivec4 packed_block;
for (int row = 0; row < 4; ++row) {
ivec4 quantized_texel = quantize(out_block.data[row], inv_scale, zp);
packed_block[row] = pack_into_int32(quantized_texel);
}
return packed_block;
}

#ifdef DEBUG_MODE

void printInputWindow1D(const InputWindow1D input_window) {
debugPrintfEXT("InputWindow1D contents (len = %d): \\n", input_window.len);
for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) {
debugPrintfEXT(
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
i,
input_window.data[i].x,
input_window.data[i].y,
input_window.data[i].z,
input_window.data[i].w);
}
}

void printWeightRow(const WeightRow weight_row) {
debugPrintfEXT("WeightRow contents (len = %d): \\n", weight_row.len);
for (int i = 0; i < min(weight_row.len, MAX_KERNEL_WIDTH); ++i) {
debugPrintfEXT(
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
i,
weight_row.data[i].x,
weight_row.data[i].y,
weight_row.data[i].z,
weight_row.data[i].w);
}
}

void printFPOutBlock(const FPOutBlock out_block) {
debugPrintfEXT("FPOutBlock contents: \\n");
for (int i = 0; i < 4; ++i) {
debugPrintfEXT(
" [%d]: (%.3f, %.3f, %.3f, %.3f) \\n",
i,
out_block.data[i].x,
out_block.data[i].y,
out_block.data[i].z,
out_block.data[i].w);
}
}

#endif // DEBUG_MODE

#endif // CONV2D_DW_Q8_UTILS_GLSLH
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
#define T ${texel_load_component_type(DTYPE, "buffer")}

$if IO_STORAGE == "buffer":
#define PACKED_INT8_OUTPUT_BUFFER
#define PACKED_INT8_INPUT_BUFFER
$if WEIGHT_STORAGE == "buffer":
#define WEIGHT_BUFFER

#define MAX_WINDOW_WIDTH 12
#define MAX_KERNEL_WIDTH 5

${define_required_extensions(DTYPE)}

layout(std430) buffer;

#include "conv2d_common.glslh"

${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}

${layout_declare_ubo(B, "ivec4", "output_sizes")}
${layout_declare_ubo(B, "ivec4", "input_sizes")}
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}

layout(push_constant) uniform restrict Block {
float input_scale;
int input_zp;
float output_inv_scale;
int output_zp;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}

#include "conv2d_dw_q8_utils.glslh"

void main() {
const int tid = int(gl_GlobalInvocationID.x);
Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes);

Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx(
tid, out_block_extents);

if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) {
return;
}

const int out_w = mul_4(out_block_idx.data.x);
const int w_start =
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
conv2d_params.padding.x +
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;

Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);

const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);

const int Kw4 = div_up_4(conv2d_params.kernel_size.x);

FPOutBlock out_block;
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
const int out_h = out_block_idx.data.y;
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
ky * conv2d_params.dilation.y;

InputWindow1D input_window = load_input_window(
w_start,
w_end,
h,
out_block_idx.data.z,
in_block_extents,
input_scale,
input_zp,
input_zps);

WeightRow weight_row = load_weight_row(
out_block_idx.data.z,
ky,
out_block_extents.data.z,
conv2d_params.kernel_size.x,
Kw4,
weight_scales);

perform_conv1d(out_block, input_window, weight_row);
}

if (apply_bias > 0) {
const vec4 bias = vec4(t_bias[out_block_idx.data.z]);
for (int row = 0; row < 4; row++) {
out_block.data[row] += bias;
}
}

const ivec4 packed_out_block = quantize_and_pack(
out_block, output_inv_scale, output_zp);

#ifdef PACKED_INT8_OUTPUT_BUFFER
t_packed_int8_output[tid] = packed_out_block;
#else
imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
#endif
}
Loading
Loading