diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 00a053612f5..95cdf70679b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -46,6 +46,14 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) { return byte; } +ivec4 unpack_int8x4(const int packed) { + return ivec4( + extract_8bit_from_packed_int_le(packed, 0), + extract_8bit_from_packed_int_le(packed, 1), + extract_8bit_from_packed_int_le(packed, 2), + extract_8bit_from_packed_int_le(packed, 3)); +} + int pack_4xqint_into_int32( const int val0, const int val1, @@ -57,6 +65,13 @@ int pack_4xqint_into_int32( return packed; } +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | + ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); + + return packed; +} + #ifdef DEBUG_MODE #extension GL_EXT_debug_printf : require diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh index 929f3da299e..6f460d1398c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh @@ -61,6 +61,18 @@ Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) { return block_sizes; } +Conv2dBlockIndex linear_idx_to_block_idx( + const int idx, const Conv2dBlockExtents block_extents) { + Conv2dBlockIndex block_idx; + block_idx.data.z = idx % block_extents.data.z; + + const int row = idx / block_extents.data.z; + block_idx.data.x = row % block_extents.data.x; + block_idx.data.y = row / block_extents.data.x; + + return block_idx; +} + bool block_idx_out_of_bounds( const Conv2dBlockIndex block_idx, const Conv2dBlockExtents block_extents) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh new file mode 100644 index 00000000000..f1d90aa83cb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_DW_Q8_UTILS_GLSLH +#define CONV2D_DW_Q8_UTILS_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct InputWindow1D { + vec4[MAX_WINDOW_WIDTH] data; + int len; +}; + +InputWindow1D initial_input_window() { + InputWindow1D input_window; + for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) { + input_window.data[i] = vec4(0); + } + input_window.len = 0; + return input_window; +} + +vec4 dequantize(const int packed_texel, const float scale, const int zp) { + return vec4(unpack_int8x4(packed_texel) - zp) * scale; +} + +vec4 dequantize(const int packed_texel, const vec4 scales) { + return vec4(unpack_int8x4(packed_texel)) * scales; +} + +bool in_bounds( + const int block_w, + const int block_h, + const int block_c4, + const Conv2dBlockExtents block_extents) { + ivec3 idx = ivec3(block_w, block_h, block_c4); + if (any(lessThan(idx, ivec3(0)))) { + return false; + } + if (any(greaterThanEqual(idx, block_extents.data))) { + return false; + } + + return true; +} + +InputWindow1D load_input_window( + const int w_start, + const int w_end, + const int h, + const int c4, + const Conv2dBlockExtents block_extents, + const float input_scale, + const int input_zp, + const ivec4 input_zps) { + InputWindow1D input_window = initial_input_window(); + + const int block_w_start = div_4(w_start); + const int block_w_end = div_4(w_end); + + int window_i = 0; + for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) { + ivec4 input_block = input_zps; + + if (in_bounds(block_w, h, c4, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + h * block_extents.data_xz + block_w * block_extents.data.z + c4; + input_block = t_packed_int8_input[buffer_idx]; +#else + input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0); +#endif + } + + const int loaded_w_start = mul_4(block_w); + for (int row = 0; row < 4; ++row) { + if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) { + input_window.data[window_i++] = + dequantize(input_block[row], input_scale, input_zp); + } + } + } + input_window.len = window_i; + return input_window; +} + +struct WeightRow { + vec4[MAX_KERNEL_WIDTH] data; + int len; +}; + +WeightRow initial_weight_row() { + WeightRow weight_row; + for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) { + weight_row.data[i] = vec4(0); + } + weight_row.len = 0; + return weight_row; +} + +WeightRow load_weight_row( + const int oc4, + const int ky, + const int OC4, + const int Kw, + const int Kw4, + const vec4 weight_scales) { + WeightRow weight_row = initial_weight_row(); + + int k4 = ky * Kw4; + int row_idx = 0; + for (int w = 0; w < Kw; w += 4) { +#ifdef WEIGHT_BUFFER + const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4]; +#else + const ivec4 weight_block = texelFetch( + t_packed_int8_weight, ivec2(oc4, k4), 0); +#endif + + for (int row = 0; row < 4; ++row) { + if (w + row < Kw) { + weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales); + } + } + k4++; + } + weight_row.len = row_idx; + return weight_row; +} + +struct FPOutBlock { + vec4[4] data; +}; + +void perform_conv1d( + inout FPOutBlock out_block, + const InputWindow1D input_window, + const WeightRow weight_row) { + for (int out_w = 0; out_w < 4; ++out_w) { + [[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) { + const int in_w = out_w * conv2d_params.stride.x; + out_block.data[out_w] = fma( + input_window.data[in_w + kx], + weight_row.data[kx], + out_block.data[out_w]); + } + } +} + +ivec4 quantize( + const vec4 texel, const float inv_scale, const int zp) { + vec4 quantized = round(texel * inv_scale) + zp; + return clamp(ivec4(quantized), -128, 127); +} + +ivec4 quantize_and_pack( + FPOutBlock out_block, const float inv_scale, const int zp) { + ivec4 packed_block; + for (int row = 0; row < 4; ++row) { + ivec4 quantized_texel = quantize(out_block.data[row], inv_scale, zp); + packed_block[row] = pack_into_int32(quantized_texel); + } + return packed_block; +} + +#ifdef DEBUG_MODE + +void printInputWindow1D(const InputWindow1D input_window) { + debugPrintfEXT("InputWindow1D contents (len = %d): \\n", input_window.len); + for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + input_window.data[i].x, + input_window.data[i].y, + input_window.data[i].z, + input_window.data[i].w); + } +} + +void printWeightRow(const WeightRow weight_row) { + debugPrintfEXT("WeightRow contents (len = %d): \\n", weight_row.len); + for (int i = 0; i < min(weight_row.len, MAX_KERNEL_WIDTH); ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + weight_row.data[i].x, + weight_row.data[i].y, + weight_row.data[i].z, + weight_row.data[i].w); + } +} + +void printFPOutBlock(const FPOutBlock out_block) { + debugPrintfEXT("FPOutBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + out_block.data[i].x, + out_block.data[i].y, + out_block.data[i].z, + out_block.data[i].w); + } + } + +#endif // DEBUG_MODE + +#endif // CONV2D_DW_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl new file mode 100644 index 00000000000..8994ced3acb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl @@ -0,0 +1,121 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define MAX_WINDOW_WIDTH 12 +#define MAX_KERNEL_WIDTH 5 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_dw_q8_utils.glslh" + +void main() { + const int tid = int(gl_GlobalInvocationID.x); + Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes); + + Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx( + tid, out_block_extents); + + if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) { + return; + } + + const int out_w = mul_4(out_block_idx.data.x); + const int w_start = + (out_w * conv2d_params.stride.x) - conv2d_params.padding.x; + const int w_end = ((out_w + 3) * conv2d_params.stride.x) - + conv2d_params.padding.x + + (conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x; + + Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp))); + const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]); + + const int Kw4 = div_up_4(conv2d_params.kernel_size.x); + + FPOutBlock out_block; + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + const int out_h = out_block_idx.data.y; + const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y + + ky * conv2d_params.dilation.y; + + InputWindow1D input_window = load_input_window( + w_start, + w_end, + h, + out_block_idx.data.z, + in_block_extents, + input_scale, + input_zp, + input_zps); + + WeightRow weight_row = load_weight_row( + out_block_idx.data.z, + ky, + out_block_extents.data.z, + conv2d_params.kernel_size.x, + Kw4, + weight_scales); + + perform_conv1d(out_block, input_window, weight_row); + } + + if (apply_bias > 0) { + const vec4 bias = vec4(t_bias[out_block_idx.data.z]); + for (int row = 0; row < 4; row++) { + out_block.data[row] += bias; + } + } + + const ivec4 packed_out_block = quantize_and_pack( + out_block, output_inv_scale, output_zp); + +#ifdef PACKED_INT8_OUTPUT_BUFFER + t_packed_int8_output[tid] = packed_out_block; +#else + imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml new file mode 100644 index 00000000000..77f801668a4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_dw_q8ta_q8csw_q8to: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_dw_q8ta_q8csw_q8to diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh new file mode 100644 index 00000000000..44c226f6891 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_INPUT_BLOCK_LOAD +#define CONV2D_INT8_INPUT_BLOCK_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "conv2d_common.glslh" +#include "conv2d_int8_activation_block.glslh" + +void store_packed_int8_input_block( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const Int8ActivationBlock packed_int8_block) { +#ifdef OUTPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + t_packed_int8_input[buffer_idx] = packed_int8_block.data; +#else + imageStore(t_packed_int8_input, block_idx.data, packed_int8_block.data); +#endif +} + +#endif // CONV2D_INT8_INPUT_BLOCK_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh new file mode 100644 index 00000000000..44aa09912ec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_INPUT_TILE_LOAD +#define CONV2D_INT8_INPUT_TILE_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +struct Int8InputTileIndex { +#ifdef PACKED_INT8_INPUT_BUFFER + int data; +#else + ivec3 data; +#endif +}; + +Int8InputTileIndex make_initial_int8_input_tile_index( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { + Int8InputTileIndex idx; +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z; +#else + idx.data = ivec3(block_idx.data.x, block_idx.data.y, 0); +#endif + return idx; +} + +Int8InputTileIndex make_initial_int8_input_tile_index( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const int group_k4_offset) { + Int8InputTileIndex idx; +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + group_k4_offset; +#else + idx.data = ivec3(block_idx.data.x, block_idx.data.y, group_k4_offset); +#endif + return idx; +} + +void load_packed_int8_input_tile( + out Int8InputTile int8_tile, + const Int8InputTileIndex idx) { +#ifdef PACKED_INT8_INPUT_BUFFER + int8_tile.data[0][0] = t_packed_int8_input[idx.data]; +#else + int8_tile.data[0][0] = texelFetch(t_packed_int8_input, idx.data, 0); +#endif + + // Guard against unsupported tile sizes +#if TILE_M4 != 1 || TILE_K4 != 1 + not_implemented; +#endif +} + +void increment_k4(inout Int8InputTileIndex idx) { +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data += 1; +#else + idx.data.z += 1; +#endif +} + +#endif // CONV2D_INT8_INPUT_TILE_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh new file mode 100644 index 00000000000..0a490360f98 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_OUTPUT_TILE_STORE +#define CONV2D_INT8_OUTPUT_TILE_STORE + +#extension GL_EXT_control_flow_attributes : require + +#include "conv2d_common.glslh" +#include "linear_int8_output_tile.glslh" + +void store_packed_int8_output_tile( + const Int8OutTile int8_tile, + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { +#ifdef PACKED_INT8_OUTPUT_BUFFER + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + int buffer_idx = block_idx.data.y * block_extents.data_xz + + (block_idx.data.x + m4) * block_extents.data.z + block_idx.data.z; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (block_idx.data.x + m4 < block_extents.data.x && + block_idx.data.z + n4 < block_extents.data.z) { + t_packed_int8_output[buffer_idx++] = int8_tile.data[m4][n4]; + } + } + } +#else + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (block_idx.data.x + m4 < block_extents.data.x && + block_idx.data.z + n4 < block_extents.data.z) { + imageStore( + t_packed_int8_output, block_idx.data, int8_tile.data[m4][n4]); + } + } + } +#endif +} + +#endif // CONV2D_INT8_OUTPUT_TILE_STORE diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl new file mode 100644 index 00000000000..16c12b3ee5a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +void main() { + Conv2dBlockIndex output_block_idx; + output_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + output_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) { + return; + } + + Conv2dBlockExtents input_block_extents = make_block_extents(input_sizes); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_input_tile; + Int8WeightTile int8_weight_tile; + + Int8InputTileIndex input_idx = make_initial_int8_input_tile_index( + output_block_idx, input_block_extents); + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_packed_int8_input_tile(int8_input_tile, input_idx); + + load_int8_weight_tile( + int8_weight_tile, + output_block_idx.data.z, + k4, + output_block_extents.data.z); + + int_accumulate_with_int8_weight( + out_accum, int8_input_tile, int8_weight_tile); + + increment_k4(input_idx); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, output_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, output_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, output_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, output_block_idx, output_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml new file mode 100644 index 00000000000..23803dc6da1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_pw_q8ta_q8csw_q8to_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_pw_q8ta_q8csw_q8to_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh new file mode 100644 index 00000000000..279f4f17f13 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_Q8_UTILS_GLSLH +#define CONV2D_Q8_UTILS_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_int_accumulator.glslh" + +struct Int8InputWindow1D { + int[MAX_WINDOW_WIDTH] data; + int len; +}; + +Int8InputWindow1D initial_input_window() { + Int8InputWindow1D input_window; + for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) { + input_window.data[i] = 0; + } + input_window.len = 0; + return input_window; +} + +bool in_bounds( + const int block_w, + const int block_h, + const int block_c4, + const Conv2dBlockExtents block_extents) { + ivec3 idx = ivec3(block_w, block_h, block_c4); + if (any(lessThan(idx, ivec3(0)))) { + return false; + } + if (any(greaterThanEqual(idx, block_extents.data))) { + return false; + } + + return true; +} + +Int8InputWindow1D load_input_window( + const int w_start, + const int w_end, + const int h, + const int c4, + const Conv2dBlockExtents block_extents, + const ivec4 input_zps) { + Int8InputWindow1D input_window = initial_input_window(); + + const int block_w_start = div_4(w_start); + const int block_w_end = div_4(w_end); + + int window_i = 0; + for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) { + ivec4 input_block = input_zps; + + if (in_bounds(block_w, h, c4, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + h * block_extents.data_xz + block_w * block_extents.data.z + c4; + input_block = t_packed_int8_input[buffer_idx]; +#else + input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0); +#endif + } + + const int loaded_w_start = mul_4(block_w); + for (int row = 0; row < 4; ++row) { + if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) { + input_window.data[window_i++] = input_block[row]; + } + } + } + input_window.len = window_i; + return input_window; +} + +ivec4 load_weight_block( + const int ic4, + const int kx, + const int ky, + const int oc4, + const int IC4, + const int Kw, + const int Kh, + const int OC4) { +#ifdef PACKED_INT8_WEIGHTS_BUFFER + const int block_x = oc4 * Kw + kx; + const int block_y = ky * IC4 + ic4; + return t_packed_int8_weight[block_y * (Kw * OC4) + block_x]; +#else + return texelFetch( + t_packed_int8_weight, ivec2(oc4 * Kw + kx, ky * IC4 + ic4), 0); +#endif +} + +void perform_conv1d( + inout Int32Accum accum, + const Int8InputWindow1D input_window, + const ivec4 weight_block, + const int kx) { + [[unroll]] for (int out_w = 0; out_w < 4; ++out_w) { + const int window_i = out_w * conv2d_params.stride.x + kx; + [[unroll]] for (int out_c = 0; out_c < 4; ++out_c) { + accum.data[out_w][0][out_c] = dotPacked4x8AccSatEXT( + input_window.data[window_i], + weight_block[out_c], + accum.data[out_w][0][out_c]); + } + } +} + +#ifdef DEBUG_MODE + +void printInt8InputWindow1D(const Int8InputWindow1D input_window) { + debugPrintfEXT("Int8InputWindow1D contents (len = %d): \\n", input_window.len); + for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) { + ivec4 unpacked = unpack_int8x4(input_window.data[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + +void printWeightBlock(const ivec4 weight_block) { + debugPrintfEXT("WeightBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + ivec4 unpacked = unpack_int8x4(weight_block[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + +#endif // DEBUG_MODE + +#endif // CONV2D_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl new file mode 100644 index 00000000000..5839b13aeaa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define MAX_WINDOW_WIDTH 16 + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 1 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "im2col_packed_int8_utils.glslh" +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +#include "conv2d_q8_utils.glslh" + +void main() { + Conv2dBlockIndex out_block_idx; + out_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + out_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + out_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) { + return; + } + + const int out_w = mul_4(out_block_idx.data.x); + const int w_start = + (out_w * conv2d_params.stride.x) - conv2d_params.padding.x; + const int w_end = ((out_w + 3) * conv2d_params.stride.x) - + conv2d_params.padding.x + + (conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x; + + Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp))); + const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]); + + Int32Accum out_accum; + initialize(out_accum); + + const int IC4_per_group = div_up_4(conv2d_params.in_channels_per_group); + + const int n = mul_4(out_block_idx.data.z); + const int group_idx = n / conv2d_params.out_channels_per_group; + const int group_ic4_offset = group_idx * IC4_per_group; + + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + const int h = out_block_idx.data.y * conv2d_params.stride.y - + conv2d_params.padding.y + ky * conv2d_params.dilation.y; + + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { + Int8InputWindow1D int8_input_window = load_input_window( + w_start, + w_end, + h, + group_ic4_offset + ic4, + in_block_extents, + input_zps); + + for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + const ivec4 weight_block = load_weight_block( + ic4, + kx, + ky, + out_block_idx.data.z, + IC4_per_group, + conv2d_params.kernel_size.x, + conv2d_params.kernel_size.y, + out_block_extents.data.z); + + perform_conv1d(out_accum, int8_input_window, weight_block, kx); + } + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, out_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, out_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, out_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, out_block_idx, out_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml new file mode 100644 index 00000000000..5da9cc14584 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_q8to: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_q8to diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl new file mode 100644 index 00000000000..b44e37766fc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "im2col_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +void main() { + Conv2dBlockIndex output_block_idx; + output_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + output_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) { + return; + } + + const int n = mul_4(output_block_idx.data.z); + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int group_k4_offset = group_idx * conv2d_params.K4_per_group; + + Conv2dBlockExtents input_block_extents = make_block_extents(im2col_sizes); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_input_tile; + Int8WeightTile int8_weight_tile; + + Int8InputTileIndex input_idx = make_initial_int8_input_tile_index( + output_block_idx, input_block_extents, group_k4_offset); + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_packed_int8_input_tile(int8_input_tile, input_idx); + + load_int8_weight_tile( + int8_weight_tile, + output_block_idx.data.z, + k4, + output_block_extents.data.z); + + int_accumulate_with_int8_weight( + out_accum, int8_input_tile, int8_weight_tile); + + increment_k4(input_idx); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, output_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, output_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, output_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, output_block_idx, output_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml new file mode 100644 index 00000000000..fa92481f5ef --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_q8to_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_q8to_linear_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl new file mode 100644 index 00000000000..3ecaa597ecc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +$if STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "im2col_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_int8_output_tile_store.glslh" +#include "im2col_packed_int8_utils.glslh" + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes); + + Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx( + out_buf_idx, im2col_block_extents); + + if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) { + return; + } + + Im2ColBlockLoadIndices load_ixs = im2col_block_idx_to_load_ixs( + im2col_block_idx); + + Conv2dBlockExtents input_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(zp))); + Int8OutTile int8_im2col_tile; + int8_im2col_tile.data[0][0] = load_im2col_block( + load_ixs, input_block_extents, zp, input_zps); + + store_packed_int8_output_tile( + int8_im2col_tile, im2col_block_idx, im2col_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml new file mode 100644 index 00000000000..1c14f1fdc5a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +im2col_packed_int8: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + shader_variants: + - NAME: im2col_packed_int8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh new file mode 100644 index 00000000000..2b1870c493d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh @@ -0,0 +1,287 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef IM2COL_PACKED_INT8_GLSLH +#define IM2COL_PACKED_INT8_GLSLH + +#include "common.glslh" + +struct Conv2dBlockElementIndex { + int x4; + int y; + int z4; + + int row; + int col; +}; + +struct Im2ColBlockLoadIndices { + bool block_aligned; + bool cols_aligned; + bool rows_contiguous; + + int im2col_w_start; + int im2col_h; + int k_in_group_start; + int group_idx; + + Conv2dBlockElementIndex block_idx_start; +}; + +Conv2dBlockElementIndex tidx_to_block_elem_idx(const TensorIndex4D tidx) { + Conv2dBlockElementIndex block_idx; + block_idx.x4 = div_4(tidx.data.x); + block_idx.row = mod_4(tidx.data.x); + + block_idx.y = tidx.data.y; + + block_idx.z4 = div_4(tidx.data.z); + block_idx.col = mod_4(tidx.data.z); + + return block_idx; +} + +TensorIndex4D get_input_tensor_tidx( + const int w, + const int h, + const int k_in_group, + const int group_idx) { + TensorIndex4D tidx; + tidx.data.w = 0; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int row = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = row % conv2d_params.kernel_size.x; + const int kernel_y = row / conv2d_params.kernel_size.x; + + tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group; + + tidx.data.x = (w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + tidx.data.y = (h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + + return tidx; +} + +Im2ColBlockLoadIndices im2col_block_idx_to_load_ixs( + Conv2dBlockIndex im2col_block_idx) { + const int im2col_w = mul_4(im2col_block_idx.data.x); + const int im2col_h = im2col_block_idx.data.y; + const int im2col_k = mul_4(im2col_block_idx.data.z); + + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + TensorIndex4D input_tidx = + get_input_tensor_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + bool cols_aligned = (mod_4(input_tidx.data.z) == 0) && + (input_tidx.data.z + 3 < conv2d_params.in_channels_per_group); + + bool rows_aligned = mod_4(input_tidx.data.x) == 0; + bool rows_contiguous = conv2d_params.stride.x == 1; + + Im2ColBlockLoadIndices load_ixs; + load_ixs.block_aligned = cols_aligned && rows_aligned && rows_contiguous; + load_ixs.cols_aligned = cols_aligned; + load_ixs.rows_contiguous = rows_contiguous; + + load_ixs.im2col_w_start = im2col_w; + load_ixs.im2col_h = im2col_h; + load_ixs.k_in_group_start = k_in_group; + load_ixs.group_idx = group_idx; + + load_ixs.block_idx_start = tidx_to_block_elem_idx(input_tidx); + + return load_ixs; +} + +bool is_block_elem_idx_in_bounds( + const Conv2dBlockElementIndex idx, + const Conv2dBlockExtents block_extents) { + const ivec3 block_idx = ivec3(idx.x4, idx.y, idx.z4); + if (any(lessThan(block_idx, ivec3(0))) || + any(greaterThanEqual(block_idx, block_extents.data))) { + return false; + } + return true; +} + +int load_packed_int8_input_element( + const Conv2dBlockElementIndex idx, + const Conv2dBlockExtents block_extents, + const int input_zp) { + // bounds checking + if (!is_block_elem_idx_in_bounds(idx, block_extents)) { + return input_zp; + } +#ifdef PACKED_INT8_INPUT_BUFFER + const int buf_idx = + idx.y * block_extents.data_xz + idx.x4 * block_extents.data.z + idx.z4; + const ivec4 tile = t_packed_int8_input[buf_idx]; +#else + const ivec4 tile = + texelFetch(t_packed_int8_input, ivec3(idx.x4, idx.y, idx.z4), 0); +#endif + return extract_8bit_from_packed_int_le(tile[idx.row], idx.col); +} + +Conv2dBlockElementIndex get_packed_int8_input_element_idx( + const int im2col_w, + const int im2col_h, + const int k_in_group, + const int group_idx) { + TensorIndex4D input_tidx = + get_input_tensor_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + return tidx_to_block_elem_idx(input_tidx); +} + +ivec4 load_im2col_block_aligned( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buf_idx = load_ixs.block_idx_start.y * block_extents.data_xz + + load_ixs.block_idx_start.x4 * block_extents.data.z + + load_ixs.block_idx_start.z4; + return t_packed_int8_input[buf_idx]; +#else + return texelFetch( + t_packed_int8_input, + ivec3( + load_ixs.block_idx_start.x4, + load_ixs.block_idx_start.y, + load_ixs.block_idx_start.z4), + 0); +#endif +} + +ivec4 load_im2col_block_c_aligned_w_contiguous( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const ivec4 input_zps) { + ivec4 im2col_block; + Conv2dBlockElementIndex block_elem_idx = load_ixs.block_idx_start; + +#ifdef PACKED_INT8_INPUT_BUFFER + int buf_idx = load_ixs.block_idx_start.y * block_extents.data_xz + + load_ixs.block_idx_start.x4 * block_extents.data.z + + load_ixs.block_idx_start.z4; +#endif + + ivec4 in_block = input_zps; + if (is_block_elem_idx_in_bounds(block_elem_idx, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + in_block = t_packed_int8_input[buf_idx]; +#else + in_block = texelFetch( + t_packed_int8_input, + ivec3(block_elem_idx.x4, block_elem_idx.y, block_elem_idx.z4), + 0); +#endif + } + + int current_row = 0; + int r_limit = min(4 - block_elem_idx.row, 4); + for (int r = 0; r < r_limit; r++) { + im2col_block[current_row++] = in_block[r + block_elem_idx.row]; + } + + in_block = input_zps; + block_elem_idx.x4++; +#ifdef PACKED_INT8_INPUT_BUFFER + buf_idx += block_extents.data.z; +#endif + + if (is_block_elem_idx_in_bounds(block_elem_idx, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + in_block = t_packed_int8_input[buf_idx]; +#else + in_block = texelFetch( + t_packed_int8_input, + ivec3(block_elem_idx.x4, block_elem_idx.y, block_elem_idx.z4), + 0); +#endif + } + + for (int r = 0; current_row < 4; ++r) { + im2col_block[current_row++] = in_block[r]; + } + + return im2col_block; +} + +ivec4 load_im2col_block_no_alignment( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const int input_zp) { + ivec4 im2col_block; + + for (int r = 0; r < 4; r++) { + const int im2col_w = load_ixs.im2col_w_start + r; + ivec4 row_values; + for (int c = 0; c < 4; c++) { + const int k_in_group = load_ixs.k_in_group_start + c; + + if (k_in_group >= conv2d_params.logical_K_per_group) { + row_values[c] = input_zp; + continue; + } + + Conv2dBlockElementIndex block_idx = get_packed_int8_input_element_idx( + im2col_w, load_ixs.im2col_h, k_in_group, load_ixs.group_idx); + + row_values[c] = + load_packed_int8_input_element(block_idx, block_extents, input_zp); + } + + im2col_block[r] = pack_into_int32(row_values); + } + return im2col_block; +} + +ivec4 load_im2col_block( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const int input_zp, + const ivec4 input_zps) { + if (load_ixs.cols_aligned && load_ixs.rows_contiguous) { + return load_im2col_block_c_aligned_w_contiguous( + load_ixs, block_extents, input_zps); + } + return load_im2col_block_no_alignment(load_ixs, block_extents, input_zp); +} + +#ifdef DEBUG_MODE + +void printLoadIndices(const Im2ColBlockLoadIndices load_ixs) { + debugPrintfEXT("LoadIndices: \\n"); + + if (load_ixs.block_aligned) { + debugPrintfEXT(" block_aligned \\n"); + } + if (load_ixs.cols_aligned) { + debugPrintfEXT(" cols_aligned \\n"); + } + if (load_ixs.rows_contiguous) { + debugPrintfEXT(" rows_contiguous \\n"); + } + + debugPrintfEXT( + " block_idx_start: %d %d %d || %d %d \\n", + load_ixs.block_idx_start.x4, + load_ixs.block_idx_start.y, + load_ixs.block_idx_start.z4, + load_ixs.block_idx_start.row, + load_ixs.block_idx_start.col); +} + +#endif + +#endif // IM2COL_PACKED_INT8_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh index a6dbd7e78a2..8f19418cd19 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -43,13 +43,6 @@ ivec4 quantize( return clamp(ivec4(quantized), -128, 127); } -int pack_into_int32(const ivec4 quant_vals) { - int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | - ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); - - return packed; -} - void quantize_and_pack( out Int8InputBlock packed, const FPInputTile in_block, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh new file mode 100644 index 00000000000..14aa6558bfc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macro Settings: + * - TILE_M + * - TILE_N4 + */ + +#ifndef LINEAR_INT8_OUTPUT_TILE_GLSLH +#define LINEAR_INT8_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8OutTile { + ivec4 data[TILE_M4][TILE_N4]; +}; + +void initialize(out Int8OutTile tile) { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m4][n4] = ivec4(0); + } + } +} + +#ifdef DEBUG_MODE + +#include "linear_common.glslh" + +void printInt8OutTile(const Int8OutTile tile) { + debugPrintfEXT( + "Int8InputTile [TILE_M4=%d][TILE_N4=%d]:\\n", TILE_M4, TILE_N4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, n4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][n4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh new file mode 100644 index 00000000000..1251ca60b87 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH +#define LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_int8_output_tile.glslh" +#include "linear_int_accumulator.glslh" +#include "linear_int_per_out_channel_params.glslh" + +void compute_int8_out_tile_with_int32_accum( + out Int8OutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const float output_q_inv_scale, + const int output_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales) { + ivec4 input_zp_vec = ivec4(-input_q_zp); + ivec4 output_zp_vec = ivec4(-output_q_zp); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + vec4 float_out_texel = + vec4(accum_adjusted) * vec4(weight_scales.data[n4] * input_q_scale); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_q_inv_scale) + output_q_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + out_tile.data[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } +} + +void compute_int8_out_tile_with_int32_accum( + out Int8OutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const float output_q_inv_scale, + const int output_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales, + const FPPerOutChannelParams bias) { + ivec4 input_zp_vec = ivec4(-input_q_zp); + ivec4 output_zp_vec = ivec4(-output_q_zp); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + vec4 float_out_texel = + fma(vec4(accum_adjusted), + vec4(weight_scales.data[n4]) * input_q_scale, + vec4(bias.data[n4])); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_q_inv_scale) + output_q_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + out_tile.data[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } +} + +#endif // LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl index 0ad91643219..878821d4189 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl @@ -76,9 +76,6 @@ void main() { const int N4 = div_up_4(output_sizes.x); // number of texels in each row const int N8 = div_up_8(output_sizes.x); // number of texels in each row - bool should_print = (n8 == 0) && (m4 == 0); - should_print = false; - // VEC4_T out_texels[4][2]; FPOutTile out_tile; initialize(out_tile); diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl new file mode 100644 index 00000000000..da4162b6e58 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec3 orig_sizes; // [K_h, aligned_K_w, OC] +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" + +void main() { + // The size of the source weight tensor is [K_h, aligned_K_w, OC] for depthwise conv. + // Each shader invocation processes a 4x4 block of weights for a group of output channels. + const int oc4 = int(gl_GlobalInvocationID.x); + const int k4 = int(gl_GlobalInvocationID.y); + const int k = mul_4(k4); + + const int H = orig_sizes.x; + const int orig_W = orig_sizes.y; + const int W4 = div_up_4(orig_W); + const int OC = orig_sizes.z; + + const int h = k4 / W4; + const int w4 = k4 % W4; + const int w = mul_4(w4); + + // Determine the total number of blocks and check bounds + const int OC4 = div_up_4(OC); + const int K4 = H * W4; + + if (oc4 >= OC4 || k4 >= K4) { + return; + } + + ivec4 packed_block; + + int buf_idx = (h * orig_W + w) * OC4 + oc4; + int r_limit = min(4, orig_W - w); + [[unroll]] for (int r = 0; r < r_limit; r++) { + packed_block[r] = t_int8_weight[buf_idx]; + buf_idx += OC4; + } + [[unroll]] for (int r = r_limit; r < 4; r++) { + packed_block[r] = 0; + } + +#ifdef USING_BUFFER + t_packed_int8_weight[k4 * OC4 + oc4] = packed_block; +#else + imageStore(t_packed_int8_weight, ivec2(oc4, k4), packed_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml new file mode 100644 index 00000000000..9cfa3108ff0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_conv2d_dw_weights: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture2d + shader_variants: + - NAME: pack_q8_conv2d_dw_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl new file mode 100644 index 00000000000..e9982a8273d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +${define_required_extensions("int8")} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int8", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec4 orig_sizes; // [OC, K_h, K_w, IC] +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" + +void main() { + const int block_x = int(gl_GlobalInvocationID.x); + const int block_y = int(gl_GlobalInvocationID.y); + + const int kx = block_x % orig_sizes.z; + const int oc4 = block_x / orig_sizes.z; + + const int OC4 = div_up_4(orig_sizes.x); + const int IC4 = div_up_4(orig_sizes.w); + + const int nblocks_x = orig_sizes.z * OC4; + const int nblocks_y = IC4 * orig_sizes.y; + + const int ic4 = block_y % IC4; + const int ky = block_y / IC4; + + if (block_x >= nblocks_x || block_y >= nblocks_y) { + return; + } + + const int oc = mul_4(oc4); + const int ic = mul_4(ic4); + + const int oc_stride = align_up_4(orig_sizes.y * orig_sizes.z * orig_sizes.w); + const int oc_offset = oc * oc_stride; + const int ky_offset = ky * (orig_sizes.z * orig_sizes.w); + const int kx_offset = kx * orig_sizes.w; + int buf_idx = oc_offset + ky_offset + kx_offset + ic; + + ivec4 packed_block = ivec4(0); + for (int row = 0; row < 4; row++) { + if (oc + row < orig_sizes.x) { + ivec4 weight_vals = ivec4(0); + for (int col = 0; col < 4; col++) { + if (ic + col < orig_sizes.w) { + weight_vals[col] = int(t_int8_weight[buf_idx + col]); + } + } + packed_block[row] = pack_into_int32(weight_vals); + } + buf_idx += oc_stride; + } + +#ifdef USING_BUFFER + const int out_buf_idx = block_y * (nblocks_x) + block_x; + t_packed_int8_weight[out_buf_idx] = packed_block; +#else + imageStore(t_packed_int8_weight, ivec2(block_x, block_y), packed_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml new file mode 100644 index 00000000000..9331de6e758 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_conv2d_weights: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture2d + shader_variants: + - NAME: pack_q8_conv2d_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh index 03132db1348..1880397181d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -44,7 +44,6 @@ void load_k_cache_tile_no_checks( const int context_len, const int C, const int KV_H) { - bool should_print = d4_start == 0 && c_start == 0 && kv_h == 0; [[unroll]] for (int c = 0; c < TILE_N; ++c) { const int c4 = div_4(c); const int c4i = mod_4(c); diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl index ed7dd25421a..798366b523a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl @@ -28,7 +28,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 6c701224f7f..71690ffc604 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -56,4 +56,27 @@ utils::uvec3 pick_hw_square_wg_size( return {16u, 4u, 1u}; } +utils::uvec3 pick_wc_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + // Some inactive invocations are okay; set 6 as the threshold to use the + // a square wg size. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[2u] >= 6) { + return {8u, 1u, 8u}; + } + // If channels dim is sufficiently small, then bias towards width dim to + // reduce the number of inactive invocations. + if (global_workgroup_size[2u] < 2u) { + return {64u, 1u, 1u}; + } + return {16u, 1u, 4u}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index 1831ab2a845..b412f737c13 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -54,4 +54,11 @@ utils::uvec3 pick_hw_square_wg_size( const std::vector& args, const std::vector& resize_args); +utils::uvec3 pick_wc_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index f6eee4ba12e..75bbb3892df 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -19,6 +19,86 @@ namespace vkcompute { // Utility functions // +bool is_pointwise(ComputeGraph* graph, const ValueRef& kernel_size) { + const auto kernel_size_list = graph->get_int_list(kernel_size); + return kernel_size_list->at(0) == 1 && kernel_size_list->at(1) == 1; +} + +bool is_s1p1d1( + ComputeGraph* graph, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation) { + const auto stride_list = graph->get_int_list(stride); + const auto padding_list = graph->get_int_list(padding); + const auto dilation_list = graph->get_int_list(dilation); + if (stride_list->at(0) != 1 && stride_list->at(1) != 1) { + return false; + } + if (padding_list->at(0) != 1 && padding_list->at(1) != 1) { + return false; + } + if (dilation_list->at(0) != 1 && dilation_list->at(1) != 1) { + return false; + } + return true; +} + +bool is_s1p0d1_pointwise( + ComputeGraph* graph, + const ValueRef& kernel_size, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation) { + if (is_pointwise(graph, kernel_size)) { + const auto stride_list = graph->get_int_list(stride); + const auto padding_list = graph->get_int_list(padding); + const auto dilation_list = graph->get_int_list(dilation); + if (stride_list->at(0) != 1 && stride_list->at(1) != 1) { + return false; + } + if (padding_list->at(0) != 0 && padding_list->at(1) != 0) { + return false; + } + if (dilation_list->at(0) != 1 && dilation_list->at(1) != 1) { + return false; + } + return true; + } + return false; +} + +bool should_use_im2col( + ComputeGraph* graph, + const ValueRef kernel_size, + const ValueRef groups) { + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Always use im2col for pointwise convolutions + if (kernel_size_list->at(0) * kernel_size_list->at(1) == 1) { + return true; + } + + // For large kernel sizes, the im2col matrix will be too big. Not only will + // this result in a larger footprint for the im2col matrix, but the cost of + // performing the im2col procedure will also become prohibitive. In these + // cases it is faster to just compute convolution directly without going + // through im2col. Empirically, im2col works well for 3x3 convolution and + // not for 5x5 convolution, so set the limit at 10. + if (kernel_size_list->at(0) * kernel_size_list->at(1) > 10) { + return false; + } + + // Only use im2col for non-grouped convolutions; manual experimentation shows + // that im2col becomes very slow when dealing with grouped convolutions. The + // reason for this is likely that memory access in the im2col shader becomes + // too non-linear due to needed to keep convolution groups contiguous in + // in memory. This means that the channels of the input tensor (which are + // originally contiguous in memory) will be split up during the im2col + // procedure. + return graph->get_int(groups) == 1; +} + struct Conv2DParams { utils::ivec2 kernel_size; utils::ivec2 stride; @@ -135,6 +215,43 @@ std::vector calculate_input_im2col_sizes( return {M, K}; } +std::vector calculate_packed_int8_input_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // Represents the number of channel groups + const int64_t groups_val = graph->extract_scalar(groups); + // No need to div_up because in_channels % groups_val = 0 + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to the next multiple of 4 to ensure that data loads align nicely with + // texel boundaries. We want to ensure that the first data element of each + // group is at the start of its texel. + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (repeated for each group) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane. This is aligned to the next + // multiple of 4 since the im2col shader operates on 4x4 blocks. + const int64_t W = utils::align_up_4(out_width); + const int64_t H = out_height; + + return {K, H, W}; +} + std::vector calculate_output_im2col_sizes( ComputeGraph* graph, const ValueRef& output) { @@ -212,6 +329,33 @@ utils::uvec3 im2col_global_wg_size( return {K4, M4, 1}; } +utils::uvec3 im2col_packed_int8_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input_im2col = args.at(0).refs.at(0); + + std::vector im2col_sizes = graph->sizes_of(input_im2col); + const uint32_t K = utils::safe_downcast(im2col_sizes[0]); + const uint32_t H = utils::safe_downcast(im2col_sizes[1]); + const uint32_t W = utils::safe_downcast(im2col_sizes[2]); + + const uint32_t K4 = utils::div_up(K, 4u); + const uint32_t W4 = utils::div_up(W, 4u); + + return {K4 * W4 * H, 1, 1}; +} + +utils::uvec3 im2col_packed_int8_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return {64, 1, 1}; +} + utils::uvec3 col2im_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -231,6 +375,229 @@ utils::uvec3 col2im_global_wg_size( return {N4, M4, 1}; } +utils::uvec3 pick_static_quantized_conv2d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + uint32_t C_per_tile = 4; + uint32_t W_per_tile = 4; + + if (shader.kernel_name.find("linear") != std::string::npos) { + C_per_tile = 8; + } + + const uint32_t num_W_tiles = utils::div_up(W, W_per_tile); + const uint32_t num_C_tiles = utils::div_up(C, C_per_tile); + + return {num_C_tiles, num_W_tiles, H}; +} + +utils::uvec3 pick_static_quantized_conv2d_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +utils::uvec3 int8_conv2d_dw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {C4 * W4 * H, 1, 1}; +} + +// +// Prepack nodes +// + +ValueRef prepack_quantized_conv2d_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef input, + const ValueRef output, + const ValueRef groups, + const ValueRef kernel_size) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const int32_t groups_val = graph.get_int(groups); + + const int64_t OC = graph.size_at(-3, output); + const int64_t IC = graph.size_at(-3, input) / groups_val; + + int64_t K_h; + int64_t K_w; + + { + const auto kernel_size_list = graph.get_int_list(kernel_size); + K_h = kernel_size_list->at(0); + K_w = kernel_size_list->at(1); + } + + const int64_t num_blocks_OC = utils::div_up_4(OC); + const int64_t num_blocks_IC = utils::div_up_4(IC); + + const int64_t num_blocks_y = num_blocks_IC * K_h; + const int64_t num_blocks_x = K_w * num_blocks_OC; + + // The packed tensor arranges blocks as [OC_blocks * K_total, IC_blocks] + const int64_t output_height = num_blocks_y; + const int64_t output_width = num_blocks_x * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec4 orig_sizes = { + utils::safe_downcast(OC), + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(IC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_x), + utils::safe_downcast(num_blocks_y), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec4))})); + + return packed_weight; +} + +ValueRef prepack_quantized_conv2d_dw_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef kernel_size) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + std::vector weight_orig_sizes = graph.sizes_of(weight_data); + const int64_t ndim = graph.dim_of(weight_data); + + // For depthwise convolution, expect weight layout [K_h, aligned_K_w, OC] + VK_CHECK_COND(ndim == 3); + int64_t K_h = weight_orig_sizes.at(0); + int64_t K_w = weight_orig_sizes.at(1); + int64_t aligned_K_w = utils::align_up_4(K_w); + int64_t OC = weight_orig_sizes.at(2); + + // The packing format packs the weight tensor into blocks of 4 output channels + // (OC) and 4 kernel elements (K_h * aligned_K_w) + int64_t OC_per_block = 4; + int64_t K_per_block = 4; + + // To figure out the size of the output tensor, determine the number of blocks + // along each dimension. + const int64_t total_K_elements = K_h * aligned_K_w; + const int64_t num_blocks_K = utils::div_up(total_K_elements, K_per_block); + const int64_t num_blocks_OC = utils::div_up(OC, OC_per_block); + + // The blocks are arranged in a transposed manner, such that the transposed + // weight block is indexed like packed_weights[k4][oc4] - this is to allow for + // optimal memory coalescing when computing the depthwise convolution. + int64_t output_height = num_blocks_K; + // The base dtype of the packed tensor is int32 (each int32 contains 4x 8bit + // values) and each block is represented as a ivec4. Therefore the width dim + // of the packed tensor is multiplied by 4. + int64_t output_width = num_blocks_OC * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec3 orig_sizes = { + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(OC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_OC), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_dw_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec3))})); + + return packed_weight; +} + // // Dispatch nodes // @@ -285,6 +652,57 @@ void add_input_im2col_node( nullptr)); } +void add_input_im2col_packed_int8_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output, + const ValueRef input_im2col) { + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + std::string kernel_name = "im2col_packed_int8"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_im2col)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(input_im2col), + graph.sizes_ubo(output), + graph.sizes_ubo(input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_packed_int8_global_wg_size, + im2col_packed_int8_local_wg_size, + // Inputs and Outputs + {{input_im2col, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + void add_quantize_and_pack_q8ta_conv2d_input_node( ComputeGraph& graph, const ValueRef fp_input, @@ -314,7 +732,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node( graph, VK_KERNEL_FROM_STR(kernel_name), pick_quantize_and_pack_conv2d_input_global_wg_size, - default_pick_local_wg_size, + pick_wc_square_wg_size, // Inputs and Outputs {{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, // Shader params buffers @@ -590,54 +1008,229 @@ void add_conv2d_q8ta_q8csw_linear_node( nullptr)); } -// -// High level operator impl -// - -void quantized_conv2d_impl( +void add_conv2d_q8ta_q8csw_q8to_node( ComputeGraph& graph, - const QuantizationConfig& input_quant_config, - const QuantizationConfig& weight_quant_config, - const ValueRef input_image, + const ValueRef packed_int8_input, + const ValueRef packed_int8_input_im2col, const ValueRef input_scale, const ValueRef input_zp, - const ValueRef weight_data, - const ValueRef weight_sums_data, - const ValueRef weight_scales_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, const ValueRef bias_data, + const ValueRef packed_bias, const ValueRef kernel_size, const ValueRef stride, const ValueRef padding, const ValueRef dilation, const ValueRef groups, - const ValueRef output_image) { - VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); - VK_CHECK_COND(weight_quant_config.nbits == 8); - VK_CHECK_COND(weight_quant_config.is_symmetric); + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); - const ValueRef packed_weight = - prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); - ValueRef packed_weight_scales = prepack_standard( - graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const bool use_im2col = should_use_im2col(&graph, kernel_size, groups); - // Create a dummy tensor to fill the binding slot of the bias tensor if it is - // not provided. This helps simplify dispatch logic and makes it so that - // fewer shader variants need to be generated. - TmpTensor dummy_bias( - &graph, - {}, - graph.dtype_of(output_image), - utils::kBuffer, - utils::kWidthPacked); + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); - ValueRef packed_bias = dummy_bias.vref; - if (!graph.val_is_none(bias_data)) { - packed_bias = - prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); - } + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); - std::vector input_im2col_sizes = calculate_input_im2col_sizes( - &graph, input_image, output_image, kernel_size, groups); + std::string kernel_name = use_im2col ? "conv2d_q8ta_q8csw_q8to_linear_tiled" + : "conv2d_q8ta_q8csw_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), + graph.sizes_ubo(packed_int8_input_im2col), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_static_quantized_conv2d_global_wg_size, + pick_static_quantized_conv2d_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input_im2col, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_dw_q8ta_q8csw_q8to_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // Verify this is actually a depthwise convolution + const int64_t groups_val = graph.extract_scalar(groups); + const int64_t in_channels = graph.size_at(-3, packed_int8_input); + VK_CHECK_COND(groups_val == in_channels); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "conv2d_dw_q8ta_q8csw_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), + graph.sizes_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + int8_conv2d_dw_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void quantized_conv2d_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef input_image, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef bias_data, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image) { + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shader variants need to be generated. + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(output_image), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (!graph.val_is_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + std::vector input_im2col_sizes = calculate_input_im2col_sizes( + &graph, input_image, output_image, kernel_size, groups); // Use weight only quantized conv2d if at least one is true: // 1. Device does not support int8 dot product @@ -805,10 +1398,244 @@ void conv2d_q8csw(ComputeGraph& graph, const std::vector& args) { output_image); } +// Implementation for statically quantized conv2d, which expects input, weight, +// and output tensors to all have packed int8 dtype/memory layout. +void static_quantized_conv2d_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const QuantizationConfig& output_quant_config, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + // Currently, only certain quantization configs are supported + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + VK_CHECK_COND(input_quant_config.nbits == 8); + + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + VK_CHECK_COND(output_quant_config.granularity == kPerTensor); + VK_CHECK_COND(output_quant_config.nbits == 8); + + // Check for depthwise conv + const int64_t groups_val = graph.extract_scalar(groups); + const int64_t in_channels = graph.size_at(-3, packed_int8_input); + + // Depthwise convs have a specialized implementation, since the regular conv + // implementations requires that the number of input and output channels per + // groups is a multiple of 4. This is so that all values that are part of the + // same 4Wx4C block have the same group index. + const bool is_depthwise = (groups_val == in_channels); + + const bool use_im2col = should_use_im2col(&graph, kernel_size, groups); + // For pointwise convolution with stride = 1, padding = 0, dilation = 1, the + // input tensor is already equivalent to its im2col representation. In this + // case we can skip the im2col procedure and pass in the input image to the + // convolution_as_matmul implementation directly. + const bool is_optimizable_pw = + is_s1p0d1_pointwise(&graph, kernel_size, stride, padding, dilation); + + ValueRef packed_weight; + if (is_depthwise) { + packed_weight = prepack_quantized_conv2d_dw_weight( + graph, weight_quant_config, weight_data, kernel_size); + } else if (use_im2col) { + packed_weight = prepack_quantized_linear_weight( + graph, weight_quant_config, weight_data); + } else { + packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + } + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // See quantized_conv2d_impl for why this is needed + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Depthwise conv path + if (is_depthwise) { + add_conv2d_dw_q8ta_q8csw_q8to_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); + return; + } + + std::vector input_im2col_sizes = + calculate_packed_int8_input_im2col_sizes( + &graph, packed_int8_input, packed_int8_output, kernel_size, groups); + + ValueRef packed_int8_input_im2col = packed_int8_input; + if (use_im2col && !is_optimizable_pw) { + TmpTensor packed_int8_input_im2col_tensor( + &graph, + input_im2col_sizes, + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + packed_int8_input_im2col = packed_int8_input_im2col_tensor.vref; + + add_input_im2col_packed_int8_node( + graph, + packed_int8_input, + input_scale, + input_zp, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output, + packed_int8_input_im2col); + } + + add_conv2d_q8ta_q8csw_q8to_node( + graph, + packed_int8_input, + packed_int8_input_im2col, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); +} + +void conv2d_q8ta_q8csw_q8to( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig input_quant_config(8, kPerTensor, {}); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + QuantizationConfig output_quant_config(8, kPerTensor, {}); + + static_quantized_conv2d_impl( + graph, + input_quant_config, + weight_quant_config, + output_quant_config, + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); +} + // // Quantize and dequantize operators // +void quantize_q8ta_for_conv2d( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef packed_int8_input = args.at(idx++); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input, scale, zero_point, packed_int8_input); +} + +void dequantize_q8to_from_conv2d( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_output = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + add_unpack_and_dequantize_q8ta_conv2d_output_node( + graph, packed_int8_output, scale, zero_point, fp_output); +} + void qdq8ta_conv2d_input( ComputeGraph& graph, const std::vector& args) { @@ -832,10 +1659,82 @@ void qdq8ta_conv2d_input( graph, packed_int8_input, scale, zero_point, fp_output); } +// +// Test operators +// + +void conv2d_q8ta_q8csw_q8to_test( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector conv2d_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; + + conv2d_q8ta_q8csw_q8to(graph, conv2d_args); + + add_unpack_and_dequantize_q8ta_conv2d_output_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw); VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); + VK_REGISTER_OP(etvk.conv2d_q8ta_q8csw_q8to.test, conv2d_q8ta_q8csw_q8to_test); + VK_REGISTER_OP( + et_vk.quantize_q8ta_for_conv2d.default, quantize_q8ta_for_conv2d); + VK_REGISTER_OP( + et_vk.dequantize_q8to_from_conv2d.default, dequantize_q8to_from_conv2d); + VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw_q8to.default, conv2d_q8ta_q8csw_q8to); + VK_REGISTER_OP( + et_vk.conv2d_q8ta_q8csw_q8to_dw.default, conv2d_q8ta_q8csw_q8to); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index fe36de3047e..348eeded962 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -96,4 +96,6 @@ if(TARGET vulkan_backend) add_operator_prototype(q4gsw_linear) add_operator_prototype(choose_qparams_per_row) add_operator_prototype(qdq8ta_conv2d_activations) + add_operator_prototype(q8ta_q8csw_q8to_conv2d) + add_operator_prototype(q8ta_q8csw_q8to_conv2d_dw) endif() diff --git a/backends/vulkan/test/custom_ops/conv2d_utils.cpp b/backends/vulkan/test/custom_ops/conv2d_utils.cpp new file mode 100644 index 00000000000..74c26cef5a1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d_utils.cpp @@ -0,0 +1,10 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "conv2d_utils.h" + +// Implementation file for conv2d utilities. +// Currently all functionality is implemented inline in the header. diff --git a/backends/vulkan/test/custom_ops/conv2d_utils.h b/backends/vulkan/test/custom_ops/conv2d_utils.h new file mode 100644 index 00000000000..cad52219062 --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d_utils.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace executorch { +namespace vulkan { +namespace prototyping { + +// Component structs for better readability +struct KernelSize { + int32_t h; + int32_t w; + + KernelSize(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Stride { + int32_t h; + int32_t w; + + Stride(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Padding { + int32_t h; + int32_t w; + + Padding(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Dilation { + int32_t h; + int32_t w; + + Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} +}; + +struct OutInChannels { + int32_t out; + int32_t in; + + OutInChannels(int32_t out_channels, int32_t in_channels) + : out(out_channels), in(in_channels) {} +}; + +struct InputSize2D { + int32_t h; + int32_t w; + + InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} +}; + +// Conv2d configuration struct +struct Conv2dConfig { + OutInChannels channels; + InputSize2D input_size; + KernelSize kernel; + Stride stride; + Padding padding; + Dilation dilation; + int32_t groups; // Number of groups for grouped convolution + std::string test_case_name = "placeholder"; + std::string op_name = "conv2d"; + + // Calculate output dimensions + int64_t get_output_height() const { + return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / + stride.h + + 1; + } + + int64_t get_output_width() const { + return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / + stride.w + + 1; + } +}; + +} // namespace prototyping +} // namespace vulkan +} // namespace executorch diff --git a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp index d566e5b2646..219bccb04c3 100644 --- a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp @@ -8,6 +8,7 @@ #include #include #include +#include "conv2d_utils.h" #include "utils.h" #include @@ -18,76 +19,6 @@ using namespace vkcompute; static constexpr int64_t kRefDimSizeLimit = 100; -// Component structs for better readability -struct KernelSize { - int32_t h; - int32_t w; - - KernelSize(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Stride { - int32_t h; - int32_t w; - - Stride(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Padding { - int32_t h; - int32_t w; - - Padding(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Dilation { - int32_t h; - int32_t w; - - Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} -}; - -struct OutInChannels { - int32_t out; - int32_t in; - - OutInChannels(int32_t out_channels, int32_t in_channels) - : out(out_channels), in(in_channels) {} -}; - -struct InputSize2D { - int32_t h; - int32_t w; - - InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} -}; - -// Conv2d configuration struct -struct Conv2dConfig { - OutInChannels channels; - InputSize2D input_size; - KernelSize kernel; - Stride stride; - Padding padding; - Dilation dilation; - int32_t groups; // Number of groups for grouped convolution - std::string test_case_name = "placeholder"; - std::string op_name = "conv2d_q8ta_q8csw"; - - // Calculate output dimensions - int64_t get_output_height() const { - return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / - stride.h + - 1; - } - - int64_t get_output_width() const { - return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / - stride.w + - 1; - } -}; - // Utility function to create a test case from a Conv2dConfig TestCase create_test_case_from_config( const Conv2dConfig& config, @@ -366,13 +297,20 @@ std::vector generate_quantized_conv2d_test_cases() { Stride(1, 1), Padding(1, 1), Dilation(1, 1), - 8}, + 1}, {OutInChannels(128, 64), InputSize2D(128, 128), KernelSize(3, 3), Stride(1, 1), Padding(1, 1), Dilation(1, 1), + 1}, + {OutInChannels(128, 1024), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), 1}}; // Test with different storage types and data types @@ -394,6 +332,7 @@ std::vector generate_quantized_conv2d_test_cases() { std::to_string(config.kernel.h) + "/" + std::to_string(config.kernel.w); + config.op_name = "conv2d_q8ta_q8csw"; config.test_case_name = prefix + suffix; // The default operator tested is activation + weight quantized conv2d; // however, only test this if the int8 dot product extension is supported @@ -763,7 +702,7 @@ int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { int main(int argc, char* argv[]) { set_debugging(false); set_print_output(false); - set_print_latencies(false); + set_print_latencies(true); set_use_gpu_timestamps(true); print_performance_header(); diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp new file mode 100644 index 00000000000..8762fe4c0d1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp @@ -0,0 +1,628 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "conv2d_utils.h" +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Utility function to create a test case from a Conv2dConfig +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".test"; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, C_in_per_group * K_h * K_w] + // Memory layout: height, width, then channels - in_c is innermost (stride 1) + // in the second dimension + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + const int64_t aligned_out_channels = utils::align_up_4(config.channels.out); + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {aligned_out_channels}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {aligned_out_channels}, // Per output channel + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {aligned_out_channels}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Output quantization parameters + // float output_scale_val = 0.01432; + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8csw_q8to operation + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate easy test cases for quantized conv2d operation (for debugging) +std::vector generate_quantized_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Conv2dConfig config = { + OutInChannels(16, 8), // channels (out, in) + InputSize2D(21, 17), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(1, 1), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 2, // groups + }; + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized conv2d operation +std::vector generate_quantized_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Pointwise convolutions: kernel size 1x1 + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(32, 32), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(96, 64), + InputSize2D(16, 16), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(13, 7), + InputSize2D(57, 33), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + // General 2D convolutions + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(16, 32), + InputSize2D(77, 77), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Grouped convolutions + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 2}, + {OutInChannels(96, 96), + InputSize2D(81, 81), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 3}, + {OutInChannels(96, 96), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}, + // Performance cases (pointwise) + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + // Performance cases (general 2d convs) + {OutInChannels(32, 3), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Generate test case name programmatically + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + std::string prefix = is_performance ? "performance_" : "correctness_"; + std::string suffix = std::to_string(config.channels.out) + "/" + + std::to_string(config.channels.in) + "_" + + std::to_string(config.input_size.h) + "/" + + std::to_string(config.input_size.w) + "_" + + std::to_string(config.kernel.h) + "/" + + std::to_string(config.kernel.w); + + config.op_name = "conv2d_q8ta_q8csw_q8to"; + config.test_case_name = prefix + suffix; + + // Only test q8ta_q8csw_q8to if the int8 dot product extension is + // supported + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + } + + return test_cases; +} + +// Reference implementation for activation, weight, and output quantized conv2d +void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [C_out, C_in_per_group * K_h * K_w] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + // Calculate channels per group for grouped convolution + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Perform activation, weight, and output quantized conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // Determine which group this output channel belongs to + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + int64_t in_c_end = (group_idx + 1) * C_in_per_group; + + // Convolution operation with integer accumulation + for (int64_t in_c = in_c_start; in_c < in_c_end; ++in_c) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight (already int8) + // Weight layout: [C_out, C_in_per_group * K_h * K_w] + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + conv2d_q8ta_q8csw_q8to_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized conv2d operation +int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { + int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized conv2d operation + // Each output element requires: + // - C_in * K_h * K_w multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + int64_t flop = output_elements * (ops_per_output); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_conv2d_test_cases, + quantized_conv2d_flop_calculator, + "QuantizedConv2dQ8ToQ8To", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp new file mode 100644 index 00000000000..c259b45de06 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp @@ -0,0 +1,592 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "conv2d_utils.h" +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Utility function to create a test case from a Conv2dConfig for depthwise +// convolution +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".test"; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor", false, 64); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) for depthwise convolution + // Memory layout: [K_h, K_w, OC] + // For depthwise conv: groups = channels.out, in_channels_per_group = 1 + std::vector weight_size = { + config.kernel.h, config.kernel.w, config.channels.out}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor", false, 64); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.channels.out}, // Per output channel + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights for depthwise layout + // For depthwise conv: each output channel has K_h * K_w weights + // Custom computation for depthwise layout [K_h, K_w, OC] + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_int8_data(); + + weight_sums_data.resize(config.channels.out); + + for (int64_t out_c = 0; out_c < config.channels.out; ++out_c) { + int32_t sum = 0; + for (int64_t kh = 0; kh < config.kernel.h; ++kh) { + for (int64_t kw = 0; kw < config.kernel.w; ++kw) { + // Weight indexing for depthwise layout [K_h, K_w, OC] + int64_t weight_idx = kh * (config.kernel.w * config.channels.out) + + kw * config.channels.out + out_c; + sum += static_cast(quantized_weight_data[weight_idx]); + } + } + weight_sums_data[out_c] = sum; + } + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + + // Output quantization parameters + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8csw_q8to operation + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate easy test cases for quantized depthwise conv2d operation (for +// debugging) +std::vector generate_quantized_conv2d_dw_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging - depthwise convolution + Conv2dConfig config = { + OutInChannels(8, 8), // channels (out, in) - equal for depthwise + InputSize2D(8, 8), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(2, 2), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 8, // groups = channels.out for depthwise + }; + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized depthwise conv2d operation +std::vector generate_quantized_conv2d_dw_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Depthwise convolutions: groups = channels.out, channels.in = + // channels.out + {OutInChannels(32, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 32}, + {OutInChannels(64, 64), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 64}, + {OutInChannels(64, 64), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 64}, + {OutInChannels(80, 80), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 80}, + {OutInChannels(16, 16), + InputSize2D(57, 33), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 16}, + // Different kernel sizes for depthwise + {OutInChannels(32, 32), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 32}, + {OutInChannels(96, 96), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 96}, + // Performance cases + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 128}, + {OutInChannels(64, 64), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 64}, + {OutInChannels(288, 288), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 288}, + {OutInChannels(32, 32), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 32}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Generate test case name programmatically + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + std::string prefix = + is_performance ? "performance_dw_" : "correctness_dw_"; + std::string suffix = std::to_string(config.channels.out) + "/" + + std::to_string(config.channels.in) + "_" + + std::to_string(config.input_size.h) + "/" + + std::to_string(config.input_size.w) + "_" + + std::to_string(config.kernel.h) + "/" + + std::to_string(config.kernel.w); + + config.op_name = "conv2d_q8ta_q8csw_q8to"; + config.test_case_name = prefix + suffix; + + // Only test q8ta_q8csw_q8to if the int8 dot product extension is + // supported + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + } + + return test_cases; +} + +// Reference implementation for activation, weight, and output quantized +// depthwise conv2d +void conv2d_q8ta_q8csw_q8to_dw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [K_h, align_up_4(K_w), OC] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Verify this is a depthwise convolution + if (groups != C_out || C_in != C_out) { + throw std::invalid_argument( + "This is not a depthwise convolution configuration"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform activation, weight, and output quantized depthwise conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // For depthwise convolution, each output channel corresponds to one + // input channel + int64_t in_c = out_c; + + // Convolution operation with integer accumulation + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight using depthwise layout [K_h, K_w, OC] + int64_t weight_idx = kh * (K_w * C_out) + kw * C_out + out_c; + int8_t quantized_weight = weight_data[weight_idx]; + + if (false && in_w == 0 && in_h == 0 && out_c == 0) { + std::cout << "input: " << input_data[input_idx] << std::endl; + std::cout << "quantized_input: " << (int)quantized_input + << std::endl; + std::cout << "quantized_weight: " << (int)quantized_weight + << std::endl; + } + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = kh * (K_w * C_out) + kw * C_out + out_c; + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + if (false && out_c < 4 && out_h < 1 && out_w < 4) { + std::cout << "int_sum[" << out_c << ", " << out_h << ", " << out_w + << "] = " << int_sum << ", " << float_result << ", " + << output_scale << ", " << quant_output_f << std::endl; + } + + // Dequantize back to float + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + conv2d_q8ta_q8csw_q8to_dw_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized depthwise conv2d operation +int64_t quantized_conv2d_dw_flop_calculator(const TestCase& test_case) { + int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized depthwise conv2d operation + // Each output element requires: + // - K_h * K_w multiply-accumulate operations (only one input channel per + // output channel) + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = K_h * K_w; + + int64_t flop = output_elements * ops_per_output; + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Depthwise Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_conv2d_dw_test_cases, + quantized_conv2d_dw_flop_calculator, + "QuantizedDepthwiseInt8Conv2d", + 0, + 1, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 1d1b1fe79bd..959e013981c 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -60,9 +60,11 @@ def define_common_targets(is_fbcode = False): ], headers = [ "utils.h", + "conv2d_utils.h", ], exported_headers = [ "utils.h", + "conv2d_utils.h", ], platforms = get_platforms(), deps = [ @@ -98,3 +100,5 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("choose_qparams_per_row") define_custom_op_test_binary("q4gsw_linear") define_custom_op_test_binary("qdq8ta_conv2d_activations") + define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d") + define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d_dw") diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 2aa827a4d5a..4de6c32ac25 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -661,7 +661,12 @@ float collect_gpu_timing_us(ComputeGraph& graph) { float total_duration_us = 0.0f; for (const auto& shader_result : results) { if (shader_result.kernel_name.find("nchw_to") == std::string::npos && - shader_result.kernel_name.find("to_nchw") == std::string::npos) { + shader_result.kernel_name.find("to_nchw") == std::string::npos && + shader_result.kernel_name.find( + "quantize_and_pack_q8ta_conv2d_input") == std::string::npos && + shader_result.kernel_name.find( + "unpack_and_dequantize_q8ta_conv2d_output") == + std::string::npos) { // Calculate duration from start and end times, convert from ns to μs uint64_t duration_ns = shader_result.end_time_ns - shader_result.start_time_ns; @@ -1715,6 +1720,41 @@ void compute_weight_sums( } } +// Compute weight sums for 4D quantized conv2d operations +// Weight layout: [C_out, K_h, K_w, align_up_4(C_in_per_group)] +void compute_weight_sums_4d( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t out_channels, + int64_t kernel_h, + int64_t kernel_w, + int64_t aligned_in_channels) { + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_int8_data(); + + weight_sums_data.resize(out_channels); + + // For each output channel, compute the sum of quantized weights + for (int64_t out_c = 0; out_c < out_channels; ++out_c) { + int32_t sum = 0; + + for (int64_t kh = 0; kh < kernel_h; ++kh) { + for (int64_t kw = 0; kw < kernel_w; ++kw) { + for (int64_t in_c = 0; in_c < aligned_in_channels; ++in_c) { + // Weight indexing: [out_c, kh, kw, in_c] + int64_t weight_idx = + out_c * (kernel_h * kernel_w * aligned_in_channels) + + kh * (kernel_w * aligned_in_channels) + kw * aligned_in_channels + + in_c; + sum += static_cast(quantized_weight_data[weight_idx]); + } + } + } + + weight_sums_data[out_c] = sum; + } +} + // Helper function to unpack 4-bit values from uint8 (same as in // q4gsw_linear.cpp) std::pair unpack_4bit_utils(uint8_t packed) { diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index f1736f1d144..b80f28639e8 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -653,6 +653,16 @@ void compute_weight_sums( int64_t out_features, int64_t elements_per_output_feature); +// Compute weight sums for 4D quantized conv2d operations +// Weight layout: [C_out, K_h, K_w, align_up_4(C_in_per_group)] +void compute_weight_sums_4d( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t out_channels, + int64_t kernel_h, + int64_t kernel_w, + int64_t aligned_in_channels); + // Compute weight sums for 4-bit group symmetric quantized weights void compute_weight_sums_4bit_grouped( ValueSpec& weight_sums,