From 91f588285477f6759d68ddac595f0f027210d168 Mon Sep 17 00:00:00 2001 From: ssjia Date: Sun, 28 Sep 2025 11:26:42 -0700 Subject: [PATCH 1/2] [ET-VK] Statically quantized convolutions ## Changes This diff adds implementations for quantized convolution under the following quantization conditions: * activations statically quantized to 8-bit with per tensor scale and zero point * weights quantized to 8-bit with per channel scales * outputs statically quantized to 8-bit with per tensor scale and zero point 3 different implementations are added, which are selected between based on the input conditions. The first is an direct convolution shader which uses the quantized int8 input directly. The second is an im2col variant, which computes the convolution via a gemm like algorithm by first applying an im2col tranformation on the input tensor. Finally, a specialized implementation is added for depthwise convolutions. Differential Revision: [D83437827](https://our.internmc.facebook.com/intern/diff/D83437827/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/common.glslh | 15 + .../graph/ops/glsl/conv2d_common.glslh | 12 + .../graph/ops/glsl/conv2d_dw_q8_utils.glslh | 214 ++++ .../ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl | 121 +++ .../ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml | 20 + .../glsl/conv2d_int8_input_block_load.glslh | 30 + .../glsl/conv2d_int8_input_tile_load.glslh | 74 ++ .../glsl/conv2d_int8_output_tile_store.glslh | 45 + .../glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl | 144 +++ .../glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml | 20 + .../graph/ops/glsl/conv2d_q8_utils.glslh | 151 +++ .../ops/glsl/conv2d_q8ta_q8csw_q8to.glsl | 173 ++++ .../ops/glsl/conv2d_q8ta_q8csw_q8to.yaml | 20 + .../conv2d_q8ta_q8csw_q8to_linear_tiled.glsl | 149 +++ .../conv2d_q8ta_q8csw_q8to_linear_tiled.yaml | 20 + .../graph/ops/glsl/im2col_packed_int8.glsl | 73 ++ .../graph/ops/glsl/im2col_packed_int8.yaml | 14 + .../ops/glsl/im2col_packed_int8_utils.glslh | 287 ++++++ .../ops/glsl/linear_int8_input_block.glslh | 7 - .../ops/glsl/linear_int8_output_tile.glslh | 67 ++ .../linear_int8_output_tile_compute.glslh | 117 +++ .../graph/ops/glsl/linear_q4gsw_tiled.glsl | 3 - .../ops/glsl/pack_q8_conv2d_dw_weights.glsl | 72 ++ .../ops/glsl/pack_q8_conv2d_dw_weights.yaml | 15 + .../ops/glsl/pack_q8_conv2d_weights.glsl | 82 ++ .../ops/glsl/pack_q8_conv2d_weights.yaml | 15 + .../ops/glsl/sdpa_fp_k_cache_tile_load.glslh | 1 - ...ack_and_dequantize_q8ta_conv2d_output.glsl | 1 - .../vulkan/runtime/graph/ops/impl/Common.cpp | 23 + .../vulkan/runtime/graph/ops/impl/Common.h | 7 + .../graph/ops/impl/QuantizedConvolution.cpp | 973 +++++++++++++++++- .../vulkan/test/custom_ops/conv2d_utils.cpp | 10 + .../vulkan/test/custom_ops/conv2d_utils.h | 88 ++ .../vulkan/test/custom_ops/q8csw_conv2d.cpp | 83 +- .../custom_ops/q8ta_q8csw_q8to_conv2d.cpp | 628 +++++++++++ .../custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp | 592 +++++++++++ backends/vulkan/test/custom_ops/targets.bzl | 4 + backends/vulkan/test/custom_ops/utils.cpp | 42 +- backends/vulkan/test/custom_ops/utils.h | 10 + 39 files changed, 4301 insertions(+), 121 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml create mode 100644 backends/vulkan/test/custom_ops/conv2d_utils.cpp create mode 100644 backends/vulkan/test/custom_ops/conv2d_utils.h create mode 100644 backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp create mode 100644 backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 00a053612f5..95cdf70679b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -46,6 +46,14 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) { return byte; } +ivec4 unpack_int8x4(const int packed) { + return ivec4( + extract_8bit_from_packed_int_le(packed, 0), + extract_8bit_from_packed_int_le(packed, 1), + extract_8bit_from_packed_int_le(packed, 2), + extract_8bit_from_packed_int_le(packed, 3)); +} + int pack_4xqint_into_int32( const int val0, const int val1, @@ -57,6 +65,13 @@ int pack_4xqint_into_int32( return packed; } +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | + ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); + + return packed; +} + #ifdef DEBUG_MODE #extension GL_EXT_debug_printf : require diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh index 929f3da299e..6f460d1398c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh @@ -61,6 +61,18 @@ Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) { return block_sizes; } +Conv2dBlockIndex linear_idx_to_block_idx( + const int idx, const Conv2dBlockExtents block_extents) { + Conv2dBlockIndex block_idx; + block_idx.data.z = idx % block_extents.data.z; + + const int row = idx / block_extents.data.z; + block_idx.data.x = row % block_extents.data.x; + block_idx.data.y = row / block_extents.data.x; + + return block_idx; +} + bool block_idx_out_of_bounds( const Conv2dBlockIndex block_idx, const Conv2dBlockExtents block_extents) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh new file mode 100644 index 00000000000..f1d90aa83cb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_DW_Q8_UTILS_GLSLH +#define CONV2D_DW_Q8_UTILS_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct InputWindow1D { + vec4[MAX_WINDOW_WIDTH] data; + int len; +}; + +InputWindow1D initial_input_window() { + InputWindow1D input_window; + for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) { + input_window.data[i] = vec4(0); + } + input_window.len = 0; + return input_window; +} + +vec4 dequantize(const int packed_texel, const float scale, const int zp) { + return vec4(unpack_int8x4(packed_texel) - zp) * scale; +} + +vec4 dequantize(const int packed_texel, const vec4 scales) { + return vec4(unpack_int8x4(packed_texel)) * scales; +} + +bool in_bounds( + const int block_w, + const int block_h, + const int block_c4, + const Conv2dBlockExtents block_extents) { + ivec3 idx = ivec3(block_w, block_h, block_c4); + if (any(lessThan(idx, ivec3(0)))) { + return false; + } + if (any(greaterThanEqual(idx, block_extents.data))) { + return false; + } + + return true; +} + +InputWindow1D load_input_window( + const int w_start, + const int w_end, + const int h, + const int c4, + const Conv2dBlockExtents block_extents, + const float input_scale, + const int input_zp, + const ivec4 input_zps) { + InputWindow1D input_window = initial_input_window(); + + const int block_w_start = div_4(w_start); + const int block_w_end = div_4(w_end); + + int window_i = 0; + for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) { + ivec4 input_block = input_zps; + + if (in_bounds(block_w, h, c4, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + h * block_extents.data_xz + block_w * block_extents.data.z + c4; + input_block = t_packed_int8_input[buffer_idx]; +#else + input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0); +#endif + } + + const int loaded_w_start = mul_4(block_w); + for (int row = 0; row < 4; ++row) { + if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) { + input_window.data[window_i++] = + dequantize(input_block[row], input_scale, input_zp); + } + } + } + input_window.len = window_i; + return input_window; +} + +struct WeightRow { + vec4[MAX_KERNEL_WIDTH] data; + int len; +}; + +WeightRow initial_weight_row() { + WeightRow weight_row; + for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) { + weight_row.data[i] = vec4(0); + } + weight_row.len = 0; + return weight_row; +} + +WeightRow load_weight_row( + const int oc4, + const int ky, + const int OC4, + const int Kw, + const int Kw4, + const vec4 weight_scales) { + WeightRow weight_row = initial_weight_row(); + + int k4 = ky * Kw4; + int row_idx = 0; + for (int w = 0; w < Kw; w += 4) { +#ifdef WEIGHT_BUFFER + const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4]; +#else + const ivec4 weight_block = texelFetch( + t_packed_int8_weight, ivec2(oc4, k4), 0); +#endif + + for (int row = 0; row < 4; ++row) { + if (w + row < Kw) { + weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales); + } + } + k4++; + } + weight_row.len = row_idx; + return weight_row; +} + +struct FPOutBlock { + vec4[4] data; +}; + +void perform_conv1d( + inout FPOutBlock out_block, + const InputWindow1D input_window, + const WeightRow weight_row) { + for (int out_w = 0; out_w < 4; ++out_w) { + [[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) { + const int in_w = out_w * conv2d_params.stride.x; + out_block.data[out_w] = fma( + input_window.data[in_w + kx], + weight_row.data[kx], + out_block.data[out_w]); + } + } +} + +ivec4 quantize( + const vec4 texel, const float inv_scale, const int zp) { + vec4 quantized = round(texel * inv_scale) + zp; + return clamp(ivec4(quantized), -128, 127); +} + +ivec4 quantize_and_pack( + FPOutBlock out_block, const float inv_scale, const int zp) { + ivec4 packed_block; + for (int row = 0; row < 4; ++row) { + ivec4 quantized_texel = quantize(out_block.data[row], inv_scale, zp); + packed_block[row] = pack_into_int32(quantized_texel); + } + return packed_block; +} + +#ifdef DEBUG_MODE + +void printInputWindow1D(const InputWindow1D input_window) { + debugPrintfEXT("InputWindow1D contents (len = %d): \\n", input_window.len); + for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + input_window.data[i].x, + input_window.data[i].y, + input_window.data[i].z, + input_window.data[i].w); + } +} + +void printWeightRow(const WeightRow weight_row) { + debugPrintfEXT("WeightRow contents (len = %d): \\n", weight_row.len); + for (int i = 0; i < min(weight_row.len, MAX_KERNEL_WIDTH); ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + weight_row.data[i].x, + weight_row.data[i].y, + weight_row.data[i].z, + weight_row.data[i].w); + } +} + +void printFPOutBlock(const FPOutBlock out_block) { + debugPrintfEXT("FPOutBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + out_block.data[i].x, + out_block.data[i].y, + out_block.data[i].z, + out_block.data[i].w); + } + } + +#endif // DEBUG_MODE + +#endif // CONV2D_DW_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl new file mode 100644 index 00000000000..8994ced3acb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl @@ -0,0 +1,121 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define MAX_WINDOW_WIDTH 12 +#define MAX_KERNEL_WIDTH 5 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_dw_q8_utils.glslh" + +void main() { + const int tid = int(gl_GlobalInvocationID.x); + Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes); + + Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx( + tid, out_block_extents); + + if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) { + return; + } + + const int out_w = mul_4(out_block_idx.data.x); + const int w_start = + (out_w * conv2d_params.stride.x) - conv2d_params.padding.x; + const int w_end = ((out_w + 3) * conv2d_params.stride.x) - + conv2d_params.padding.x + + (conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x; + + Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp))); + const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]); + + const int Kw4 = div_up_4(conv2d_params.kernel_size.x); + + FPOutBlock out_block; + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + const int out_h = out_block_idx.data.y; + const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y + + ky * conv2d_params.dilation.y; + + InputWindow1D input_window = load_input_window( + w_start, + w_end, + h, + out_block_idx.data.z, + in_block_extents, + input_scale, + input_zp, + input_zps); + + WeightRow weight_row = load_weight_row( + out_block_idx.data.z, + ky, + out_block_extents.data.z, + conv2d_params.kernel_size.x, + Kw4, + weight_scales); + + perform_conv1d(out_block, input_window, weight_row); + } + + if (apply_bias > 0) { + const vec4 bias = vec4(t_bias[out_block_idx.data.z]); + for (int row = 0; row < 4; row++) { + out_block.data[row] += bias; + } + } + + const ivec4 packed_out_block = quantize_and_pack( + out_block, output_inv_scale, output_zp); + +#ifdef PACKED_INT8_OUTPUT_BUFFER + t_packed_int8_output[tid] = packed_out_block; +#else + imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml new file mode 100644 index 00000000000..77f801668a4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_dw_q8ta_q8csw_q8to: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_dw_q8ta_q8csw_q8to diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh new file mode 100644 index 00000000000..44c226f6891 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_INPUT_BLOCK_LOAD +#define CONV2D_INT8_INPUT_BLOCK_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "conv2d_common.glslh" +#include "conv2d_int8_activation_block.glslh" + +void store_packed_int8_input_block( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const Int8ActivationBlock packed_int8_block) { +#ifdef OUTPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + t_packed_int8_input[buffer_idx] = packed_int8_block.data; +#else + imageStore(t_packed_int8_input, block_idx.data, packed_int8_block.data); +#endif +} + +#endif // CONV2D_INT8_INPUT_BLOCK_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh new file mode 100644 index 00000000000..44aa09912ec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_INPUT_TILE_LOAD +#define CONV2D_INT8_INPUT_TILE_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +struct Int8InputTileIndex { +#ifdef PACKED_INT8_INPUT_BUFFER + int data; +#else + ivec3 data; +#endif +}; + +Int8InputTileIndex make_initial_int8_input_tile_index( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { + Int8InputTileIndex idx; +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z; +#else + idx.data = ivec3(block_idx.data.x, block_idx.data.y, 0); +#endif + return idx; +} + +Int8InputTileIndex make_initial_int8_input_tile_index( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const int group_k4_offset) { + Int8InputTileIndex idx; +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + group_k4_offset; +#else + idx.data = ivec3(block_idx.data.x, block_idx.data.y, group_k4_offset); +#endif + return idx; +} + +void load_packed_int8_input_tile( + out Int8InputTile int8_tile, + const Int8InputTileIndex idx) { +#ifdef PACKED_INT8_INPUT_BUFFER + int8_tile.data[0][0] = t_packed_int8_input[idx.data]; +#else + int8_tile.data[0][0] = texelFetch(t_packed_int8_input, idx.data, 0); +#endif + + // Guard against unsupported tile sizes +#if TILE_M4 != 1 || TILE_K4 != 1 + not_implemented; +#endif +} + +void increment_k4(inout Int8InputTileIndex idx) { +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data += 1; +#else + idx.data.z += 1; +#endif +} + +#endif // CONV2D_INT8_INPUT_TILE_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh new file mode 100644 index 00000000000..0a490360f98 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_OUTPUT_TILE_STORE +#define CONV2D_INT8_OUTPUT_TILE_STORE + +#extension GL_EXT_control_flow_attributes : require + +#include "conv2d_common.glslh" +#include "linear_int8_output_tile.glslh" + +void store_packed_int8_output_tile( + const Int8OutTile int8_tile, + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { +#ifdef PACKED_INT8_OUTPUT_BUFFER + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + int buffer_idx = block_idx.data.y * block_extents.data_xz + + (block_idx.data.x + m4) * block_extents.data.z + block_idx.data.z; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (block_idx.data.x + m4 < block_extents.data.x && + block_idx.data.z + n4 < block_extents.data.z) { + t_packed_int8_output[buffer_idx++] = int8_tile.data[m4][n4]; + } + } + } +#else + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (block_idx.data.x + m4 < block_extents.data.x && + block_idx.data.z + n4 < block_extents.data.z) { + imageStore( + t_packed_int8_output, block_idx.data, int8_tile.data[m4][n4]); + } + } + } +#endif +} + +#endif // CONV2D_INT8_OUTPUT_TILE_STORE diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl new file mode 100644 index 00000000000..16c12b3ee5a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +void main() { + Conv2dBlockIndex output_block_idx; + output_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + output_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) { + return; + } + + Conv2dBlockExtents input_block_extents = make_block_extents(input_sizes); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_input_tile; + Int8WeightTile int8_weight_tile; + + Int8InputTileIndex input_idx = make_initial_int8_input_tile_index( + output_block_idx, input_block_extents); + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_packed_int8_input_tile(int8_input_tile, input_idx); + + load_int8_weight_tile( + int8_weight_tile, + output_block_idx.data.z, + k4, + output_block_extents.data.z); + + int_accumulate_with_int8_weight( + out_accum, int8_input_tile, int8_weight_tile); + + increment_k4(input_idx); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, output_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, output_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, output_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, output_block_idx, output_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml new file mode 100644 index 00000000000..23803dc6da1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_pw_q8ta_q8csw_q8to_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_pw_q8ta_q8csw_q8to_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh new file mode 100644 index 00000000000..279f4f17f13 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_Q8_UTILS_GLSLH +#define CONV2D_Q8_UTILS_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_int_accumulator.glslh" + +struct Int8InputWindow1D { + int[MAX_WINDOW_WIDTH] data; + int len; +}; + +Int8InputWindow1D initial_input_window() { + Int8InputWindow1D input_window; + for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) { + input_window.data[i] = 0; + } + input_window.len = 0; + return input_window; +} + +bool in_bounds( + const int block_w, + const int block_h, + const int block_c4, + const Conv2dBlockExtents block_extents) { + ivec3 idx = ivec3(block_w, block_h, block_c4); + if (any(lessThan(idx, ivec3(0)))) { + return false; + } + if (any(greaterThanEqual(idx, block_extents.data))) { + return false; + } + + return true; +} + +Int8InputWindow1D load_input_window( + const int w_start, + const int w_end, + const int h, + const int c4, + const Conv2dBlockExtents block_extents, + const ivec4 input_zps) { + Int8InputWindow1D input_window = initial_input_window(); + + const int block_w_start = div_4(w_start); + const int block_w_end = div_4(w_end); + + int window_i = 0; + for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) { + ivec4 input_block = input_zps; + + if (in_bounds(block_w, h, c4, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + h * block_extents.data_xz + block_w * block_extents.data.z + c4; + input_block = t_packed_int8_input[buffer_idx]; +#else + input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0); +#endif + } + + const int loaded_w_start = mul_4(block_w); + for (int row = 0; row < 4; ++row) { + if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) { + input_window.data[window_i++] = input_block[row]; + } + } + } + input_window.len = window_i; + return input_window; +} + +ivec4 load_weight_block( + const int ic4, + const int kx, + const int ky, + const int oc4, + const int IC4, + const int Kw, + const int Kh, + const int OC4) { +#ifdef PACKED_INT8_WEIGHTS_BUFFER + const int block_x = oc4 * Kw + kx; + const int block_y = ky * IC4 + ic4; + return t_packed_int8_weight[block_y * (Kw * OC4) + block_x]; +#else + return texelFetch( + t_packed_int8_weight, ivec2(oc4 * Kw + kx, ky * IC4 + ic4), 0); +#endif +} + +void perform_conv1d( + inout Int32Accum accum, + const Int8InputWindow1D input_window, + const ivec4 weight_block, + const int kx) { + [[unroll]] for (int out_w = 0; out_w < 4; ++out_w) { + const int window_i = out_w * conv2d_params.stride.x + kx; + [[unroll]] for (int out_c = 0; out_c < 4; ++out_c) { + accum.data[out_w][0][out_c] = dotPacked4x8AccSatEXT( + input_window.data[window_i], + weight_block[out_c], + accum.data[out_w][0][out_c]); + } + } +} + +#ifdef DEBUG_MODE + +void printInt8InputWindow1D(const Int8InputWindow1D input_window) { + debugPrintfEXT("Int8InputWindow1D contents (len = %d): \\n", input_window.len); + for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) { + ivec4 unpacked = unpack_int8x4(input_window.data[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + +void printWeightBlock(const ivec4 weight_block) { + debugPrintfEXT("WeightBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + ivec4 unpacked = unpack_int8x4(weight_block[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + +#endif // DEBUG_MODE + +#endif // CONV2D_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl new file mode 100644 index 00000000000..5839b13aeaa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define MAX_WINDOW_WIDTH 16 + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 1 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "im2col_packed_int8_utils.glslh" +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +#include "conv2d_q8_utils.glslh" + +void main() { + Conv2dBlockIndex out_block_idx; + out_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + out_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + out_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) { + return; + } + + const int out_w = mul_4(out_block_idx.data.x); + const int w_start = + (out_w * conv2d_params.stride.x) - conv2d_params.padding.x; + const int w_end = ((out_w + 3) * conv2d_params.stride.x) - + conv2d_params.padding.x + + (conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x; + + Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp))); + const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]); + + Int32Accum out_accum; + initialize(out_accum); + + const int IC4_per_group = div_up_4(conv2d_params.in_channels_per_group); + + const int n = mul_4(out_block_idx.data.z); + const int group_idx = n / conv2d_params.out_channels_per_group; + const int group_ic4_offset = group_idx * IC4_per_group; + + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + const int h = out_block_idx.data.y * conv2d_params.stride.y - + conv2d_params.padding.y + ky * conv2d_params.dilation.y; + + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { + Int8InputWindow1D int8_input_window = load_input_window( + w_start, + w_end, + h, + group_ic4_offset + ic4, + in_block_extents, + input_zps); + + for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + const ivec4 weight_block = load_weight_block( + ic4, + kx, + ky, + out_block_idx.data.z, + IC4_per_group, + conv2d_params.kernel_size.x, + conv2d_params.kernel_size.y, + out_block_extents.data.z); + + perform_conv1d(out_accum, int8_input_window, weight_block, kx); + } + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, out_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, out_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, out_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, out_block_idx, out_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml new file mode 100644 index 00000000000..5da9cc14584 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_q8to: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_q8to diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl new file mode 100644 index 00000000000..b44e37766fc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "im2col_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +void main() { + Conv2dBlockIndex output_block_idx; + output_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + output_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) { + return; + } + + const int n = mul_4(output_block_idx.data.z); + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int group_k4_offset = group_idx * conv2d_params.K4_per_group; + + Conv2dBlockExtents input_block_extents = make_block_extents(im2col_sizes); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_input_tile; + Int8WeightTile int8_weight_tile; + + Int8InputTileIndex input_idx = make_initial_int8_input_tile_index( + output_block_idx, input_block_extents, group_k4_offset); + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_packed_int8_input_tile(int8_input_tile, input_idx); + + load_int8_weight_tile( + int8_weight_tile, + output_block_idx.data.z, + k4, + output_block_extents.data.z); + + int_accumulate_with_int8_weight( + out_accum, int8_input_tile, int8_weight_tile); + + increment_k4(input_idx); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, output_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, output_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, output_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, output_block_idx, output_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml new file mode 100644 index 00000000000..fa92481f5ef --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_q8to_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_q8to_linear_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl new file mode 100644 index 00000000000..3ecaa597ecc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +$if STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "im2col_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_int8_output_tile_store.glslh" +#include "im2col_packed_int8_utils.glslh" + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes); + + Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx( + out_buf_idx, im2col_block_extents); + + if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) { + return; + } + + Im2ColBlockLoadIndices load_ixs = im2col_block_idx_to_load_ixs( + im2col_block_idx); + + Conv2dBlockExtents input_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(zp))); + Int8OutTile int8_im2col_tile; + int8_im2col_tile.data[0][0] = load_im2col_block( + load_ixs, input_block_extents, zp, input_zps); + + store_packed_int8_output_tile( + int8_im2col_tile, im2col_block_idx, im2col_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml new file mode 100644 index 00000000000..1c14f1fdc5a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +im2col_packed_int8: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + shader_variants: + - NAME: im2col_packed_int8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh new file mode 100644 index 00000000000..2b1870c493d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh @@ -0,0 +1,287 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef IM2COL_PACKED_INT8_GLSLH +#define IM2COL_PACKED_INT8_GLSLH + +#include "common.glslh" + +struct Conv2dBlockElementIndex { + int x4; + int y; + int z4; + + int row; + int col; +}; + +struct Im2ColBlockLoadIndices { + bool block_aligned; + bool cols_aligned; + bool rows_contiguous; + + int im2col_w_start; + int im2col_h; + int k_in_group_start; + int group_idx; + + Conv2dBlockElementIndex block_idx_start; +}; + +Conv2dBlockElementIndex tidx_to_block_elem_idx(const TensorIndex4D tidx) { + Conv2dBlockElementIndex block_idx; + block_idx.x4 = div_4(tidx.data.x); + block_idx.row = mod_4(tidx.data.x); + + block_idx.y = tidx.data.y; + + block_idx.z4 = div_4(tidx.data.z); + block_idx.col = mod_4(tidx.data.z); + + return block_idx; +} + +TensorIndex4D get_input_tensor_tidx( + const int w, + const int h, + const int k_in_group, + const int group_idx) { + TensorIndex4D tidx; + tidx.data.w = 0; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int row = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = row % conv2d_params.kernel_size.x; + const int kernel_y = row / conv2d_params.kernel_size.x; + + tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group; + + tidx.data.x = (w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + tidx.data.y = (h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + + return tidx; +} + +Im2ColBlockLoadIndices im2col_block_idx_to_load_ixs( + Conv2dBlockIndex im2col_block_idx) { + const int im2col_w = mul_4(im2col_block_idx.data.x); + const int im2col_h = im2col_block_idx.data.y; + const int im2col_k = mul_4(im2col_block_idx.data.z); + + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + TensorIndex4D input_tidx = + get_input_tensor_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + bool cols_aligned = (mod_4(input_tidx.data.z) == 0) && + (input_tidx.data.z + 3 < conv2d_params.in_channels_per_group); + + bool rows_aligned = mod_4(input_tidx.data.x) == 0; + bool rows_contiguous = conv2d_params.stride.x == 1; + + Im2ColBlockLoadIndices load_ixs; + load_ixs.block_aligned = cols_aligned && rows_aligned && rows_contiguous; + load_ixs.cols_aligned = cols_aligned; + load_ixs.rows_contiguous = rows_contiguous; + + load_ixs.im2col_w_start = im2col_w; + load_ixs.im2col_h = im2col_h; + load_ixs.k_in_group_start = k_in_group; + load_ixs.group_idx = group_idx; + + load_ixs.block_idx_start = tidx_to_block_elem_idx(input_tidx); + + return load_ixs; +} + +bool is_block_elem_idx_in_bounds( + const Conv2dBlockElementIndex idx, + const Conv2dBlockExtents block_extents) { + const ivec3 block_idx = ivec3(idx.x4, idx.y, idx.z4); + if (any(lessThan(block_idx, ivec3(0))) || + any(greaterThanEqual(block_idx, block_extents.data))) { + return false; + } + return true; +} + +int load_packed_int8_input_element( + const Conv2dBlockElementIndex idx, + const Conv2dBlockExtents block_extents, + const int input_zp) { + // bounds checking + if (!is_block_elem_idx_in_bounds(idx, block_extents)) { + return input_zp; + } +#ifdef PACKED_INT8_INPUT_BUFFER + const int buf_idx = + idx.y * block_extents.data_xz + idx.x4 * block_extents.data.z + idx.z4; + const ivec4 tile = t_packed_int8_input[buf_idx]; +#else + const ivec4 tile = + texelFetch(t_packed_int8_input, ivec3(idx.x4, idx.y, idx.z4), 0); +#endif + return extract_8bit_from_packed_int_le(tile[idx.row], idx.col); +} + +Conv2dBlockElementIndex get_packed_int8_input_element_idx( + const int im2col_w, + const int im2col_h, + const int k_in_group, + const int group_idx) { + TensorIndex4D input_tidx = + get_input_tensor_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + return tidx_to_block_elem_idx(input_tidx); +} + +ivec4 load_im2col_block_aligned( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buf_idx = load_ixs.block_idx_start.y * block_extents.data_xz + + load_ixs.block_idx_start.x4 * block_extents.data.z + + load_ixs.block_idx_start.z4; + return t_packed_int8_input[buf_idx]; +#else + return texelFetch( + t_packed_int8_input, + ivec3( + load_ixs.block_idx_start.x4, + load_ixs.block_idx_start.y, + load_ixs.block_idx_start.z4), + 0); +#endif +} + +ivec4 load_im2col_block_c_aligned_w_contiguous( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const ivec4 input_zps) { + ivec4 im2col_block; + Conv2dBlockElementIndex block_elem_idx = load_ixs.block_idx_start; + +#ifdef PACKED_INT8_INPUT_BUFFER + int buf_idx = load_ixs.block_idx_start.y * block_extents.data_xz + + load_ixs.block_idx_start.x4 * block_extents.data.z + + load_ixs.block_idx_start.z4; +#endif + + ivec4 in_block = input_zps; + if (is_block_elem_idx_in_bounds(block_elem_idx, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + in_block = t_packed_int8_input[buf_idx]; +#else + in_block = texelFetch( + t_packed_int8_input, + ivec3(block_elem_idx.x4, block_elem_idx.y, block_elem_idx.z4), + 0); +#endif + } + + int current_row = 0; + int r_limit = min(4 - block_elem_idx.row, 4); + for (int r = 0; r < r_limit; r++) { + im2col_block[current_row++] = in_block[r + block_elem_idx.row]; + } + + in_block = input_zps; + block_elem_idx.x4++; +#ifdef PACKED_INT8_INPUT_BUFFER + buf_idx += block_extents.data.z; +#endif + + if (is_block_elem_idx_in_bounds(block_elem_idx, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + in_block = t_packed_int8_input[buf_idx]; +#else + in_block = texelFetch( + t_packed_int8_input, + ivec3(block_elem_idx.x4, block_elem_idx.y, block_elem_idx.z4), + 0); +#endif + } + + for (int r = 0; current_row < 4; ++r) { + im2col_block[current_row++] = in_block[r]; + } + + return im2col_block; +} + +ivec4 load_im2col_block_no_alignment( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const int input_zp) { + ivec4 im2col_block; + + for (int r = 0; r < 4; r++) { + const int im2col_w = load_ixs.im2col_w_start + r; + ivec4 row_values; + for (int c = 0; c < 4; c++) { + const int k_in_group = load_ixs.k_in_group_start + c; + + if (k_in_group >= conv2d_params.logical_K_per_group) { + row_values[c] = input_zp; + continue; + } + + Conv2dBlockElementIndex block_idx = get_packed_int8_input_element_idx( + im2col_w, load_ixs.im2col_h, k_in_group, load_ixs.group_idx); + + row_values[c] = + load_packed_int8_input_element(block_idx, block_extents, input_zp); + } + + im2col_block[r] = pack_into_int32(row_values); + } + return im2col_block; +} + +ivec4 load_im2col_block( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const int input_zp, + const ivec4 input_zps) { + if (load_ixs.cols_aligned && load_ixs.rows_contiguous) { + return load_im2col_block_c_aligned_w_contiguous( + load_ixs, block_extents, input_zps); + } + return load_im2col_block_no_alignment(load_ixs, block_extents, input_zp); +} + +#ifdef DEBUG_MODE + +void printLoadIndices(const Im2ColBlockLoadIndices load_ixs) { + debugPrintfEXT("LoadIndices: \\n"); + + if (load_ixs.block_aligned) { + debugPrintfEXT(" block_aligned \\n"); + } + if (load_ixs.cols_aligned) { + debugPrintfEXT(" cols_aligned \\n"); + } + if (load_ixs.rows_contiguous) { + debugPrintfEXT(" rows_contiguous \\n"); + } + + debugPrintfEXT( + " block_idx_start: %d %d %d || %d %d \\n", + load_ixs.block_idx_start.x4, + load_ixs.block_idx_start.y, + load_ixs.block_idx_start.z4, + load_ixs.block_idx_start.row, + load_ixs.block_idx_start.col); +} + +#endif + +#endif // IM2COL_PACKED_INT8_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh index a6dbd7e78a2..8f19418cd19 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -43,13 +43,6 @@ ivec4 quantize( return clamp(ivec4(quantized), -128, 127); } -int pack_into_int32(const ivec4 quant_vals) { - int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | - ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); - - return packed; -} - void quantize_and_pack( out Int8InputBlock packed, const FPInputTile in_block, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh new file mode 100644 index 00000000000..14aa6558bfc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macro Settings: + * - TILE_M + * - TILE_N4 + */ + +#ifndef LINEAR_INT8_OUTPUT_TILE_GLSLH +#define LINEAR_INT8_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8OutTile { + ivec4 data[TILE_M4][TILE_N4]; +}; + +void initialize(out Int8OutTile tile) { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m4][n4] = ivec4(0); + } + } +} + +#ifdef DEBUG_MODE + +#include "linear_common.glslh" + +void printInt8OutTile(const Int8OutTile tile) { + debugPrintfEXT( + "Int8InputTile [TILE_M4=%d][TILE_N4=%d]:\\n", TILE_M4, TILE_N4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, n4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][n4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh new file mode 100644 index 00000000000..f909675984d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH +#define LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_int8_output_tile.glslh" +#include "linear_int_accumulator.glslh" +#include "linear_int_per_out_channel_params.glslh" + +void compute_int8_out_tile_with_int32_accum( + out Int8OutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const float output_q_inv_scale, + const int output_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales) { + ivec4 input_zp_vec = ivec4(-input_q_zp); + ivec4 output_zp_vec = ivec4(-output_q_zp); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + vec4 float_out_texel = + vec4(accum_adjusted) * vec4(weight_scales.data[n4] * input_q_scale); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_q_inv_scale) + output_q_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + out_tile.data[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } +} + +void compute_int8_out_tile_with_int32_accum( + out Int8OutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const float output_q_inv_scale, + const int output_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales, + const FPPerOutChannelParams bias) { + ivec4 input_zp_vec = ivec4(-input_q_zp); + ivec4 output_zp_vec = ivec4(-output_q_zp); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + vec4 float_out_texel = + fma(vec4(accum_adjusted), + vec4(weight_scales.data[n4]) * input_q_scale, + vec4(bias.data[n4])); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_q_inv_scale) + output_q_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + out_tile.data[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } +} + +// // overload of the above but with bias +// void accumulate_out_tile_with_int_accum( +// inout FPOutTile out_tile, +// const Int32Accum accum, +// const float input_q_scale, +// const int input_q_zp, +// const IntPerOutChannelParams weight_sums, +// const FPPerOutChannelParams weight_scales, +// const FPPerOutChannelParams bias) { +// ivec4 input_zp_vec = ivec4(-input_q_zp); +// [[unroll]] for (int m = 0; m < TILE_M; ++m) { +// [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { +// // Apply scale and zero points to the int accumulator +// ivec4 accum_adjusted = +// input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; +// out_tile.data[m][n4] = +// fma(VEC4_T(accum_adjusted), +// VEC4_T(input_q_scale * weight_scales.data[n4]), +// out_tile.data[m][n4]); +// out_tile.data[m][n4] += bias.data[n4]; +// } +// } +// } + +#endif // LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl index 0ad91643219..878821d4189 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl @@ -76,9 +76,6 @@ void main() { const int N4 = div_up_4(output_sizes.x); // number of texels in each row const int N8 = div_up_8(output_sizes.x); // number of texels in each row - bool should_print = (n8 == 0) && (m4 == 0); - should_print = false; - // VEC4_T out_texels[4][2]; FPOutTile out_tile; initialize(out_tile); diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl new file mode 100644 index 00000000000..da4162b6e58 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec3 orig_sizes; // [K_h, aligned_K_w, OC] +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" + +void main() { + // The size of the source weight tensor is [K_h, aligned_K_w, OC] for depthwise conv. + // Each shader invocation processes a 4x4 block of weights for a group of output channels. + const int oc4 = int(gl_GlobalInvocationID.x); + const int k4 = int(gl_GlobalInvocationID.y); + const int k = mul_4(k4); + + const int H = orig_sizes.x; + const int orig_W = orig_sizes.y; + const int W4 = div_up_4(orig_W); + const int OC = orig_sizes.z; + + const int h = k4 / W4; + const int w4 = k4 % W4; + const int w = mul_4(w4); + + // Determine the total number of blocks and check bounds + const int OC4 = div_up_4(OC); + const int K4 = H * W4; + + if (oc4 >= OC4 || k4 >= K4) { + return; + } + + ivec4 packed_block; + + int buf_idx = (h * orig_W + w) * OC4 + oc4; + int r_limit = min(4, orig_W - w); + [[unroll]] for (int r = 0; r < r_limit; r++) { + packed_block[r] = t_int8_weight[buf_idx]; + buf_idx += OC4; + } + [[unroll]] for (int r = r_limit; r < 4; r++) { + packed_block[r] = 0; + } + +#ifdef USING_BUFFER + t_packed_int8_weight[k4 * OC4 + oc4] = packed_block; +#else + imageStore(t_packed_int8_weight, ivec2(oc4, k4), packed_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml new file mode 100644 index 00000000000..9cfa3108ff0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_conv2d_dw_weights: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture2d + shader_variants: + - NAME: pack_q8_conv2d_dw_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl new file mode 100644 index 00000000000..e9982a8273d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +${define_required_extensions("int8")} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int8", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec4 orig_sizes; // [OC, K_h, K_w, IC] +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" + +void main() { + const int block_x = int(gl_GlobalInvocationID.x); + const int block_y = int(gl_GlobalInvocationID.y); + + const int kx = block_x % orig_sizes.z; + const int oc4 = block_x / orig_sizes.z; + + const int OC4 = div_up_4(orig_sizes.x); + const int IC4 = div_up_4(orig_sizes.w); + + const int nblocks_x = orig_sizes.z * OC4; + const int nblocks_y = IC4 * orig_sizes.y; + + const int ic4 = block_y % IC4; + const int ky = block_y / IC4; + + if (block_x >= nblocks_x || block_y >= nblocks_y) { + return; + } + + const int oc = mul_4(oc4); + const int ic = mul_4(ic4); + + const int oc_stride = align_up_4(orig_sizes.y * orig_sizes.z * orig_sizes.w); + const int oc_offset = oc * oc_stride; + const int ky_offset = ky * (orig_sizes.z * orig_sizes.w); + const int kx_offset = kx * orig_sizes.w; + int buf_idx = oc_offset + ky_offset + kx_offset + ic; + + ivec4 packed_block = ivec4(0); + for (int row = 0; row < 4; row++) { + if (oc + row < orig_sizes.x) { + ivec4 weight_vals = ivec4(0); + for (int col = 0; col < 4; col++) { + if (ic + col < orig_sizes.w) { + weight_vals[col] = int(t_int8_weight[buf_idx + col]); + } + } + packed_block[row] = pack_into_int32(weight_vals); + } + buf_idx += oc_stride; + } + +#ifdef USING_BUFFER + const int out_buf_idx = block_y * (nblocks_x) + block_x; + t_packed_int8_weight[out_buf_idx] = packed_block; +#else + imageStore(t_packed_int8_weight, ivec2(block_x, block_y), packed_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml new file mode 100644 index 00000000000..9331de6e758 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_conv2d_weights: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture2d + shader_variants: + - NAME: pack_q8_conv2d_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh index 03132db1348..1880397181d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -44,7 +44,6 @@ void load_k_cache_tile_no_checks( const int context_len, const int C, const int KV_H) { - bool should_print = d4_start == 0 && c_start == 0 && kv_h == 0; [[unroll]] for (int c = 0; c < TILE_N; ++c) { const int c4 = div_4(c); const int c4i = mod_4(c); diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl index ed7dd25421a..798366b523a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl @@ -28,7 +28,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 6c701224f7f..71690ffc604 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -56,4 +56,27 @@ utils::uvec3 pick_hw_square_wg_size( return {16u, 4u, 1u}; } +utils::uvec3 pick_wc_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + // Some inactive invocations are okay; set 6 as the threshold to use the + // a square wg size. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[2u] >= 6) { + return {8u, 1u, 8u}; + } + // If channels dim is sufficiently small, then bias towards width dim to + // reduce the number of inactive invocations. + if (global_workgroup_size[2u] < 2u) { + return {64u, 1u, 1u}; + } + return {16u, 1u, 4u}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index 1831ab2a845..b412f737c13 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -54,4 +54,11 @@ utils::uvec3 pick_hw_square_wg_size( const std::vector& args, const std::vector& resize_args); +utils::uvec3 pick_wc_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index f6eee4ba12e..fb55822619f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -13,12 +13,93 @@ #include #include +#include + namespace vkcompute { // // Utility functions // +bool is_pointwise(ComputeGraph* graph, const ValueRef& kernel_size) { + const auto kernel_size_list = graph->get_int_list(kernel_size); + return kernel_size_list->at(0) == 1 && kernel_size_list->at(1) == 1; +} + +bool is_s1p1d1( + ComputeGraph* graph, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation) { + const auto stride_list = graph->get_int_list(stride); + const auto padding_list = graph->get_int_list(padding); + const auto dilation_list = graph->get_int_list(dilation); + if (stride_list->at(0) != 1 && stride_list->at(1) != 1) { + return false; + } + if (padding_list->at(0) != 1 && padding_list->at(1) != 1) { + return false; + } + if (dilation_list->at(0) != 1 && dilation_list->at(1) != 1) { + return false; + } + return true; +} + +bool is_s1p0d1_pointwise( + ComputeGraph* graph, + const ValueRef& kernel_size, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation) { + if (is_pointwise(graph, kernel_size)) { + const auto stride_list = graph->get_int_list(stride); + const auto padding_list = graph->get_int_list(padding); + const auto dilation_list = graph->get_int_list(dilation); + if (stride_list->at(0) != 1 && stride_list->at(1) != 1) { + return false; + } + if (padding_list->at(0) != 0 && padding_list->at(1) != 0) { + return false; + } + if (dilation_list->at(0) != 1 && dilation_list->at(1) != 1) { + return false; + } + return true; + } + return false; +} + +bool should_use_im2col( + ComputeGraph* graph, + const ValueRef kernel_size, + const ValueRef groups) { + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Always use im2col for pointwise convolutions + if (kernel_size_list->at(0) * kernel_size_list->at(1) == 1) { + return true; + } + + // For large kernel sizes, the im2col matrix will be too big. Not only will + // this result in a larger footprint for the im2col matrix, but the cost of + // performing the im2col procedure will also become prohibitive. In these + // cases it is faster to just compute convolution directly without going + // through im2col. + if (kernel_size_list->at(0) * kernel_size_list->at(1) <= 10) { + const int64_t groups_val = graph->get_int(groups); + // Do not use im2col for grouped convolutions; manual experimentation shows + // that im2col becomes very slow when dealing with grouped convolutions. + // The reason for this is likely that memory access in the im2col shader + // becomes too non-linear due to needed to keep convolution groups + // contiguous in memory. + if (groups_val == 1) { + return true; + } + } + return false; +} + struct Conv2DParams { utils::ivec2 kernel_size; utils::ivec2 stride; @@ -135,6 +216,44 @@ std::vector calculate_input_im2col_sizes( return {M, K}; } +std::vector calculate_packed_int8_input_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + // const int64_t batches = utils::val_at(-4, out_sizes); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // Represents the number of channel groups + const int64_t groups_val = graph->extract_scalar(groups); + // No need to div_up because in_channels % groups_val = 0 + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to the next multiple of 4 to ensure that data loads align nicely with + // texel boundaries. We want to ensure that the first data element of each + // group is at the start of its texel. + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (repeated for each group) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane. This is aligned to the next + // multiple of 4 since the im2col shader operates on 4x4 blocks. + const int64_t W = utils::align_up_4(out_width); + const int64_t H = out_height; + + return {K, H, W}; +} + std::vector calculate_output_im2col_sizes( ComputeGraph* graph, const ValueRef& output) { @@ -212,6 +331,33 @@ utils::uvec3 im2col_global_wg_size( return {K4, M4, 1}; } +utils::uvec3 im2col_packed_int8_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input_im2col = args.at(0).refs.at(0); + + std::vector im2col_sizes = graph->sizes_of(input_im2col); + const uint32_t K = utils::safe_downcast(im2col_sizes[0]); + const uint32_t H = utils::safe_downcast(im2col_sizes[1]); + const uint32_t W = utils::safe_downcast(im2col_sizes[2]); + + const uint32_t K4 = utils::div_up(K, 4u); + const uint32_t W4 = utils::div_up(W, 4u); + + return {K4 * W4 * H, 1, 1}; +} + +utils::uvec3 im2col_packed_int8_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return {64, 1, 1}; +} + utils::uvec3 col2im_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -231,6 +377,229 @@ utils::uvec3 col2im_global_wg_size( return {N4, M4, 1}; } +utils::uvec3 pick_static_quantized_conv2d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + uint32_t C_per_tile = 4; + uint32_t W_per_tile = 4; + + if (shader.kernel_name.find("linear") != std::string::npos) { + C_per_tile = 8; + } + + const uint32_t num_W_tiles = utils::div_up(W, W_per_tile); + const uint32_t num_C_tiles = utils::div_up(C, C_per_tile); + + return {num_C_tiles, num_W_tiles, H}; +} + +utils::uvec3 pick_static_quantized_conv2d_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +utils::uvec3 int8_conv2d_dw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {C4 * W4 * H, 1, 1}; +} + +// +// Prepack nodes +// + +ValueRef prepack_quantized_conv2d_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef input, + const ValueRef output, + const ValueRef groups, + const ValueRef kernel_size) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const int32_t groups_val = graph.get_int(groups); + + const int64_t OC = graph.size_at(-3, output); + const int64_t IC = graph.size_at(-3, input) / groups_val; + + int64_t K_h; + int64_t K_w; + + { + const auto kernel_size_list = graph.get_int_list(kernel_size); + K_h = kernel_size_list->at(0); + K_w = kernel_size_list->at(1); + } + + const int64_t num_blocks_OC = utils::div_up_4(OC); + const int64_t num_blocks_IC = utils::div_up_4(IC); + + const int64_t num_blocks_y = num_blocks_IC * K_h; + const int64_t num_blocks_x = K_w * num_blocks_OC; + + // The packed tensor arranges blocks as [OC_blocks * K_total, IC_blocks] + const int64_t output_height = num_blocks_y; + const int64_t output_width = num_blocks_x * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec4 orig_sizes = { + utils::safe_downcast(OC), + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(IC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_x), + utils::safe_downcast(num_blocks_y), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec4))})); + + return packed_weight; +} + +ValueRef prepack_quantized_conv2d_dw_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef kernel_size) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + std::vector weight_orig_sizes = graph.sizes_of(weight_data); + const int64_t ndim = graph.dim_of(weight_data); + + // For depthwise convolution, expect weight layout [K_h, aligned_K_w, OC] + VK_CHECK_COND(ndim == 3); + int64_t K_h = weight_orig_sizes.at(0); + int64_t K_w = weight_orig_sizes.at(1); + int64_t aligned_K_w = utils::align_up_4(K_w); + int64_t OC = weight_orig_sizes.at(2); + + // The packing format packs the weight tensor into blocks of 4 output channels + // (OC) and 4 kernel elements (K_h * aligned_K_w) + int64_t OC_per_block = 4; + int64_t K_per_block = 4; + + // To figure out the size of the output tensor, determine the number of blocks + // along each dimension. + const int64_t total_K_elements = K_h * aligned_K_w; + const int64_t num_blocks_K = utils::div_up(total_K_elements, K_per_block); + const int64_t num_blocks_OC = utils::div_up(OC, OC_per_block); + + // The blocks are arranged in a transposed manner, such that the transposed + // weight block is indexed like packed_weights[k4][oc4] - this is to allow for + // optimal memory coalescing when computing the depthwise convolution. + int64_t output_height = num_blocks_K; + // The base dtype of the packed tensor is int32 (each int32 contains 4x 8bit + // values) and each block is represented as a ivec4. Therefore the width dim + // of the packed tensor is multiplied by 4. + int64_t output_width = num_blocks_OC * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec3 orig_sizes = { + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(OC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_OC), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_dw_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec3))})); + + return packed_weight; +} + // // Dispatch nodes // @@ -285,6 +654,57 @@ void add_input_im2col_node( nullptr)); } +void add_input_im2col_packed_int8_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output, + const ValueRef input_im2col) { + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + std::string kernel_name = "im2col_packed_int8"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_im2col)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(input_im2col), + graph.sizes_ubo(output), + graph.sizes_ubo(input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_packed_int8_global_wg_size, + im2col_packed_int8_local_wg_size, + // Inputs and Outputs + {{input_im2col, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + void add_quantize_and_pack_q8ta_conv2d_input_node( ComputeGraph& graph, const ValueRef fp_input, @@ -314,7 +734,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node( graph, VK_KERNEL_FROM_STR(kernel_name), pick_quantize_and_pack_conv2d_input_global_wg_size, - default_pick_local_wg_size, + pick_wc_square_wg_size, // Inputs and Outputs {{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, // Shader params buffers @@ -590,54 +1010,229 @@ void add_conv2d_q8ta_q8csw_linear_node( nullptr)); } -// -// High level operator impl -// - -void quantized_conv2d_impl( +void add_conv2d_q8ta_q8csw_q8to_node( ComputeGraph& graph, - const QuantizationConfig& input_quant_config, - const QuantizationConfig& weight_quant_config, - const ValueRef input_image, + const ValueRef packed_int8_input, + const ValueRef packed_int8_input_im2col, const ValueRef input_scale, const ValueRef input_zp, - const ValueRef weight_data, - const ValueRef weight_sums_data, - const ValueRef weight_scales_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, const ValueRef bias_data, + const ValueRef packed_bias, const ValueRef kernel_size, const ValueRef stride, const ValueRef padding, const ValueRef dilation, const ValueRef groups, - const ValueRef output_image) { - VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); - VK_CHECK_COND(weight_quant_config.nbits == 8); - VK_CHECK_COND(weight_quant_config.is_symmetric); + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); - const ValueRef packed_weight = - prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); - ValueRef packed_weight_scales = prepack_standard( - graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const bool use_im2col = should_use_im2col(&graph, kernel_size, groups); - // Create a dummy tensor to fill the binding slot of the bias tensor if it is - // not provided. This helps simplify dispatch logic and makes it so that - // fewer shader variants need to be generated. - TmpTensor dummy_bias( - &graph, - {}, - graph.dtype_of(output_image), - utils::kBuffer, - utils::kWidthPacked); + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); - ValueRef packed_bias = dummy_bias.vref; - if (!graph.val_is_none(bias_data)) { - packed_bias = - prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); - } + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); - std::vector input_im2col_sizes = calculate_input_im2col_sizes( - &graph, input_image, output_image, kernel_size, groups); + std::string kernel_name = use_im2col ? "conv2d_q8ta_q8csw_q8to_linear_tiled" + : "conv2d_q8ta_q8csw_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), + graph.sizes_ubo(packed_int8_input_im2col), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_static_quantized_conv2d_global_wg_size, + pick_static_quantized_conv2d_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input_im2col, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_dw_q8ta_q8csw_q8to_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // Verify this is actually a depthwise convolution + const int64_t groups_val = graph.extract_scalar(groups); + const int64_t in_channels = graph.size_at(-3, packed_int8_input); + VK_CHECK_COND(groups_val == in_channels); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "conv2d_dw_q8ta_q8csw_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), + graph.sizes_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + int8_conv2d_dw_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void quantized_conv2d_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef input_image, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef bias_data, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image) { + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shader variants need to be generated. + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(output_image), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (!graph.val_is_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + std::vector input_im2col_sizes = calculate_input_im2col_sizes( + &graph, input_image, output_image, kernel_size, groups); // Use weight only quantized conv2d if at least one is true: // 1. Device does not support int8 dot product @@ -805,10 +1400,244 @@ void conv2d_q8csw(ComputeGraph& graph, const std::vector& args) { output_image); } +// Implementation for statically quantized conv2d, which expects input, weight, +// and output tensors to all have packed int8 dtype/memory layout. +void static_quantized_conv2d_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const QuantizationConfig& output_quant_config, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + // Currently, only certain quantization configs are supported + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + VK_CHECK_COND(input_quant_config.nbits == 8); + + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + VK_CHECK_COND(output_quant_config.granularity == kPerTensor); + VK_CHECK_COND(output_quant_config.nbits == 8); + + // Check for depthwise conv + const int64_t groups_val = graph.extract_scalar(groups); + const int64_t in_channels = graph.size_at(-3, packed_int8_input); + + // Depthwise convs have a specialized implementation, since the regular conv + // implementations requires that the number of input and output channels per + // groups is a multiple of 4. This is so that all values that are part of the + // same 4Wx4C block have the same group index. + const bool is_depthwise = (groups_val == in_channels); + + const bool use_im2col = should_use_im2col(&graph, kernel_size, groups); + // For pointwise convolution with stride = 1, padding = 0, dilation = 1, the + // input tensor is already equivalent to its im2col representation. In this + // case we can skip the im2col procedure and pass in the input image to the + // convolution_as_matmul implementation directly. + const bool is_optimizable_pw = + is_s1p0d1_pointwise(&graph, kernel_size, stride, padding, dilation); + + ValueRef packed_weight; + if (is_depthwise) { + packed_weight = prepack_quantized_conv2d_dw_weight( + graph, weight_quant_config, weight_data, kernel_size); + } else if (use_im2col) { + packed_weight = prepack_quantized_linear_weight( + graph, weight_quant_config, weight_data); + } else { + packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + } + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // See quantized_conv2d_impl for why this is needed + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Depthwise conv path + if (is_depthwise) { + add_conv2d_dw_q8ta_q8csw_q8to_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); + return; + } + + std::vector input_im2col_sizes = + calculate_packed_int8_input_im2col_sizes( + &graph, packed_int8_input, packed_int8_output, kernel_size, groups); + + ValueRef packed_int8_input_im2col = packed_int8_input; + if (use_im2col && !is_optimizable_pw) { + TmpTensor packed_int8_input_im2col_tensor( + &graph, + input_im2col_sizes, + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + packed_int8_input_im2col = packed_int8_input_im2col_tensor.vref; + + add_input_im2col_packed_int8_node( + graph, + packed_int8_input, + input_scale, + input_zp, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output, + packed_int8_input_im2col); + } + + add_conv2d_q8ta_q8csw_q8to_node( + graph, + packed_int8_input, + packed_int8_input_im2col, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); +} + +void conv2d_q8ta_q8csw_q8to( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig input_quant_config(8, kPerTensor, {}); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + QuantizationConfig output_quant_config(8, kPerTensor, {}); + + static_quantized_conv2d_impl( + graph, + input_quant_config, + weight_quant_config, + output_quant_config, + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); +} + // // Quantize and dequantize operators // +void quantize_q8ta_for_conv2d( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef packed_int8_input = args.at(idx++); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input, scale, zero_point, packed_int8_input); +} + +void dequantize_q8to_from_conv2d( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_output = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + add_unpack_and_dequantize_q8ta_conv2d_output_node( + graph, packed_int8_output, scale, zero_point, fp_output); +} + void qdq8ta_conv2d_input( ComputeGraph& graph, const std::vector& args) { @@ -832,10 +1661,82 @@ void qdq8ta_conv2d_input( graph, packed_int8_input, scale, zero_point, fp_output); } +// +// Test operators +// + +void conv2d_q8ta_q8csw_q8to_test( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector conv2d_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; + + conv2d_q8ta_q8csw_q8to(graph, conv2d_args); + + add_unpack_and_dequantize_q8ta_conv2d_output_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw); VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); + VK_REGISTER_OP(etvk.conv2d_q8ta_q8csw_q8to.test, conv2d_q8ta_q8csw_q8to_test); + VK_REGISTER_OP( + et_vk.quantize_q8ta_for_conv2d.default, quantize_q8ta_for_conv2d); + VK_REGISTER_OP( + et_vk.dequantize_q8to_from_conv2d.default, dequantize_q8to_from_conv2d); + VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw_q8to.default, conv2d_q8ta_q8csw_q8to); + VK_REGISTER_OP( + et_vk.conv2d_q8ta_q8csw_q8to_dw.default, conv2d_q8ta_q8csw_q8to); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/conv2d_utils.cpp b/backends/vulkan/test/custom_ops/conv2d_utils.cpp new file mode 100644 index 00000000000..74c26cef5a1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d_utils.cpp @@ -0,0 +1,10 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "conv2d_utils.h" + +// Implementation file for conv2d utilities. +// Currently all functionality is implemented inline in the header. diff --git a/backends/vulkan/test/custom_ops/conv2d_utils.h b/backends/vulkan/test/custom_ops/conv2d_utils.h new file mode 100644 index 00000000000..cad52219062 --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d_utils.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace executorch { +namespace vulkan { +namespace prototyping { + +// Component structs for better readability +struct KernelSize { + int32_t h; + int32_t w; + + KernelSize(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Stride { + int32_t h; + int32_t w; + + Stride(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Padding { + int32_t h; + int32_t w; + + Padding(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Dilation { + int32_t h; + int32_t w; + + Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} +}; + +struct OutInChannels { + int32_t out; + int32_t in; + + OutInChannels(int32_t out_channels, int32_t in_channels) + : out(out_channels), in(in_channels) {} +}; + +struct InputSize2D { + int32_t h; + int32_t w; + + InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} +}; + +// Conv2d configuration struct +struct Conv2dConfig { + OutInChannels channels; + InputSize2D input_size; + KernelSize kernel; + Stride stride; + Padding padding; + Dilation dilation; + int32_t groups; // Number of groups for grouped convolution + std::string test_case_name = "placeholder"; + std::string op_name = "conv2d"; + + // Calculate output dimensions + int64_t get_output_height() const { + return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / + stride.h + + 1; + } + + int64_t get_output_width() const { + return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / + stride.w + + 1; + } +}; + +} // namespace prototyping +} // namespace vulkan +} // namespace executorch diff --git a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp index d566e5b2646..219bccb04c3 100644 --- a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp @@ -8,6 +8,7 @@ #include #include #include +#include "conv2d_utils.h" #include "utils.h" #include @@ -18,76 +19,6 @@ using namespace vkcompute; static constexpr int64_t kRefDimSizeLimit = 100; -// Component structs for better readability -struct KernelSize { - int32_t h; - int32_t w; - - KernelSize(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Stride { - int32_t h; - int32_t w; - - Stride(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Padding { - int32_t h; - int32_t w; - - Padding(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Dilation { - int32_t h; - int32_t w; - - Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} -}; - -struct OutInChannels { - int32_t out; - int32_t in; - - OutInChannels(int32_t out_channels, int32_t in_channels) - : out(out_channels), in(in_channels) {} -}; - -struct InputSize2D { - int32_t h; - int32_t w; - - InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} -}; - -// Conv2d configuration struct -struct Conv2dConfig { - OutInChannels channels; - InputSize2D input_size; - KernelSize kernel; - Stride stride; - Padding padding; - Dilation dilation; - int32_t groups; // Number of groups for grouped convolution - std::string test_case_name = "placeholder"; - std::string op_name = "conv2d_q8ta_q8csw"; - - // Calculate output dimensions - int64_t get_output_height() const { - return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / - stride.h + - 1; - } - - int64_t get_output_width() const { - return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / - stride.w + - 1; - } -}; - // Utility function to create a test case from a Conv2dConfig TestCase create_test_case_from_config( const Conv2dConfig& config, @@ -366,13 +297,20 @@ std::vector generate_quantized_conv2d_test_cases() { Stride(1, 1), Padding(1, 1), Dilation(1, 1), - 8}, + 1}, {OutInChannels(128, 64), InputSize2D(128, 128), KernelSize(3, 3), Stride(1, 1), Padding(1, 1), Dilation(1, 1), + 1}, + {OutInChannels(128, 1024), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), 1}}; // Test with different storage types and data types @@ -394,6 +332,7 @@ std::vector generate_quantized_conv2d_test_cases() { std::to_string(config.kernel.h) + "/" + std::to_string(config.kernel.w); + config.op_name = "conv2d_q8ta_q8csw"; config.test_case_name = prefix + suffix; // The default operator tested is activation + weight quantized conv2d; // however, only test this if the int8 dot product extension is supported @@ -763,7 +702,7 @@ int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { int main(int argc, char* argv[]) { set_debugging(false); set_print_output(false); - set_print_latencies(false); + set_print_latencies(true); set_use_gpu_timestamps(true); print_performance_header(); diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp new file mode 100644 index 00000000000..8762fe4c0d1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp @@ -0,0 +1,628 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "conv2d_utils.h" +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Utility function to create a test case from a Conv2dConfig +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".test"; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, C_in_per_group * K_h * K_w] + // Memory layout: height, width, then channels - in_c is innermost (stride 1) + // in the second dimension + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + const int64_t aligned_out_channels = utils::align_up_4(config.channels.out); + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {aligned_out_channels}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {aligned_out_channels}, // Per output channel + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {aligned_out_channels}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Output quantization parameters + // float output_scale_val = 0.01432; + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8csw_q8to operation + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate easy test cases for quantized conv2d operation (for debugging) +std::vector generate_quantized_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Conv2dConfig config = { + OutInChannels(16, 8), // channels (out, in) + InputSize2D(21, 17), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(1, 1), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 2, // groups + }; + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized conv2d operation +std::vector generate_quantized_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Pointwise convolutions: kernel size 1x1 + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(32, 32), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(96, 64), + InputSize2D(16, 16), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(13, 7), + InputSize2D(57, 33), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + // General 2D convolutions + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(16, 32), + InputSize2D(77, 77), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Grouped convolutions + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 2}, + {OutInChannels(96, 96), + InputSize2D(81, 81), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 3}, + {OutInChannels(96, 96), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}, + // Performance cases (pointwise) + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + // Performance cases (general 2d convs) + {OutInChannels(32, 3), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Generate test case name programmatically + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + std::string prefix = is_performance ? "performance_" : "correctness_"; + std::string suffix = std::to_string(config.channels.out) + "/" + + std::to_string(config.channels.in) + "_" + + std::to_string(config.input_size.h) + "/" + + std::to_string(config.input_size.w) + "_" + + std::to_string(config.kernel.h) + "/" + + std::to_string(config.kernel.w); + + config.op_name = "conv2d_q8ta_q8csw_q8to"; + config.test_case_name = prefix + suffix; + + // Only test q8ta_q8csw_q8to if the int8 dot product extension is + // supported + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + } + + return test_cases; +} + +// Reference implementation for activation, weight, and output quantized conv2d +void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [C_out, C_in_per_group * K_h * K_w] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + // Calculate channels per group for grouped convolution + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Perform activation, weight, and output quantized conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // Determine which group this output channel belongs to + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + int64_t in_c_end = (group_idx + 1) * C_in_per_group; + + // Convolution operation with integer accumulation + for (int64_t in_c = in_c_start; in_c < in_c_end; ++in_c) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight (already int8) + // Weight layout: [C_out, C_in_per_group * K_h * K_w] + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + conv2d_q8ta_q8csw_q8to_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized conv2d operation +int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { + int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized conv2d operation + // Each output element requires: + // - C_in * K_h * K_w multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + int64_t flop = output_elements * (ops_per_output); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_conv2d_test_cases, + quantized_conv2d_flop_calculator, + "QuantizedConv2dQ8ToQ8To", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp new file mode 100644 index 00000000000..c259b45de06 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp @@ -0,0 +1,592 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "conv2d_utils.h" +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Utility function to create a test case from a Conv2dConfig for depthwise +// convolution +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".test"; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor", false, 64); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) for depthwise convolution + // Memory layout: [K_h, K_w, OC] + // For depthwise conv: groups = channels.out, in_channels_per_group = 1 + std::vector weight_size = { + config.kernel.h, config.kernel.w, config.channels.out}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor", false, 64); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.channels.out}, // Per output channel + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights for depthwise layout + // For depthwise conv: each output channel has K_h * K_w weights + // Custom computation for depthwise layout [K_h, K_w, OC] + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_int8_data(); + + weight_sums_data.resize(config.channels.out); + + for (int64_t out_c = 0; out_c < config.channels.out; ++out_c) { + int32_t sum = 0; + for (int64_t kh = 0; kh < config.kernel.h; ++kh) { + for (int64_t kw = 0; kw < config.kernel.w; ++kw) { + // Weight indexing for depthwise layout [K_h, K_w, OC] + int64_t weight_idx = kh * (config.kernel.w * config.channels.out) + + kw * config.channels.out + out_c; + sum += static_cast(quantized_weight_data[weight_idx]); + } + } + weight_sums_data[out_c] = sum; + } + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + + // Output quantization parameters + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8csw_q8to operation + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate easy test cases for quantized depthwise conv2d operation (for +// debugging) +std::vector generate_quantized_conv2d_dw_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging - depthwise convolution + Conv2dConfig config = { + OutInChannels(8, 8), // channels (out, in) - equal for depthwise + InputSize2D(8, 8), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(2, 2), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 8, // groups = channels.out for depthwise + }; + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized depthwise conv2d operation +std::vector generate_quantized_conv2d_dw_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Depthwise convolutions: groups = channels.out, channels.in = + // channels.out + {OutInChannels(32, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 32}, + {OutInChannels(64, 64), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 64}, + {OutInChannels(64, 64), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 64}, + {OutInChannels(80, 80), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 80}, + {OutInChannels(16, 16), + InputSize2D(57, 33), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 16}, + // Different kernel sizes for depthwise + {OutInChannels(32, 32), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 32}, + {OutInChannels(96, 96), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 96}, + // Performance cases + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 128}, + {OutInChannels(64, 64), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 64}, + {OutInChannels(288, 288), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 288}, + {OutInChannels(32, 32), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 32}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Generate test case name programmatically + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + std::string prefix = + is_performance ? "performance_dw_" : "correctness_dw_"; + std::string suffix = std::to_string(config.channels.out) + "/" + + std::to_string(config.channels.in) + "_" + + std::to_string(config.input_size.h) + "/" + + std::to_string(config.input_size.w) + "_" + + std::to_string(config.kernel.h) + "/" + + std::to_string(config.kernel.w); + + config.op_name = "conv2d_q8ta_q8csw_q8to"; + config.test_case_name = prefix + suffix; + + // Only test q8ta_q8csw_q8to if the int8 dot product extension is + // supported + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + } + + return test_cases; +} + +// Reference implementation for activation, weight, and output quantized +// depthwise conv2d +void conv2d_q8ta_q8csw_q8to_dw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [K_h, align_up_4(K_w), OC] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Verify this is a depthwise convolution + if (groups != C_out || C_in != C_out) { + throw std::invalid_argument( + "This is not a depthwise convolution configuration"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform activation, weight, and output quantized depthwise conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // For depthwise convolution, each output channel corresponds to one + // input channel + int64_t in_c = out_c; + + // Convolution operation with integer accumulation + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight using depthwise layout [K_h, K_w, OC] + int64_t weight_idx = kh * (K_w * C_out) + kw * C_out + out_c; + int8_t quantized_weight = weight_data[weight_idx]; + + if (false && in_w == 0 && in_h == 0 && out_c == 0) { + std::cout << "input: " << input_data[input_idx] << std::endl; + std::cout << "quantized_input: " << (int)quantized_input + << std::endl; + std::cout << "quantized_weight: " << (int)quantized_weight + << std::endl; + } + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = kh * (K_w * C_out) + kw * C_out + out_c; + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + if (false && out_c < 4 && out_h < 1 && out_w < 4) { + std::cout << "int_sum[" << out_c << ", " << out_h << ", " << out_w + << "] = " << int_sum << ", " << float_result << ", " + << output_scale << ", " << quant_output_f << std::endl; + } + + // Dequantize back to float + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + conv2d_q8ta_q8csw_q8to_dw_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized depthwise conv2d operation +int64_t quantized_conv2d_dw_flop_calculator(const TestCase& test_case) { + int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized depthwise conv2d operation + // Each output element requires: + // - K_h * K_w multiply-accumulate operations (only one input channel per + // output channel) + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = K_h * K_w; + + int64_t flop = output_elements * ops_per_output; + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Depthwise Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_conv2d_dw_test_cases, + quantized_conv2d_dw_flop_calculator, + "QuantizedDepthwiseInt8Conv2d", + 0, + 1, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 1d1b1fe79bd..959e013981c 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -60,9 +60,11 @@ def define_common_targets(is_fbcode = False): ], headers = [ "utils.h", + "conv2d_utils.h", ], exported_headers = [ "utils.h", + "conv2d_utils.h", ], platforms = get_platforms(), deps = [ @@ -98,3 +100,5 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("choose_qparams_per_row") define_custom_op_test_binary("q4gsw_linear") define_custom_op_test_binary("qdq8ta_conv2d_activations") + define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d") + define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d_dw") diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 2aa827a4d5a..4de6c32ac25 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -661,7 +661,12 @@ float collect_gpu_timing_us(ComputeGraph& graph) { float total_duration_us = 0.0f; for (const auto& shader_result : results) { if (shader_result.kernel_name.find("nchw_to") == std::string::npos && - shader_result.kernel_name.find("to_nchw") == std::string::npos) { + shader_result.kernel_name.find("to_nchw") == std::string::npos && + shader_result.kernel_name.find( + "quantize_and_pack_q8ta_conv2d_input") == std::string::npos && + shader_result.kernel_name.find( + "unpack_and_dequantize_q8ta_conv2d_output") == + std::string::npos) { // Calculate duration from start and end times, convert from ns to μs uint64_t duration_ns = shader_result.end_time_ns - shader_result.start_time_ns; @@ -1715,6 +1720,41 @@ void compute_weight_sums( } } +// Compute weight sums for 4D quantized conv2d operations +// Weight layout: [C_out, K_h, K_w, align_up_4(C_in_per_group)] +void compute_weight_sums_4d( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t out_channels, + int64_t kernel_h, + int64_t kernel_w, + int64_t aligned_in_channels) { + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_int8_data(); + + weight_sums_data.resize(out_channels); + + // For each output channel, compute the sum of quantized weights + for (int64_t out_c = 0; out_c < out_channels; ++out_c) { + int32_t sum = 0; + + for (int64_t kh = 0; kh < kernel_h; ++kh) { + for (int64_t kw = 0; kw < kernel_w; ++kw) { + for (int64_t in_c = 0; in_c < aligned_in_channels; ++in_c) { + // Weight indexing: [out_c, kh, kw, in_c] + int64_t weight_idx = + out_c * (kernel_h * kernel_w * aligned_in_channels) + + kh * (kernel_w * aligned_in_channels) + kw * aligned_in_channels + + in_c; + sum += static_cast(quantized_weight_data[weight_idx]); + } + } + } + + weight_sums_data[out_c] = sum; + } +} + // Helper function to unpack 4-bit values from uint8 (same as in // q4gsw_linear.cpp) std::pair unpack_4bit_utils(uint8_t packed) { diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index f1736f1d144..b80f28639e8 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -653,6 +653,16 @@ void compute_weight_sums( int64_t out_features, int64_t elements_per_output_feature); +// Compute weight sums for 4D quantized conv2d operations +// Weight layout: [C_out, K_h, K_w, align_up_4(C_in_per_group)] +void compute_weight_sums_4d( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t out_channels, + int64_t kernel_h, + int64_t kernel_w, + int64_t aligned_in_channels); + // Compute weight sums for 4-bit group symmetric quantized weights void compute_weight_sums_4bit_grouped( ValueSpec& weight_sums, From 584b6fd295cdc5cd8ea56dd0d02008b2f1695fb4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Sun, 28 Sep 2025 13:13:38 -0700 Subject: [PATCH 2/2] Update on "[ET-VK] Statically quantized convolutions" ## Changes This diff adds implementations for quantized convolution under the following quantization conditions: * activations statically quantized to 8-bit with per tensor scale and zero point * weights quantized to 8-bit with per channel scales * outputs statically quantized to 8-bit with per tensor scale and zero point 3 different implementations are added, which are selected between based on the input conditions. The first is an direct convolution shader which uses the quantized int8 input directly. The second is an im2col variant, which computes the convolution via a gemm like algorithm by first applying an im2col tranformation on the input tensor. Finally, a specialized implementation is added for depthwise convolutions. Differential Revision: [D83437827](https://our.internmc.facebook.com/intern/diff/D83437827/) [ghstack-poisoned] --- backends/vulkan/test/custom_ops/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index fe36de3047e..348eeded962 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -96,4 +96,6 @@ if(TARGET vulkan_backend) add_operator_prototype(q4gsw_linear) add_operator_prototype(choose_qparams_per_row) add_operator_prototype(qdq8ta_conv2d_activations) + add_operator_prototype(q8ta_q8csw_q8to_conv2d) + add_operator_prototype(q8ta_q8csw_q8to_conv2d_dw) endif()