From 501f89cb45705375c805ea30fcaeff69c0cc2ca0 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 29 Aug 2025 07:42:53 -0700 Subject: [PATCH] [ET-VK] Quantized Int8 Convolution + Linear Title says it all! This PR adds implementations for int8 quantized convolution and linear layers. Convolution is implemented as matrix multiplication under the hood by using the im2col procedure. For both linear and convolution, two versions are implemented: 1. `q8ta_q8csw` variant which quantized the input tensor and then performs integer accumulation via the int8 dot product extension 2. `q8csw` variant which dequantized the weight tensor in-shader and performs floating point accumulation. The second one is needed to provide an alternative path for executing quantized models if the target GPU does not support int8 dot product extension. These new ops are tested via the custom op testing + benchmarking framework introduced in the previous diff. Differential Revision: [D81323424](https://our.internmc.facebook.com/intern/diff/D81323424/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/glsl/col2im.glsl | 94 +++ .../vulkan/runtime/graph/ops/glsl/col2im.yaml | 19 + .../runtime/graph/ops/glsl/common.glslh | 40 ++ .../graph/ops/glsl/conv2d_common.glslh | 51 ++ .../ops/glsl/conv2d_fp_im2col_block.glslh | 96 +++ .../glsl/conv2d_fp_im2col_block_load.glslh | 167 +++++ .../glsl/conv2d_fp_im2col_block_store.glslh | 68 ++ .../ops/glsl/conv2d_q8csw_linear_tiled.glsl | 123 ++++ .../ops/glsl/conv2d_q8csw_linear_tiled.yaml | 22 + .../glsl/conv2d_q8ta_q8csw_linear_tiled.glsl | 124 ++++ .../glsl/conv2d_q8ta_q8csw_linear_tiled.yaml | 22 + .../vulkan/runtime/graph/ops/glsl/im2col.glsl | 110 +++ .../vulkan/runtime/graph/ops/glsl/im2col.yaml | 18 + .../graph/ops/glsl/linear_bias_load.glslh | 30 + .../graph/ops/glsl/linear_common.glslh | 41 ++ .../graph/ops/glsl/linear_fp_input_tile.glslh | 43 ++ .../ops/glsl/linear_fp_input_tile_load.glslh | 91 +++ .../ops/glsl/linear_fp_output_tile.glslh | 60 ++ .../linear_fp_output_tile_fp_compute.glslh | 96 +++ .../linear_fp_output_tile_int8_compute.glslh | 124 ++++ .../glsl/linear_fp_output_tile_store.glslh | 114 ++++ .../ops/glsl/linear_fp_weight_tile.glslh | 100 +++ .../ops/glsl/linear_int8_input_block.glslh | 77 +++ .../ops/glsl/linear_int8_input_tile.glslh | 93 +++ .../glsl/linear_int8_input_tile_load.glslh | 75 ++ .../ops/glsl/linear_int8_weight_block.glslh | 140 ++++ .../ops/glsl/linear_int8_weight_tile.glslh | 45 ++ .../glsl/linear_int8_weight_tile_load.glslh | 75 ++ .../graph/ops/glsl/linear_q8csw_tiled.glsl | 117 ++++ .../graph/ops/glsl/linear_q8csw_tiled.yaml | 30 + .../ops/glsl/linear_q8ta_q8csw_tiled.glsl | 117 ++++ .../ops/glsl/linear_q8ta_q8csw_tiled.yaml | 30 + .../graph/ops/glsl/linear_scales_load.glslh | 30 + .../ops/glsl/linear_weight_sums_load.glslh | 30 + .../graph/ops/glsl/pack_q8_linear_weight.glsl | 62 ++ .../graph/ops/glsl/pack_q8_linear_weight.yaml | 14 + .../ops/glsl/quantize_and_pack_im2col.glsl | 89 +++ .../ops/glsl/quantize_and_pack_im2col.yaml | 18 + .../glsl/quantize_and_pack_linear_input.glsl | 79 +++ .../glsl/quantize_and_pack_linear_input.yaml | 24 + .../graph/ops/impl/QuantizedConvolution.cpp | 645 ++++++++++++++++++ .../graph/ops/impl/QuantizedLinear.cpp | 548 +++++++++++++++ .../runtime/graph/ops/impl/QuantizedLinear.h | 35 + .../vulkan/test/custom_ops/CMakeLists.txt | 3 + backends/vulkan/test/custom_ops/conv2d.cpp | 320 +++++++++ .../test/custom_ops/quantized_conv2d.cpp | 601 ++++++++++++++++ .../test/custom_ops/quantized_linear.cpp | 352 ++++++++++ backends/vulkan/test/custom_ops/targets.bzl | 3 + 48 files changed, 5305 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/col2im.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/col2im.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/im2col.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/im2col.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h create mode 100644 backends/vulkan/test/custom_ops/conv2d.cpp create mode 100644 backends/vulkan/test/custom_ops/quantized_conv2d.cpp create mode 100644 backends/vulkan/test/custom_ops/quantized_linear.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/col2im.glsl b/backends/vulkan/runtime/graph/ops/glsl/col2im.glsl new file mode 100644 index 00000000000..c105ef18719 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/col2im.glsl @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +// Sizes of the convolution output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} +// Sizes of the convolution input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} +// Sizes of the im2col matrix of the convolution output +${layout_declare_ubo(B, "ivec4", "matrix_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_im2col_block_store.glslh" + +#ifdef INPUT_BUFFER + +void load_matrix_tile( + out FPOutTile tile, + const int n4, + const int m_start, + const int N4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + tile.data[m][0] = t_input[(m_start + m) * N4 + n4]; + } +} + +#else // INPUT_TEXTURE + +void load_matrix_tile( + out FPOutTile tile, + const int n4, + const int m_start, + const int N4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + tile.data[m][0] = texelFetch( + t_input, ivec3(n4, m_start + m, 0), 0); + } +} + +#endif // INPUT_BUFFER + +void main() { + // Each thread loads and writes a 4 wide x 4 high block of the matrix + const int n4 = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + const int n = mul_4(n4); + const int m = mul_4(m4); + + if (n >= matrix_sizes.x || m >= matrix_sizes.y) { + return; + } + + FPOutTile tile; + + const int N4 = div_4(matrix_sizes.x); + load_matrix_tile(tile, n4, m, N4); + write_im2col_tile_as_image(tile, n4, m); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/col2im.yaml b/backends/vulkan/runtime/graph/ops/glsl/col2im.yaml new file mode 100644 index 00000000000..b6d0972271a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/col2im.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +col2im: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: col2im_texture3d_buffer + - NAME: col2im_texture3d_texture3d + INPUT_STORAGE: texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh new file mode 100644 index 00000000000..c96392792b2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef COMMON_GLSLH +#define COMMON_GLSLH + +#define align_up_4(x) ((x + 3) & -4) + +#define div_up_4(x) (((x) + 3) >> 2) + +#define mul_4(x) ((x) << 2) +#define div_4(x) ((x) >> 2) + +#define mod_4(x) ((x) & 3) + +struct TensorIndex4D { + ivec4 data; +}; + +#ifdef DEBUG_MODE + +#extension GL_EXT_debug_printf : require + +void printTensorIndex4D(const TensorIndex4D index) { + debugPrintfEXT( + "tensor_idx: %d, %d, %d, %d\\n", + index.data.x, + index.data.y, + index.data.z, + index.data.w); +} + +#endif // DEBUG_MODE + +#endif // COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh new file mode 100644 index 00000000000..41825cba867 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_COMMON_GLSLH +#define CONV2D_COMMON_GLSLH + +#include "common.glslh" + +struct Conv2DParams { + ivec2 kernel_size; + ivec2 stride; + ivec2 padding; + ivec2 dilation; + int groups; + int out_channels_per_group; + int in_channels_per_group; + int logical_K_per_group; + int K_per_group; + int K4_per_group; + int logical_K; + int K; + int K4; +}; + +#ifdef DEBUG_MODE + +void printConv2DParams(const Conv2DParams params) { + debugPrintfEXT("Conv2DParams: \\n"); + debugPrintfEXT( + " kernel_size: %d, %d\\n", params.kernel_size.x, params.kernel_size.y); + debugPrintfEXT(" stride: %d, %d\\n", params.stride.x, params.stride.y); + debugPrintfEXT(" padding: %d, %d\\n", params.padding.x, params.padding.y); + debugPrintfEXT(" dilation: %d, %d\\n", params.dilation.x, params.dilation.y); + debugPrintfEXT(" groups: %d\\n", params.groups); + debugPrintfEXT( + " out_channels_per_group: %d\\n", params.out_channels_per_group); + debugPrintfEXT( + " in_channels_per_group: %d\\n", params.in_channels_per_group); + debugPrintfEXT(" logical_K_per_group: %d\\n", params.logical_K_per_group); + debugPrintfEXT(" K_per_group: %d\\n", params.K_per_group); + debugPrintfEXT(" K4_per_group: %d\\n", params.K4_per_group); +} + +#endif // DEBUG_MODE + +#endif // CONV2D_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh new file mode 100644 index 00000000000..7add8c4cd16 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_IM2COL_BLOCK +#define CONV2D_FP_IM2COL_BLOCK + +/* + * Defines utilities to convert between (col, row) indices of an im2col matrix + * and 4-dimension tensor indices of image tensors. + * + * Requires: + * - output_sizes to be defined in the shader layout, corresponding to the sizes + * of the output image of the convolution op. + * - image_sizes to be defined in the shader layout, corresponding to the sizes + * of the input image of the convolution op. + * - conv2d_params to be defined in the shader layout + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" +#include "conv2d_common.glslh" + +struct Im2ColMatrixIdx { + int row; + int col; + // Relevant for grouped convolution. This indicates the column index relative + // to the first column in the group. + int col_idx_in_group; + int group_idx; +}; + +void unwrap_m(out TensorIndex4D out_tidx_base, const int m) { + out_tidx_base.data[3] = m / (output_sizes.y * output_sizes.x); + out_tidx_base.data[1] = (m / output_sizes.x) % output_sizes.y; + out_tidx_base.data[0] = m % output_sizes.x; + + // Initialize channels to 0; assume it will be set later on + out_tidx_base.data[2] = 0; +} + +void im2col_tidx_to_output_tidx( + out TensorIndex4D output_tidx, + const Im2ColMatrixIdx im2col_tidx) { + unwrap_m(output_tidx, im2col_tidx.row); + // Set channels + output_tidx.data.z = im2col_tidx.col; +} + +/* + * Converts im2col matrix position to corresponding 4D tensor index, accounting + * for grouped convolutions. The conversion should ensure that all data within + * the same group occupy a contiguous block in memory. + */ +void im2col_idx_to_input_tidx( + out TensorIndex4D input_tidx, + const Im2ColMatrixIdx im2col_idx) { + TensorIndex4D output_tidx; + unwrap_m(output_tidx, im2col_idx.row); + + const int in_channels_per_group = conv2d_params.in_channels_per_group; + // Determine the corresponding position within the convolution window based + // on the col index (more specifically, the col index within the group) + const int channel_within_group = + im2col_idx.col_idx_in_group % in_channels_per_group; + const int kernel_x = (im2col_idx.col_idx_in_group / in_channels_per_group) % + conv2d_params.kernel_size.x; + const int kernel_y = im2col_idx.col_idx_in_group / + (in_channels_per_group * conv2d_params.kernel_size.x); + + // Calculate the actual input channel index + const int channel_idx = + im2col_idx.group_idx * conv2d_params.in_channels_per_group + + channel_within_group; + + // Calculate corresponding input coordinates based on output position + // associated with the row index. + const int input_y = int(output_tidx.data.y * conv2d_params.stride.y) - + int(conv2d_params.padding.y) + int(kernel_y * conv2d_params.dilation.y); + const int input_x = int(output_tidx.data.x * conv2d_params.stride.x) - + int(conv2d_params.padding.x) + int(kernel_x * conv2d_params.dilation.x); + + input_tidx.data = ivec4(input_x, input_y, channel_idx, output_tidx.data.w); +} + +// 4x4 block of the im2col matrix +struct FPIm2ColBlock { + VEC4_T data[4]; +}; + +#endif // CONV2D_FP_IM2COL_BLOCK diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh new file mode 100644 index 00000000000..71fff169375 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh @@ -0,0 +1,167 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_IM2COL_BLOCK_LOAD +#define CONV2D_FP_IM2COL_BLOCK_LOAD + +/* + * Defines utilities to load data for a 4x4 im2col matrix block from an + * input image and store the data as a FPInputTile. + * + * Requires: + * - t_input to be defined in the shader layout, representing the texture of the + * source image + * - conv2d_params to be defined in the shader layout + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" +#include "conv2d_common.glslh" +#include "conv2d_fp_im2col_block.glslh" +#include "linear_fp_input_tile.glslh" + +VEC4_T load_input_texel(const TensorIndex4D tidx) { + // Assumes batch size is 1 and channels packing + return texelFetch( + t_input, ivec3(tidx.data.x, tidx.data.y, div_4(tidx.data.z)), 0); +} + +T load_input_texel_element(const TensorIndex4D tidx) { + const int channels_texel_idx = div_4(tidx.data.z); + const int texel_comp = mod_4(tidx.data.z); + // Assumes batch size is 1 and channels packing + return texelFetch( + t_input, + ivec3(tidx.data.x, tidx.data.y, channels_texel_idx), + 0)[texel_comp]; +} + +// k4 -> group of 4 input channels idx +// m -> flattened batch, output width, output height dim idx +/* + * Fast impl for when the input image's channels per group is a multiple of 4. + * In this case, it is guaranteed that a texel loaded from the input can be + * stored directly to the output without any additional filtering. + */ +void load_im2col_block_fast( + out FPIm2ColBlock block, + const int k4, + const int m4, + const int logical_K, + const int M) { + Im2ColMatrixIdx im2col_idx; + im2col_idx.col = mul_4(k4); // k + im2col_idx.row = mul_4(m4); // m + + // Due to the assumption that in_channels_per_group % 4 == 0, it is + // guaranteed that the next 4 columns (including this one) is part of the + // same group. + im2col_idx.group_idx = im2col_idx.col / conv2d_params.K_per_group; + im2col_idx.col_idx_in_group = im2col_idx.col % conv2d_params.K_per_group; + + [[unroll]] for (int m_off = 0; m_off < 4; ++m_off) { + if (im2col_idx.row >= M) { + block.data[m_off] = VEC4_T(0); + continue; + } + + TensorIndex4D input_tidx; + im2col_idx_to_input_tidx(input_tidx, im2col_idx); + + // Load the texel + block.data[m_off] = load_input_texel(input_tidx); + + im2col_idx.row++; + } +} + +/* + * If input image channels is not a multiple of 4, then it is likely that for + * some matrix texels, the source data is split between different texels of the + * source image. In this case it's better to retreive each element individually. + */ +void load_im2col_block_slow( + out FPIm2ColBlock block, + const int k4, + const int m4, + const int logical_K, + const int M) { + Im2ColMatrixIdx im2col_idx_base; + im2col_idx_base.col = mul_4(k4); + im2col_idx_base.row = mul_4(m4); + + im2col_idx_base.group_idx = im2col_idx_base.col / conv2d_params.K_per_group; + im2col_idx_base.col_idx_in_group = + im2col_idx_base.col % conv2d_params.K_per_group; + + [[unroll]] for (int m_off = 0; m_off < 4; ++m_off) { + [[unroll]] for (int k_off = 0; k_off < 4; ++k_off) { + Im2ColMatrixIdx im2col_idx = im2col_idx_base; + im2col_idx.row += m_off; + im2col_idx.col_idx_in_group += k_off; + + // bounds checking + if (im2col_idx.col >= conv2d_params.logical_K_per_group || + im2col_idx.row >= M) { + block.data[m_off][k_off] = T(0); + continue; + } + + TensorIndex4D input_tidx; + im2col_idx_to_input_tidx(input_tidx, im2col_idx); + + block.data[m_off][k_off] = load_input_texel_element(input_tidx); + } + } +} + +void load_im2col_block( + out FPIm2ColBlock block, + const int k4, + const int m4, + const int logical_K, + const int M) { + if (mod_4(conv2d_params.in_channels_per_group) == 0) { + load_im2col_block_fast(block, k4, m4, logical_K, M); + } else { + load_im2col_block_slow(block, k4, m4, logical_K, M); + } +} + +void load_input_im2col_tile( + out FPInputTile tile, + const int k4_start, + const int m4_start, + const int logical_K, + const int M) { + FPIm2ColBlock block; +#if TILE_K4 == 1 + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + load_im2col_block(block, k4_start, m4_start + m4, logical_K, M); + for (int row = 0; row < 4; ++row) { + const int m = mul_4(m4) + row; + tile.data[m][0] = block.data[row]; + } + } + +#else + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + load_im2col_block(block, k4_start + k4, m4_start + m4, logical_K, M); + for (int row = 0; row < 4; ++row) { + const int m = mul_4(m4) + row; + tile.data[m][k4] = block.data[row]; + } + } + } + +#endif +} + +#endif // CONV2D_FP_IM2COL_BLOCK_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh new file mode 100644 index 00000000000..2171d75c628 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_IM2COL_BLOCK_STORE +#define CONV2D_FP_IM2COL_BLOCK_STORE + +/* + * Defines utilities to store data for a 4x4 im2col output matrix block computed + * from matrix multiplication to an output image. + * + * Requires: + * - t_output to be defined in the shader layout, representing the texture of + * the output image + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" +#include "conv2d_common.glslh" +#include "conv2d_fp_im2col_block.glslh" +#include "linear_fp_output_tile.glslh" + +// TODO: implement buffer support +void write_output_texel(const VEC4_T out_texel, const TensorIndex4D tidx) { + // Assume batch size is 1 + imageStore( + t_output, ivec3(tidx.data.x, tidx.data.y, div_4(tidx.data.z)), out_texel); +} + +void write_im2col_tile_as_image( + const FPOutTile tile, + const int n4_start, + const int m_start) { + Im2ColMatrixIdx im2col_tidx; + im2col_tidx.col = mul_4(n4_start); + im2col_tidx.row = m_start; +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + TensorIndex4D output_tidx; + im2col_tidx_to_output_tidx(output_tidx, im2col_tidx); + + if (any(greaterThanEqual(output_tidx.data, output_sizes))) { + continue; + } + write_output_texel(tile.data[m][0], output_tidx); + im2col_tidx.row++; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + TensorIndex4D output_tidx; + im2col_tidx_to_output_tidx(output_tidx, im2col_tidx); + + write_output_texel(tile.data[m][k4], output_tidx); + im2col_tidx.row++; + } + } + +#endif +} + +#endif // CONV2D_FP_IM2COL_BLOCK_STORE diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl new file mode 100644 index 00000000000..a54d4fa7466 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_N4 ${TILE_N4} +#define TILE_K4 ${TILE_K4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_N ${TILE_N4 * 4} +#define TILE_K ${TILE_K4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "uint", "apply_bias", "1")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_weight_tile.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_scales_load.glslh" +#include "linear_bias_load.glslh" +#include "conv2d_fp_im2col_block_store.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const uint out_tile_x = gl_GlobalInvocationID.x; + const uint out_tile_y = gl_GlobalInvocationID.y; + + const int n = int(out_tile_x * TILE_N); + const int m = int(out_tile_y * TILE_M); + + const int n4 = div_4(n); + const int m4 = div_4(m); + + // M = flattened output width, height, batches dims + const int M = output_sizes.x * output_sizes.y * output_sizes.w; + // N = output channels + const int N = output_sizes.z; + + if (n >= N || m >= M) { + return; + } + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int input_k4_offset = conv2d_params.K4_per_group * group_idx; + + const int K4 = conv2d_params.K4; + const int N4 = div_up_4(N); + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int8WeightTile weight_tile; + FPWeightTile fp_weight_tile; + + const bool dont_check_bounds = (M - m) >= TILE_M; + + if (dont_check_bounds) { + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_input_tile_no_checks(in_tile, k4 + input_k4_offset, m, K4, M); + load_weight_tile(weight_tile, n4, k4, N4); + unpack(fp_weight_tile, weight_tile); + update(out_tile, in_tile, fp_weight_tile); + } + } else { + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_input_tile_with_checks(in_tile, k4 + input_k4_offset, m, K4, M); + load_weight_tile(weight_tile, n4, k4, N4); + unpack(fp_weight_tile, weight_tile); + update(out_tile, in_tile, fp_weight_tile); + } + } + + FPPerOutChannelParams scales_tile; + load_scales_tile(scales_tile, n4); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + apply_scales_and_biases(out_tile, scales_tile, bias_tile); + } + else { + apply_scales(out_tile, scales_tile); + } + + write_im2col_tile_as_image(out_tile, n4, m); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml new file mode 100644 index 00000000000..9b3b5aa2c0a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8csw_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8csw_linear_tiled_texture3d_buffer_texture2d + - NAME: conv2d_q8csw_linear_tiled_texture3d_buffer_buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl new file mode 100644 index 00000000000..61d34736277 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_N4 ${TILE_N4} +#define TILE_K4 ${TILE_K4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_N ${TILE_N4 * 4} +#define TILE_K ${TILE_K4 * 4} + +${define_required_extensions(DTYPE)} + +#extension GL_EXT_integer_dot_product : require + +layout(std430) buffer; + +#define DEBUG_MODE +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "int", INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "float", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "uint", "apply_bias", "1")} + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_compute.glslh" +#include "linear_scales_load.glslh" +#include "linear_weight_sums_load.glslh" +#include "linear_bias_load.glslh" +#include "conv2d_fp_im2col_block_store.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const uint out_tile_x = gl_GlobalInvocationID.x; + const uint out_tile_y = gl_GlobalInvocationID.y; + + const int n = int(out_tile_x * TILE_N); + const int m = int(out_tile_y * TILE_M); + + const int n4 = div_4(n); + const int m4 = div_4(m); + + // M = flattened output width, height, batches dims + const int M = output_sizes.x * output_sizes.y * output_sizes.w; + // N = output channels + const int N = output_sizes.z; + + if (n >= N || m >= M) { + return; + } + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int input_k4_offset = conv2d_params.K4_per_group * group_idx; + + const int K4 = conv2d_params.K4; + const int N4 = div_up_4(N); + + Int8OutAccum out_accum; + initialize(out_accum); + + Int8InputTile in_tile; + Int8WeightTile weight_tile; + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_input_tile(in_tile, k4 + input_k4_offset, m4, K4); + load_weight_tile(weight_tile, n4, k4, N4); + + accumulate(out_accum, in_tile, weight_tile); + } + + FPPerOutChannelParams scales_tile; + load_scales_tile(scales_tile, n4); + + FPPerOutChannelParams sums_tile; + load_sums_tile(sums_tile, n4); + + FPOutTile out_tile; + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, uint(n4)); + + compute(out_tile, out_accum, sums_tile, scales_tile, bias_tile); + } + else { + compute(out_tile, out_accum, sums_tile, scales_tile); + } + + write_im2col_tile_as_image(out_tile, n4, m); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml new file mode 100644 index 00000000000..4b630a85143 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_linear_tiled_texture3d_buffer_texture2d + - NAME: conv2d_q8ta_q8csw_linear_tiled_texture3d_buffer_buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl new file mode 100644 index 00000000000..f045d4e9702 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#extension GL_EXT_debug_printf : enable + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +// Sizes of the im2col matrix of the convolution input +${layout_declare_ubo(B, "ivec4", "matrix_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_im2col_block_load.glslh" + +#ifdef OUTPUT_BUFFER + +void write_tile( + const FPInputTile in_tile, + const int k4, + const int m_start, + const int K4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + t_output[(m_start + m) * K4 + k4] = in_tile.data[m][0]; + } +} + +#else // OUTPUT_TEXTURE + +void write_tile( + const FPInputTile in_tile, + const int k4, + const int m_start, + const int K4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + imageStore(t_output, ivec3(k4, m_start + m, 0), vec4(in_tile.data[m][0])); + } +} + +#endif // OUTPUT_BUFFER + +void main() { + // Each thread writes out a 4 wide x 4 high block of the output matrix. The + // thread position corresponds to the block index. + const int k4 = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + // Convert block idx to tensor idx + const int k = mul_4(k4); + const int m = mul_4(m4); + + const int in_channels_per_group = input_sizes.z / conv2d_params.groups; + + // Logical K dim size (unpadded) + const int logical_K = conv2d_params.logical_K; + // Physical K dim, which contains padding elements + const int K = matrix_sizes.x; + + // M dim, which represents the number of flattened output width, height, + // batches. Unlike K, there is no difference between the physical and logical + // sizes. + const int M = matrix_sizes.y; + + if (k >= K || m >= M) { + return; + } + + FPInputTile in_tile; + load_input_im2col_tile(in_tile, k4, m4, logical_K, M); + + // Number of texels in the x dim of the output matrix + const int K4 = div_4(K); + write_tile(in_tile, k4, m, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/im2col.yaml new file mode 100644 index 00000000000..dd486b0e1a6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +im2col: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + INPUT_STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: im2col_buffer_texture3d + - NAME: im2col_texture3d_texture3d + OUTPUT_STORAGE: texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh new file mode 100644 index 00000000000..346ed2b0a87 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_bias_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_BIAS_LOAD_GLSLH +#define LINEAR_BIAS_LOAD_GLSLH + +#include "linear_common.glslh" + +VEC4_T load_bias_x4(const uint n4) { + return t_bias[n4]; +} + +void load_bias_tile(out FPPerOutChannelParams bias, const uint n4_start) { +#if TILE_N4 == 1 + bias.data[0] = load_bias_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + bias.data[n4] = load_bias_x4[n4_start + n4]; + } + +#endif +} + +#endif // LINEAR_BIAS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh new file mode 100644 index 00000000000..e1717bc5e18 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_COMMON_GLSLH +#define LINEAR_COMMON_GLSLH + +#include "common.glslh" + +// Represents floating point parameter tensors where each element is associated +// with an output channel, such as weight scales, biases, etc. +struct FPPerOutChannelParams { + VEC4_T data[TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printFPPerOutChannelParams(const FPPerOutChannelParams params) { + debugPrintfEXT("per_out_channel_params: \\n"); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + params.data[n4].x, + params.data[n4].y, + params.data[n4].z, + params.data[n4].w); + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh new file mode 100644 index 00000000000..492dab8239d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_INPUT_TILE_GLSLH +#define LINEAR_FP_INPUT_TILE_GLSLH + +/* + * Defines the FPInputTile struct, which is used to represent a tile of the + * input matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +struct FPInputTile { + VEC4_T data[TILE_M][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +void printFPInputTile(const FPInputTile in_tile) { + debugPrintfEXT("input_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + in_tile.data[m][k4].x, + in_tile.data[m][k4].y, + in_tile.data[m][k4].z, + in_tile.data[m][k4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh new file mode 100644 index 00000000000..a98f07b042a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a FPInputTile from input buffer/texture. + * + * Requires: + * - t_input to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - INPUT_BUFFER to indicate input resource is a buffer, otherwise texture is + * assumed. + */ + +#ifndef LINEAR_FP_INPUT_TILE_LOAD_GLSLH +#define LINEAR_FP_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +#ifdef INPUT_BUFFER + +VEC4_T load_input_x4(const uint k4, const uint m, const uint ntexels_k) { + return t_input[(m * ntexels_k) + k4]; +} + +#else + +VEC4_T load_input_x4(const uint k4, const uint m, const uint ntexels_k) { + return texelFetch(t_input, ivec3(k4, m, 0), 0); +} + +#endif // INPUT_BUFFER + +// To be used if (M - m_start >= TILE_M) || (K4 - k4_start >= TILE_K4) +void load_input_tile_no_checks( + out FPInputTile in_tile, + const uint k4_start, + const uint m_start, + const uint K4, + const uint M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } + } +#endif +} + +// To be used if near tensor boundaries +void load_input_tile_with_checks( + out FPInputTile in_tile, + const uint k4_start, + const uint m_start, + const uint K4, + const uint M) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); + } else { + in_tile.data[m][0] = VEC4_T(0.0); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + if (m_start + m < M && k4_start + k4 < K4) { + in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); + } else { + in_tile.data[m][k4] = VEC4_T(0.0); + } + } + } +#endif +} + +#endif // LINEAR_FP_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh new file mode 100644 index 00000000000..c4571315bdd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPOutTile struct, which is used to represent a tile of the output + * matrix of a matrix multiplication operation. + * + * Settings: + * - TILE_M: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct FPOutTile { + VEC4_T data[TILE_M][TILE_N4]; +}; + +void initialize(out FPOutTile out_tile) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + out_tile.data[y][0] = VEC4_T(0); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { + out_tile.data[y][x4] = VEC4_T(0); + } + } +#endif +} + +#ifdef DEBUG_MODE + +void printFPOutputTile(const FPOutTile tile) { + debugPrintfEXT("output_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh new file mode 100644 index 00000000000..470db8b529a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using fp input and weight tiles. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_input_tile.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_fp_weight_tile.glslh" + +/* + * Accumulates floating point input tile and floating point weight tile into + * floating point output tile. + */ +void update(inout FPOutTile accum, FPInputTile in_tile, FPWeightTile w_tile) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int n = mul_4(n4); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][0]), + w_tile.data[k4][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][1]), + w_tile.data[k4 + 1][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][2]), + w_tile.data[k4 + 2][n4], + accum.data[m][n4]); + + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][3]), + w_tile.data[k4 + 3][n4], + accum.data[m][n4]); + } + } + } +} + +/* + * Applies per output channel weight scales to the output tile. + */ +void apply_scales(inout FPOutTile tile, const FPPerOutChannelParams scales) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4]; + } + } +#endif +} + +/* + * Applies per output channel weight scales and per output channel biases to the + * output tile. + */ +void apply_scales_and_biases( + inout FPOutTile tile, + const FPPerOutChannelParams scales, + const FPPerOutChannelParams bias) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + tile.data[m][0] = tile.data[m][0] * scales.data[0] + bias.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m][n4] = tile.data[m][n4] * scales.data[n4] + bias.data[n4]; + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_FP_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh new file mode 100644 index 00000000000..58fd9086266 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_compute.glslh @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_INT8_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_common.glslh" +#include "linear_fp_output_tile.glslh" +#include "linear_int8_input_tile.glslh" +#include "linear_int8_weight_tile.glslh" + +// Stores integer accumulators for an output tile. +struct Int8OutAccum { + ivec4 data[TILE_M][TILE_N4]; +}; + +// Initialize values to 0 +void initialize(out Int8OutAccum out_accum) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + out_accum.data[y][0] = ivec4(0); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { + out_accum.data[y][x4] = ivec4(0); + } + } +#endif +} + +// Accumulate int8 input and weight tiles into accumulator tile +void accumulate( + inout Int8OutAccum accum, + Int8InputTile in_tile, + Int8WeightTile w_tile) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int m4 = div_4(m); + const int m4i = mod_4(m); + [[unroll]] for (int n = 0; n < TILE_N; ++n) { + const int n4 = div_4(n); + const int n4i = mod_4(n); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT( + in_tile.data[m4][k4][m4i], + w_tile.data[k4][n4][n4i], + accum.data[m][n4][n4i]); + } + } + } +} + +/* + * Computes final weight matrix output tile using: + * - int8 accumulator tile + * - per output channel weight sums + * - per output channel scales + */ +void compute( + out FPOutTile out_tile, + const Int8OutAccum out_accum, + const FPPerOutChannelParams sums, + const FPPerOutChannelParams scales) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] = + (VEC4_T(out_accum.data[m][0]) - input_zp * sums.data[0]) * + scales.data[0] * input_scale; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = + (VEC4_T(out_accum.data[m][n4]) - input_zp * sums.data[n4]) * + scales.data[n4] * input_scale; + } + } +#endif +} + +void compute( + out FPOutTile out_tile, + const Int8OutAccum out_accum, + const FPPerOutChannelParams sums, + const FPPerOutChannelParams scales, + const FPPerOutChannelParams bias) { +#if TILE_M > 1 && TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + out_tile.data[m][0] = + (VEC4_T(out_accum.data[m][0]) - input_zp * sums.data[0]) * + scales.data[0] * input_scale + + bias.data[0]; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + out_tile.data[m][n4] = + (VEC4_T(out_accum.data[m][n4]) - input_zp * sums.data[n4]) * + scales.data[n4] * input_scale + + bias.data[n4]; + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh new file mode 100644 index 00000000000..d40a0fe98cc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions store a FpOutTile to output buffer/texture. + * + * Requires: + * - t_output to be declared in the shader layout + * + * Settings: + * - OUTPUT_BUFFER to indicate t_output is a vec4 buffer, otherwise texture + * storage is assumed. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_STORE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_STORE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +#ifdef OUTPUT_BUFFER + +void write_output_x4( + const VEC4_T out_texel, + const uint n4, + const uint m, + const uint N4) { + t_output[m * N4 + n4] = out_texel; +} + +#else + +void write_output_x4( + const VEC4_T out_texel, + const uint n4, + const uint m, + const uint N4) { + imageStore(t_output, ivec3(n4, m, 0), out_texel); +} + +#endif // OUTPUT_BUFFER + +void write_output_tile( + const FPOutTile out_tile, + const uint n4_start, + const uint m_start, + const uint N4) { +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if M - m >= TILE_M && N4 - n4 >= TILE_N4 +void write_output_tile_no_checks( + const FPOutTile out_tile, + const uint n4_start, + const uint m_start, + const uint N4, + const uint M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } +#endif +} + +// To be used if close to tensor boundaries +void write_output_tile_with_checks( + const FPOutTile out_tile, + const uint n4_start, + const uint m_start, + const uint N4, + const uint M) { +#if TILE_N4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + if (m_start + m < M) { + write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); + } + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + if (m_start + m < M && n4_start + n4 < N4) { + write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); + } + } + } +#endif +} + +#endif // LINEAR_FP_OUTPUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh new file mode 100644 index 00000000000..fb50911fb98 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the FPWeightTile struct, which is used to represent a fp tile of a + * weight matrix in matrix multiplication. + * + * Settings: + * - TILE_K: number of rows in the output tile + * - TILE_N4: number of (groups of 4) columns in the output tile + */ + +#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH +#define LINEAR_FP_WEIGHT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct FPWeightTile { + VEC4_T data[TILE_K][TILE_N4]; +}; + +#ifdef LINEAR_INT8_WEIGHT_TILE_GLSLH + +int sign_extend(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +T extract_8bit_value(const Int8WeightTile w_tile, const uint k, const uint n) { +#if TILE_K4 == 1 && TILE_N4 == 1 + const uint k4i = k; + const uint n4i = n; + ivec4 block = w_tile.data[0][0]; + +#else + const uint k4 = div_4(k); + const uint k4i = mod_4(k); + + const uint n4 = div_4(n); + const uint n4i = mod_4(n); + + ivec4 block = w_tile.data[k4][n4]; +#endif + + int col = block[n4i]; + int val = (col >> ((3 - k4i) * 8)) & 0xFF; + + return T(sign_extend(val)); +} + +void unpack(out FPWeightTile fp_w_tile, const Int8WeightTile w_tile) { +#if TILE_K > 1 && TILE_N4 == 1 + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + fp_w_tile.data[k][0][0] = extract_8bit_value(w_tile, k, 0); + fp_w_tile.data[k][0][1] = extract_8bit_value(w_tile, k, 1); + fp_w_tile.data[k][0][2] = extract_8bit_value(w_tile, k, 2); + fp_w_tile.data[k][0][3] = extract_8bit_value(w_tile, k, 3); + } + +#else + [[unroll]] for (int k = 0; k < TILE_M; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const uint n = mul_4(n4); + fp_w_tile.data[k][n4][0] = extract_8bit_value(w_tile, k, n); + fp_w_tile.data[k][n4][1] = extract_8bit_value(w_tile, k, n + 1); + fp_w_tile.data[k][n4][2] = extract_8bit_value(w_tile, k, n + 2); + fp_w_tile.data[k][n4][3] = extract_8bit_value(w_tile, k, n + 3); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH + +#ifdef DEBUG_MODE + +void printFPWeightTile(const FPWeightTile tile) { + debugPrintfEXT("weight_tile: \\n"); + [[unroll]] for (int k = 0; k < TILE_K; ++k) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + tile.data[k][n4].x, + tile.data[k][n4].y, + tile.data[k][n4].z, + tile.data[k][n4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_FP_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh new file mode 100644 index 00000000000..5b3a86b77d7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This file defines utilties to perform int8 quantization and block packing of + * matrix multiplation inputs. It also defines utilities to store packed block + * data to an output buffer or texture. + * + * Requires: + * - t_output to be defined in shader layout (output buffer/texture) + * + * Settings: + * - OUTPUT_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#ifndef LINEAR_INT8_INPUT_BLOCK_GLSLH +#define LINEAR_INT8_INPUT_BLOCK_GLSLH + +#define TILE_M 4 +#define TILE_K4 1 + +#include "linear_fp_input_tile.glslh" + +struct Int8InputBlock { + ivec4 data; +}; + +ivec4 quantize(const VEC4_T val) { + vec4 quantized = round(vec4(val) * inv_scale) + zp; + + // hard-code 8 bit quantization range + return clamp(ivec4(quantized), -127, 127); +} + +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[3] & 0xFF) << 0) | ((quant_vals[2] & 0xFF) << 8) | + ((quant_vals[1] & 0xFF) << 16) | ((quant_vals[0] & 0xFF) << 24); + + return packed; +} + +void quantize_and_pack(out Int8InputBlock packed, const FPInputTile in_block) { + for (int row = 0; row < 4; ++row) { + ivec4 quantized_inputs = quantize(in_block.data[row][0]); + packed.data[row] = pack_into_int32(quantized_inputs); + } +} + +#ifdef OUTPUT_BUFFER + +void write_block( + const Int8InputBlock block, + const uint block_x, + const uint block_y, + const uint nblocks_x) { + t_output[block_y * nblocks_x + block_x] = block.data; +} + +#else // OUTPUT_TEXTURE + +void write_block( + const Int8InputBlock block, + const uint block_x, + const uint block_y, + const uint nblocks_x) { + imageStore(t_output, ivec3(block_x, block_y, 0), block.data); +} + +#endif // OUTPUT_BUFFER + +#endif // LINEAR_INT8_INPUT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh new file mode 100644 index 00000000000..21e8ba031c5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines the Int8InputTile struct, which is used to represent a tile of the + * quantized int8 input matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_M4: number of (groups of 4) rows in the tile + * - TILE_K4: number of (groups of 4) columns in the tile + */ + +#ifndef LINEAR_INT8_INPUT_TILE_GLSLH +#define LINEAR_INT8_INPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8InputTile { + ivec4 data[TILE_M4][TILE_K4]; +}; + +#ifdef DEBUG_MODE + +int extract_8bit_from_packed_int_le(const int packed, const uint i) { + // account for little endian, extract 8-bit value at position i + int byte = int(uint(packed) >> (8 * i) & 255u); + // convert unsigned byte to signed byte + if (byte > 127) { + byte = byte - 256; + } + return byte; +} + +void printInt8InputTile(const Int8InputTile tile) { + debugPrintfEXT( + "Int8InputTile [TILE_M4=%d][TILE_K4=%d]:\\n", TILE_M4, TILE_K4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + } + debugPrintfEXT("\\n"); + } + } +} + +void printInt8InputTileCompact(const Int8InputTile tile) { + debugPrintfEXT( + "Int8InputTile [%dx%d] (showing extracted 8-bit values):\\n", + TILE_M4 * 4, + TILE_K4 * 4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + // Print 4 rows at a time (since each m4 represents 4 rows) + [[unroll]] for (int row_in_m4 = 0; row_in_m4 < 4; ++row_in_m4) { + debugPrintfEXT(" row %d: ", m4 * 4 + row_in_m4); + + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][k4][vec_idx]; + int val = extract_8bit_from_packed_int_le(packed_int, row_in_m4); + debugPrintfEXT("%4d ", val); + } + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh new file mode 100644 index 00000000000..ea302ab4f40 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to load a Int8InputTile from input buffer/texture. + * + * Requires: + * - t_input to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - INPUT_BUFFER to indicate resource is a buffer, otherwise texture storage is + * assumed. + */ + +#ifndef LINEAR_INT8_INPUT_TILE_LOAD_GLSLH +#define LINEAR_INT8_INPUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +#ifdef INPUT_BUFFER + +ivec4 load_input_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return t_input[(block_y * nblocks_x) + block_x]; +} + +#else + +ivec4 load_input_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return texelFetch(t_input, ivec3(block_x, block_y, 0), 0); +} + +#endif // INPUT_BUFFER + +void load_input_tile( + out Int8InputTile in_tile, + const uint block_x, + const uint block_y, + const uint nblocks_x) { +#if TILE_M4 == 1 && TILE_K4 == 1 + in_tile.data[0][0] = load_input_block(block_x, block_y, nblocks_x); + +#elif TILE_M4 == 1 && TILE_K4 > 1 + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[0][x] = load_input_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_M4 > 1 && TILE_K4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + in_tile.data[y][0] = load_input_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + [[unroll]] for (int x = 0; x < TILE_K4; ++x) { + in_tile.data[y][x] = + load_input_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh new file mode 100644 index 00000000000..c7a2022730b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh @@ -0,0 +1,140 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_BLOCK_GLSLH +#define LINEAR_INT8_WEIGHT_BLOCK_GLSLH + +/* + * This file defines utilties to perform weight prepacking of quantized int8 + * matrix multiplation weights. It also defines utilities to load source + * weight data from inputbuffer, and write out a packed weight block to output + * texture/buffer. + * + * Requires: + * - t_qmat2 to be defined in shader layout (output texture/buffer) + * - t_input to be defined in shader layout (input buffer) + * + * Settings: + * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture + * is assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +// Represents source data for a 4x4 block of the weight matrix read from the +// input buffer. +struct Int8WeightBlockSourceData { + int data[4]; +}; + +// Represents data for a packed 4x4 block of the weight matrix to be written out +// to output texture/buffer. +struct Int8WeightBlockPacked { + ivec4 data; +}; + +// To be used if K - k_start >= 4 +void load_block_source_data_no_checks( + out Int8WeightBlockSourceData src_data, + const uint n4, + const uint k_start, + const uint ntexels_N, + const uint K) { + [[unroll]] for (int k = 0; k < 4; ++k) { + src_data.data[k] = t_input[(k_start + k) * ntexels_N + n4]; + } +} + +// To be used if K - k_start < 4 +void load_block_source_data_with_checks( + out Int8WeightBlockSourceData src_data, + const uint n4, + const uint k_start, + const uint ntexels_N, + const uint K) { + [[unroll]] for (int k = 0; k < 4; ++k) { + if (k_start + k < K) { + src_data.data[k] = t_input[(k_start + k) * ntexels_N + n4]; + } else { + src_data.data[k] = 0; + } + } +} + +int extract_8bit_from_packed_uint_le(const uint packed, const uint i) { + // account for little endian + int byte = int(packed >> (8 * i) & 255); + return byte; +} + +int pack_4x8bit_signed_into_int( + const int val0, + const int val1, + const int val2, + const int val3) { + return int( + ((val0 & 0xFF) << 24) | ((val1 & 0xFF) << 16) | ((val2 & 0xFF) << 8) | + ((val3 & 0xFF))); +} + +void create_packed_block( + out Int8WeightBlockPacked block, + const Int8WeightBlockSourceData src_data) { + [[unroll]] for (int col = 0; col < 4; ++col) { + block.data[col] = pack_4x8bit_signed_into_int( + extract_8bit_from_packed_uint_le(src_data.data[0], col), + extract_8bit_from_packed_uint_le(src_data.data[1], col), + extract_8bit_from_packed_uint_le(src_data.data[2], col), + extract_8bit_from_packed_uint_le(src_data.data[3], col)); + } +} + +#ifdef USING_BUFFER + +void write_packed_block( + const Int8WeightBlockPacked block, + const uint block_x, + const uint block_y, + const uint nblocks_x) { + t_qmat2[block_y * nblocks_x + block_x] = block.data; +} + +#else // USING_TEXTURE + +void write_packed_block( + const Int8WeightBlockPacked block, + const uint block_x, + const uint block_y, + const uint nblocks_w) { + imageStore(t_qmat2, ivec2(block_x, block_y), block.data); +} + +#endif // USING_BUFFER + +#ifdef DEBUG_MODE + +void printInt8WeightBlockSourceData(const Int8WeightBlockSourceData src_data) { + debugPrintfEXT("int8_weight_block_source_data: \\n"); + [[unroll]] for (int row = 0; row < 4; ++row) { + debugPrintfEXT("row %i: %u \\n", row, src_data.data[row]); + } +} + +void printInt8WeightBlockPacked(const Int8WeightBlockPacked block) { + debugPrintfEXT("int8_weight_block_packed: \\n"); + debugPrintfEXT( + "%i %i %i %i \\n", + block.data[0], + block.data[1], + block.data[2], + block.data[3]); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh new file mode 100644 index 00000000000..2711f1d3174 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_GLSLH + +/* + * Defines the Int8WeightTile struct, which is used to represent a tile of the + * quantized int8 weight matrix of a quantized matrix multiplication operation. + * + * Settings: + * - TILE_K4: number of (groups of 4) rows in the weight tile + * - TILE_N4: number of (groups of 4) columns in the weight tile + */ + +#extension GL_EXT_control_flow_attributes : require + +struct Int8WeightTile { + ivec4 data[TILE_K4][TILE_N4]; +}; + +#ifdef DEBUG_MODE + +void printInt8WeightTile(const Int8WeightTile tile) { + debugPrintfEXT("int8_weight_tile: \\n"); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + "%i %i %i %i \\n", + tile.data[k4][n4][0], + tile.data[k4][n4][1], + tile.data[k4][n4][2], + tile.data[k4][n4][3]); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh new file mode 100644 index 00000000000..2b9baa84356 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH +#define LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH + +/* + * Defines functions to load a Int8WeightTile from input buffer/texture. + * + * Requires: + * - t_qmat2 to be declared in the shader layout (input buffer/texture) + * + * Settings: + * - WEIGHT_BUFFER to indicate t_qmat2 is a buffer, otherwise texture storage is + * assumed. + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_weight_tile.glslh" + +#ifdef WEIGHT_BUFFER + +ivec4 load_weight_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return t_qmat2[(block_y * nblocks_x) + block_x]; +} + +#else // WEIGHT_TEXTURE + +ivec4 load_weight_block( + const uint block_x, + const uint block_y, + const uint nblocks_x) { + return texelFetch(t_qmat2, ivec2(block_x, block_y), 0); +} + +#endif // WEIGHT_BUFFER + +void load_weight_tile( + out Int8WeightTile weight_tile, + const uint block_x, + const uint block_y, + const uint nblocks_x) { +#if TILE_K4 == 1 && TILE_N4 == 1 + weight_tile.data[0][0] = load_weight_block(block_x, block_y, nblocks_x); + +#elif TILE_K4 == 1 && TILE_N4 > 1 + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[0][x] = load_weight_block(block_x + x, block_y, nblocks_x); + } + +#elif TILE_K4 > 1 && TILE_N4 == 1 + [[unroll]] for (int y = 0; y < TILE_M4; ++y) { + weight_tile.data[y][0] = load_weight_block(block_x, block_y + y, nblocks_x); + } + +#else + [[unroll]] for (int y = 0; y < TILE_K4; ++y) { + [[unroll]] for (int x = 0; x < TILE_N4; ++x) { + weight_tile.data[y][x] = + load_weight_block(block_x + x, block_y + y, nblocks_x); + } + } +#endif +} + +#endif // LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl new file mode 100644 index 00000000000..49d880f732f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_N4 ${TILE_N4} +#define TILE_K4 ${TILE_K4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_N ${TILE_N4 * 4} +#define TILE_K ${TILE_K4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "uint", "apply_bias", "0")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_weight_tile.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_scales_load.glslh" +#include "linear_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const uint out_tile_x = gl_GlobalInvocationID.x; + const uint out_tile_y = gl_GlobalInvocationID.y; + + const uint n = out_tile_x * TILE_N; + const uint m = out_tile_y * TILE_M; + + const uint n4 = div_4(n); + const uint m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const uint M = uint(input_sizes.y); + const uint K4 = div_up_4(input_sizes.x); + const uint N4 = div_up_4(output_sizes.x); + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int8WeightTile weight_tile; + FPWeightTile fp_weight_tile; + + const bool dont_check_bounds = (M - m) >= TILE_M; + + if (dont_check_bounds) { + for (int k4 = 0; k4 < K4; k4++) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + load_weight_tile(weight_tile, n4, k4, N4); + unpack(fp_weight_tile, weight_tile); + update(out_tile, in_tile, fp_weight_tile); + } + } else { + for (int k4 = 0; k4 < K4; k4++) { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + load_weight_tile(weight_tile, n4, k4, N4); + unpack(fp_weight_tile, weight_tile); + update(out_tile, in_tile, fp_weight_tile); + } + } + + FPPerOutChannelParams scales_tile; + load_scales_tile(scales_tile, n4); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + apply_scales_and_biases(out_tile, scales_tile, bias_tile); + } + else { + apply_scales(out_tile, scales_tile); + } + + if (dont_check_bounds) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml new file mode 100644 index 00000000000..2356fcdb251 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8csw_tiled_texture3d_texture3d_texture2d + - NAME: linear_q8csw_tiled_texture3d_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q8csw_tiled_buffer_buffer_texture2d + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q8csw_tiled_buffer_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl new file mode 100644 index 00000000000..a4bd4b4a115 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.glsl @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T int + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_N4 ${TILE_N4} +#define TILE_K4 ${TILE_K4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_N ${TILE_N4 * 4} +#define TILE_K ${TILE_K4 * 4} + +${define_required_extensions(DTYPE)} + +#extension GL_EXT_integer_dot_product : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "int", INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "float", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "uint", "apply_bias", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_scales_load.glslh" +#include "linear_weight_sums_load.glslh" +#include "linear_bias_load.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const uint out_tile_x = gl_GlobalInvocationID.x; + const uint out_tile_y = gl_GlobalInvocationID.y; + + const uint n = out_tile_x * TILE_N; + const uint m = out_tile_y * TILE_M; + + const uint n4 = div_4(n); + const uint m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const uint M = output_sizes.y; + const uint K4 = div_up_4(input_sizes.x); + const uint N4 = div_up_4(output_sizes.x); + + Int8OutAccum out_accum; + initialize(out_accum); + + Int8InputTile in_tile; + Int8WeightTile weight_tile; + + for (int k4 = 0; k4 < K4; k4++) { + load_input_tile(in_tile, k4, m4, K4); + load_weight_tile(weight_tile, n4, k4, N4); + + accumulate(out_accum, in_tile, weight_tile); + } + + FPPerOutChannelParams scales_tile; + load_scales_tile(scales_tile, n4); + + FPPerOutChannelParams sums_tile; + load_sums_tile(sums_tile, n4); + + FPOutTile out_tile; + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, uint(n4)); + + compute(out_tile, out_accum, sums_tile, scales_tile, bias_tile); + } + else { + compute(out_tile, out_accum, sums_tile, scales_tile); + } + + if (M - m >= TILE_M) { + write_output_tile_no_checks(out_tile, n4, m, N4, M); + } else { + write_output_tile_with_checks(out_tile, n4, m, N4, M); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml new file mode 100644 index 00000000000..dfaa839e02e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_q8ta_q8csw_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_q8ta_q8csw_tiled_texture3d_texture3d_texture2d + - NAME: linear_q8ta_q8csw_tiled_texture3d_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_q8ta_q8csw_tiled_buffer_buffer_texture2d + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_q8ta_q8csw_tiled_buffer_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh new file mode 100644 index 00000000000..47f6d318008 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_scales_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_SCALES_LOAD_GLSLH +#define LINEAR_SCALES_LOAD_GLSLH + +#include "linear_common.glslh" + +VEC4_T load_scale_x4(const uint n4) { + return t_weight_scales[n4]; +} + +void load_scales_tile(out FPPerOutChannelParams scales, const uint n4_start) { +#if TILE_N4 == 1 + scales.data[0] = load_scale_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + scales.data[n4] = load_scale_x4[n4_start + n4]; + } + +#endif +} + +#endif // LINEAR_SCALES_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh new file mode 100644 index 00000000000..8c13315d50d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_weight_sums_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_WEIGHT_SUMS_LOAD_GLSLH +#define LINEAR_WEIGHT_SUMS_LOAD_GLSLH + +#include "linear_common.glslh" + +VEC4_T load_sum_x4(const uint n4) { + return VEC4_T(t_weight_sums[n4]); +} + +void load_sums_tile(out FPPerOutChannelParams sums, const uint n4_start) { +#if TILE_N4 == 1 + sums.data[0] = load_sum_x4(n4_start); + +#else + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + sums.data[n4] = load_sum_x4[n4_start + n4]; + } + +#endif +} + +#endif // LINEAR_WEIGHT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl new file mode 100644 index 00000000000..e731aa596a7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_qmat2", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec2 orig_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_weight_block.glslh" + +void main() { + uint block_x = gl_GlobalInvocationID.x; + uint block_y = gl_GlobalInvocationID.y; + + const int N = orig_sizes.y; + const int K = orig_sizes.x; + + // Each group of 4 8bit values are packed into each uint in the input tensor. + const int N4 = div_up_4(N); + const int K4 = div_up_4(K); + + // Check bounds + if (block_x >= N4 || block_y >= K4) { + return; + } + + Int8WeightBlockSourceData src_data; + const uint k = mul_4(block_y); + if (K - k >= 4) { + load_block_source_data_no_checks(src_data, block_x, mul_4(block_y), N4, K); + } else { + load_block_source_data_with_checks(src_data, block_x, mul_4(block_y), N4, K); + } + + Int8WeightBlockPacked packed_block; + create_packed_block(packed_block, src_data); + + write_packed_block( + packed_block, + block_x, + block_y, + N4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml new file mode 100644 index 00000000000..13e6d43b2c5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_linear_weight.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_linear_weight: + parameter_names_with_default_values: + STORAGE: buffer + shader_variants: + - NAME: pack_q8_linear_weight_buffer + STORAGE: buffer + - NAME: pack_q8_linear_weight_texture2d + STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl new file mode 100644 index 00000000000..fafad45d92e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +// Sizes of the im2col matrix of the convolution input +${layout_declare_ubo(B, "ivec4", "matrix_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_im2col_block_load.glslh" +#include "linear_int8_input_block.glslh" + +void main() { + // The quantized and packed im2col matrix can be conceptualized as a 2D matrix + // with K/4 columns and M/4 rows. Each element of the matrix is a ivec4 which + // contains packed data for a 4 wide x 4 high block of the original im2col + // matrix. Each shader invocation works on writing out one ivec4, i.e. one + // block of the quantized and packed matrix. + + // Thread id corresponds to the block index + const int k4 = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + // Convert block idx to tensor idx + const int k = mul_4(k4); + const int m = mul_4(m4); + + const int logical_K = conv2d_params.logical_K; + // Similarly, compute the logical size of the M dim. + const int logical_M = output_sizes.x * output_sizes.y * output_sizes.w; + + // Check if tensor indices are out of bounds + if (k >= logical_K || m >= logical_M) { + return; + } + + FPInputTile in_tile; + load_input_im2col_tile(in_tile, k4, m4, logical_K, logical_M); + + Int8InputBlock packed_block; + quantize_and_pack(packed_block, in_tile); + + // Number of texels in the x dim of the output matrix + const int K4 = div_4(matrix_sizes.x); + write_block(packed_block, k4, m4, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml new file mode 100644 index 00000000000..93f8269d607 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_im2col: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + INPUT_STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_im2col_buffer_texture3d + - NAME: quantize_and_pack_im2col_texture3d_texture3d + OUTPUT_STORAGE: texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl new file mode 100644 index 00000000000..5a9b9f30ce4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +$if GRANULARITY == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_int8_input_block.glslh" +#include "linear_fp_input_tile_load.glslh" + +void main() { + // Each input block contains 4x4 int8 quantized values, which are packed into + // a ivec4. k4 and m4 represent the "block index" of the current block being + // processed. + uint k4 = gl_GlobalInvocationID.x; + uint m4 = gl_GlobalInvocationID.y; + + const int K = input_sizes.x; + const int M = input_sizes.y; + + // K4 and M4 represent the number of blocks in each dimension. + const int K4 = div_up_4(K); + const int M4 = div_up_4(M); + + if (k4 >= K4 || m4 >= M4) { + return; + } + + // row of the input tensor to start loading from. Note the input tensor is + // interpreted as a t + const uint m = mul_4(m4); + + const bool dont_check_bounds = (M - m) >= 4; + + FPInputTile in_tile; + if (dont_check_bounds) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + } else { + load_input_tile_with_checks(in_tile, k4, m, K4, M); + } + + Int8InputBlock packed_block; + quantize_and_pack(packed_block, in_tile); + + write_block(packed_block, k4, m4, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml new file mode 100644 index 00000000000..37721db1ba8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_linear_input: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + STORAGE: texture3d + GRANULARITY: per_tensor + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d + OUTPUT_STORAGE: buffer + - NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp new file mode 100644 index 00000000000..e93b98f0f92 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -0,0 +1,645 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +struct Conv2DParams { + utils::ivec2 kernel_size; + utils::ivec2 stride; + utils::ivec2 padding; + utils::ivec2 dilation; + int32_t groups; + int32_t out_channels_per_group; + int32_t in_channels_per_group; + int32_t logical_K_per_group; + int32_t K_per_group; + int32_t K4_per_group; + int32_t logical_K; + int32_t K; + int32_t K4; +}; + +Conv2DParams create_conv2d_params( + ComputeGraph& graph, + const ValueRef& conv_input, + const ValueRef& conv_output, + const ValueRef& kernel_size, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation, + const ValueRef& groups) { + const auto kernel_size_list = graph.get_int_list(kernel_size); + const auto stride_list = graph.get_int_list(stride); + const auto padding_list = graph.get_int_list(padding); + const auto dilation_list = graph.get_int_list(dilation); + const int32_t groups_val = graph.get_int(groups); + + // Pre-compute input and output channels per group + + std::vector out_sizes = graph.sizes_of(conv_output); + const int32_t out_channels = utils::val_at(-3, out_sizes); + const int32_t out_channels_per_group = out_channels / groups_val; + + std::vector in_sizes = graph.sizes_of(conv_input); + const int32_t in_channels = utils::val_at(-3, in_sizes); + const int32_t in_channels_per_group = in_channels / groups_val; + + // Pre-compute the number of elements along the K dimension per group. This + // quantity is aligned to the next multiple of 4 to ensure data loads are + // aligned to texel boundaries. + + const int32_t logical_K_per_group = + kernel_size_list->at(0) * kernel_size_list->at(1) * in_channels_per_group; + const int32_t K_per_group = utils::align_up_4(logical_K_per_group); + const int32_t K4_per_group = K_per_group / 4; + + // Pre-compute the "theoretical" size of the K dim of the input im2col matrix, + // which represents the flattened convolution window used to compute an output + // element. This is used for bounds checking. + + const int32_t logical_K = + kernel_size_list->at(0) * kernel_size_list->at(1) * in_channels; + + const int32_t K = K_per_group * groups_val; + // Used for texel stride calculations + const int32_t K4 = K / 4; + + return Conv2DParams{ + // Swap the order from HW to WH + utils::make_ivec2({kernel_size_list->at(1), kernel_size_list->at(0)}), + utils::make_ivec2({stride_list->at(1), stride_list->at(0)}), + utils::make_ivec2({padding_list->at(1), padding_list->at(0)}), + utils::make_ivec2({dilation_list->at(1), dilation_list->at(0)}), + groups_val, + out_channels_per_group, + in_channels_per_group, + logical_K_per_group, + K_per_group, + K4_per_group, + logical_K, + K, + K4, + }; +} + +std::vector calculate_input_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + const int64_t batches = utils::val_at(-4, out_sizes); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // Represents the number of channel groups + const int64_t groups_val = graph->extract_scalar(groups); + // No need to div_up because in_channels % groups_val = 0 + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to the next multiple of 4 to ensure that data loads align nicely with + // texel boundaries. We want to ensure that the first data element of each + // group is at the start of its texel. + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (adjusted) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane + const int64_t M = out_height * out_width * batches; + + return {M, K}; +} + +utils::uvec3 im2col_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + const ValueRef output = resize_args.at(0); + const ValueRef kernel_size = resize_args.at(1); + const ValueRef groups = resize_args.at(2); + + std::vector im2col_sizes = + calculate_input_im2col_sizes(graph, input, output, kernel_size, groups); + const uint32_t K = utils::safe_downcast(im2col_sizes[1]); + const uint32_t M = utils::safe_downcast(im2col_sizes[0]); + + // 1 output tile is 4x4 elements + const uint32_t K4 = utils::div_up(K, 4u); + const uint32_t M4 = utils::div_up(M, 4u); + + return {K4, M4, 1}; +} + +std::vector calculate_output_im2col_sizes( + ComputeGraph* graph, + const ValueRef& output) { + std::vector out_sizes = graph->sizes_of(output); + const int64_t batches = utils::val_at(-4, out_sizes); + const int64_t out_channels = utils::val_at(-3, out_sizes); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // N -> output channels + const int64_t N = out_channels; + // M -> number of elements in 2D output plane + const int64_t M = out_height * out_width * batches; + + return {M, N}; +} + +utils::uvec3 col2im_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef output = args.at(0).refs.at(0); + + std::vector im2col_sizes = + calculate_output_im2col_sizes(graph, output); + const uint32_t N = utils::safe_downcast(im2col_sizes[1]); + const uint32_t M = utils::safe_downcast(im2col_sizes[0]); + + // 1 output tile is 4x4 elements + const uint32_t N4 = utils::div_up(N, 4u); + const uint32_t M4 = utils::div_up(M, 4u); + + return {N4, M4, 1}; +} + +void add_input_im2col_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef im2col_matrix = args.at(idx++); + + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "im2col"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(im2col_matrix)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(im2col_matrix), + graph.sizes_ubo(input), + graph.sizes_ubo(output), + graph.create_params_buffer(conv_params)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{im2col_matrix, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize args + {output, kernel_size, groups}, + // Resizing Logic + nullptr)); +} + +void add_quantize_and_pack_im2col_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef quantized_im2col_matrix = args.at(idx++); + + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "quantize_and_pack_im2col"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(quantized_im2col_matrix)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(quantized_im2col_matrix), + graph.sizes_ubo(input), + graph.sizes_ubo(output), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{quantized_im2col_matrix, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {output, kernel_size, groups}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_q8csw_linear_tiled_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input_im2col = args.at(idx++); + const ValueRef input = args.at(idx++); + const ValueRef packed_weight = args.at(idx++); + const ValueRef packed_weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef packed_bias = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef original_weight = args.at(idx++); // For resize args + + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + // One limitation of the current implementation is that for grouped convs, + // the number of output channels per group must be a multiple of 4. One loaded + // 4x4 weight tile must all belong to the same group. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.out_channels_per_group % 4 == 0); + } + + std::string kernel_name = "conv2d_q8csw_linear_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_im2col)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), + graph.sizes_ubo(input), + graph.create_params_buffer(conv_params)}; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + col2im_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{input_im2col, packed_weight, packed_weight_scales, packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias}, + // Resize args + {original_weight}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_q8ta_q8csw_linear_tiled_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef quantized_input_im2col = args.at(idx++); + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef packed_weight = args.at(idx++); + const ValueRef packed_weight_sums = args.at(idx++); + const ValueRef packed_weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef packed_bias = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef original_weight = args.at(idx++); // For resize args + + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + // One limitation of the current implementation is that for grouped convs, + // the number of output channels per group must be a multiple of 4. One loaded + // 4x4 weight tile must all belong to the same group. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.out_channels_per_group % 4 == 0); + } + + float scale = graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + std::string kernel_name = "conv2d_q8ta_q8csw_linear_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix( + kernel_name, graph.storage_type_of(quantized_input_im2col)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), + graph.sizes_ubo(input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + col2im_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{quantized_input_im2col, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {original_weight}, + // Resizing Logic + nullptr)); +} + +/* + * Computes weight only quantized conv2d with the conv2d_q8csw_linear_tiled + * shader. The input image will first be converted to matrix form using the + * im2col procedure. The convolution is performed via matrix multiplication, but + * the output is written directly as image format which circumvents the need for + * a separate step to convert the output matrix back to image format. This + * implementation will be used when accelerated int8 dot product is not + * available on a particular device, in which case there is no benefit from + * quantizing the input tensor. + */ +void conv2d_q8csw_linear_tiled_impl( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + (void)input_scale; + const ValueRef input_zp = args.at(idx++); + (void)input_zp; + const ValueRef weight = args.at(idx++); + const ValueRef weight_sums = args.at(idx++); + (void)weight_sums; + const ValueRef weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output = args.at(idx++); + + const ValueRef packed_weight = prepack_q8_linear_weight(graph, weight); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shdaer variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (!graph.val_is_none(bias)) { + packed_bias = + prepack_standard(graph, bias, utils::kBuffer, utils::kWidthPacked); + } + + std::vector input_im2col_sizes = + calculate_input_im2col_sizes(&graph, input, output, kernel_size, groups); + + TmpTensor input_im2col_matrix( + &graph, + input_im2col_sizes, + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + + std::vector im2col_args = { + input, + kernel_size, + stride, + padding, + dilation, + groups, + output, + input_im2col_matrix}; + + add_input_im2col_node(graph, im2col_args); + + std::vector conv2d_linear_args = { + input_im2col_matrix, + input, + packed_weight, + packed_weight_scales, + bias, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + output, + weight}; + + add_conv2d_q8csw_linear_tiled_node(graph, conv2d_linear_args); +} + +void conv2d_q8ta_q8csw_linear_tiled_impl( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight = args.at(idx++); + const ValueRef weight_sums = args.at(idx++); + const ValueRef weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output = args.at(idx++); + + const ValueRef packed_weight = prepack_q8_linear_weight(graph, weight); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales, utils::kBuffer, utils::kWidthPacked); + ValueRef packed_weight_sums = + prepack_standard(graph, weight_sums, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shdaer variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (!graph.val_is_none(bias)) { + packed_bias = + prepack_standard(graph, bias, utils::kBuffer, utils::kWidthPacked); + } + + std::vector input_im2col_sizes = + calculate_input_im2col_sizes(&graph, input, output, kernel_size, groups); + + const int64_t num_blocks_M = utils::div_up_4(input_im2col_sizes.at(0)); + const int64_t num_blocks_K = utils::div_up_4(input_im2col_sizes.at(1)); + + TmpTensor quantized_input_im2col_matrix( + &graph, + {num_blocks_M, num_blocks_K * 4}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + std::vector quantize_and_pack_im2col_args = { + input, + input_scale, + input_zp, + kernel_size, + stride, + padding, + dilation, + groups, + output, + quantized_input_im2col_matrix}; + + add_quantize_and_pack_im2col_node(graph, quantize_and_pack_im2col_args); + + std::vector conv2d_linear_args = { + quantized_input_im2col_matrix, + input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + bias, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + output, + weight}; + + add_conv2d_q8ta_q8csw_linear_tiled_node(graph, conv2d_linear_args); +} + +void conv2d_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { + // If accelerated int8 dot product is available, quantize the input tensor + // to allow for faster arithmetic throughput. + if (graph.can_use_int8_dot_product()) { + conv2d_q8ta_q8csw_linear_tiled_impl(graph, args); + } + // Otherwise, dequantize the weight tensor and do math in fp32. + else { + conv2d_q8csw_linear_tiled_impl(graph, args); + } +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); + VK_REGISTER_OP( + et_vk.conv2d_q8ta_q8csw.conv2d_q8csw_linear_tiled, + conv2d_q8csw_linear_tiled_impl); + VK_REGISTER_OP( + et_vk.conv2d_q8ta_q8csw.conv2d_q8ta_q8csw_linear_tiled, + conv2d_q8ta_q8csw_linear_tiled_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp new file mode 100644 index 00000000000..e41b732b9bc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -0,0 +1,548 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +utils::uvec3 quantized_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + // height + const uint32_t M = utils::val_at(-2, out_sizes); + // width + const uint32_t N = utils::val_at(-1, out_sizes); + + // 1 output tile is 4x4 elements + const uint32_t M4 = utils::div_up(M, 4u); + const uint32_t N4 = utils::div_up(N, 4u); + + return {N4, M4, 1}; +} + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + // Optimize local workgroup size for linear operations + uint32_t local_wg_size_x = 1; + uint32_t local_wg_size_y = 1; + + if (global_workgroup_size[1] % 8 == 0) { + local_wg_size_y = 8; + } else if (global_workgroup_size[1] % 4 == 0) { + local_wg_size_y = 4; + } else if (global_workgroup_size[1] % 2 == 0) { + local_wg_size_y = 2; + } + + // Adjust x dimension to maintain reasonable total workgroup size + local_wg_size_x = std::min(64u / local_wg_size_y, global_workgroup_size[0]); + + return {local_wg_size_x, local_wg_size_y, 1}; +} + +ValueRef prepack_q8_linear_weight( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + // Input is [K, N] + const int64_t K = qmat2_orig_sizes.at(ndim - 2); + const int64_t N = qmat2_orig_sizes.at(ndim - 1); + + // N must be a multiple of 4 so data data loads are aligned nicely with texel + // boundaries. + VK_CHECK_COND(N % 4 == 0); + + // This packing format partitions the weight tensor into 4 wide x 4 high + // blocks. To figure out the size of the output tensor, determine the number + // of blocks along the width and height dims. + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + const int64_t num_blocks_N = utils::div_up(N, int64_t(4)); + + // Each transposed block is 4 wide x 4 high. To maximize memory loading + // efficiency, the packed weight tensor will use a base data type of uint32_t; + // in terms of uint32_t, each block is 1 wide x 4 high. However, each block is + // also flattened as it is stored, so that the whole block can be loaded at + // once. As a result, the stored block will be 4 wide x 1 high. + const int64_t output_height = num_blocks_K; + const int64_t output_width = num_blocks_N * 4; + + // Store the original sizes of the tensor to pass to the shader + utils::ivec2 orig_sizes{ + utils::safe_downcast(K), utils::safe_downcast(N)}; + + std::vector qmat2_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kInt, storage_type, utils::kWidthPacked); + + // Global workgroup size: each thread writes out two adjacent blocks + utils::uvec3 global_wg_size{ + utils::safe_downcast(num_blocks_N), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_linear_weight"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); + + return qmat2; +} + +struct InputQuantConstants { + alignas(16) float inv_scale; + alignas(16) int32_t zp; +}; + +std::tuple get_quantized_input_num_blocks( + ComputeGraph& graph, + const ValueRef input) { + std::vector input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + const int64_t M = input_sizes.at(ndim - 2); + const int64_t K = input_sizes.at(ndim - 1); + + const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + + return std::make_tuple(num_blocks_M, num_blocks_K); +} + +utils::uvec3 quant_pack_input_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, input); + + return { + utils::safe_downcast(num_blocks_K), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +DynamicDispatchNode make_quantize_and_pack_linear_input_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef quantized_input) { + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, input); + + bool is_per_channel = graph.val_is_tensor(input_scale); + + float inv_scale = 1.0f; + int32_t zp = 0; + if (!is_per_channel) { + inv_scale = 1.0f / graph.extract_scalar(input_scale); + zp = graph.extract_scalar(input_zp); + } + + std::string shader_name = "quantize_and_pack_linear_input"; + if (is_per_channel) { + shader_name += "_per_channel"; + } else { + shader_name += "_per_tensor"; + } + add_storage_type_suffix(shader_name, graph.storage_type_of(quantized_input)); + add_storage_type_suffix(shader_name, graph.storage_type_of(input)); + add_dtype_suffix(shader_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + quant_pack_input_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{quantized_input, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}); +} + +DynamicDispatchNode make_linear_q8ta_q8csw_tiled_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef packed_weight = args.at(idx++); + const ValueRef packed_weight_sums = args.at(idx++); + const ValueRef packed_weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef packed_bias = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef original_weight = args.at(idx++); // For resize args + + bool is_per_channel = graph.val_is_tensor(input_scale); + + float scale = 1.0f; + int32_t zp = 0; + if (!is_per_channel) { + scale = graph.extract_scalar(input_scale); + zp = graph.extract_scalar(input_zp); + } + + // Get shader for quantized linear + std::string kernel_name = "linear_q8ta_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(input)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + uint32_t apply_bias = 0; + if (!graph.val_is_none(bias)) { + apply_bias = 1; + } + + // Add the compute node + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {original_weight}, + // Resizing Logic + nullptr); +} + +DynamicDispatchNode make_linear_q8csw_node( + ComputeGraph& graph, + const std::vector& args) { + // Extract arguments + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef packed_weight = args.at(idx++); + const ValueRef packed_weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef packed_bias = args.at(idx++); + const ValueRef output = args.at(idx++); + const ValueRef original_weight = args.at(idx++); // For resize args + + // Get shader for quantized linear + std::string kernel_name = "linear_q8csw_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(input)}; + + uint32_t apply_bias = 0; + if (!graph.val_is_none(bias)) { + apply_bias = 1; + } + + return DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{input, packed_weight, packed_weight_scales, packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias}, + // Resize args + {original_weight}, + // Resizing Logic + nullptr); +} + +/* + * Allows orchestration of two compute shader dispatch paths: + * 1. quantize & pack input to int8, execute linear_q8ta_q8csw + * 2. execute linear_q8csw with fp inputs + * + * The reason for this split is twofold: + * - Some devices may not support accelerated int8 dot product. In that case, + * there is no benefit to quantizing the input tensor. In that case + * linear_q8csw is required. + * - For LLMs, which switch between GEMM and GEMV input conditions when going + * from prefill to decode. GEMM is typically a compute bound operation, which + * will benefit from accelerated int8 accumulation. On the other hand, GEMV + * is usually memory bound, which means it may actually suffer from the extra + * cost of having to quantize and pack the input tensor. Therefore, + * linear_q8ta_q8csw is preferred fro GEMM and linear_q8csw is preferred for + * GEMV. + * + * Note that dynamic shape is currently not supported, so switching paths + * when input conditions go between GEMM -> GEMV is currently not implemented. + * This will be implemented at a later date. + */ +struct QuantizedLinearNode : public ExecuteNode { + friend class ComputeGraph; + + bool can_use_int8_dot_product = false; + DynamicDispatchNode quantize_and_pack_input_node; + DynamicDispatchNode linear_q8ta_q8csw_tiled_node; + DynamicDispatchNode linear_q8csw_node; + + explicit QuantizedLinearNode( + ComputeGraph& graph, + const std::vector& args, + DynamicDispatchNode&& quant_pack_input, + DynamicDispatchNode&& qaqw_tiled_linear, + DynamicDispatchNode&& linear_q8csw, + bool int8_dot_product_enabled) + : ExecuteNode(), + quantize_and_pack_input_node(std::move(quant_pack_input)), + linear_q8ta_q8csw_tiled_node(std::move(qaqw_tiled_linear)), + linear_q8csw_node(std::move(linear_q8csw)) { + if (int8_dot_product_enabled) { + can_use_int8_dot_product = graph.can_use_int8_dot_product(); + } + } + + void prepare_pipelines(ComputeGraph* graph) override { + if (can_use_int8_dot_product) { + quantize_and_pack_input_node.prepare_pipelines(graph); + linear_q8ta_q8csw_tiled_node.prepare_pipelines(graph); + } + linear_q8csw_node.prepare_pipelines(graph); + } + + void encode(ComputeGraph* graph) override { + if (can_use_int8_dot_product) { + quantize_and_pack_input_node.encode(graph); + linear_q8ta_q8csw_tiled_node.encode(graph); + } else { + linear_q8csw_node.encode(graph); + } + } +}; + +/* + * Implements activation and weight quantized linear. Currently, only the + * following quantization configurations are supported: + * - activation quantized to int8 with per tensor quant params + * - weight quantized to int8 with per channel quant params + */ +void linear_q8ta_q8csw_impl( + ComputeGraph& graph, + const std::vector& args, + const bool use_int8_dot_product = true) { + int32_t idx = 0; + const ValueRef input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight = args.at(idx++); + const ValueRef weight_sums = args.at(idx++); + const ValueRef weight_scales = args.at(idx++); + const ValueRef bias = args.at(idx++); + const ValueRef output = args.at(idx++); + + bool is_per_channel = graph.val_is_tensor(input_scale); + + // Input validation + std::vector input_sizes = graph.sizes_of(input); + std::vector weight_sizes = graph.sizes_of(weight); + + const int64_t K = utils::val_at(-1, input_sizes); + // K (input channels) must be a multiple of 4 to ensure that reading a group + // of 4 input channels from the input tensor will be aligned on a texel + // boundary. + VK_CHECK_COND(K % 4 == 0); + + const int64_t N = utils::val_at(-1, input_sizes); + // N (output channels) must be a multiple of 4 to ensure that reading a group + // of 4 output channels from the weight/output tensor will be aligned on a + // texel boundary. + VK_CHECK_COND(N % 4 == 0); + + // Prepacking + const ValueRef packed_weight = prepack_q8_linear_weight(graph, weight); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales, utils::kBuffer, utils::kWidthPacked); + ValueRef packed_weight_sums = + prepack_standard(graph, weight_sums, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_input_scale = input_scale; + ValueRef packed_input_zp = input_zp; + if (is_per_channel) { + packed_input_scale = prepack_standard( + graph, input_scale, utils::kBuffer, utils::kWidthPacked); + packed_input_zp = + prepack_standard(graph, input_zp, utils::kBuffer, utils::kWidthPacked); + } + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shdaer variants need to be generated. + TmpTensor dummy_bias( + &graph, {}, graph.dtype_of(output), utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_none(packed_bias)) { + packed_bias = + prepack_standard(graph, bias, utils::kBuffer, utils::kWidthPacked); + } + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, input); + + const int64_t quantized_input_height = num_blocks_M; + const int64_t quantized_input_width = num_blocks_K * 4; + + TmpTensor quantized_packed_input( + &graph, + {quantized_input_height, quantized_input_width}, + vkapi::kInt, + graph.storage_type_of(input), + utils::kWidthPacked); + + DynamicDispatchNode quantize_and_pack_linear_node( + make_quantize_and_pack_linear_input_node( + graph, + input, + packed_input_scale, + packed_input_zp, + quantized_packed_input)); + + std::vector linear_args = { + quantized_packed_input, + packed_input_scale, + packed_input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + bias, + packed_bias, + output, + weight}; + + DynamicDispatchNode linear_q8ta_q8csw_tiled_node( + make_linear_q8ta_q8csw_tiled_node(graph, linear_args)); + + linear_args = { + input, + packed_weight, + packed_weight_scales, + bias, + packed_bias, + output, + weight}; + + DynamicDispatchNode linear_q8csw_node( + make_linear_q8csw_node(graph, linear_args)); + + graph.execute_nodes().emplace_back(new QuantizedLinearNode( + graph, + linear_args, + std::move(quantize_and_pack_linear_node), + std::move(linear_q8ta_q8csw_tiled_node), + std::move(linear_q8csw_node), + use_int8_dot_product)); +} + +void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { + linear_q8ta_q8csw_impl(graph, args, true); +} + +void linear_q8ta_q8csw_no_int8( + ComputeGraph& graph, + const std::vector& args) { + linear_q8ta_q8csw_impl(graph, args, false); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.default, linear_q8ta_q8csw); + VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.noint8, linear_q8ta_q8csw_no_int8); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h new file mode 100644 index 00000000000..11af0b4a0f5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +utils::uvec3 quantized_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + +ValueRef prepack_q8_linear_weight( + ComputeGraph& graph, + const ValueRef qmat2_data); + +DynamicDispatchNode make_linear_q8ta_q8csw_tiled_node( + ComputeGraph& graph, + const std::vector& args); + +DynamicDispatchNode make_linear_q8csw_node( + ComputeGraph& graph, + const std::vector& args); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index bb8554e1a91..3e844274383 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -93,4 +93,7 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) + add_operator_prototype(quantized_linear) + add_operator_prototype(quantized_conv2d) + add_operator_prototype(conv2d) endif() diff --git a/backends/vulkan/test/custom_ops/conv2d.cpp b/backends/vulkan/test/custom_ops/conv2d.cpp new file mode 100644 index 00000000000..66feebbe9e4 --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d.cpp @@ -0,0 +1,320 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +// Component structs for better readability +struct KernelSize { + int32_t h; + int32_t w; + + KernelSize(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Stride { + int32_t h; + int32_t w; + + Stride(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Padding { + int32_t h; + int32_t w; + + Padding(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Dilation { + int32_t h; + int32_t w; + + Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} +}; + +struct OutInChannels { + int32_t out; + int32_t in; + + OutInChannels(int32_t out_channels, int32_t in_channels) + : out(out_channels), in(in_channels) {} +}; + +struct InputSize2D { + int32_t h; + int32_t w; + + InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} +}; + +// Conv2d configuration struct +struct Conv2dConfig { + OutInChannels channels; + InputSize2D input_size; + KernelSize kernel; + Stride stride; + Padding padding; + Dilation dilation; + int32_t groups; // Number of groups for grouped convolution + std::string name_suffix; + std::string shader_variant_name = "default"; + + // Calculate output dimensions + int64_t get_output_height() const { + return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / + stride.h + + 1; + } + + int64_t get_output_width() const { + return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / + stride.w + + 1; + } +}; + +// Utility function to create a test case from a Conv2dConfig +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + "Conv2d_" + config.name_suffix + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "aten.convolution."; + operator_name += config.shader_variant_name; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDINT); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + // Weight tensor (float/half) - [C_out, C_in, K_h, K_w] + std::vector weight_size = { + config.channels.out, + config.channels.in, + config.kernel.h, + config.kernel.w}; + ValueSpec weight( + weight_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(weight, "weight_tensor"); + } + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec transposed{false}; + ValueSpec output_padding({0, 0}); + ValueSpec groups(config.groups); + ValueSpec out_min{-1000.0f}; + ValueSpec out_max{-1000.0f}; + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(weight); + test_case.add_input_spec(bias); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(transposed); + test_case.add_input_spec(output_padding); + test_case.add_input_spec(groups); + test_case.add_input_spec(out_min); + test_case.add_input_spec(out_max); + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for conv2d operation (for debugging) +std::vector generate_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Conv2dConfig config = { + OutInChannels(32, 3), // channels (out, in) + InputSize2D(64, 64), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(2, 2), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 1, // groups + "simple" // descriptive name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for conv2d operation +std::vector generate_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = {// Performance test cases + {OutInChannels(128, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1, + "perf"}, + {OutInChannels(256, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 8, + "pw_perf"}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (const auto& config : configs) { + for (const auto& storage_type : storage_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kHalf)); + } + } + + return test_cases; +} + +// Custom FLOP calculator for conv2d operation +int64_t conv2d_flop_calculator(const TestCase& test_case) { + if (test_case.num_inputs() < 7 || test_case.num_outputs() < 1) { + return 0; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[1].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = weight_sizes[0]; + int64_t K_h = weight_sizes[2]; + int64_t K_w = weight_sizes[3]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for conv2d operation + // Each output element requires: + // - C_in * K_h * K_w multiply-accumulate operations + // - 1 bias addition + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + // Add bias operation + int64_t bias_ops = 1; + + int64_t flop = output_elements * (ops_per_output + bias_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Conv2d Operation Prototyping Framework" << std::endl; + print_separator(); + + // No reference function needed since fp32 convolutions are tested elsewhere + ReferenceComputeFunc ref_fn = nullptr; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_conv2d_test_cases, + conv2d_flop_calculator, + "Conv2d", + 0, + 1, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/quantized_conv2d.cpp b/backends/vulkan/test/custom_ops/quantized_conv2d.cpp new file mode 100644 index 00000000000..795e26b2ca4 --- /dev/null +++ b/backends/vulkan/test/custom_ops/quantized_conv2d.cpp @@ -0,0 +1,601 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +// Component structs for better readability +struct KernelSize { + int32_t h; + int32_t w; + + KernelSize(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Stride { + int32_t h; + int32_t w; + + Stride(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Padding { + int32_t h; + int32_t w; + + Padding(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Dilation { + int32_t h; + int32_t w; + + Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} +}; + +struct OutInChannels { + int32_t out; + int32_t in; + + OutInChannels(int32_t out_channels, int32_t in_channels) + : out(out_channels), in(in_channels) {} +}; + +struct InputSize2D { + int32_t h; + int32_t w; + + InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} +}; + +// Conv2d configuration struct +struct Conv2dConfig { + OutInChannels channels; + InputSize2D input_size; + KernelSize kernel; + Stride stride; + Padding padding; + Dilation dilation; + int32_t groups; // Number of groups for grouped convolution + std::string name_suffix; + std::string shader_variant_name = "conv2d_q8ta_q8csw_linear_tiled"; + + // Calculate output dimensions + int64_t get_output_height() const { + return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / + stride.h + + 1; + } + + int64_t get_output_width() const { + return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / + stride.w + + 1; + } +}; + +// Utility function to create a test case from a Conv2dConfig +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = "QuantizedConv2d_" + config.name_suffix + "_" + + config.shader_variant_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk.conv2d_q8ta_q8csw."; + operator_name += config.shader_variant_name; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDINT); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.1f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_in * K_h * K_w, C_out] (transposed for + // matrix multiplication) Memory layout: height, width, then channels - in_c + // is innermost (stride 1) in the first dimension + const int64_t in_channels_per_group = config.channels.in / config.groups; + std::vector weight_size = { + in_channels_per_group * config.kernel.h * config.kernel.w, + config.channels.out}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.channels.out}, // Per output channel + vkapi::kFloat, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums( + weight_sums, + quantized_weight, + config.channels.out, + in_channels_per_group * config.kernel.h * config.kernel.w); + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ONES); + bias.set_constant(true); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized conv2d operation (for debugging) +std::vector generate_quantized_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Conv2dConfig config = { + OutInChannels(8, 8), // channels (out, in) + InputSize2D(8, 8), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(1, 1), // stride + Padding(0, 0), // padding + Dilation(1, 1), // dilation + 2, // groups + "simple_groups" // descriptive name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized conv2d operation +std::vector generate_quantized_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = {// Small conv2d layers + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1, + "3x32x32_to_16x32x32"}, + {OutInChannels(32, 16), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1, + "16x32x32_to_32x32x32"}, + {OutInChannels(64, 32), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1, + "32x16x16_to_64x16x16"}, + + // Stride 2 convolutions + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1, + "3x64x64_to_32x32x32_s2"}, + {OutInChannels(64, 32), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1, + "32x32x32_to_64x16x16_s2"}, + // Different kernel sizes + {OutInChannels(32, 16), + InputSize2D(28, 28), + KernelSize(5, 5), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 1, + "16x28x28_to_32x28x28_k5"}, + {OutInChannels(64, 32), + InputSize2D(14, 14), + KernelSize(7, 7), + Stride(1, 1), + Padding(3, 3), + Dilation(1, 1), + 1, + "32x14x14_to_64x14x14_k7"}, + + // Dilated convolutions + {OutInChannels(32, 16), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(1, 1), + Padding(2, 2), + Dilation(2, 2), + 1, + "16x32x32_to_32x32x32_d2"}, + {OutInChannels(64, 32), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(3, 3), + Dilation(3, 3), + 1, + "32x16x16_to_64x16x16_d3"}, + + // Grouped convolutions + {OutInChannels(32, 32), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 4, + "32x32x32_to_32x32x32_g4"}, + {OutInChannels(64, 64), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 8, + "64x16x16_to_64x16x16_g8"}, + // Performance test cases + {OutInChannels(256, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 8, + "64x16x16_to_64x16x16_g8"}, + {OutInChannels(128, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1, + "perf"}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (const auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Test both with and without shader int8 dot product + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + + Conv2dConfig config2 = config; + config2.shader_variant_name = "conv2d_q8csw_linear_tiled"; + + test_cases.push_back( + create_test_case_from_config(config2, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for quantized conv2d operation +void quantized_conv2d_reference_impl(TestCase& test_case) { + static constexpr int64_t kRefDimSizeLimit = 100; + + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [C_in * K_h * K_w, C_out] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = weight_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate channels per group for grouped convolution + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + float sum = 0.0f; + + // Determine which group this output channel belongs to + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + int64_t in_c_end = (group_idx + 1) * C_in_per_group; + + // Convolution operation with dilation support and grouped convolution + for (int64_t in_c = in_c_start; in_c < in_c_end; ++in_c) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize/dequantize + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input = + std::min(std::max(quant_input, -128.0f), 127.0f); + float dequant_input = + (quant_input - input_zero_point) * input_scale; + + // Get weight value and dequantize + // Weight layout: [C_in_per_group * K_h * K_w, C_out] + // (transposed) with memory layout: height, width, then + // channels - in_c is innermost (stride 1) in the first + // dimension + int64_t weight_idx = + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)) * + C_out + + out_c; + float dequant_weight = + (static_cast(weight_data[weight_idx])) * + weight_scales_data[out_c]; + + sum += dequant_input * dequant_weight; + } + } + } + } + + // Add bias and store result + sum += bias_data[out_c]; + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = sum; + } + } + } + } +} + +// Custom FLOP calculator for quantized conv2d operation +int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { + if (test_case.num_inputs() < 11 || test_case.num_outputs() < 1) { + return 0; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[7].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized conv2d operation + // Each output element requires: + // - C_in * K_h * K_w multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + int64_t flop = output_elements * (ops_per_output); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Conv2d Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = quantized_conv2d_reference_impl; + // ref_fn = nullptr; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_conv2d_test_cases, + quantized_conv2d_flop_calculator, + "QuantizedConv2d", + 0, + 1, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/quantized_linear.cpp b/backends/vulkan/test/custom_ops/quantized_linear.cpp new file mode 100644 index 00000000000..d081f3b621c --- /dev/null +++ b/backends/vulkan/test/custom_ops/quantized_linear.cpp @@ -0,0 +1,352 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +// Linear configuration struct +struct LinearConfig { + int64_t M; // Batch size / number of rows in input + int64_t K; // Input features / columns in input, rows in weight + int64_t N; // Output features / columns in weight + std::string name_suffix; + std::string shader_variant_name = "default"; +}; + +// Utility function to create a test case from a LinearConfig +TestCase create_test_case_from_config( + const LinearConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = "QuantizedLinear_" + config.name_suffix + "_" + + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk.linear_q8ta_q8csw."; + operator_name += config.shader_variant_name; + test_case.set_operator_name(operator_name); + + // Derive sizes from M, K, N + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.K, config.N}; + + // Input tensor (float/half) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.5f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -4; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [K, N] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, // Per output features + vkapi::kFloat, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t in_features = config.K; + int64_t out_features = config.N; + compute_weight_sums(weight_sums, quantized_weight, out_features, in_features); + + // Bias (optional, float/half) - [N] + ValueSpec bias( + {config.N}, // Per output feature + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Output tensor (float/half) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized linear operation (for debugging) +std::vector generate_quantized_linear_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int M = 16; + int K = 128; + int N = 64; + + LinearConfig config = { + M, // Batch size + K, // Input features + N, // Output features + "simple" // descriptive name + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized linear operation +std::vector generate_quantized_linear_test_cases() { + std::vector test_cases; + + std::vector configs = {// Small linear layers + {1, 64, 32, "64to32_single"}, + {1, 128, 64, "128to64_single"}, + {1, 256, 128, "256to128_single"}, + + // Larger batch sizes + {32, 64, 32, "64to32_batch32"}, + {32, 128, 64, "128to64_batch32"}, + {32, 256, 128, "256to128_batch32"}, + + // Performance test cases + {128, 2048, 2048, "perf_K2048"}, + {16384, 576, 128, "perf_conv"} + + }; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (const auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Test both with and without shader int8 dot product + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + + LinearConfig no_int_config = config; + no_int_config.name_suffix = config.name_suffix + "_noint8"; + no_int_config.shader_variant_name = "noint8"; + + test_cases.push_back(create_test_case_from_config( + no_int_config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for quantized linear operation +void quantized_linear_reference_impl(TestCase& test_case) { + static constexpr int64_t kRefDimSizeLimit = 300; + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[1]; + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + float sum = 0.0f; + + // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * + // weight[out_f][in_f]) + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and dequantize + int64_t input_idx = b * in_features + in_f; + + float quant_input = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input = std::min(std::max(quant_input, -128.0f), 127.0f); + float dequant_input = (quant_input - input_zero_point) * input_scale; + + // Get weight value and dequantize + int64_t weight_idx = in_f * out_features + out_f; + float dequant_weight = (static_cast(weight_data[weight_idx])) * + weight_scales_data[out_f]; + + sum += dequant_input * dequant_weight; + } + + // Add bias and store result + sum += bias_data[out_f]; + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = sum; + } + } +} + +// Custom FLOP calculator for quantized linear operation +int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + if (test_case.num_inputs() < 5 || test_case.num_outputs() < 1) { + return 0; + } + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[3].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + // Calculate FLOPs for quantized linear operation + // Each output element requires: + // - in_features multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + // Add quantization overhead (approximate) + // - Dequantize input: 1 op per input element used + // - Dequantize weight: 1 op per weight element used + // - Add bias: 1 op per output element + int64_t quantization_ops = ops_per_output + 1; // Simplified estimate + + int64_t flop = output_elements * (ops_per_output + quantization_ops); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Linear Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = quantized_linear_reference_impl; + + // Execute easy test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_linear_test_cases, + quantized_linear_flop_calculator, + "QuantizedLinear", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 2ddf49834e1..9dd297cbb20 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -97,3 +97,6 @@ def define_common_targets(is_fbcode = False): ) define_custom_op_test_binary("add") + define_custom_op_test_binary("conv2d") + define_custom_op_test_binary("quantized_conv2d") + define_custom_op_test_binary("quantized_linear")